Skip to content

Commit b7fdb7b

Browse files
committed
FIx generics error when passing fn as param to intrinsic
1 parent 38b27a2 commit b7fdb7b

File tree

2 files changed

+46
-8
lines changed

2 files changed

+46
-8
lines changed

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@ mod llvm_enzyme {
1616
use rustc_ast::tokenstream::*;
1717
use rustc_ast::visit::AssocCtxt::*;
1818
use rustc_ast::{
19-
self as ast, AssocItemKind, BindingMode, ExprKind, FnRetTy, FnSig, Generics, ItemKind,
20-
MetaItemInner, PatKind, QSelf, TyKind, Visibility,
19+
self as ast, AngleBracketedArg, AngleBracketedArgs, AssocItemKind, BindingMode, ExprKind,
20+
FnRetTy, FnSig, GenericArg, GenericArgs, Generics, ItemKind, MetaItemInner, PatKind, Path,
21+
PathSegment, QSelf, TyKind, Visibility,
2122
};
2223
use rustc_expand::base::{Annotatable, ExtCtxt};
2324
use rustc_span::{Ident, Span, Symbol, kw, sym};
@@ -337,8 +338,14 @@ mod llvm_enzyme {
337338
&generics,
338339
);
339340

340-
let d_body =
341-
call_enzyme_autodiff(ecx, primal, first_ident(&meta_item_vec[0]), span, &d_sig);
341+
let d_body = call_enzyme_autodiff(
342+
ecx,
343+
primal,
344+
first_ident(&meta_item_vec[0]),
345+
span,
346+
&d_sig,
347+
&generics,
348+
);
342349

343350
// The first element of it is the name of the function to be generated
344351
let asdf = Box::new(ast::Fn {
@@ -505,9 +512,10 @@ mod llvm_enzyme {
505512
diff: Ident,
506513
span: Span,
507514
d_sig: &FnSig,
515+
generics: &Generics,
508516
) -> P<ast::Block> {
509-
let primal_path_expr = ecx.expr_path(ecx.path_ident(span, primal));
510-
let diff_path_expr = ecx.expr_path(ecx.path_ident(span, diff));
517+
let primal_path_expr = gen_turbofish_expr(ecx, primal, generics, span);
518+
let diff_path_expr = gen_turbofish_expr(ecx, diff, generics, span);
511519

512520
let tuple_expr = ecx.expr_tuple(
513521
span,
@@ -542,6 +550,37 @@ mod llvm_enzyme {
542550
block
543551
}
544552

553+
// Generate turbofish expression from fn name and generics
554+
// Given `foo` and `<A, B, C>`, gen `foo::<A, B, C>`
555+
fn gen_turbofish_expr(
556+
ecx: &ExtCtxt<'_>,
557+
ident: Ident,
558+
generics: &Generics,
559+
span: Span,
560+
) -> P<ast::Expr> {
561+
let generic_args = generics
562+
.params
563+
.iter()
564+
.map(|p| {
565+
let path = ast::Path::from_ident(p.ident);
566+
let ty = ecx.ty_path(path);
567+
AngleBracketedArg::Arg(GenericArg::Type(ty))
568+
})
569+
.collect::<ThinVec<_>>();
570+
571+
let args = AngleBracketedArgs { span, args: generic_args };
572+
573+
let segment = PathSegment {
574+
ident,
575+
id: ast::DUMMY_NODE_ID,
576+
args: Some(P(GenericArgs::AngleBracketed(args))),
577+
};
578+
579+
let path = Path { span, segments: thin_vec![segment], tokens: None };
580+
581+
ecx.expr_path(path)
582+
}
583+
545584
// Generate dummy const to prevent primal function
546585
// from being optimized away before applying enzyme
547586
// ```

compiler/rustc_hir_analysis/src/check/intrinsic.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,6 @@ pub(crate) fn check_intrinsic_type(
171171
}
172172
};
173173

174-
let has_autodiff = tcx.has_attr(intrinsic_id, sym::rustc_autodiff);
175-
176174
let bound_vars = tcx.mk_bound_variable_kinds(&[
177175
ty::BoundVariableKind::Region(ty::BoundRegionKind::Anon),
178176
ty::BoundVariableKind::Region(ty::BoundRegionKind::Anon),
@@ -197,6 +195,7 @@ pub(crate) fn check_intrinsic_type(
197195
(Ty::new_ref(tcx, env_region, va_list_ty, mutbl), va_list_ty)
198196
};
199197

198+
let safety = intrinsic_operation_unsafety(tcx, intrinsic_id);
200199
let n_lts = 0;
201200
let (n_tps, n_cts, inputs, output) = match intrinsic_name {
202201
sym::enzyme_autodiff => (4, 0, vec![param(0), param(1), param(2)], param(3)),

0 commit comments

Comments
 (0)