Skip to content

Commit cfa89c8

Browse files
committed
FIx generics error when passing fn as param to intrinsic
1 parent 880b852 commit cfa89c8

File tree

2 files changed

+46
-8
lines changed

2 files changed

+46
-8
lines changed

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@ mod llvm_enzyme {
1515
use rustc_ast::tokenstream::*;
1616
use rustc_ast::visit::AssocCtxt::*;
1717
use rustc_ast::{
18-
self as ast, AssocItemKind, BindingMode, ExprKind, FnRetTy, FnSig, Generics, ItemKind,
19-
MetaItemInner, PatKind, QSelf, TyKind, Visibility,
18+
self as ast, AngleBracketedArg, AngleBracketedArgs, AssocItemKind, BindingMode, ExprKind,
19+
FnRetTy, FnSig, GenericArg, GenericArgs, Generics, ItemKind, MetaItemInner, PatKind, Path,
20+
PathSegment, QSelf, TyKind, Visibility,
2021
};
2122
use rustc_expand::base::{Annotatable, ExtCtxt};
2223
use rustc_span::{Ident, Span, Symbol, kw, sym};
@@ -336,8 +337,14 @@ mod llvm_enzyme {
336337
&generics,
337338
);
338339

339-
let d_body =
340-
call_enzyme_autodiff(ecx, primal, first_ident(&meta_item_vec[0]), span, &d_sig);
340+
let d_body = call_enzyme_autodiff(
341+
ecx,
342+
primal,
343+
first_ident(&meta_item_vec[0]),
344+
span,
345+
&d_sig,
346+
&generics,
347+
);
341348

342349
// The first element of it is the name of the function to be generated
343350
let asdf = Box::new(ast::Fn {
@@ -504,9 +511,10 @@ mod llvm_enzyme {
504511
diff: Ident,
505512
span: Span,
506513
d_sig: &FnSig,
514+
generics: &Generics,
507515
) -> P<ast::Block> {
508-
let primal_path_expr = ecx.expr_path(ecx.path_ident(span, primal));
509-
let diff_path_expr = ecx.expr_path(ecx.path_ident(span, diff));
516+
let primal_path_expr = gen_turbofish_expr(ecx, primal, generics, span);
517+
let diff_path_expr = gen_turbofish_expr(ecx, diff, generics, span);
510518

511519
let tuple_expr = ecx.expr_tuple(
512520
span,
@@ -541,6 +549,37 @@ mod llvm_enzyme {
541549
block
542550
}
543551

552+
// Generate turbofish expression from fn name and generics
553+
// Given `foo` and `<A, B, C>`, gen `foo::<A, B, C>`
554+
fn gen_turbofish_expr(
555+
ecx: &ExtCtxt<'_>,
556+
ident: Ident,
557+
generics: &Generics,
558+
span: Span,
559+
) -> P<ast::Expr> {
560+
let generic_args = generics
561+
.params
562+
.iter()
563+
.map(|p| {
564+
let path = ast::Path::from_ident(p.ident);
565+
let ty = ecx.ty_path(path);
566+
AngleBracketedArg::Arg(GenericArg::Type(ty))
567+
})
568+
.collect::<ThinVec<_>>();
569+
570+
let args = AngleBracketedArgs { span, args: generic_args };
571+
572+
let segment = PathSegment {
573+
ident,
574+
id: ast::DUMMY_NODE_ID,
575+
args: Some(P(GenericArgs::AngleBracketed(args))),
576+
};
577+
578+
let path = Path { span, segments: thin_vec![segment], tokens: None };
579+
580+
ecx.expr_path(path)
581+
}
582+
544583
// Generate dummy const to prevent primal function
545584
// from being optimized away before applying enzyme
546585
// ```

compiler/rustc_hir_analysis/src/check/intrinsic.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,6 @@ pub(crate) fn check_intrinsic_type(
172172
}
173173
};
174174

175-
let has_autodiff = tcx.has_attr(intrinsic_id, sym::rustc_autodiff);
176-
177175
let bound_vars = tcx.mk_bound_variable_kinds(&[
178176
ty::BoundVariableKind::Region(ty::BoundRegionKind::Anon),
179177
ty::BoundVariableKind::Region(ty::BoundRegionKind::Anon),
@@ -198,6 +196,7 @@ pub(crate) fn check_intrinsic_type(
198196
(Ty::new_ref(tcx, env_region, va_list_ty, mutbl), va_list_ty)
199197
};
200198

199+
let safety = intrinsic_operation_unsafety(tcx, intrinsic_id);
201200
let n_lts = 0;
202201
let (n_tps, n_cts, inputs, output) = match intrinsic_name {
203202
sym::enzyme_autodiff => (4, 0, vec![param(0), param(1), param(2)], param(3)),

0 commit comments

Comments
 (0)