Skip to content

Commit 1bd54a7

Browse files
committed
Better error handling
1 parent 68c6c08 commit 1bd54a7

File tree

3 files changed

+46
-18
lines changed

3 files changed

+46
-18
lines changed

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,10 @@ mod llvm_enzyme {
7474
}
7575

7676
// Get information about the function the macro is applied to
77-
fn extract_item_info(
78-
iitem: &P<ast::Item>,
79-
) -> Option<(Visibility, FnSig, Ident, Generics, bool)> {
77+
fn extract_item_info(iitem: &P<ast::Item>) -> Option<(Visibility, FnSig, Ident, Generics)> {
8078
match &iitem.kind {
8179
ItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => {
82-
Some((iitem.vis.clone(), sig.clone(), ident.clone(), generics.clone(), false))
80+
Some((iitem.vis.clone(), sig.clone(), ident.clone(), generics.clone()))
8381
}
8482
_ => None,
8583
}
@@ -223,9 +221,13 @@ mod llvm_enzyme {
223221
// parameters.
224222
// these will be used to generate the differentiated version of the function
225223
let Some((vis, sig, primal, generics, impl_of_trait)) = (match &item {
226-
Annotatable::Item(iitem) => extract_item_info(iitem),
224+
Annotatable::Item(iitem) => {
225+
extract_item_info(iitem).map(|(v, s, p, g)| (v, s, p, g, false))
226+
}
227227
Annotatable::Stmt(stmt) => match &stmt.kind {
228-
ast::StmtKind::Item(iitem) => extract_item_info(iitem),
228+
ast::StmtKind::Item(iitem) => {
229+
extract_item_info(iitem).map(|(v, s, p, g)| (v, s, p, g, false))
230+
}
229231
_ => None,
230232
},
231233
Annotatable::AssocItem(assoc_item, Impl { of_trait }) => match &assoc_item.kind {

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1155,25 +1155,51 @@ fn codegen_enzyme_autodiff<'ll, 'tcx>(
11551155
// Get source, diff, and attrs
11561156
let (source_id, source_args) = match fn_args.into_type_list(tcx)[0].kind() {
11571157
ty::FnDef(def_id, source_params) => (def_id, source_params),
1158-
_ => bug!("invalid args"),
1158+
_ => bug!("invalid autodiff intrinsic args"),
1159+
};
1160+
1161+
let fn_source = match Instance::try_resolve(tcx, bx.cx.typing_env(), *source_id, source_args) {
1162+
Ok(Some(instance)) => instance,
1163+
Ok(None) => bug!(
1164+
"could not resolve ({:?}, {:?}) to a specific autodiff instance",
1165+
source_id,
1166+
source_args
1167+
),
1168+
Err(_) => {
1169+
// An error has already been emitted
1170+
return;
1171+
}
11591172
};
1160-
let fn_source =
1161-
Instance::try_resolve(tcx, bx.cx.typing_env(), *source_id, source_args).unwrap().unwrap();
1173+
11621174
let source_symbol = symbol_name_for_instance_in_crate(tcx, fn_source.clone(), LOCAL_CRATE);
1163-
let fn_to_diff: Option<&'ll llvm::Value> = bx.cx.get_function(&source_symbol);
1164-
let Some(fn_to_diff) = fn_to_diff else { bug!("could not find source function") };
1175+
let Some(fn_to_diff) = bx.cx.get_function(&source_symbol) else {
1176+
bug!("could not find source function")
1177+
};
11651178

11661179
let (diff_id, diff_args) = match fn_args.into_type_list(tcx)[1].kind() {
11671180
ty::FnDef(def_id, diff_args) => (def_id, diff_args),
11681181
_ => bug!("invalid args"),
11691182
};
1170-
let fn_diff =
1171-
Instance::try_resolve(tcx, bx.cx.typing_env(), *diff_id, diff_args).unwrap().unwrap();
1172-
let val_arr: Vec<&'ll Value> = get_args_from_tuple(bx, args[2], fn_diff);
1183+
1184+
let fn_diff = match Instance::try_resolve(tcx, bx.cx.typing_env(), *diff_id, diff_args) {
1185+
Ok(Some(instance)) => instance,
1186+
Ok(None) => bug!(
1187+
"could not resolve ({:?}, {:?}) to a specific autodiff instance",
1188+
diff_id,
1189+
diff_args
1190+
),
1191+
Err(_) => {
1192+
// An error has already been emitted
1193+
return;
1194+
}
1195+
};
1196+
1197+
let val_arr = get_args_from_tuple(bx, args[2], fn_diff);
11731198
let diff_symbol = symbol_name_for_instance_in_crate(tcx, fn_diff.clone(), LOCAL_CRATE);
11741199

1175-
let diff_attrs = autodiff_attrs(tcx, fn_diff.def_id());
1176-
let Some(mut diff_attrs) = diff_attrs else { bug!("could not find autodiff attrs") };
1200+
let Some(mut diff_attrs) = autodiff_attrs(tcx, fn_diff.def_id()) else {
1201+
bug!("could not find autodiff attrs")
1202+
};
11771203

11781204
adjust_activity_to_abi(
11791205
tcx,

compiler/rustc_monomorphize/src/collector/autodiff.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ fn collect_autodiff_fn_from_arg<'tcx>(
3939

4040
(instance, span)
4141
}
42-
_ => bug!("expected function"),
42+
_ => bug!("expected autodiff function"),
4343
},
44-
_ => bug!("expected type"),
44+
_ => bug!("expected type when matching autodiff arg"),
4545
};
4646

4747
output.push(create_fn_mono_item(tcx, instance, span));

0 commit comments

Comments
 (0)