Skip to content

Commit 065172e

Browse files
committed
Lower autodiff functions using instrinsics
1 parent 6de3a73 commit 065172e

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,9 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
197197
Some(instance),
198198
)
199199
}
200+
_ if tcx.has_attr(def_id, sym::rustc_autodiff) => {
201+
return Err(ty::Instance::new_raw(def_id, instance.args));
202+
}
200203
sym::is_val_statically_known => {
201204
let intrinsic_type = args[0].layout.immediate_llvm_type(self.cx);
202205
let kind = self.type_kind(intrinsic_type);

compiler/rustc_hir_analysis/src/check/intrinsic.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,8 @@ pub(crate) fn check_intrinsic_type(
174174
};
175175
let name_str = intrinsic_name.as_str();
176176

177+
let has_autodiff = tcx.has_attr(intrinsic_id, sym::rustc_autodiff);
178+
177179
let bound_vars = tcx.mk_bound_variable_kinds(&[
178180
ty::BoundVariableKind::Region(ty::BoundRegionKind::Anon),
179181
ty::BoundVariableKind::Region(ty::BoundRegionKind::Anon),
@@ -229,6 +231,17 @@ pub(crate) fn check_intrinsic_type(
229231
//
230232
// so: two type params, 0 lifetime param, 0 const params, two inputs, no return
231233
(2, 0, 0, vec![param(0), param(1)], param(1), hir::Safety::Safe)
234+
} else if has_autodiff {
235+
let sig = tcx.fn_sig(intrinsic_id.to_def_id());
236+
let sig = sig.skip_binder();
237+
let n_tps = generics.own_counts().types;
238+
let n_lts = generics.own_counts().lifetimes;
239+
let n_cts = generics.own_counts().consts;
240+
241+
let inputs = sig.skip_binder().inputs().to_vec();
242+
let output = sig.skip_binder().output();
243+
244+
(n_tps, n_lts, n_cts, inputs, output, hir::Safety::Safe)
232245
} else {
233246
let safety = intrinsic_operation_unsafety(tcx, intrinsic_id);
234247
let (n_tps, n_cts, inputs, output) = match intrinsic_name {

0 commit comments

Comments
 (0)