@@ -3,7 +3,6 @@ use std::cmp::Ordering;
3
3
4
4
use rustc_abi:: { Align , BackendRepr , ExternAbi , Float , HasDataLayout , Primitive , Size } ;
5
5
use rustc_codegen_ssa:: base:: { compare_simd_types, wants_msvc_seh, wants_wasm_eh} ;
6
- use rustc_codegen_ssa:: codegen_attrs:: autodiff_attrs;
7
6
use rustc_codegen_ssa:: common:: { IntPredicate , TypeKind } ;
8
7
use rustc_codegen_ssa:: errors:: { ExpectedPointerMutability , InvalidMonomorphization } ;
9
8
use rustc_codegen_ssa:: mir:: operand:: { OperandRef , OperandValue } ;
@@ -1167,7 +1166,13 @@ fn codegen_enzyme_autodiff<'ll, 'tcx>(
1167
1166
Instance :: try_resolve ( tcx, bx. cx . typing_env ( ) , * diff_id, diff_args) . unwrap ( ) . unwrap ( ) ;
1168
1167
let diff_symbol = symbol_name_for_instance_in_crate ( tcx, fn_diff. clone ( ) , LOCAL_CRATE ) ;
1169
1168
1170
- let diff_attrs = autodiff_attrs ( tcx, fn_diff. def_id ( ) ) ;
1169
+ // TODO(Sa4dUs): Store autodiff items in a single pass and just get them here
1170
+ // in a O(1) step
1171
+ let diff_attrs = tcx
1172
+ . collect_and_partition_mono_items ( ( ) )
1173
+ . autodiff_items
1174
+ . iter ( )
1175
+ . find ( |item| item. target == diff_symbol) ;
1171
1176
let Some ( diff_attrs) = diff_attrs else { bug ! ( "could not find autodiff attrs" ) } ;
1172
1177
1173
1178
// Build body
@@ -1178,7 +1183,7 @@ fn codegen_enzyme_autodiff<'ll, 'tcx>(
1178
1183
& diff_symbol,
1179
1184
llret_ty,
1180
1185
& val_arr,
1181
- diff_attrs. clone ( ) ,
1186
+ diff_attrs. attrs . clone ( ) ,
1182
1187
result,
1183
1188
) ;
1184
1189
}
@@ -1195,11 +1200,22 @@ fn get_args_from_tuple<'ll, 'tcx>(
1195
1200
for i in 0 ..tuple_place. layout . layout . 0 . fields . count ( ) {
1196
1201
let field_place = tuple_place. project_field ( bx, i) ;
1197
1202
let field_layout = tuple_place. layout . field ( bx, i) ;
1203
+ let field_ty = field_layout. ty ;
1198
1204
let llvm_ty = field_layout. llvm_type ( bx. cx ) ;
1199
1205
1200
1206
let field_val = bx. load ( llvm_ty, field_place. val . llval , field_place. val . align ) ;
1201
1207
1202
- ret_arr. push ( field_val)
1208
+ match field_ty. kind ( ) {
1209
+ ty:: Ref ( _, inner_ty, _) if matches ! ( inner_ty. kind( ) , ty:: Slice ( _) ) => {
1210
+ let ptr = bx. extract_value ( field_val, 0 ) ;
1211
+ let len = bx. extract_value ( field_val, 1 ) ;
1212
+ ret_arr. push ( ptr) ;
1213
+ ret_arr. push ( len) ;
1214
+ }
1215
+ _ => {
1216
+ ret_arr. push ( field_val) ;
1217
+ }
1218
+ }
1203
1219
}
1204
1220
1205
1221
ret_arr
0 commit comments