@@ -3,14 +3,14 @@ use std::ptr;
33use rustc_ast:: expand:: autodiff_attrs:: { AutoDiffAttrs , AutoDiffItem , DiffActivity , DiffMode } ;
44use rustc_codegen_ssa:: ModuleCodegen ;
55use rustc_codegen_ssa:: back:: write:: ModuleConfig ;
6- use rustc_errors:: FatalError ;
6+ use rustc_errors:: { DiagCtxt , FatalError } ;
77use tracing:: { debug, trace} ;
88
99use crate :: back:: write:: llvm_err;
1010use crate :: builder:: SBuilder ;
1111use crate :: context:: SimpleCx ;
1212use crate :: declare:: declare_simple_fn;
13- use crate :: errors:: { AutoDiffWithoutEnable , LlvmError } ;
13+ use crate :: errors:: { AutoDiffUnusedArgs , AutoDiffWithoutEnable , LlvmError } ;
1414use crate :: llvm:: AttributePlace :: Function ;
1515use crate :: llvm:: { Metadata , True } ;
1616use crate :: value:: Value ;
@@ -27,6 +27,69 @@ fn get_params(fnc: &Value) -> Vec<&Value> {
2727 }
2828}
2929
30+ // A helper object to make sure, that we pass all of the input and output arguments of the outer
31+ // wrapper into the inner enzyme call.
32+ struct FunctionArgs {
33+ input_args : Vec < bool > ,
34+ return_arg : bool ,
35+ has_sret : bool ,
36+ }
37+
38+
39+ impl FunctionArgs {
40+ fn fully_used ( & self ) -> bool {
41+ self . input_args . iter ( ) . all ( |x| * x) && self . return_arg
42+ }
43+
44+ fn use_input ( & mut self , idx : usize ) {
45+ assert ! ( !self . input_args[ idx] ) ;
46+ self . input_args [ idx] = true ;
47+ }
48+
49+ fn use_output ( & mut self ) {
50+ assert ! ( !self . return_arg) ;
51+ self . return_arg = true ;
52+ }
53+
54+ fn has_sret ( & self ) -> bool {
55+ self . has_sret
56+ }
57+ }
58+
59+ impl < ' ll > From < ( & ' ll llvm:: Context , & ' ll Value ) > for FunctionArgs {
60+ fn from ( wrapper : ( & ' ll llvm:: Context , & ' ll Value ) ) -> Self {
61+ let ( llcx, fnc) = wrapper;
62+ let num_args = unsafe { llvm:: LLVMCountParams ( fnc) as usize } ;
63+ let input_args = vec ! [ false ; num_args] ;
64+
65+ let fn_ty = unsafe { llvm:: LLVMGlobalGetValueType ( fnc) } ;
66+ let ret_ty = unsafe { llvm:: LLVMGetReturnType ( fn_ty) } ;
67+ let void_ty = unsafe { llvm:: LLVMVoidTypeInContext ( llcx) } ;
68+ let return_arg = ret_ty != void_ty;
69+
70+
71+ let has_sret = if num_args == 0 {
72+ false
73+ } else {
74+ unsafe { llvm:: LLVMRustHasAttributeAtIndex ( fnc, 0 , llvm:: AttributeKind :: StructRet ) }
75+ } ;
76+ if has_sret {
77+ dbg ! ( "has sret" ) ;
78+ } else {
79+ dbg ! ( "no sret" ) ;
80+ }
81+
82+ FunctionArgs { input_args, return_arg, has_sret }
83+ }
84+ }
85+
86+ // The drop implementation makes sure that when we're done, all input and output args are matched.
87+ //impl Drop for FunctionArgs {
88+ // fn drop(&mut self) {
89+ // assert!(self.fully_used());
90+ // }
91+ //}
92+
3093/// When differentiating `fn_to_diff`, take a `outer_fn` and generate another
3194/// function with expected naming and calling conventions[^1] which will be
3295/// discovered by the enzyme LLVM pass and its body populated with the differentiated
@@ -37,14 +100,18 @@ fn get_params(fnc: &Value) -> Vec<&Value> {
37100// FIXME(ZuseZ4): `outer_fn` should include upstream safety checks to
38101// cover some assumptions of enzyme/autodiff, which could lead to UB otherwise.
39102fn generate_enzyme_call < ' ll > (
103+ _dcx : & DiagCtxt ,
40104 cx : & SimpleCx < ' ll > ,
41105 fn_to_diff : & ' ll Value ,
42106 outer_fn : & ' ll Value ,
43107 attrs : AutoDiffAttrs ,
44- ) {
108+ ) -> Result < ( ) , FatalError >
109+ {
45110 let inputs = attrs. input_activity ;
46111 let output = attrs. ret_activity ;
47112
113+ let fa = FunctionArgs :: from ( ( cx. llcx , outer_fn) ) ;
114+
48115 // We have to pick the name depending on whether we want forward or reverse mode autodiff.
49116 let mut ad_name: String = match attrs. mode {
50117 DiffMode :: Forward => "__enzyme_fwddiff" ,
@@ -99,6 +166,9 @@ fn generate_enzyme_call<'ll>(
99166 let fn_ty = llvm:: LLVMGlobalGetValueType ( outer_fn) ;
100167 let ret_ty = llvm:: LLVMGetReturnType ( fn_ty) ;
101168
169+ dbg ! ( & outer_fn) ;
170+ dbg ! ( & fn_ty) ;
171+
102172 // LLVM can figure out the input types on it's own, so we take a shortcut here.
103173 let enzyme_ty = llvm:: LLVMFunctionType ( ret_ty, ptr:: null ( ) , 0 , True ) ;
104174
@@ -163,6 +233,15 @@ fn generate_enzyme_call<'ll>(
163233 // using iterators and peek()?
164234 let mut outer_pos: usize = 0 ;
165235 let mut activity_pos = 0 ;
236+
237+ if fa. has_sret ( ) {
238+ // Then the first outer arg is the sret pointer. Enzyme doesn't know about sret, so the
239+ // inner function will still return something. We increase our outer_pos by one,
240+ // and once we're done with all other args we will take the return of the inner call and
241+ // update the sret pointer with it
242+ outer_pos = 1 ;
243+ }
244+
166245 let outer_args: Vec < & llvm:: Value > = get_params ( outer_fn) ;
167246 while activity_pos < inputs. len ( ) {
168247 let diff_activity = inputs[ activity_pos as usize ] ;
@@ -293,6 +372,11 @@ fn generate_enzyme_call<'ll>(
293372 llvm:: LLVMRustVerifierFailureAction :: LLVMAbortProcessAction ,
294373 ) ;
295374 }
375+
376+ //if !fa.fully_used() {
377+ // return Err(dcx.handle().emit_almost_fatal(AutoDiffUnusedArgs));
378+ //}
379+ Ok ( ( ) )
296380}
297381
298382pub ( crate ) fn differentiate < ' ll > (
@@ -312,8 +396,7 @@ pub(crate) fn differentiate<'ll>(
312396 if !diff_items. is_empty ( )
313397 && !cgcx. opts . unstable_opts . autodiff . contains ( & rustc_session:: config:: AutoDiff :: Enable )
314398 {
315- let dcx = cgcx. create_dcx ( ) ;
316- return Err ( dcx. handle ( ) . emit_almost_fatal ( AutoDiffWithoutEnable ) ) ;
399+ return Err ( diag_handler. handle ( ) . emit_almost_fatal ( AutoDiffWithoutEnable ) ) ;
317400 }
318401
319402 // Before dumping the module, we want all the TypeTrees to become part of the module.
@@ -343,7 +426,7 @@ pub(crate) fn differentiate<'ll>(
343426 ) ) ;
344427 } ;
345428
346- generate_enzyme_call ( & cx, fn_def, fn_target, item. attrs . clone ( ) ) ;
429+ generate_enzyme_call ( & diag_handler , & cx, fn_def, fn_target, item. attrs . clone ( ) ) ? ;
347430 }
348431
349432 // FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts
0 commit comments