@@ -25,6 +25,7 @@ pub use generic_args::{GenericArgKind, TermKind, *};
2525pub use generics:: * ;
2626pub use intrinsic:: IntrinsicDef ;
2727use rustc_abi:: { Align , FieldIdx , Integer , IntegerType , ReprFlags , ReprOptions , VariantIdx } ;
28+ use rustc_ast:: expand:: typetree:: { FncTree , Kind , Type , TypeTree } ;
2829use rustc_ast:: node_id:: NodeMap ;
2930pub use rustc_ast_ir:: { Movability , Mutability , try_visit} ;
3031use rustc_data_structures:: fx:: { FxHashMap , FxHashSet , FxIndexMap , FxIndexSet } ;
@@ -2222,3 +2223,82 @@ pub struct DestructuredConst<'tcx> {
22222223 pub variant : Option < VariantIdx > ,
22232224 pub fields : & ' tcx [ ty:: Const < ' tcx > ] ,
22242225}
2226+
2227+ /// Generate TypeTree information for autodiff.
2228+ /// This function creates TypeTree metadata that describes the memory layout
2229+ /// of function parameters and return types for Enzyme autodiff.
2230+ pub fn fnc_typetrees < ' tcx > ( tcx : TyCtxt < ' tcx > , fn_ty : Ty < ' tcx > ) -> FncTree {
2231+ // Check if TypeTrees are disabled via NoTT flag
2232+ if tcx. sess . opts . unstable_opts . autodiff . contains ( & rustc_session:: config:: AutoDiff :: NoTT ) {
2233+ return FncTree { args : vec ! [ ] , ret : TypeTree :: new ( ) } ;
2234+ }
2235+
2236+ // Check if this is actually a function type
2237+ if !fn_ty. is_fn ( ) {
2238+ return FncTree { args : vec ! [ ] , ret : TypeTree :: new ( ) } ;
2239+ }
2240+
2241+ // Get the function signature
2242+ let fn_sig = fn_ty. fn_sig ( tcx) ;
2243+ let sig = tcx. instantiate_bound_regions_with_erased ( fn_sig) ;
2244+
2245+ // Create TypeTrees for each input parameter
2246+ let mut args = vec ! [ ] ;
2247+ for ty in sig. inputs ( ) . iter ( ) {
2248+ let type_tree = typetree_from_ty ( tcx, * ty) ;
2249+ args. push ( type_tree) ;
2250+ }
2251+
2252+ // Create TypeTree for return type
2253+ let ret = typetree_from_ty ( tcx, sig. output ( ) ) ;
2254+
2255+ FncTree { args, ret }
2256+ }
2257+
2258+ /// Generate TypeTree for a specific type.
2259+ /// This function analyzes a Rust type and creates appropriate TypeTree metadata.
2260+ fn typetree_from_ty < ' tcx > ( tcx : TyCtxt < ' tcx > , ty : Ty < ' tcx > ) -> TypeTree {
2261+ // Handle basic scalar types
2262+ if ty. is_scalar ( ) {
2263+ let ( kind, size) = if ty. is_integral ( ) || ty. is_char ( ) || ty. is_bool ( ) {
2264+ ( Kind :: Integer , ty. primitive_size ( tcx) . bytes_usize ( ) )
2265+ } else if ty. is_floating_point ( ) {
2266+ match ty {
2267+ x if x == tcx. types . f32 => ( Kind :: Float , 4 ) ,
2268+ x if x == tcx. types . f64 => ( Kind :: Double , 8 ) ,
2269+ _ => return TypeTree :: new ( ) , // Unknown float type
2270+ }
2271+ } else {
2272+ // TODO(KMJ-007): Handle other scalar types if needed
2273+ return TypeTree :: new ( ) ;
2274+ } ;
2275+
2276+ return TypeTree ( vec ! [ Type {
2277+ offset: -1 ,
2278+ size,
2279+ kind,
2280+ child: TypeTree :: new( )
2281+ } ] ) ;
2282+ }
2283+
2284+ // Handle references and pointers
2285+ if ty. is_ref ( ) || ty. is_raw_ptr ( ) || ty. is_box ( ) {
2286+ let inner_ty = if let Some ( inner) = ty. builtin_deref ( true ) {
2287+ inner
2288+ } else {
2289+ // TODO(KMJ-007): Handle complex pointer types
2290+ return TypeTree :: new ( ) ;
2291+ } ;
2292+
2293+ let child = typetree_from_ty ( tcx, inner_ty) ;
2294+ return TypeTree ( vec ! [ Type {
2295+ offset: -1 ,
2296+ size: 8 , // TODO(KMJ-007): Get actual pointer size from target
2297+ kind: Kind :: Pointer ,
2298+ child,
2299+ } ] ) ;
2300+ }
2301+
2302+ // TODO(KMJ-007): Handle arrays, slices, structs, and other complex types
2303+ TypeTree :: new ( )
2304+ }
0 commit comments