@@ -25,6 +25,7 @@ pub use generic_args::{GenericArgKind, TermKind, *};
25
25
pub use generics:: * ;
26
26
pub use intrinsic:: IntrinsicDef ;
27
27
use rustc_abi:: { Align , FieldIdx , Integer , IntegerType , ReprFlags , ReprOptions , VariantIdx } ;
28
+ use rustc_ast:: expand:: typetree:: { FncTree , Kind , Type , TypeTree } ;
28
29
use rustc_ast:: node_id:: NodeMap ;
29
30
pub use rustc_ast_ir:: { Movability , Mutability , try_visit} ;
30
31
use rustc_data_structures:: fx:: { FxHashMap , FxHashSet , FxIndexMap , FxIndexSet } ;
@@ -2216,3 +2217,82 @@ pub struct DestructuredConst<'tcx> {
2216
2217
pub variant : Option < VariantIdx > ,
2217
2218
pub fields : & ' tcx [ ty:: Const < ' tcx > ] ,
2218
2219
}
2220
+
2221
+ /// Generate TypeTree information for autodiff.
2222
+ /// This function creates TypeTree metadata that describes the memory layout
2223
+ /// of function parameters and return types for Enzyme autodiff.
2224
+ pub fn fnc_typetrees < ' tcx > ( tcx : TyCtxt < ' tcx > , fn_ty : Ty < ' tcx > ) -> FncTree {
2225
+ // Check if TypeTrees are disabled via NoTT flag
2226
+ if tcx. sess . opts . unstable_opts . autodiff . contains ( & rustc_session:: config:: AutoDiff :: NoTT ) {
2227
+ return FncTree { args : vec ! [ ] , ret : TypeTree :: new ( ) } ;
2228
+ }
2229
+
2230
+ // Check if this is actually a function type
2231
+ if !fn_ty. is_fn ( ) {
2232
+ return FncTree { args : vec ! [ ] , ret : TypeTree :: new ( ) } ;
2233
+ }
2234
+
2235
+ // Get the function signature
2236
+ let fn_sig = fn_ty. fn_sig ( tcx) ;
2237
+ let sig = tcx. instantiate_bound_regions_with_erased ( fn_sig) ;
2238
+
2239
+ // Create TypeTrees for each input parameter
2240
+ let mut args = vec ! [ ] ;
2241
+ for ty in sig. inputs ( ) . iter ( ) {
2242
+ let type_tree = typetree_from_ty ( tcx, * ty) ;
2243
+ args. push ( type_tree) ;
2244
+ }
2245
+
2246
+ // Create TypeTree for return type
2247
+ let ret = typetree_from_ty ( tcx, sig. output ( ) ) ;
2248
+
2249
+ FncTree { args, ret }
2250
+ }
2251
+
2252
+ /// Generate TypeTree for a specific type.
2253
+ /// This function analyzes a Rust type and creates appropriate TypeTree metadata.
2254
+ fn typetree_from_ty < ' tcx > ( tcx : TyCtxt < ' tcx > , ty : Ty < ' tcx > ) -> TypeTree {
2255
+ // Handle basic scalar types
2256
+ if ty. is_scalar ( ) {
2257
+ let ( kind, size) = if ty. is_integral ( ) || ty. is_char ( ) || ty. is_bool ( ) {
2258
+ ( Kind :: Integer , ty. primitive_size ( tcx) . bytes_usize ( ) )
2259
+ } else if ty. is_floating_point ( ) {
2260
+ match ty {
2261
+ x if x == tcx. types . f32 => ( Kind :: Float , 4 ) ,
2262
+ x if x == tcx. types . f64 => ( Kind :: Double , 8 ) ,
2263
+ _ => return TypeTree :: new ( ) , // Unknown float type
2264
+ }
2265
+ } else {
2266
+ // TODO(KMJ-007): Handle other scalar types if needed
2267
+ return TypeTree :: new ( ) ;
2268
+ } ;
2269
+
2270
+ return TypeTree ( vec ! [ Type {
2271
+ offset: -1 ,
2272
+ size,
2273
+ kind,
2274
+ child: TypeTree :: new( )
2275
+ } ] ) ;
2276
+ }
2277
+
2278
+ // Handle references and pointers
2279
+ if ty. is_ref ( ) || ty. is_raw_ptr ( ) || ty. is_box ( ) {
2280
+ let inner_ty = if let Some ( inner) = ty. builtin_deref ( true ) {
2281
+ inner
2282
+ } else {
2283
+ // TODO(KMJ-007): Handle complex pointer types
2284
+ return TypeTree :: new ( ) ;
2285
+ } ;
2286
+
2287
+ let child = typetree_from_ty ( tcx, inner_ty) ;
2288
+ return TypeTree ( vec ! [ Type {
2289
+ offset: -1 ,
2290
+ size: 8 , // TODO(KMJ-007): Get actual pointer size from target
2291
+ kind: Kind :: Pointer ,
2292
+ child,
2293
+ } ] ) ;
2294
+ }
2295
+
2296
+ // TODO(KMJ-007): Handle arrays, slices, structs, and other complex types
2297
+ TypeTree :: new ( )
2298
+ }
0 commit comments