Skip to content

Commit c29fb2e

Browse files
authored
Rollup merge of #144197 - KMJ-007:type-tree, r=ZuseZ4
TypeTree support in autodiff # TypeTrees for Autodiff ## What are TypeTrees? Memory layout descriptors for Enzyme. Tell Enzyme exactly how types are structured in memory so it can compute derivatives efficiently. ## Structure ```rust TypeTree(Vec<Type>) Type { offset: isize, // byte offset (-1 = everywhere) size: usize, // size in bytes kind: Kind, // Float, Integer, Pointer, etc. child: TypeTree // nested structure } ``` ## Example: `fn compute(x: &f32, data: &[f32]) -> f32` **Input 0: `x: &f32`** ```rust TypeTree(vec![Type { offset: -1, size: 8, kind: Pointer, child: TypeTree(vec![Type { offset: -1, size: 4, kind: Float, child: TypeTree::new() }]) }]) ``` **Input 1: `data: &[f32]`** ```rust TypeTree(vec![Type { offset: -1, size: 8, kind: Pointer, child: TypeTree(vec![Type { offset: -1, size: 4, kind: Float, // -1 = all elements child: TypeTree::new() }]) }]) ``` **Output: `f32`** ```rust TypeTree(vec![Type { offset: -1, size: 4, kind: Float, child: TypeTree::new() }]) ``` ## Why Needed? - Enzyme can't deduce complex type layouts from LLVM IR - Prevents slow memory pattern analysis - Enables correct derivative computation for nested structures - Tells Enzyme which bytes are differentiable vs metadata ## What Enzyme Does With This Information: Without TypeTrees (current state): ```llvm ; Enzyme sees generic LLVM IR: define float ``@distance(ptr*`` %p1, ptr* %p2) { ; Has to guess what these pointers point to ; Slow analysis of all memory operations ; May miss optimization opportunities } ``` With TypeTrees (our implementation): ```llvm define "enzyme_type"="{[]:Float@float}" float ``@distance(`` ptr "enzyme_type"="{[]:Pointer}" %p1, ptr "enzyme_type"="{[]:Pointer}" %p2 ) { ; Enzyme knows exact type layout ; Can generate efficient derivative code directly } ``` # TypeTrees - Offset and -1 Explained ## Type Structure ```rust Type { offset: isize, // WHERE this type starts size: usize, // HOW BIG this type is kind: Kind, // WHAT KIND of data (Float, Int, Pointer) child: TypeTree // WHAT'S INSIDE (for pointers/containers) } ``` ## Offset Values ### Regular Offset (0, 4, 8, etc.) **Specific byte position within a structure** ```rust struct Point { x: f32, // offset 0, size 4 y: f32, // offset 4, size 4 id: i32, // offset 8, size 4 } ``` TypeTree for `&Point` (internal representation): ```rust TypeTree(vec![ Type { offset: 0, size: 4, kind: Float }, // x at byte 0 Type { offset: 4, size: 4, kind: Float }, // y at byte 4 Type { offset: 8, size: 4, kind: Integer } // id at byte 8 ]) ``` Generates LLVM: ```llvm "enzyme_type"="{[]:Float@float}" ``` ### Offset -1 (Special: "Everywhere") **Means "this pattern repeats for ALL elements"** #### Example 1: Array `[f32; 100]` ```rust TypeTree(vec![Type { offset: -1, // ALL positions size: 4, // each f32 is 4 bytes kind: Float, // every element is float }]) ``` Instead of listing 100 separate Types with offsets `0,4,8,12...396` #### Example 2: Slice `&[i32]` ```rust // Pointer to slice data TypeTree(vec![Type { offset: -1, size: 8, kind: Pointer, child: TypeTree(vec![Type { offset: -1, // ALL slice elements size: 4, // each i32 is 4 bytes kind: Integer }]) }]) ``` #### Example 3: Mixed Structure ```rust struct Container { header: i64, // offset 0 data: [f32; 1000], // offset 8, but elements use -1 } ``` ```rust TypeTree(vec![ Type { offset: 0, size: 8, kind: Integer }, // header Type { offset: 8, size: 4000, kind: Pointer, child: TypeTree(vec![Type { offset: -1, size: 4, kind: Float // ALL array elements }]) } ]) ```
2 parents 6059195 + 3ba5f19 commit c29fb2e

File tree

68 files changed

+1250
-14
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

68 files changed

+1250
-14
lines changed

compiler/rustc_ast/src/expand/autodiff_attrs.rs

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
use std::fmt::{self, Display, Formatter};
77
use std::str::FromStr;
88

9+
use crate::expand::typetree::TypeTree;
910
use crate::expand::{Decodable, Encodable, HashStable_Generic};
1011
use crate::{Ty, TyKind};
1112

@@ -84,6 +85,8 @@ pub struct AutoDiffItem {
8485
/// The name of the function being generated
8586
pub target: String,
8687
pub attrs: AutoDiffAttrs,
88+
pub inputs: Vec<TypeTree>,
89+
pub output: TypeTree,
8790
}
8891

8992
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
@@ -275,14 +278,22 @@ impl AutoDiffAttrs {
275278
!matches!(self.mode, DiffMode::Error | DiffMode::Source)
276279
}
277280

278-
pub fn into_item(self, source: String, target: String) -> AutoDiffItem {
279-
AutoDiffItem { source, target, attrs: self }
281+
pub fn into_item(
282+
self,
283+
source: String,
284+
target: String,
285+
inputs: Vec<TypeTree>,
286+
output: TypeTree,
287+
) -> AutoDiffItem {
288+
AutoDiffItem { source, target, inputs, output, attrs: self }
280289
}
281290
}
282291

