Skip to content

Commit d6acc7d

Browse files
committed
Better error handling
1 parent 4e9fdeb commit d6acc7d

File tree

3 files changed

+45
-15
lines changed

3 files changed

+45
-15
lines changed

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ mod llvm_enzyme {
7676
fn extract_item_info(iitem: &Box<ast::Item>) -> Option<(Visibility, FnSig, Ident, Generics)> {
7777
match &iitem.kind {
7878
ItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => {
79-
Some((iitem.vis.clone(), sig.clone(), ident.clone(), generics.clone(), false))
79+
Some((iitem.vis.clone(), sig.clone(), ident.clone(), generics.clone()))
8080
}
8181
_ => None,
8282
}
@@ -220,9 +220,13 @@ mod llvm_enzyme {
220220
// parameters.
221221
// these will be used to generate the differentiated version of the function
222222
let Some((vis, sig, primal, generics, impl_of_trait)) = (match &item {
223-
Annotatable::Item(iitem) => extract_item_info(iitem),
223+
Annotatable::Item(iitem) => {
224+
extract_item_info(iitem).map(|(v, s, p, g)| (v, s, p, g, false))
225+
}
224226
Annotatable::Stmt(stmt) => match &stmt.kind {
225-
ast::StmtKind::Item(iitem) => extract_item_info(iitem),
227+
ast::StmtKind::Item(iitem) => {
228+
extract_item_info(iitem).map(|(v, s, p, g)| (v, s, p, g, false))
229+
}
226230
_ => None,
227231
},
228232
Annotatable::AssocItem(assoc_item, Impl { of_trait }) => match &assoc_item.kind {

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,25 +1145,51 @@ fn codegen_enzyme_autodiff<'ll, 'tcx>(
11451145
// Get source, diff, and attrs
11461146
let (source_id, source_args) = match fn_args.into_type_list(tcx)[0].kind() {
11471147
ty::FnDef(def_id, source_params) => (def_id, source_params),
1148-
_ => bug!("invalid args"),
1148+
_ => bug!("invalid autodiff intrinsic args"),
1149+
};
1150+
1151+
let fn_source = match Instance::try_resolve(tcx, bx.cx.typing_env(), *source_id, source_args) {
1152+
Ok(Some(instance)) => instance,
1153+
Ok(None) => bug!(
1154+
"could not resolve ({:?}, {:?}) to a specific autodiff instance",
1155+
source_id,
1156+
source_args
1157+
),
1158+
Err(_) => {
1159+
// An error has already been emitted
1160+
return;
1161+
}
11491162
};
1150-
let fn_source =
1151-
Instance::try_resolve(tcx, bx.cx.typing_env(), *source_id, source_args).unwrap().unwrap();
1163+
11521164
let source_symbol = symbol_name_for_instance_in_crate(tcx, fn_source.clone(), LOCAL_CRATE);
1153-
let fn_to_diff: Option<&'ll llvm::Value> = bx.cx.get_function(&source_symbol);
1154-
let Some(fn_to_diff) = fn_to_diff else { bug!("could not find source function") };
1165+
let Some(fn_to_diff) = bx.cx.get_function(&source_symbol) else {
1166+
bug!("could not find source function")
1167+
};
11551168

11561169
let (diff_id, diff_args) = match fn_args.into_type_list(tcx)[1].kind() {
11571170
ty::FnDef(def_id, diff_args) => (def_id, diff_args),
11581171
_ => bug!("invalid args"),
11591172
};
1160-
let fn_diff =
1161-
Instance::try_resolve(tcx, bx.cx.typing_env(), *diff_id, diff_args).unwrap().unwrap();
1162-
let val_arr: Vec<&'ll Value> = get_args_from_tuple(bx, args[2], fn_diff);
1173+
1174+
let fn_diff = match Instance::try_resolve(tcx, bx.cx.typing_env(), *diff_id, diff_args) {
1175+
Ok(Some(instance)) => instance,
1176+
Ok(None) => bug!(
1177+
"could not resolve ({:?}, {:?}) to a specific autodiff instance",
1178+
diff_id,
1179+
diff_args
1180+
),
1181+
Err(_) => {
1182+
// An error has already been emitted
1183+
return;
1184+
}
1185+
};
1186+
1187+
let val_arr = get_args_from_tuple(bx, args[2], fn_diff);
11631188
let diff_symbol = symbol_name_for_instance_in_crate(tcx, fn_diff.clone(), LOCAL_CRATE);
11641189

1165-
let diff_attrs = autodiff_attrs(tcx, fn_diff.def_id());
1166-
let Some(mut diff_attrs) = diff_attrs else { bug!("could not find autodiff attrs") };
1190+
let Some(mut diff_attrs) = autodiff_attrs(tcx, fn_diff.def_id()) else {
1191+
bug!("could not find autodiff attrs")
1192+
};
11671193

11681194
adjust_activity_to_abi(
11691195
tcx,

compiler/rustc_monomorphize/src/collector/autodiff.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ fn collect_autodiff_fn_from_arg<'tcx>(
3939

4040
(instance, span)
4141
}
42-
_ => bug!("expected function"),
42+
_ => bug!("expected autodiff function"),
4343
},
44-
_ => bug!("expected type"),
44+
_ => bug!("expected type when matching autodiff arg"),
4545
};
4646

4747
output.push(create_fn_mono_item(tcx, instance, span));

0 commit comments

Comments
 (0)