Skip to content

Commit 8f2fa11

Browse files
committed
update inner autodiff call
1 parent 527229a commit 8f2fa11

File tree

2 files changed

+47
-4
lines changed

2 files changed

+47
-4
lines changed

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,20 @@ mod llvm_enzyme {
484484

485485
if primal_ret && n_active == 0 && x.mode.is_rev() {
486486
// We only have the primal ret.
487-
body.stmts.push(ecx.stmt_expr(black_box_primal_call.clone()));
487+
dbg!(&primal_call);
488+
if x.width > 1 {
489+
// We have to return [T; width], thus add `[` and `]` and repeat ret
490+
// width times.
491+
let mut rets = ThinVec::new();
492+
for _ in 0..x.width {
493+
rets.push(primal_call.clone());
494+
}
495+
let exprs = ecx.expr_array(span, rets);
496+
let ret = ecx.expr_call(new_decl_span, blackbox_call_expr.clone(), thin_vec![exprs]);
497+
body.stmts.push(ecx.stmt_expr(ret));
498+
} else {
499+
body.stmts.push(ecx.stmt_expr(black_box_primal_call.clone()));
500+
}
488501
return body;
489502
}
490503

@@ -576,10 +589,23 @@ mod llvm_enzyme {
576589
return body;
577590
}
578591
[arg] => {
592+
// if width > 1, then we need to return [T; width], thus add `[` and `]` and repeat ret
593+
// width times.
594+
let exprs;
595+
if x.width > 1 {
596+
let mut rets = ThinVec::new();
597+
for _ in 0..x.width {
598+
rets.push(arg.clone());
599+
}
600+
exprs = ecx.expr_array(span, rets);
601+
} else {
602+
exprs = arg.clone();
603+
}
604+
dbg!(&exprs);
579605
ret = ecx.expr_call(
580606
new_decl_span,
581607
blackbox_call_expr.clone(),
582-
thin_vec![arg.clone()],
608+
thin_vec![exprs],
583609
);
584610
}
585611
args => {
@@ -807,6 +833,20 @@ mod llvm_enzyme {
807833
}
808834
}
809835

836+
// If we have a width > 1, then we don't return -> T, but -> [T; width]
837+
if x.width > 1 && has_ret {
838+
let ty = match d_decl.output {
839+
FnRetTy::Ty(ref ty) => ty.clone(),
840+
FnRetTy::Default(span) => {
841+
panic!("Did not expect Default ret ty: {:?}", span);
842+
}
843+
};
844+
let anon_const = rustc_ast::AnonConst { id: ast::DUMMY_NODE_ID, value: ecx.expr_usize(span, x.width as usize) };
845+
let kind = TyKind::Array(ty.clone(), anon_const);
846+
let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
847+
d_decl.output = FnRetTy::Ty(ty);
848+
}
849+
810850
// If we use ActiveOnly, drop the original return value.
811851
d_decl.output =
812852
if active_only_ret { FnRetTy::Default(span) } else { d_decl.output.clone() };

tests/codegen/autodiffv.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,11 @@ fn main() {
4141
let mut df_dx2 = 0.0;
4242
let mut df_dx3 = 0.0;
4343
let mut df_dx4 = 0.0;
44-
let output_ = d_square(&x, &mut df_dx1, &mut df_dx2, &mut df_dx3, &mut df_dx4, 1.0);
45-
assert_eq!(output, output_);
44+
let [o1;o2;o3;o4] = d_square(&x, &mut df_dx1, &mut df_dx2, &mut df_dx3, &mut df_dx4, 1.0);
45+
assert_eq!(output, o1);
46+
assert_eq!(output, o2);
47+
assert_eq!(output, o3);
48+
assert_eq!(output, o4);
4649
assert_eq!(6.0, df_dx1);
4750
assert_eq!(6.0, df_dx2);
4851
assert_eq!(6.0, df_dx3);

0 commit comments

Comments
 (0)