@@ -484,7 +484,20 @@ mod llvm_enzyme {
484484
485485 if primal_ret && n_active == 0 && x. mode . is_rev ( ) {
486486 // We only have the primal ret.
487- body. stmts . push ( ecx. stmt_expr ( black_box_primal_call. clone ( ) ) ) ;
487+ dbg ! ( & primal_call) ;
488+ if x. width > 1 {
489+ // We have to return [T; width], thus add `[` and `]` and repeat ret
490+ // width times.
491+ let mut rets = ThinVec :: new ( ) ;
492+ for _ in 0 ..x. width {
493+ rets. push ( primal_call. clone ( ) ) ;
494+ }
495+ let exprs = ecx. expr_array ( span, rets) ;
496+ let ret = ecx. expr_call ( new_decl_span, blackbox_call_expr. clone ( ) , thin_vec ! [ exprs] ) ;
497+ body. stmts . push ( ecx. stmt_expr ( ret) ) ;
498+ } else {
499+ body. stmts . push ( ecx. stmt_expr ( black_box_primal_call. clone ( ) ) ) ;
500+ }
488501 return body;
489502 }
490503
@@ -576,10 +589,23 @@ mod llvm_enzyme {
576589 return body;
577590 }
578591 [ arg] => {
592+ // if width > 1, then we need to return [T; width], thus add `[` and `]` and repeat ret
593+ // width times.
594+ let exprs;
595+ if x. width > 1 {
596+ let mut rets = ThinVec :: new ( ) ;
597+ for _ in 0 ..x. width {
598+ rets. push ( arg. clone ( ) ) ;
599+ }
600+ exprs = ecx. expr_array ( span, rets) ;
601+ } else {
602+ exprs = arg. clone ( ) ;
603+ }
604+ dbg ! ( & exprs) ;
579605 ret = ecx. expr_call (
580606 new_decl_span,
581607 blackbox_call_expr. clone ( ) ,
582- thin_vec ! [ arg . clone ( ) ] ,
608+ thin_vec ! [ exprs ] ,
583609 ) ;
584610 }
585611 args => {
@@ -807,6 +833,20 @@ mod llvm_enzyme {
807833 }
808834 }
809835
836+ // If we have a width > 1, then we don't return -> T, but -> [T; width]
837+ if x. width > 1 && has_ret {
838+ let ty = match d_decl. output {
839+ FnRetTy :: Ty ( ref ty) => ty. clone ( ) ,
840+ FnRetTy :: Default ( span) => {
841+ panic ! ( "Did not expect Default ret ty: {:?}" , span) ;
842+ }
843+ } ;
844+ let anon_const = rustc_ast:: AnonConst { id : ast:: DUMMY_NODE_ID , value : ecx. expr_usize ( span, x. width as usize ) } ;
845+ let kind = TyKind :: Array ( ty. clone ( ) , anon_const) ;
846+ let ty = P ( rustc_ast:: Ty { kind, id : ty. id , span : ty. span , tokens : None } ) ;
847+ d_decl. output = FnRetTy :: Ty ( ty) ;
848+ }
849+
810850 // If we use ActiveOnly, drop the original return value.
811851 d_decl. output =
812852 if active_only_ret { FnRetTy :: Default ( span) } else { d_decl. output . clone ( ) } ;
0 commit comments