283292
impl fmt::Display for AutoDiffItem {
284293
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
285294
write!(f, "Differentiating {} -> {}", self.source, self.target)?;
286-
write!(f, " with attributes: {:?}", self.attrs)
295+
write!(f, " with attributes: {:?}", self.attrs)?;
296+
write!(f, " with inputs: {:?}", self.inputs)?;
297+
write!(f, " with output: {:?}", self.output)
287298
}
288299
}

compiler/rustc_ast/src/expand/typetree.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ pub enum Kind {
3131
Half,
3232
Float,
3333
Double,
34+
F128,
3435
Unknown,
3536
}
3637

compiler/rustc_codegen_gcc/src/builder.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1383,6 +1383,7 @@ impl<'a, 'gcc, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'gcc, 'tcx> {
13831383
_src_align: Align,
13841384
size: RValue<'gcc>,
13851385
flags: MemFlags,
1386+
_tt: Option<rustc_ast::expand::typetree::FncTree>, // Autodiff TypeTrees are LLVM-only, ignored in GCC backend
13861387
) {
13871388
assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported");
13881389
let size = self.intcast(size, self.type_size_t(), false);

compiler/rustc_codegen_gcc/src/intrinsic/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -770,6 +770,7 @@ impl<'gcc, 'tcx> ArgAbiExt<'gcc, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> {
770770
scratch_align,
771771
bx.const_usize(self.layout.size.bytes()),
772772
MemFlags::empty(),
773+
None,
773774
);
774775

775776
bx.lifetime_end(scratch, scratch_size);

compiler/rustc_codegen_llvm/src/abi.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@ impl<'ll, 'tcx> ArgAbiExt<'ll, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> {
246246
scratch_align,
247247
bx.const_usize(copy_bytes),
248248
MemFlags::empty(),
249+
None,
249250
);
250251
bx.lifetime_end(llscratch, scratch_size);
251252
}

compiler/rustc_codegen_llvm/src/back/lto.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,8 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
563563
config::AutoDiff::Enable => {}
564564
// We handle this below
565565
config::AutoDiff::NoPostopt => {}
566+
// Disables TypeTree generation
567+
config::AutoDiff::NoTT => {}
566568
}
567569
}
568570
// This helps with handling enums for now.

compiler/rustc_codegen_llvm/src/builder.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use std::borrow::{Borrow, Cow};
22
use std::ops::Deref;
33
use std::{iter, ptr};
44

5+
use rustc_ast::expand::typetree::FncTree;
56
pub(crate) mod autodiff;
67
pub(crate) mod gpu_offload;
78

@@ -1107,11 +1108,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
11071108
src_align: Align,
11081109
size: &'ll Value,
11091110
flags: MemFlags,
1111+
tt: Option<FncTree>,
11101112
) {
11111113
assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported");
11121114
let size = self.intcast(size, self.type_isize(), false);
11131115
let is_volatile = flags.contains(MemFlags::VOLATILE);
1114-
unsafe {
1116+
let memcpy = unsafe {
11151117
llvm::LLVMRustBuildMemCpy(
11161118
self.llbuilder,
11171119
dst,
@@ -1120,7 +1122,16 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
11201122
src_align.bytes() as c_uint,
11211123
size,
11221124
is_volatile,
1123-
);
1125+
)
1126+
};
1127+
1128+
// TypeTree metadata for memcpy is especially important: when Enzyme encounters
1129+
// a memcpy during autodiff, it needs to know the structure of the data being
1130+
// copied to properly track derivatives. For example, copying an array of floats
1131+
// vs. copying a struct with mixed types requires different derivative handling.
1132+
// The TypeTree tells Enzyme exactly what memory layout to expect.
1133+
if let Some(tt) = tt {
1134+
crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, memcpy, tt);
11241135
}
11251136
}
11261137

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use std::ptr;
22

33
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode};
4+
use rustc_ast::expand::typetree::FncTree;
45
use rustc_codegen_ssa::common::TypeKind;
56
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
67
use rustc_middle::ty::{Instance, PseudoCanonicalInput, TyCtxt, TypingEnv};
@@ -294,6 +295,7 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
294295
fn_args: &[&'ll Value],
295296
attrs: AutoDiffAttrs,
296297
dest: PlaceRef<'tcx, &'ll Value>,
298+
fnc_tree: FncTree,
297299
) {
298300
// We have to pick the name depending on whether we want forward or reverse mode autodiff.
299301
let mut ad_name: String = match attrs.mode {
@@ -370,6 +372,10 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
370372
fn_args,
371373
);
372374

375+
if !fnc_tree.args.is_empty() || !fnc_tree.ret.0.is_empty() {
376+
crate::typetree::add_tt(cx.llmod, cx.llcx, fn_to_diff, fnc_tree);
377+
}
378+
373379
let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None);
374380

375381
builder.store_to_place(call, dest.val);

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1212,6 +1212,9 @@ fn codegen_autodiff<'ll, 'tcx>(
12121212
&mut diff_attrs.input_activity,
12131213
);
12141214

1215+
let fnc_tree =
1216+
rustc_middle::ty::fnc_typetrees(tcx, fn_source.ty(tcx, TypingEnv::fully_monomorphized()));
1217+
12151218
// Build body
12161219
generate_enzyme_call(
12171220
bx,
@@ -1222,6 +1225,7 @@ fn codegen_autodiff<'ll, 'tcx>(
12221225
&val_arr,
12231226
diff_attrs.clone(),
12241227
result,
1228+
fnc_tree,
12251229
);
12261230
}
12271231

compiler/rustc_codegen_llvm/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ mod llvm_util;
6868
mod mono_item;
6969
mod type_;
7070
mod type_of;
71+
mod typetree;
7172
mod va_arg;
7273
mod value;
7374

0 commit comments

Comments
 (0)