Skip to content

Commit c97d7f9

Browse files
committed
adjust fn signature if width > 1
1 parent b5701b1 commit c97d7f9

File tree

1 file changed

+41
-36
lines changed

1 file changed

+41
-36
lines changed

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 41 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -683,50 +683,55 @@ mod llvm_enzyme {
683683
match activity {
684684
DiffActivity::Active => {
685685
act_ret.push(arg.ty.clone());
686+
// if width =/= 1, then push [arg.ty; width] to act_ret
686687
}
687688
DiffActivity::ActiveOnly => {
688689
// We will add the active scalar to the return type.
689690
// This is handled later.
690691
}
691692
DiffActivity::Duplicated | DiffActivity::DuplicatedOnly => {
692-
let mut shadow_arg = arg.clone();
693-
// We += into the shadow in reverse mode.
694-
shadow_arg.ty = P(assure_mut_ref(&arg.ty));
695-
let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
696-
ident.name
697-
} else {
698-
debug!("{:#?}", &shadow_arg.pat);
699-
panic!("not an ident?");
700-
};
701-
let name: String = format!("d{}", old_name);
702-
new_inputs.push(name.clone());
703-
let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
704-
shadow_arg.pat = P(ast::Pat {
705-
id: ast::DUMMY_NODE_ID,
706-
kind: PatKind::Ident(BindingMode::NONE, ident, None),
707-
span: shadow_arg.pat.span,
708-
tokens: shadow_arg.pat.tokens.clone(),
709-
});
710-
d_inputs.push(shadow_arg);
693+
for i in 0..x.width {
694+
let mut shadow_arg = arg.clone();
695+
// We += into the shadow in reverse mode.
696+
shadow_arg.ty = P(assure_mut_ref(&arg.ty));
697+
let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
698+
ident.name
699+
} else {
700+
debug!("{:#?}", &shadow_arg.pat);
701+
panic!("not an ident?");
702+
};
703+
let name: String = format!("d{}_{}", old_name, i);
704+
new_inputs.push(name.clone());
705+
let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
706+
shadow_arg.pat = P(ast::Pat {
707+
id: ast::DUMMY_NODE_ID,
708+
kind: PatKind::Ident(BindingMode::NONE, ident, None),
709+
span: shadow_arg.pat.span,
710+
tokens: shadow_arg.pat.tokens.clone(),
711+
});
712+
d_inputs.push(shadow_arg.clone());
713+
}
711714
}
712715
DiffActivity::Dual | DiffActivity::DualOnly => {
713-
let mut shadow_arg = arg.clone();
714-
let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
715-
ident.name
716-
} else {
717-
debug!("{:#?}", &shadow_arg.pat);
718-
panic!("not an ident?");
719-
};
720-
let name: String = format!("b{}", old_name);
721-
new_inputs.push(name.clone());
722-
let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
723-
shadow_arg.pat = P(ast::Pat {
724-
id: ast::DUMMY_NODE_ID,
725-
kind: PatKind::Ident(BindingMode::NONE, ident, None),
726-
span: shadow_arg.pat.span,
727-
tokens: shadow_arg.pat.tokens.clone(),
728-
});
729-
d_inputs.push(shadow_arg);
716+
for i in 0..x.width {
717+
let mut shadow_arg = arg.clone();
718+
let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
719+
ident.name
720+
} else {
721+
debug!("{:#?}", &shadow_arg.pat);
722+
panic!("not an ident?");
723+
};
724+
let name: String = format!("b{}_{}", old_name, i);
725+
new_inputs.push(name.clone());
726+
let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
727+
shadow_arg.pat = P(ast::Pat {
728+
id: ast::DUMMY_NODE_ID,
729+
kind: PatKind::Ident(BindingMode::NONE, ident, None),
730+
span: shadow_arg.pat.span,
731+
tokens: shadow_arg.pat.tokens.clone(),
732+
});
733+
d_inputs.push(shadow_arg.clone());
734+
}
730735
}
731736
DiffActivity::Const => {
732737
// Nothing to do here.

0 commit comments

Comments
 (0)