From 23f1767c808d759ff683812581ca0e954d72c463 Mon Sep 17 00:00:00 2001 From: Kivooeo Date: Sun, 31 Aug 2025 17:03:46 +0000 Subject: [PATCH 01/22] allow taking address to union field --- .../rustc_mir_build/src/check_unsafety.rs | 15 ++++ .../miri/tests/pass/both_borrows/smallvec.rs | 2 +- tests/ui/union/union-unsafe.rs | 73 +++++++++++++++++++ tests/ui/union/union-unsafe.stderr | 54 +++++++++++--- 4 files changed, 132 insertions(+), 12 deletions(-) diff --git a/compiler/rustc_mir_build/src/check_unsafety.rs b/compiler/rustc_mir_build/src/check_unsafety.rs index b5e165c75170a..195d45c2c4c49 100644 --- a/compiler/rustc_mir_build/src/check_unsafety.rs +++ b/compiler/rustc_mir_build/src/check_unsafety.rs @@ -554,6 +554,21 @@ impl<'a, 'tcx> Visitor<'a, 'tcx> for UnsafetyVisitor<'a, 'tcx> { visit::walk_expr(self, &self.thir[arg]); return; } + + // Secondly, we allow raw borrows of union field accesses. Peel + // any of those off, and recurse normally on the LHS, which should + // reject any unsafe operations within. + let mut peeled = arg; + while let ExprKind::Scope { value: arg, .. } = self.thir[peeled].kind + && let ExprKind::Field { lhs, name: _, variant_index: _ } = self.thir[arg].kind + && let ty::Adt(def, _) = &self.thir[lhs].ty.kind() + && def.is_union() + { + peeled = lhs; + } + visit::walk_expr(self, &self.thir[peeled]); + // And return so we don't recurse directly onto the union field access(es). + return; } ExprKind::Deref { arg } => { if let ExprKind::StaticRef { def_id, .. } | ExprKind::ThreadLocalRef(def_id) = diff --git a/src/tools/miri/tests/pass/both_borrows/smallvec.rs b/src/tools/miri/tests/pass/both_borrows/smallvec.rs index f48815e37be34..fa5cfb03de2a9 100644 --- a/src/tools/miri/tests/pass/both_borrows/smallvec.rs +++ b/src/tools/miri/tests/pass/both_borrows/smallvec.rs @@ -25,7 +25,7 @@ impl RawSmallVec { } const fn as_mut_ptr_inline(&mut self) -> *mut T { - (unsafe { &raw mut self.inline }) as *mut T + &raw mut self.inline as *mut T } const unsafe fn as_mut_ptr_heap(&mut self) -> *mut T { diff --git a/tests/ui/union/union-unsafe.rs b/tests/ui/union/union-unsafe.rs index bd3946686be36..beb074f4e8ebc 100644 --- a/tests/ui/union/union-unsafe.rs +++ b/tests/ui/union/union-unsafe.rs @@ -1,5 +1,6 @@ use std::cell::RefCell; use std::mem::ManuallyDrop; +use std::ops::Deref; union U1 { a: u8, @@ -17,6 +18,10 @@ union U4 { a: T, } +union U5 { + a: usize, +} + union URef { p: &'static mut i32, } @@ -31,6 +36,20 @@ fn deref_union_field(mut u: URef) { *(u.p) = 13; //~ ERROR access to union field is unsafe } +union A { + a: usize, + b: &'static &'static B, +} + +union B { + c: usize, +} + +fn raw_deref_union_field(mut u: URef) { + // This is unsafe because we first dereference u.p (reading uninitialized memory) + let _p = &raw const *(u.p); //~ ERROR access to union field is unsafe +} + fn assign_noncopy_union_field(mut u: URefCell) { u.a = (ManuallyDrop::new(RefCell::new(0)), 1); // OK (assignment does not drop) u.a.0 = ManuallyDrop::new(RefCell::new(0)); // OK (assignment does not drop) @@ -57,6 +76,20 @@ fn main() { let a = u1.a; //~ ERROR access to union field is unsafe u1.a = 11; // OK + let mut u2 = U1 { a: 10 }; + let a = &raw mut u2.a; // OK + unsafe { *a = 3 }; + + let mut u3 = U1 { a: 10 }; + let a = std::ptr::addr_of_mut!(u3.a); // OK + unsafe { *a = 14 }; + + let u4 = U5 { a: 2 }; + let vec = vec![1, 2, 3]; + // This is unsafe because we read u4.a (potentially uninitialized memory) + // to use as an array index + let _a = &raw const vec[u4.a]; //~ ERROR access to union field is unsafe + let U1 { a } = u1; //~ ERROR access to union field is unsafe if let U1 { a: 12 } = u1 {} //~ ERROR access to union field is unsafe if let Some(U1 { a: 13 }) = Some(u1) {} //~ ERROR access to union field is unsafe @@ -73,4 +106,44 @@ fn main() { let mut u3 = U3 { a: ManuallyDrop::new(String::from("old")) }; // OK u3.a = ManuallyDrop::new(String::from("new")); // OK (assignment does not drop) *u3.a = String::from("new"); //~ ERROR access to union field is unsafe + + let mut unions = [U1 { a: 1 }, U1 { a: 2 }]; + + // Array indexing + union field raw borrow - should be OK + let ptr = &raw mut unions[0].a; // OK + let ptr2 = &raw const unions[1].a; // OK + + let a = A { a: 0 }; + let _p = &raw const (**a.b).c; //~ ERROR access to union field is unsafe + + arbitrary_deref(); +} + +// regression test for https://github.com/rust-lang/rust/pull/141469#discussion_r2312546218 +fn arbitrary_deref() { + use std::ops::Deref; + + union A { + a: usize, + b: B, + } + + #[derive(Copy, Clone)] + struct B(&'static str); + + impl Deref for B { + type Target = C; + + fn deref(&self) -> &C { + println!("{:?}", self.0); + &C { c: 0 } + } + } + + union C { + c: usize, + } + + let a = A { a: 0 }; + let _p = &raw const (*a.b).c; //~ ERROR access to union field is unsafe } diff --git a/tests/ui/union/union-unsafe.stderr b/tests/ui/union/union-unsafe.stderr index 82b3f897167c7..01f4d95eb649d 100644 --- a/tests/ui/union/union-unsafe.stderr +++ b/tests/ui/union/union-unsafe.stderr @@ -1,5 +1,5 @@ error[E0133]: access to union field is unsafe and requires unsafe function or block - --> $DIR/union-unsafe.rs:31:6 + --> $DIR/union-unsafe.rs:36:6 | LL | *(u.p) = 13; | ^^^^^ access to union field @@ -7,7 +7,15 @@ LL | *(u.p) = 13; = note: the field may not be properly initialized: using uninitialized data will cause undefined behavior error[E0133]: access to union field is unsafe and requires unsafe function or block - --> $DIR/union-unsafe.rs:43:6 + --> $DIR/union-unsafe.rs:50:26 + | +LL | let _p = &raw const *(u.p); + | ^^^^^ access to union field + | + = note: the field may not be properly initialized: using uninitialized data will cause undefined behavior + +error[E0133]: access to union field is unsafe and requires unsafe function or block + --> $DIR/union-unsafe.rs:62:6 | LL | *u3.a = T::default(); | ^^^^ access to union field @@ -15,7 +23,7 @@ LL | *u3.a = T::default(); = note: the field may not be properly initialized: using uninitialized data will cause undefined behavior error[E0133]: access to union field is unsafe and requires unsafe function or block - --> $DIR/union-unsafe.rs:49:6 + --> $DIR/union-unsafe.rs:68:6 | LL | *u3.a = T::default(); | ^^^^ access to union field @@ -23,7 +31,7 @@ LL | *u3.a = T::default(); = note: the field may not be properly initialized: using uninitialized data will cause undefined behavior error[E0133]: access to union field is unsafe and requires unsafe function or block - --> $DIR/union-unsafe.rs:57:13 + --> $DIR/union-unsafe.rs:76:13 | LL | let a = u1.a; | ^^^^ access to union field @@ -31,7 +39,15 @@ LL | let a = u1.a; = note: the field may not be properly initialized: using uninitialized data will cause undefined behavior error[E0133]: access to union field is unsafe and requires unsafe function or block - --> $DIR/union-unsafe.rs:60:14 + --> $DIR/union-unsafe.rs:91:29 + | +LL | let _a = &raw const vec[u4.a]; + | ^^^^ access to union field + | + = note: the field may not be properly initialized: using uninitialized data will cause undefined behavior + +error[E0133]: access to union field is unsafe and requires unsafe function or block + --> $DIR/union-unsafe.rs:93:14 | LL | let U1 { a } = u1; | ^ access to union field @@ -39,7 +55,7 @@ LL | let U1 { a } = u1; = note: the field may not be properly initialized: using uninitialized data will cause undefined behavior error[E0133]: access to union field is unsafe and requires unsafe function or block - --> $DIR/union-unsafe.rs:61:20 + --> $DIR/union-unsafe.rs:94:20 | LL | if let U1 { a: 12 } = u1 {} | ^^ access to union field @@ -47,7 +63,7 @@ LL | if let U1 { a: 12 } = u1 {} = note: the field may not be properly initialized: using uninitialized data will cause undefined behavior error[E0133]: access to union field is unsafe and requires unsafe function or block - --> $DIR/union-unsafe.rs:62:25 + --> $DIR/union-unsafe.rs:95:25 | LL | if let Some(U1 { a: 13 }) = Some(u1) {} | ^^ access to union field @@ -55,7 +71,7 @@ LL | if let Some(U1 { a: 13 }) = Some(u1) {} = note: the field may not be properly initialized: using uninitialized data will cause undefined behavior error[E0133]: access to union field is unsafe and requires unsafe function or block - --> $DIR/union-unsafe.rs:67:6 + --> $DIR/union-unsafe.rs:100:6 | LL | *u2.a = String::from("new"); | ^^^^ access to union field @@ -63,7 +79,7 @@ LL | *u2.a = String::from("new"); = note: the field may not be properly initialized: using uninitialized data will cause undefined behavior error[E0133]: access to union field is unsafe and requires unsafe function or block - --> $DIR/union-unsafe.rs:71:6 + --> $DIR/union-unsafe.rs:104:6 | LL | *u3.a = 1; | ^^^^ access to union field @@ -71,13 +87,29 @@ LL | *u3.a = 1; = note: the field may not be properly initialized: using uninitialized data will cause undefined behavior error[E0133]: access to union field is unsafe and requires unsafe function or block - --> $DIR/union-unsafe.rs:75:6 + --> $DIR/union-unsafe.rs:108:6 | LL | *u3.a = String::from("new"); | ^^^^ access to union field | = note: the field may not be properly initialized: using uninitialized data will cause undefined behavior -error: aborting due to 10 previous errors +error[E0133]: access to union field is unsafe and requires unsafe function or block + --> $DIR/union-unsafe.rs:117:28 + | +LL | let _p = &raw const (**a.b).c; + | ^^^ access to union field + | + = note: the field may not be properly initialized: using uninitialized data will cause undefined behavior + +error[E0133]: access to union field is unsafe and requires unsafe function or block + --> $DIR/union-unsafe.rs:148:27 + | +LL | let _p = &raw const (*a.b).c; + | ^^^ access to union field + | + = note: the field may not be properly initialized: using uninitialized data will cause undefined behavior + +error: aborting due to 14 previous errors For more information about this error, try `rustc --explain E0133`. From e1258e79d6cb709b26ded97d32de6c55f355e2aa Mon Sep 17 00:00:00 2001 From: Karan Janthe Date: Sat, 23 Aug 2025 20:17:32 +0000 Subject: [PATCH 02/22] autodiff: Add basic TypeTree with NoTT flag Signed-off-by: Karan Janthe --- .../rustc_ast/src/expand/autodiff_attrs.rs | 17 +++- compiler/rustc_codegen_llvm/src/back/lto.rs | 2 + compiler/rustc_interface/src/tests.rs | 1 + compiler/rustc_middle/src/error.rs | 1 - compiler/rustc_middle/src/ty/mod.rs | 80 +++++++++++++++++++ compiler/rustc_session/src/config.rs | 2 + compiler/rustc_session/src/options.rs | 3 +- tests/codegen-llvm/autodiff/typetree.rs | 33 ++++++++ .../autodiff/type-trees/nott-flag/nott.check | 3 + .../autodiff/type-trees/nott-flag/rmake.rs | 38 +++++++++ .../autodiff/type-trees/nott-flag/test.rs | 15 ++++ .../type-trees/nott-flag/with_tt.check | 3 + tests/ui/autodiff/flag_nott.rs | 19 +++++ 13 files changed, 212 insertions(+), 5 deletions(-) create mode 100644 tests/codegen-llvm/autodiff/typetree.rs create mode 100644 tests/run-make/autodiff/type-trees/nott-flag/nott.check create mode 100644 tests/run-make/autodiff/type-trees/nott-flag/rmake.rs create mode 100644 tests/run-make/autodiff/type-trees/nott-flag/test.rs create mode 100644 tests/run-make/autodiff/type-trees/nott-flag/with_tt.check create mode 100644 tests/ui/autodiff/flag_nott.rs diff --git a/compiler/rustc_ast/src/expand/autodiff_attrs.rs b/compiler/rustc_ast/src/expand/autodiff_attrs.rs index 33451f9974835..90f15753e99c9 100644 --- a/compiler/rustc_ast/src/expand/autodiff_attrs.rs +++ b/compiler/rustc_ast/src/expand/autodiff_attrs.rs @@ -6,6 +6,7 @@ use std::fmt::{self, Display, Formatter}; use std::str::FromStr; +use crate::expand::typetree::TypeTree; use crate::expand::{Decodable, Encodable, HashStable_Generic}; use crate::{Ty, TyKind}; @@ -84,6 +85,8 @@ pub struct AutoDiffItem { /// The name of the function being generated pub target: String, pub attrs: AutoDiffAttrs, + pub inputs: Vec, + pub output: TypeTree, } #[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] @@ -275,14 +278,22 @@ impl AutoDiffAttrs { !matches!(self.mode, DiffMode::Error | DiffMode::Source) } - pub fn into_item(self, source: String, target: String) -> AutoDiffItem { - AutoDiffItem { source, target, attrs: self } + pub fn into_item( + self, + source: String, + target: String, + inputs: Vec, + output: TypeTree, + ) -> AutoDiffItem { + AutoDiffItem { source, target, inputs, output, attrs: self } } } impl fmt::Display for AutoDiffItem { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "Differentiating {} -> {}", self.source, self.target)?; - write!(f, " with attributes: {:?}", self.attrs) + write!(f, " with attributes: {:?}", self.attrs)?; + write!(f, " with inputs: {:?}", self.inputs)?; + write!(f, " with output: {:?}", self.output) } } diff --git a/compiler/rustc_codegen_llvm/src/back/lto.rs b/compiler/rustc_codegen_llvm/src/back/lto.rs index 78107d95e5a70..5ac3a87c158e6 100644 --- a/compiler/rustc_codegen_llvm/src/back/lto.rs +++ b/compiler/rustc_codegen_llvm/src/back/lto.rs @@ -563,6 +563,8 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff]) { config::AutoDiff::Enable => {} // We handle this below config::AutoDiff::NoPostopt => {} + // Disables TypeTree generation + config::AutoDiff::NoTT => {} } } // This helps with handling enums for now. diff --git a/compiler/rustc_interface/src/tests.rs b/compiler/rustc_interface/src/tests.rs index 7730bddc0f12d..837acdadd579d 100644 --- a/compiler/rustc_interface/src/tests.rs +++ b/compiler/rustc_interface/src/tests.rs @@ -766,6 +766,7 @@ fn test_unstable_options_tracking_hash() { tracked!(always_encode_mir, true); tracked!(assume_incomplete_release, true); tracked!(autodiff, vec![AutoDiff::Enable]); + tracked!(autodiff, vec![AutoDiff::Enable, AutoDiff::NoTT]); tracked!(binary_dep_depinfo, true); tracked!(box_noalias, false); tracked!( diff --git a/compiler/rustc_middle/src/error.rs b/compiler/rustc_middle/src/error.rs index e3e1393b5f9ae..ef014af7beb14 100644 --- a/compiler/rustc_middle/src/error.rs +++ b/compiler/rustc_middle/src/error.rs @@ -37,7 +37,6 @@ pub(crate) struct OpaqueHiddenTypeMismatch<'tcx> { pub sub: TypeMismatchReason, } -// FIXME(autodiff): I should get used somewhere #[derive(Diagnostic)] #[diag(middle_unsupported_union)] pub struct UnsupportedUnion { diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index 0ffef393a33bd..df42c40031752 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -25,6 +25,7 @@ pub use generic_args::{GenericArgKind, TermKind, *}; pub use generics::*; pub use intrinsic::IntrinsicDef; use rustc_abi::{Align, FieldIdx, Integer, IntegerType, ReprFlags, ReprOptions, VariantIdx}; +use rustc_ast::expand::typetree::{FncTree, Kind, Type, TypeTree}; use rustc_ast::node_id::NodeMap; pub use rustc_ast_ir::{Movability, Mutability, try_visit}; use rustc_data_structures::fx::{FxHashMap, FxHashSet, FxIndexMap, FxIndexSet}; @@ -2216,3 +2217,82 @@ pub struct DestructuredConst<'tcx> { pub variant: Option, pub fields: &'tcx [ty::Const<'tcx>], } + +/// Generate TypeTree information for autodiff. +/// This function creates TypeTree metadata that describes the memory layout +/// of function parameters and return types for Enzyme autodiff. +pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>) -> FncTree { + // Check if TypeTrees are disabled via NoTT flag + if tcx.sess.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::NoTT) { + return FncTree { args: vec![], ret: TypeTree::new() }; + } + + // Check if this is actually a function type + if !fn_ty.is_fn() { + return FncTree { args: vec![], ret: TypeTree::new() }; + } + + // Get the function signature + let fn_sig = fn_ty.fn_sig(tcx); + let sig = tcx.instantiate_bound_regions_with_erased(fn_sig); + + // Create TypeTrees for each input parameter + let mut args = vec![]; + for ty in sig.inputs().iter() { + let type_tree = typetree_from_ty(tcx, *ty); + args.push(type_tree); + } + + // Create TypeTree for return type + let ret = typetree_from_ty(tcx, sig.output()); + + FncTree { args, ret } +} + +/// Generate TypeTree for a specific type. +/// This function analyzes a Rust type and creates appropriate TypeTree metadata. +fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { + // Handle basic scalar types + if ty.is_scalar() { + let (kind, size) = if ty.is_integral() || ty.is_char() || ty.is_bool() { + (Kind::Integer, ty.primitive_size(tcx).bytes_usize()) + } else if ty.is_floating_point() { + match ty { + x if x == tcx.types.f32 => (Kind::Float, 4), + x if x == tcx.types.f64 => (Kind::Double, 8), + _ => return TypeTree::new(), // Unknown float type + } + } else { + // TODO(KMJ-007): Handle other scalar types if needed + return TypeTree::new(); + }; + + return TypeTree(vec![Type { + offset: -1, + size, + kind, + child: TypeTree::new() + }]); + } + + // Handle references and pointers + if ty.is_ref() || ty.is_raw_ptr() || ty.is_box() { + let inner_ty = if let Some(inner) = ty.builtin_deref(true) { + inner + } else { + // TODO(KMJ-007): Handle complex pointer types + return TypeTree::new(); + }; + + let child = typetree_from_ty(tcx, inner_ty); + return TypeTree(vec![Type { + offset: -1, + size: 8, // TODO(KMJ-007): Get actual pointer size from target + kind: Kind::Pointer, + child, + }]); + } + + // TODO(KMJ-007): Handle arrays, slices, structs, and other complex types + TypeTree::new() +} diff --git a/compiler/rustc_session/src/config.rs b/compiler/rustc_session/src/config.rs index 297df7c2c9765..3a035fcf9e896 100644 --- a/compiler/rustc_session/src/config.rs +++ b/compiler/rustc_session/src/config.rs @@ -257,6 +257,8 @@ pub enum AutoDiff { LooseTypes, /// Runs Enzyme's aggressive inlining Inline, + /// Disable Type Tree + NoTT, } /// Settings for `-Z instrument-xray` flag. diff --git a/compiler/rustc_session/src/options.rs b/compiler/rustc_session/src/options.rs index 69facde693689..ee53dd39bc887 100644 --- a/compiler/rustc_session/src/options.rs +++ b/compiler/rustc_session/src/options.rs @@ -792,7 +792,7 @@ mod desc { pub(crate) const parse_list: &str = "a space-separated list of strings"; pub(crate) const parse_list_with_polarity: &str = "a comma-separated list of strings, with elements beginning with + or -"; - pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintTAFn`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `PrintPasses`, `NoPostopt`, `LooseTypes`, `Inline`"; + pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintTAFn`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `PrintPasses`, `NoPostopt`, `LooseTypes`, `Inline`, `NoTT`"; pub(crate) const parse_offload: &str = "a comma separated list of settings: `Enable`"; pub(crate) const parse_comma_list: &str = "a comma-separated list of strings"; pub(crate) const parse_opt_comma_list: &str = parse_comma_list; @@ -1479,6 +1479,7 @@ pub mod parse { "PrintPasses" => AutoDiff::PrintPasses, "LooseTypes" => AutoDiff::LooseTypes, "Inline" => AutoDiff::Inline, + "NoTT" => AutoDiff::NoTT, _ => { // FIXME(ZuseZ4): print an error saying which value is not recognized return false; diff --git a/tests/codegen-llvm/autodiff/typetree.rs b/tests/codegen-llvm/autodiff/typetree.rs new file mode 100644 index 0000000000000..3ad38d581b9a1 --- /dev/null +++ b/tests/codegen-llvm/autodiff/typetree.rs @@ -0,0 +1,33 @@ +//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat +//@ no-prefer-dynamic +//@ needs-enzyme + +// Test that basic autodiff still works with our TypeTree infrastructure +#![feature(autodiff)] + +use std::autodiff::autodiff_reverse; + +#[autodiff_reverse(d_simple, Duplicated, Active)] +#[no_mangle] +#[inline(never)] +fn simple(x: &f64) -> f64 { + 2.0 * x +} + +// CHECK-LABEL: @simple +// CHECK: fmul double + +// The derivative function should be generated normally +// CHECK-LABEL: diffesimple +// CHECK: fadd fast double + +fn main() { + let x = std::hint::black_box(3.0); + let output = simple(&x); + assert_eq!(6.0, output); + + let mut df_dx = 0.0; + let output_ = d_simple(&x, &mut df_dx, 1.0); + assert_eq!(output, output_); + assert_eq!(2.0, df_dx); +} \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/nott-flag/nott.check b/tests/run-make/autodiff/type-trees/nott-flag/nott.check new file mode 100644 index 0000000000000..56ef2f0bdf380 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/nott-flag/nott.check @@ -0,0 +1,3 @@ +// TODO(KMJ-007): Update this test when TypeTree integration is complete +// CHECK: square - {[-1]:Float@double} |{[-1]:Pointer}:{} +// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@double} \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/nott-flag/rmake.rs b/tests/run-make/autodiff/type-trees/nott-flag/rmake.rs new file mode 100644 index 0000000000000..536164192dc06 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/nott-flag/rmake.rs @@ -0,0 +1,38 @@ +//@ needs-enzyme +//@ ignore-cross-compile + +use run_make_support::{llvm_filecheck, rfs, rustc}; + +fn main() { + // Test with NoTT flag - should not generate TypeTree metadata + let output_nott = rustc() + .input("test.rs") + .arg("-Zautodiff=Enable,NoTT,PrintTAFn=square") + .arg("-Zautodiff=NoPostopt") + .opt_level("3") + .arg("-Clto=fat") + .arg("-g") + .run(); + + // Write output for NoTT case + rfs::write("nott.stdout", output_nott.stdout_utf8()); + + // Test without NoTT flag - should generate TypeTree metadata + let output_with_tt = rustc() + .input("test.rs") + .arg("-Zautodiff=Enable,PrintTAFn=square") + .arg("-Zautodiff=NoPostopt") + .opt_level("3") + .arg("-Clto=fat") + .arg("-g") + .run(); + + // Write output for TypeTree case + rfs::write("with_tt.stdout", output_with_tt.stdout_utf8()); + + // Verify NoTT output has minimal TypeTree info + llvm_filecheck().patterns("nott.check").stdin_buf(rfs::read("nott.stdout")).run(); + + // Verify normal output will have TypeTree info (once implemented) + llvm_filecheck().patterns("with_tt.check").stdin_buf(rfs::read("with_tt.stdout")).run(); +} \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/nott-flag/test.rs b/tests/run-make/autodiff/type-trees/nott-flag/test.rs new file mode 100644 index 0000000000000..5c634eea035d9 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/nott-flag/test.rs @@ -0,0 +1,15 @@ +#![feature(autodiff)] + +use std::autodiff::autodiff_reverse; + +#[autodiff_reverse(d_square, Duplicated, Active)] +#[no_mangle] +fn square(x: &f64) -> f64 { + x * x +} + +fn main() { + let x = 2.0; + let mut dx = 0.0; + let _result = d_square(&x, &mut dx, 1.0); +} \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/nott-flag/with_tt.check b/tests/run-make/autodiff/type-trees/nott-flag/with_tt.check new file mode 100644 index 0000000000000..56ef2f0bdf380 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/nott-flag/with_tt.check @@ -0,0 +1,3 @@ +// TODO(KMJ-007): Update this test when TypeTree integration is complete +// CHECK: square - {[-1]:Float@double} |{[-1]:Pointer}:{} +// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@double} \ No newline at end of file diff --git a/tests/ui/autodiff/flag_nott.rs b/tests/ui/autodiff/flag_nott.rs new file mode 100644 index 0000000000000..7a97d892cd881 --- /dev/null +++ b/tests/ui/autodiff/flag_nott.rs @@ -0,0 +1,19 @@ +//@ compile-flags: -Zautodiff=Enable,NoTT +//@ needs-enzyme +//@ check-pass + +#![feature(autodiff)] + +use std::autodiff::autodiff_reverse; + +// Test that NoTT flag is accepted and doesn't cause compilation errors +#[autodiff_reverse(d_square, Duplicated, Active)] +fn square(x: &f64) -> f64 { + x * x +} + +fn main() { + let x = 2.0; + let mut dx = 0.0; + let result = d_square(&x, &mut dx, 1.0); +} \ No newline at end of file From 375e14ef491ac7bfa701e269b2815625abf2fca6 Mon Sep 17 00:00:00 2001 From: Karan Janthe Date: Sat, 23 Aug 2025 21:56:56 +0000 Subject: [PATCH 03/22] Add TypeTree metadata attachment for autodiff - Add F128 support to TypeTree Kind enum - Implement TypeTree FFI bindings and conversion functions - Add typetree.rs module for metadata attachment to LLVM functions - Integrate TypeTree generation with autodiff intrinsic pipeline - Support scalar types: f32, f64, integers, f16, f128 - Attach enzyme_type attributes as LLVM string metadata for Enzyme Signed-off-by: Karan Janthe --- compiler/rustc_ast/src/expand/typetree.rs | 1 + .../src/builder/autodiff.rs | 6 + compiler/rustc_codegen_llvm/src/intrinsic.rs | 4 + compiler/rustc_codegen_llvm/src/lib.rs | 1 + .../rustc_codegen_llvm/src/llvm/enzyme_ffi.rs | 182 +++++++++++++++++- compiler/rustc_codegen_llvm/src/typetree.rs | 144 ++++++++++++++ compiler/rustc_middle/src/ty/mod.rs | 19 +- 7 files changed, 343 insertions(+), 14 deletions(-) create mode 100644 compiler/rustc_codegen_llvm/src/typetree.rs diff --git a/compiler/rustc_ast/src/expand/typetree.rs b/compiler/rustc_ast/src/expand/typetree.rs index 9a2dd2e85e0d6..e7b4f3aff413a 100644 --- a/compiler/rustc_ast/src/expand/typetree.rs +++ b/compiler/rustc_ast/src/expand/typetree.rs @@ -31,6 +31,7 @@ pub enum Kind { Half, Float, Double, + F128, Unknown, } diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs index b66e3dfdeec29..c3485f563916f 100644 --- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs +++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs @@ -1,6 +1,7 @@ use std::ptr; use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode}; +use rustc_ast::expand::typetree::FncTree; use rustc_codegen_ssa::common::TypeKind; use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods}; use rustc_middle::ty::{Instance, PseudoCanonicalInput, TyCtxt, TypingEnv}; @@ -294,6 +295,7 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>( fn_args: &[&'ll Value], attrs: AutoDiffAttrs, dest: PlaceRef<'tcx, &'ll Value>, + fnc_tree: FncTree, ) { // We have to pick the name depending on whether we want forward or reverse mode autodiff. let mut ad_name: String = match attrs.mode { @@ -370,6 +372,10 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>( fn_args, ); + if !fnc_tree.args.is_empty() || !fnc_tree.ret.0.is_empty() { + crate::typetree::add_tt(cx.llmod, cx.llcx, fn_to_diff, fnc_tree); + } + let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None); builder.store_to_place(call, dest.val); diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index e7f4a3570488e..3254b31865197 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -1213,6 +1213,9 @@ fn codegen_autodiff<'ll, 'tcx>( &mut diff_attrs.input_activity, ); + let fnc_tree = + rustc_middle::ty::fnc_typetrees(tcx, fn_source.ty(tcx, TypingEnv::fully_monomorphized())); + // Build body generate_enzyme_call( bx, @@ -1223,6 +1226,7 @@ fn codegen_autodiff<'ll, 'tcx>( &val_arr, diff_attrs.clone(), result, + fnc_tree, ); } diff --git a/compiler/rustc_codegen_llvm/src/lib.rs b/compiler/rustc_codegen_llvm/src/lib.rs index 6fb23d0984335..a5ac789285d47 100644 --- a/compiler/rustc_codegen_llvm/src/lib.rs +++ b/compiler/rustc_codegen_llvm/src/lib.rs @@ -67,6 +67,7 @@ mod llvm_util; mod mono_item; mod type_; mod type_of; +mod typetree; mod va_arg; mod value; diff --git a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs index 695435eb6daa7..12b6ffa3c0b5f 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs @@ -3,9 +3,35 @@ use libc::{c_char, c_uint}; use super::MetadataKindId; -use super::ffi::{AttributeKind, BasicBlock, Metadata, Module, Type, Value}; +use super::ffi::{AttributeKind, BasicBlock, Context, Metadata, Module, Type, Value}; use crate::llvm::{Bool, Builder}; +// TypeTree types +pub(crate) type CTypeTreeRef = *mut EnzymeTypeTree; + +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub(crate) struct EnzymeTypeTree { + _unused: [u8; 0], +} + +#[repr(u32)] +#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] +#[allow(non_camel_case_types)] +pub(crate) enum CConcreteType { + DT_Anything = 0, + DT_Integer = 1, + DT_Pointer = 2, + DT_Half = 3, + DT_Float = 4, + DT_Double = 5, + DT_Unknown = 6, +} + +pub(crate) struct TypeTree { + pub(crate) inner: CTypeTreeRef, +} + #[link(name = "llvm-wrapper", kind = "static")] unsafe extern "C" { // Enzyme @@ -68,10 +94,33 @@ pub(crate) mod Enzyme_AD { use libc::c_void; + use super::{CConcreteType, CTypeTreeRef, Context}; + unsafe extern "C" { pub(crate) fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8); pub(crate) fn EnzymeSetCLString(arg1: *mut ::std::os::raw::c_void, arg2: *const c_char); } + + // TypeTree functions + unsafe extern "C" { + pub(crate) fn EnzymeNewTypeTree() -> CTypeTreeRef; + pub(crate) fn EnzymeNewTypeTreeCT(arg1: CConcreteType, ctx: &Context) -> CTypeTreeRef; + pub(crate) fn EnzymeNewTypeTreeTR(arg1: CTypeTreeRef) -> CTypeTreeRef; + pub(crate) fn EnzymeFreeTypeTree(CTT: CTypeTreeRef); + pub(crate) fn EnzymeMergeTypeTree(arg1: CTypeTreeRef, arg2: CTypeTreeRef) -> bool; + pub(crate) fn EnzymeTypeTreeOnlyEq(arg1: CTypeTreeRef, pos: i64); + pub(crate) fn EnzymeTypeTreeData0Eq(arg1: CTypeTreeRef); + pub(crate) fn EnzymeTypeTreeShiftIndiciesEq( + arg1: CTypeTreeRef, + data_layout: *const c_char, + offset: i64, + max_size: i64, + add_offset: u64, + ); + pub(crate) fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char; + pub(crate) fn EnzymeTypeTreeToStringFree(arg1: *const c_char); + } + unsafe extern "C" { static mut EnzymePrintPerf: c_void; static mut EnzymePrintActivity: c_void; @@ -141,6 +190,57 @@ pub(crate) use self::Fallback_AD::*; pub(crate) mod Fallback_AD { #![allow(unused_variables)] + use libc::c_char; + + use super::{CConcreteType, CTypeTreeRef, Context}; + + // TypeTree function fallbacks + pub(crate) unsafe fn EnzymeNewTypeTree() -> CTypeTreeRef { + unimplemented!() + } + + pub(crate) unsafe fn EnzymeNewTypeTreeCT(arg1: CConcreteType, ctx: &Context) -> CTypeTreeRef { + unimplemented!() + } + + pub(crate) unsafe fn EnzymeNewTypeTreeTR(arg1: CTypeTreeRef) -> CTypeTreeRef { + unimplemented!() + } + + pub(crate) unsafe fn EnzymeFreeTypeTree(CTT: CTypeTreeRef) { + unimplemented!() + } + + pub(crate) unsafe fn EnzymeMergeTypeTree(arg1: CTypeTreeRef, arg2: CTypeTreeRef) -> bool { + unimplemented!() + } + + pub(crate) unsafe fn EnzymeTypeTreeOnlyEq(arg1: CTypeTreeRef, pos: i64) { + unimplemented!() + } + + pub(crate) unsafe fn EnzymeTypeTreeData0Eq(arg1: CTypeTreeRef) { + unimplemented!() + } + + pub(crate) unsafe fn EnzymeTypeTreeShiftIndiciesEq( + arg1: CTypeTreeRef, + data_layout: *const c_char, + offset: i64, + max_size: i64, + add_offset: u64, + ) { + unimplemented!() + } + + pub(crate) unsafe fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char { + unimplemented!() + } + + pub(crate) unsafe fn EnzymeTypeTreeToStringFree(arg1: *const c_char) { + unimplemented!() + } + pub(crate) fn set_inline(val: bool) { unimplemented!() } @@ -169,3 +269,83 @@ pub(crate) mod Fallback_AD { unimplemented!() } } + +impl TypeTree { + pub(crate) fn new() -> TypeTree { + let inner = unsafe { EnzymeNewTypeTree() }; + TypeTree { inner } + } + + pub(crate) fn from_type(t: CConcreteType, ctx: &Context) -> TypeTree { + let inner = unsafe { EnzymeNewTypeTreeCT(t, ctx) }; + TypeTree { inner } + } + + pub(crate) fn merge(self, other: Self) -> Self { + unsafe { + EnzymeMergeTypeTree(self.inner, other.inner); + } + drop(other); + self + } + + #[must_use] + pub(crate) fn shift( + self, + layout: &str, + offset: isize, + max_size: isize, + add_offset: usize, + ) -> Self { + let layout = std::ffi::CString::new(layout).unwrap(); + + unsafe { + EnzymeTypeTreeShiftIndiciesEq( + self.inner, + layout.as_ptr(), + offset as i64, + max_size as i64, + add_offset as u64, + ); + } + + self + } +} + +impl Clone for TypeTree { + fn clone(&self) -> Self { + let inner = unsafe { EnzymeNewTypeTreeTR(self.inner) }; + TypeTree { inner } + } +} + +impl std::fmt::Display for TypeTree { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let ptr = unsafe { EnzymeTypeTreeToString(self.inner) }; + let cstr = unsafe { std::ffi::CStr::from_ptr(ptr) }; + match cstr.to_str() { + Ok(x) => write!(f, "{}", x)?, + Err(err) => write!(f, "could not parse: {}", err)?, + } + + // delete C string pointer + unsafe { + EnzymeTypeTreeToStringFree(ptr); + } + + Ok(()) + } +} + +impl std::fmt::Debug for TypeTree { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + ::fmt(self, f) + } +} + +impl Drop for TypeTree { + fn drop(&mut self) { + unsafe { EnzymeFreeTypeTree(self.inner) } + } +} diff --git a/compiler/rustc_codegen_llvm/src/typetree.rs b/compiler/rustc_codegen_llvm/src/typetree.rs new file mode 100644 index 0000000000000..434316464e626 --- /dev/null +++ b/compiler/rustc_codegen_llvm/src/typetree.rs @@ -0,0 +1,144 @@ +use std::ffi::{CString, c_char, c_uint}; + +use rustc_ast::expand::typetree::{FncTree, TypeTree as RustTypeTree}; + +use crate::attributes; +use crate::llvm::{self, Value}; + +/// Converts a Rust TypeTree to Enzyme's internal TypeTree format +/// +/// This function takes a Rust-side TypeTree (from rustc_ast::expand::typetree) +/// and converts it to Enzyme's internal C++ TypeTree representation that +/// Enzyme can understand during differentiation analysis. +#[cfg(llvm_enzyme)] +fn to_enzyme_typetree( + rust_typetree: RustTypeTree, + data_layout: &str, + llcx: &llvm::Context, +) -> llvm::TypeTree { + // Start with an empty TypeTree + let mut enzyme_tt = llvm::TypeTree::new(); + + // Convert each Type in the Rust TypeTree to Enzyme format + for rust_type in rust_typetree.0 { + let concrete_type = match rust_type.kind { + rustc_ast::expand::typetree::Kind::Anything => llvm::CConcreteType::DT_Anything, + rustc_ast::expand::typetree::Kind::Integer => llvm::CConcreteType::DT_Integer, + rustc_ast::expand::typetree::Kind::Pointer => llvm::CConcreteType::DT_Pointer, + rustc_ast::expand::typetree::Kind::Half => llvm::CConcreteType::DT_Half, + rustc_ast::expand::typetree::Kind::Float => llvm::CConcreteType::DT_Float, + rustc_ast::expand::typetree::Kind::Double => llvm::CConcreteType::DT_Double, + rustc_ast::expand::typetree::Kind::F128 => llvm::CConcreteType::DT_Unknown, + rustc_ast::expand::typetree::Kind::Unknown => llvm::CConcreteType::DT_Unknown, + }; + + // Create a TypeTree for this specific type + let type_tt = llvm::TypeTree::from_type(concrete_type, llcx); + + // Apply offset if specified + let type_tt = if rust_type.offset == -1 { + type_tt // -1 means everywhere/no specific offset + } else { + // Apply specific offset positioning + type_tt.shift(data_layout, rust_type.offset, rust_type.size as isize, 0) + }; + + // Merge this type into the main TypeTree + enzyme_tt = enzyme_tt.merge(type_tt); + } + + enzyme_tt +} + +#[cfg(not(llvm_enzyme))] +fn to_enzyme_typetree( + _rust_typetree: RustTypeTree, + _data_layout: &str, + _llcx: &llvm::Context, +) -> ! { + unimplemented!("TypeTree conversion not available without llvm_enzyme support") +} + +// Attaches TypeTree information to LLVM function as enzyme_type attributes. +#[cfg(llvm_enzyme)] +pub(crate) fn add_tt<'ll>( + llmod: &'ll llvm::Module, + llcx: &'ll llvm::Context, + fn_def: &'ll Value, + tt: FncTree, +) { + let inputs = tt.args; + let ret_tt: RustTypeTree = tt.ret; + + // Get LLVM data layout string for TypeTree conversion + let llvm_data_layout: *const c_char = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) }; + let llvm_data_layout = + std::str::from_utf8(unsafe { std::ffi::CStr::from_ptr(llvm_data_layout) }.to_bytes()) + .expect("got a non-UTF8 data-layout from LLVM"); + + // Attribute name that Enzyme recognizes for TypeTree information + let attr_name = "enzyme_type"; + let c_attr_name = CString::new(attr_name).unwrap(); + + // Attach TypeTree attributes to each input parameter + // Enzyme uses these to understand parameter memory layouts during differentiation + for (i, input) in inputs.iter().enumerate() { + unsafe { + // Convert Rust TypeTree to Enzyme's internal format + let enzyme_tt = to_enzyme_typetree(input.clone(), llvm_data_layout, llcx); + + // Serialize TypeTree to string format that Enzyme can parse + let c_str = llvm::EnzymeTypeTreeToString(enzyme_tt.inner); + let c_str = std::ffi::CStr::from_ptr(c_str); + + // Create LLVM string attribute with TypeTree information + let attr = llvm::LLVMCreateStringAttribute( + llcx, + c_attr_name.as_ptr(), + c_attr_name.as_bytes().len() as c_uint, + c_str.as_ptr(), + c_str.to_bytes().len() as c_uint, + ); + + // Attach attribute to the specific function parameter + // Note: ArgumentPlace uses 0-based indexing, but LLVM uses 1-based for arguments + attributes::apply_to_llfn(fn_def, llvm::AttributePlace::Argument(i as u32), &[attr]); + + // Free the C string to prevent memory leaks + llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr()); + } + } + + // Attach TypeTree attribute to the return type + // Enzyme needs this to understand how to handle return value derivatives + unsafe { + let enzyme_tt = to_enzyme_typetree(ret_tt, llvm_data_layout, llcx); + let c_str = llvm::EnzymeTypeTreeToString(enzyme_tt.inner); + let c_str = std::ffi::CStr::from_ptr(c_str); + + let ret_attr = llvm::LLVMCreateStringAttribute( + llcx, + c_attr_name.as_ptr(), + c_attr_name.as_bytes().len() as c_uint, + c_str.as_ptr(), + c_str.to_bytes().len() as c_uint, + ); + + // Attach to function return type + attributes::apply_to_llfn(fn_def, llvm::AttributePlace::ReturnValue, &[ret_attr]); + + // Free the C string + llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr()); + } +} + +// Fallback implementation when Enzyme is not available +#[cfg(not(llvm_enzyme))] +pub(crate) fn add_tt<'ll>( + _llmod: &'ll llvm::Module, + _llcx: &'ll llvm::Context, + _fn_def: &'ll Value, + _tt: FncTree, +) { + unimplemented!() +} diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index df42c40031752..e7130757f1db6 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -2251,36 +2251,29 @@ pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>) -> FncTree { /// Generate TypeTree for a specific type. /// This function analyzes a Rust type and creates appropriate TypeTree metadata. -fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { - // Handle basic scalar types +pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { if ty.is_scalar() { let (kind, size) = if ty.is_integral() || ty.is_char() || ty.is_bool() { (Kind::Integer, ty.primitive_size(tcx).bytes_usize()) } else if ty.is_floating_point() { match ty { + x if x == tcx.types.f16 => (Kind::Half, 2), x if x == tcx.types.f32 => (Kind::Float, 4), x if x == tcx.types.f64 => (Kind::Double, 8), - _ => return TypeTree::new(), // Unknown float type + x if x == tcx.types.f128 => (Kind::F128, 16), + _ => return TypeTree::new(), } } else { - // TODO(KMJ-007): Handle other scalar types if needed return TypeTree::new(); }; - - return TypeTree(vec![Type { - offset: -1, - size, - kind, - child: TypeTree::new() - }]); + + return TypeTree(vec![Type { offset: -1, size, kind, child: TypeTree::new() }]); } - // Handle references and pointers if ty.is_ref() || ty.is_raw_ptr() || ty.is_box() { let inner_ty = if let Some(inner) = ty.builtin_deref(true) { inner } else { - // TODO(KMJ-007): Handle complex pointer types return TypeTree::new(); }; From beba6b9b8692d0db14c56292d4fb793fcd8ceb1e Mon Sep 17 00:00:00 2001 From: Karan Janthe Date: Sat, 23 Aug 2025 21:57:37 +0000 Subject: [PATCH 04/22] Update TypeTree tests to verify metadata attachment - Fix nott-flag test to emit LLVM IR and check enzyme_type attributes - Replace TODO comments with actual TypeTree metadata verification - Test that NoTT flag properly disables TypeTree generation - Test that TypeTree enabled generates proper enzyme_type attributes Signed-off-by: Karan Janthe --- .../autodiff/type-trees/nott-flag/nott.check | 8 ++-- .../autodiff/type-trees/nott-flag/rmake.rs | 42 +++++++++---------- .../type-trees/nott-flag/with_tt.check | 7 ++-- 3 files changed, 29 insertions(+), 28 deletions(-) diff --git a/tests/run-make/autodiff/type-trees/nott-flag/nott.check b/tests/run-make/autodiff/type-trees/nott-flag/nott.check index 56ef2f0bdf380..8d23e2ee31957 100644 --- a/tests/run-make/autodiff/type-trees/nott-flag/nott.check +++ b/tests/run-make/autodiff/type-trees/nott-flag/nott.check @@ -1,3 +1,5 @@ -// TODO(KMJ-007): Update this test when TypeTree integration is complete -// CHECK: square - {[-1]:Float@double} |{[-1]:Pointer}:{} -// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@double} \ No newline at end of file +// Check that enzyme_type attributes are NOT present when NoTT flag is used +// This verifies the NoTT flag correctly disables TypeTree metadata + +CHECK: define{{.*}}@square +CHECK-NOT: "enzyme_type" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/nott-flag/rmake.rs b/tests/run-make/autodiff/type-trees/nott-flag/rmake.rs index 536164192dc06..bab863ca9ff29 100644 --- a/tests/run-make/autodiff/type-trees/nott-flag/rmake.rs +++ b/tests/run-make/autodiff/type-trees/nott-flag/rmake.rs @@ -5,34 +5,32 @@ use run_make_support::{llvm_filecheck, rfs, rustc}; fn main() { // Test with NoTT flag - should not generate TypeTree metadata - let output_nott = rustc() + rustc() .input("test.rs") - .arg("-Zautodiff=Enable,NoTT,PrintTAFn=square") - .arg("-Zautodiff=NoPostopt") - .opt_level("3") - .arg("-Clto=fat") - .arg("-g") + .arg("-Zautodiff=Enable,NoTT") + .emit("llvm-ir") + .arg("-o") + .arg("nott.ll") .run(); - // Write output for NoTT case - rfs::write("nott.stdout", output_nott.stdout_utf8()); - // Test without NoTT flag - should generate TypeTree metadata - let output_with_tt = rustc() + rustc() .input("test.rs") - .arg("-Zautodiff=Enable,PrintTAFn=square") - .arg("-Zautodiff=NoPostopt") - .opt_level("3") - .arg("-Clto=fat") - .arg("-g") + .arg("-Zautodiff=Enable") + .emit("llvm-ir") + .arg("-o") + .arg("with_tt.ll") .run(); - // Write output for TypeTree case - rfs::write("with_tt.stdout", output_with_tt.stdout_utf8()); - - // Verify NoTT output has minimal TypeTree info - llvm_filecheck().patterns("nott.check").stdin_buf(rfs::read("nott.stdout")).run(); + // Verify NoTT version does NOT have enzyme_type attributes + llvm_filecheck() + .patterns("nott.check") + .stdin_buf(rfs::read("nott.ll")) + .run(); - // Verify normal output will have TypeTree info (once implemented) - llvm_filecheck().patterns("with_tt.check").stdin_buf(rfs::read("with_tt.stdout")).run(); + // Verify TypeTree version DOES have enzyme_type attributes + llvm_filecheck() + .patterns("with_tt.check") + .stdin_buf(rfs::read("with_tt.ll")) + .run(); } \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/nott-flag/with_tt.check b/tests/run-make/autodiff/type-trees/nott-flag/with_tt.check index 56ef2f0bdf380..53eafc10a65da 100644 --- a/tests/run-make/autodiff/type-trees/nott-flag/with_tt.check +++ b/tests/run-make/autodiff/type-trees/nott-flag/with_tt.check @@ -1,3 +1,4 @@ -// TODO(KMJ-007): Update this test when TypeTree integration is complete -// CHECK: square - {[-1]:Float@double} |{[-1]:Pointer}:{} -// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@double} \ No newline at end of file +// Check that enzyme_type attributes are present when TypeTree is enabled +// This verifies our TypeTree metadata attachment is working + +CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@square{{.*}}"enzyme_type"="{[]:Pointer}" \ No newline at end of file From 5d3ebc3804299503387b7ea1427f1619d410c2b2 Mon Sep 17 00:00:00 2001 From: Karan Janthe Date: Sat, 23 Aug 2025 21:57:54 +0000 Subject: [PATCH 05/22] Add TypeTree tests for scalar types - Add specific tests for f32, f64, i32, f16, f128 TypeTree generation - Verify correct enzyme_type metadata for each scalar type - Ensure TypeTree metadata matches expected Enzyme format Signed-off-by: Karan Janthe --- .../scalar-types/f128-typetree/f128.check | 4 ++++ .../scalar-types/f128-typetree/rmake.rs | 12 ++++++++++++ .../type-trees/scalar-types/f128-typetree/test.rs | 15 +++++++++++++++ .../scalar-types/f16-typetree/f16.check | 4 ++++ .../type-trees/scalar-types/f16-typetree/rmake.rs | 12 ++++++++++++ .../type-trees/scalar-types/f16-typetree/test.rs | 15 +++++++++++++++ .../scalar-types/f32-typetree/f32.check | 4 ++++ .../type-trees/scalar-types/f32-typetree/rmake.rs | 12 ++++++++++++ .../type-trees/scalar-types/f32-typetree/test.rs | 15 +++++++++++++++ .../scalar-types/f64-typetree/f64.check | 4 ++++ .../type-trees/scalar-types/f64-typetree/rmake.rs | 12 ++++++++++++ .../type-trees/scalar-types/f64-typetree/test.rs | 15 +++++++++++++++ .../scalar-types/i32-typetree/i32.check | 4 ++++ .../type-trees/scalar-types/i32-typetree/rmake.rs | 12 ++++++++++++ .../type-trees/scalar-types/i32-typetree/test.rs | 14 ++++++++++++++ 15 files changed, 154 insertions(+) create mode 100644 tests/run-make/autodiff/type-trees/scalar-types/f128-typetree/f128.check create mode 100644 tests/run-make/autodiff/type-trees/scalar-types/f128-typetree/rmake.rs create mode 100644 tests/run-make/autodiff/type-trees/scalar-types/f128-typetree/test.rs create mode 100644 tests/run-make/autodiff/type-trees/scalar-types/f16-typetree/f16.check create mode 100644 tests/run-make/autodiff/type-trees/scalar-types/f16-typetree/rmake.rs create mode 100644 tests/run-make/autodiff/type-trees/scalar-types/f16-typetree/test.rs create mode 100644 tests/run-make/autodiff/type-trees/scalar-types/f32-typetree/f32.check create mode 100644 tests/run-make/autodiff/type-trees/scalar-types/f32-typetree/rmake.rs create mode 100644 tests/run-make/autodiff/type-trees/scalar-types/f32-typetree/test.rs create mode 100644 tests/run-make/autodiff/type-trees/scalar-types/f64-typetree/f64.check create mode 100644 tests/run-make/autodiff/type-trees/scalar-types/f64-typetree/rmake.rs create mode 100644 tests/run-make/autodiff/type-trees/scalar-types/f64-typetree/test.rs create mode 100644 tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/i32.check create mode 100644 tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/rmake.rs create mode 100644 tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/test.rs diff --git a/tests/run-make/autodiff/type-trees/scalar-types/f128-typetree/f128.check b/tests/run-make/autodiff/type-trees/scalar-types/f128-typetree/f128.check new file mode 100644 index 0000000000000..31cef113420db --- /dev/null +++ b/tests/run-make/autodiff/type-trees/scalar-types/f128-typetree/f128.check @@ -0,0 +1,4 @@ +; Check that f128 TypeTree metadata is correctly generated +; f128 maps to Unknown in our current implementation since CConcreteType doesn't have DT_F128 + +CHECK: define{{.*}}"enzyme_type"={{.*}}@test_f128{{.*}}"enzyme_type"="{[]:Pointer}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/scalar-types/f128-typetree/rmake.rs b/tests/run-make/autodiff/type-trees/scalar-types/f128-typetree/rmake.rs new file mode 100644 index 0000000000000..44320ecdd5714 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/scalar-types/f128-typetree/rmake.rs @@ -0,0 +1,12 @@ +//@ needs-enzyme +//@ ignore-cross-compile + +use run_make_support::{llvm_filecheck, rfs, rustc}; + +fn main() { + // Compile with TypeTree enabled and emit LLVM IR + rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run(); + + // Check that f128 TypeTree metadata is correctly generated + llvm_filecheck().patterns("f128.check").stdin_buf(rfs::read("test.ll")).run(); +} diff --git a/tests/run-make/autodiff/type-trees/scalar-types/f128-typetree/test.rs b/tests/run-make/autodiff/type-trees/scalar-types/f128-typetree/test.rs new file mode 100644 index 0000000000000..5c71baa3e6999 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/scalar-types/f128-typetree/test.rs @@ -0,0 +1,15 @@ +#![feature(autodiff, f128)] + +use std::autodiff::autodiff_reverse; + +#[autodiff_reverse(d_test, Duplicated, Active)] +#[no_mangle] +fn test_f128(x: &f128) -> f128 { + *x * *x +} + +fn main() { + let x = 2.0_f128; + let mut dx = 0.0_f128; + let _result = d_test(&x, &mut dx, 1.0); +} diff --git a/tests/run-make/autodiff/type-trees/scalar-types/f16-typetree/f16.check b/tests/run-make/autodiff/type-trees/scalar-types/f16-typetree/f16.check new file mode 100644 index 0000000000000..97a3964a569b1 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/scalar-types/f16-typetree/f16.check @@ -0,0 +1,4 @@ +; Check that f16 TypeTree metadata is correctly generated +; Should show Half for f16 values and Pointer for references + +CHECK: define{{.*}}"enzyme_type"={{.*}}@test_f16{{.*}}"enzyme_type"="{[]:Pointer}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/scalar-types/f16-typetree/rmake.rs b/tests/run-make/autodiff/type-trees/scalar-types/f16-typetree/rmake.rs new file mode 100644 index 0000000000000..0aebdbf55209b --- /dev/null +++ b/tests/run-make/autodiff/type-trees/scalar-types/f16-typetree/rmake.rs @@ -0,0 +1,12 @@ +//@ needs-enzyme +//@ ignore-cross-compile + +use run_make_support::{llvm_filecheck, rfs, rustc}; + +fn main() { + // Compile with TypeTree enabled and emit LLVM IR + rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run(); + + // Check that f16 TypeTree metadata is correctly generated + llvm_filecheck().patterns("f16.check").stdin_buf(rfs::read("test.ll")).run(); +} diff --git a/tests/run-make/autodiff/type-trees/scalar-types/f16-typetree/test.rs b/tests/run-make/autodiff/type-trees/scalar-types/f16-typetree/test.rs new file mode 100644 index 0000000000000..6b68e8252f4c0 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/scalar-types/f16-typetree/test.rs @@ -0,0 +1,15 @@ +#![feature(autodiff, f16)] + +use std::autodiff::autodiff_reverse; + +#[autodiff_reverse(d_test, Duplicated, Active)] +#[no_mangle] +fn test_f16(x: &f16) -> f16 { + *x * *x +} + +fn main() { + let x = 2.0_f16; + let mut dx = 0.0_f16; + let _result = d_test(&x, &mut dx, 1.0); +} diff --git a/tests/run-make/autodiff/type-trees/scalar-types/f32-typetree/f32.check b/tests/run-make/autodiff/type-trees/scalar-types/f32-typetree/f32.check new file mode 100644 index 0000000000000..b32d5086d07d1 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/scalar-types/f32-typetree/f32.check @@ -0,0 +1,4 @@ +; Check that f32 TypeTree metadata is correctly generated +; Should show Float@float for f32 values and Pointer for references + +CHECK: define{{.*}}"enzyme_type"="{[]:Float@float}"{{.*}}@test_f32{{.*}}"enzyme_type"="{[]:Pointer}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/scalar-types/f32-typetree/rmake.rs b/tests/run-make/autodiff/type-trees/scalar-types/f32-typetree/rmake.rs new file mode 100644 index 0000000000000..ee3ab753bf505 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/scalar-types/f32-typetree/rmake.rs @@ -0,0 +1,12 @@ +//@ needs-enzyme +//@ ignore-cross-compile + +use run_make_support::{llvm_filecheck, rfs, rustc}; + +fn main() { + // Compile with TypeTree enabled and emit LLVM IR + rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run(); + + // Check that f32 TypeTree metadata is correctly generated + llvm_filecheck().patterns("f32.check").stdin_buf(rfs::read("test.ll")).run(); +} diff --git a/tests/run-make/autodiff/type-trees/scalar-types/f32-typetree/test.rs b/tests/run-make/autodiff/type-trees/scalar-types/f32-typetree/test.rs new file mode 100644 index 0000000000000..56c118399ee4b --- /dev/null +++ b/tests/run-make/autodiff/type-trees/scalar-types/f32-typetree/test.rs @@ -0,0 +1,15 @@ +#![feature(autodiff)] + +use std::autodiff::autodiff_reverse; + +#[autodiff_reverse(d_test, Duplicated, Active)] +#[no_mangle] +fn test_f32(x: &f32) -> f32 { + x * x +} + +fn main() { + let x = 2.0_f32; + let mut dx = 0.0_f32; + let _result = d_test(&x, &mut dx, 1.0); +} diff --git a/tests/run-make/autodiff/type-trees/scalar-types/f64-typetree/f64.check b/tests/run-make/autodiff/type-trees/scalar-types/f64-typetree/f64.check new file mode 100644 index 0000000000000..0e9d9c03a0f86 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/scalar-types/f64-typetree/f64.check @@ -0,0 +1,4 @@ +; Check that f64 TypeTree metadata is correctly generated +; Should show Float@double for f64 values and Pointer for references + +CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_f64{{.*}}"enzyme_type"="{[]:Pointer}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/scalar-types/f64-typetree/rmake.rs b/tests/run-make/autodiff/type-trees/scalar-types/f64-typetree/rmake.rs new file mode 100644 index 0000000000000..5fac9b23bc80f --- /dev/null +++ b/tests/run-make/autodiff/type-trees/scalar-types/f64-typetree/rmake.rs @@ -0,0 +1,12 @@ +//@ needs-enzyme +//@ ignore-cross-compile + +use run_make_support::{llvm_filecheck, rfs, rustc}; + +fn main() { + // Compile with TypeTree enabled and emit LLVM IR + rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run(); + + // Check that f64 TypeTree metadata is correctly generated + llvm_filecheck().patterns("f64.check").stdin_buf(rfs::read("test.ll")).run(); +} diff --git a/tests/run-make/autodiff/type-trees/scalar-types/f64-typetree/test.rs b/tests/run-make/autodiff/type-trees/scalar-types/f64-typetree/test.rs new file mode 100644 index 0000000000000..235360b76b23d --- /dev/null +++ b/tests/run-make/autodiff/type-trees/scalar-types/f64-typetree/test.rs @@ -0,0 +1,15 @@ +#![feature(autodiff)] + +use std::autodiff::autodiff_reverse; + +#[autodiff_reverse(d_test, Duplicated, Active)] +#[no_mangle] +fn test_f64(x: &f64) -> f64 { + x * x +} + +fn main() { + let x = 2.0_f64; + let mut dx = 0.0_f64; + let _result = d_test(&x, &mut dx, 1.0); +} diff --git a/tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/i32.check b/tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/i32.check new file mode 100644 index 0000000000000..aa8e5223e8ed5 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/i32.check @@ -0,0 +1,4 @@ +; Check that i32 TypeTree metadata is correctly generated +; Should show Integer for i32 values (integers are typically Const in autodiff) + +CHECK: define{{.*}}"enzyme_type"="{[]:Integer}"{{.*}}@test_i32{{.*}}"enzyme_type"="{[]:Integer}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/rmake.rs b/tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/rmake.rs new file mode 100644 index 0000000000000..a40fd55d88adf --- /dev/null +++ b/tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/rmake.rs @@ -0,0 +1,12 @@ +//@ needs-enzyme +//@ ignore-cross-compile + +use run_make_support::{llvm_filecheck, rfs, rustc}; + +fn main() { + // Compile with TypeTree enabled and emit LLVM IR + rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run(); + + // Check that i32 TypeTree metadata is correctly generated + llvm_filecheck().patterns("i32.check").stdin_buf(rfs::read("test.ll")).run(); +} diff --git a/tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/test.rs b/tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/test.rs new file mode 100644 index 0000000000000..530c137f62b9c --- /dev/null +++ b/tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/test.rs @@ -0,0 +1,14 @@ +#![feature(autodiff)] + +use std::autodiff::autodiff_reverse; + +#[autodiff_reverse(d_test, Const, Active)] +#[no_mangle] +fn test_i32(x: i32) -> i32 { + x * x +} + +fn main() { + let x = 5_i32; + let _result = d_test(x, 1); +} From 664e83b3e76b51fb5192e74a64eef3bc5bbd4e32 Mon Sep 17 00:00:00 2001 From: Karan Janthe Date: Sat, 23 Aug 2025 23:10:48 +0000 Subject: [PATCH 06/22] added typetree support for memcpy --- compiler/rustc_codegen_gcc/src/builder.rs | 1 + .../rustc_codegen_gcc/src/intrinsic/mod.rs | 1 + compiler/rustc_codegen_llvm/src/abi.rs | 1 + compiler/rustc_codegen_llvm/src/builder.rs | 15 ++++++- .../rustc_codegen_llvm/src/llvm/enzyme_ffi.rs | 1 + compiler/rustc_codegen_llvm/src/typetree.rs | 20 ++++------ compiler/rustc_codegen_llvm/src/va_arg.rs | 1 + compiler/rustc_codegen_ssa/src/mir/block.rs | 1 + .../rustc_codegen_ssa/src/mir/intrinsic.rs | 2 +- .../rustc_codegen_ssa/src/mir/statement.rs | 2 +- .../rustc_codegen_ssa/src/traits/builder.rs | 3 +- compiler/rustc_interface/src/tests.rs | 1 - compiler/rustc_middle/src/ty/mod.rs | 4 +- tests/codegen-llvm/autodiff/typetree.rs | 2 +- .../memcpy-typetree/memcpy-ir.check | 8 ++++ .../type-trees/memcpy-typetree/memcpy.check | 13 +++++++ .../type-trees/memcpy-typetree/memcpy.rs | 36 +++++++++++++++++ .../type-trees/memcpy-typetree/rmake.rs | 39 +++++++++++++++++++ .../autodiff/type-trees/nott-flag/rmake.rs | 14 ++----- .../autodiff/type-trees/nott-flag/test.rs | 2 +- tests/ui/autodiff/flag_nott.rs | 2 +- 21 files changed, 135 insertions(+), 34 deletions(-) create mode 100644 tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy-ir.check create mode 100644 tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy.check create mode 100644 tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy.rs create mode 100644 tests/run-make/autodiff/type-trees/memcpy-typetree/rmake.rs diff --git a/compiler/rustc_codegen_gcc/src/builder.rs b/compiler/rustc_codegen_gcc/src/builder.rs index f7a7a3f8c7e35..5657620879ca1 100644 --- a/compiler/rustc_codegen_gcc/src/builder.rs +++ b/compiler/rustc_codegen_gcc/src/builder.rs @@ -1383,6 +1383,7 @@ impl<'a, 'gcc, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'gcc, 'tcx> { _src_align: Align, size: RValue<'gcc>, flags: MemFlags, + _tt: Option, // Autodiff TypeTrees are LLVM-only, ignored in GCC backend ) { assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported"); let size = self.intcast(size, self.type_size_t(), false); diff --git a/compiler/rustc_codegen_gcc/src/intrinsic/mod.rs b/compiler/rustc_codegen_gcc/src/intrinsic/mod.rs index 84fa56cf90308..3b897a1983503 100644 --- a/compiler/rustc_codegen_gcc/src/intrinsic/mod.rs +++ b/compiler/rustc_codegen_gcc/src/intrinsic/mod.rs @@ -771,6 +771,7 @@ impl<'gcc, 'tcx> ArgAbiExt<'gcc, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> { scratch_align, bx.const_usize(self.layout.size.bytes()), MemFlags::empty(), + None, ); bx.lifetime_end(scratch, scratch_size); diff --git a/compiler/rustc_codegen_llvm/src/abi.rs b/compiler/rustc_codegen_llvm/src/abi.rs index 11be704116790..61fadad606613 100644 --- a/compiler/rustc_codegen_llvm/src/abi.rs +++ b/compiler/rustc_codegen_llvm/src/abi.rs @@ -246,6 +246,7 @@ impl<'ll, 'tcx> ArgAbiExt<'ll, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> { scratch_align, bx.const_usize(copy_bytes), MemFlags::empty(), + None, ); bx.lifetime_end(llscratch, scratch_size); } diff --git a/compiler/rustc_codegen_llvm/src/builder.rs b/compiler/rustc_codegen_llvm/src/builder.rs index 0f17cc9063a87..83a9cf620f109 100644 --- a/compiler/rustc_codegen_llvm/src/builder.rs +++ b/compiler/rustc_codegen_llvm/src/builder.rs @@ -2,6 +2,7 @@ use std::borrow::{Borrow, Cow}; use std::ops::Deref; use std::{iter, ptr}; +use rustc_ast::expand::typetree::FncTree; pub(crate) mod autodiff; pub(crate) mod gpu_offload; @@ -1107,11 +1108,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { src_align: Align, size: &'ll Value, flags: MemFlags, + tt: Option, ) { assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported"); let size = self.intcast(size, self.type_isize(), false); let is_volatile = flags.contains(MemFlags::VOLATILE); - unsafe { + let memcpy = unsafe { llvm::LLVMRustBuildMemCpy( self.llbuilder, dst, @@ -1120,7 +1122,16 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { src_align.bytes() as c_uint, size, is_volatile, - ); + ) + }; + + // TypeTree metadata for memcpy is especially important: when Enzyme encounters + // a memcpy during autodiff, it needs to know the structure of the data being + // copied to properly track derivatives. For example, copying an array of floats + // vs. copying a struct with mixed types requires different derivative handling. + // The TypeTree tells Enzyme exactly what memory layout to expect. + if let Some(tt) = tt { + crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, memcpy, tt); } } diff --git a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs index 12b6ffa3c0b5f..9dd84ad9a4dc3 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs @@ -25,6 +25,7 @@ pub(crate) enum CConcreteType { DT_Half = 3, DT_Float = 4, DT_Double = 5, + // FIXME(KMJ-007): handle f128 using long double here(https://github.com/EnzymeAD/Enzyme/issues/1600) DT_Unknown = 6, } diff --git a/compiler/rustc_codegen_llvm/src/typetree.rs b/compiler/rustc_codegen_llvm/src/typetree.rs index 434316464e626..8e6272a186b88 100644 --- a/compiler/rustc_codegen_llvm/src/typetree.rs +++ b/compiler/rustc_codegen_llvm/src/typetree.rs @@ -1,8 +1,11 @@ -use std::ffi::{CString, c_char, c_uint}; - -use rustc_ast::expand::typetree::{FncTree, TypeTree as RustTypeTree}; +use rustc_ast::expand::typetree::FncTree; +#[cfg(llvm_enzyme)] +use { + crate::attributes, + rustc_ast::expand::typetree::TypeTree as RustTypeTree, + std::ffi::{CString, c_char, c_uint}, +}; -use crate::attributes; use crate::llvm::{self, Value}; /// Converts a Rust TypeTree to Enzyme's internal TypeTree format @@ -50,15 +53,6 @@ fn to_enzyme_typetree( enzyme_tt } -#[cfg(not(llvm_enzyme))] -fn to_enzyme_typetree( - _rust_typetree: RustTypeTree, - _data_layout: &str, - _llcx: &llvm::Context, -) -> ! { - unimplemented!("TypeTree conversion not available without llvm_enzyme support") -} - // Attaches TypeTree information to LLVM function as enzyme_type attributes. #[cfg(llvm_enzyme)] pub(crate) fn add_tt<'ll>( diff --git a/compiler/rustc_codegen_llvm/src/va_arg.rs b/compiler/rustc_codegen_llvm/src/va_arg.rs index ab08125217ff5..d48c7cf874a0a 100644 --- a/compiler/rustc_codegen_llvm/src/va_arg.rs +++ b/compiler/rustc_codegen_llvm/src/va_arg.rs @@ -738,6 +738,7 @@ fn copy_to_temporary_if_more_aligned<'ll, 'tcx>( src_align, bx.const_u32(layout.layout.size().bytes() as u32), MemFlags::empty(), + None, ); tmp } else { diff --git a/compiler/rustc_codegen_ssa/src/mir/block.rs b/compiler/rustc_codegen_ssa/src/mir/block.rs index 1b218a0d33956..b2dc4fe32b0fe 100644 --- a/compiler/rustc_codegen_ssa/src/mir/block.rs +++ b/compiler/rustc_codegen_ssa/src/mir/block.rs @@ -1626,6 +1626,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { align, bx.const_usize(copy_bytes), MemFlags::empty(), + None, ); // ...and then load it with the ABI type. llval = load_cast(bx, cast, llscratch, scratch_align); diff --git a/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs b/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs index 3c667b8e88203..befa00c6861ed 100644 --- a/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs +++ b/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs @@ -30,7 +30,7 @@ fn copy_intrinsic<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>( if allow_overlap { bx.memmove(dst, align, src, align, size, flags); } else { - bx.memcpy(dst, align, src, align, size, flags); + bx.memcpy(dst, align, src, align, size, flags, None); } } diff --git a/compiler/rustc_codegen_ssa/src/mir/statement.rs b/compiler/rustc_codegen_ssa/src/mir/statement.rs index f164e0f912373..0a50d7f18dbef 100644 --- a/compiler/rustc_codegen_ssa/src/mir/statement.rs +++ b/compiler/rustc_codegen_ssa/src/mir/statement.rs @@ -90,7 +90,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { let align = pointee_layout.align; let dst = dst_val.immediate(); let src = src_val.immediate(); - bx.memcpy(dst, align, src, align, bytes, crate::MemFlags::empty()); + bx.memcpy(dst, align, src, align, bytes, crate::MemFlags::empty(), None); } mir::StatementKind::FakeRead(..) | mir::StatementKind::Retag { .. } diff --git a/compiler/rustc_codegen_ssa/src/traits/builder.rs b/compiler/rustc_codegen_ssa/src/traits/builder.rs index 4a5694e97fa20..60296e36e0c1a 100644 --- a/compiler/rustc_codegen_ssa/src/traits/builder.rs +++ b/compiler/rustc_codegen_ssa/src/traits/builder.rs @@ -451,6 +451,7 @@ pub trait BuilderMethods<'a, 'tcx>: src_align: Align, size: Self::Value, flags: MemFlags, + tt: Option, ); fn memmove( &mut self, @@ -507,7 +508,7 @@ pub trait BuilderMethods<'a, 'tcx>: temp.val.store_with_flags(self, dst.with_type(layout), flags); } else if !layout.is_zst() { let bytes = self.const_usize(layout.size.bytes()); - self.memcpy(dst.llval, dst.align, src.llval, src.align, bytes, flags); + self.memcpy(dst.llval, dst.align, src.llval, src.align, bytes, flags, None); } } diff --git a/compiler/rustc_interface/src/tests.rs b/compiler/rustc_interface/src/tests.rs index 837acdadd579d..0dc5b5af3aca0 100644 --- a/compiler/rustc_interface/src/tests.rs +++ b/compiler/rustc_interface/src/tests.rs @@ -765,7 +765,6 @@ fn test_unstable_options_tracking_hash() { tracked!(allow_features, Some(vec![String::from("lang_items")])); tracked!(always_encode_mir, true); tracked!(assume_incomplete_release, true); - tracked!(autodiff, vec![AutoDiff::Enable]); tracked!(autodiff, vec![AutoDiff::Enable, AutoDiff::NoTT]); tracked!(binary_dep_depinfo, true); tracked!(box_noalias, false); diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index e7130757f1db6..741b5d7fd4e06 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -2280,12 +2280,12 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { let child = typetree_from_ty(tcx, inner_ty); return TypeTree(vec![Type { offset: -1, - size: 8, // TODO(KMJ-007): Get actual pointer size from target + size: tcx.data_layout.pointer_size().bytes_usize(), kind: Kind::Pointer, child, }]); } - // TODO(KMJ-007): Handle arrays, slices, structs, and other complex types + // FIXME(KMJ-007): Handle arrays, slices, structs, and other complex types TypeTree::new() } diff --git a/tests/codegen-llvm/autodiff/typetree.rs b/tests/codegen-llvm/autodiff/typetree.rs index 3ad38d581b9a1..1cb0c2fb68be3 100644 --- a/tests/codegen-llvm/autodiff/typetree.rs +++ b/tests/codegen-llvm/autodiff/typetree.rs @@ -30,4 +30,4 @@ fn main() { let output_ = d_simple(&x, &mut df_dx, 1.0); assert_eq!(output, output_); assert_eq!(2.0, df_dx); -} \ No newline at end of file +} diff --git a/tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy-ir.check b/tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy-ir.check new file mode 100644 index 0000000000000..3a59f06dda1e4 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy-ir.check @@ -0,0 +1,8 @@ +; Check that enzyme_type attributes are present in the LLVM IR function definition +; This verifies our TypeTree system correctly attaches metadata for Enzyme + +CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_memcpy({{.*}}"enzyme_type"="{[]:Pointer}" + +; Check that llvm.memcpy exists (either call or declare) +CHECK: {{(call|declare).*}}@llvm.memcpy + diff --git a/tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy.check b/tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy.check new file mode 100644 index 0000000000000..ae70830297a78 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy.check @@ -0,0 +1,13 @@ +CHECK: force_memcpy + +CHECK: @llvm.memcpy.p0.p0.i64 + +CHECK: test_memcpy - {[-1]:Float@double} |{[-1]:Pointer}:{} + +CHECK-DAG: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@double, [-1,8]:Float@double, [-1,16]:Float@double, [-1,24]:Float@double} + +CHECK-DAG: load double{{.*}}: {[-1]:Float@double} + +CHECK-DAG: fmul double{{.*}}: {[-1]:Float@double} + +CHECK-DAG: fadd double{{.*}}: {[-1]:Float@double} \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy.rs b/tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy.rs new file mode 100644 index 0000000000000..3c1029190c88a --- /dev/null +++ b/tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy.rs @@ -0,0 +1,36 @@ +#![feature(autodiff)] + +use std::autodiff::autodiff_reverse; +use std::ptr; + +#[inline(never)] +fn force_memcpy(src: *const f64, dst: *mut f64, count: usize) { + unsafe { + ptr::copy_nonoverlapping(src, dst, count); + } +} + +#[autodiff_reverse(d_test_memcpy, Duplicated, Active)] +#[no_mangle] +fn test_memcpy(input: &[f64; 128]) -> f64 { + let mut local_data = [0.0f64; 128]; + + // Use a separate function to prevent inlining and optimization + force_memcpy(input.as_ptr(), local_data.as_mut_ptr(), 128); + + // Sum only first few elements to keep the computation simple + local_data[0] * local_data[0] + + local_data[1] * local_data[1] + + local_data[2] * local_data[2] + + local_data[3] * local_data[3] +} + +fn main() { + let input = [1.0; 128]; + let mut d_input = [0.0; 128]; + let result = test_memcpy(&input); + let result_d = d_test_memcpy(&input, &mut d_input, 1.0); + + assert_eq!(result, result_d); + println!("Memcpy test passed: result = {}", result); +} diff --git a/tests/run-make/autodiff/type-trees/memcpy-typetree/rmake.rs b/tests/run-make/autodiff/type-trees/memcpy-typetree/rmake.rs new file mode 100644 index 0000000000000..b4c650330fe94 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/memcpy-typetree/rmake.rs @@ -0,0 +1,39 @@ +//@ needs-enzyme +//@ ignore-cross-compile + +use run_make_support::{llvm_filecheck, rfs, rustc}; + +fn main() { + // First, compile to LLVM IR to check for enzyme_type attributes + let _ir_output = rustc() + .input("memcpy.rs") + .arg("-Zautodiff=Enable") + .arg("-Zautodiff=NoPostopt") + .opt_level("0") + .arg("--emit=llvm-ir") + .arg("-o") + .arg("main.ll") + .run(); + + // Then compile with TypeTree analysis output for the existing checks + let output = rustc() + .input("memcpy.rs") + .arg("-Zautodiff=Enable,PrintTAFn=test_memcpy") + .arg("-Zautodiff=NoPostopt") + .opt_level("3") + .arg("-Clto=fat") + .arg("-g") + .run(); + + let stdout = output.stdout_utf8(); + let stderr = output.stderr_utf8(); + let ir_content = rfs::read_to_string("main.ll"); + + rfs::write("memcpy.stdout", &stdout); + rfs::write("memcpy.stderr", &stderr); + rfs::write("main.ir", &ir_content); + + llvm_filecheck().patterns("memcpy.check").stdin_buf(stdout).run(); + + llvm_filecheck().patterns("memcpy-ir.check").stdin_buf(ir_content).run(); +} diff --git a/tests/run-make/autodiff/type-trees/nott-flag/rmake.rs b/tests/run-make/autodiff/type-trees/nott-flag/rmake.rs index bab863ca9ff29..de540b990cabd 100644 --- a/tests/run-make/autodiff/type-trees/nott-flag/rmake.rs +++ b/tests/run-make/autodiff/type-trees/nott-flag/rmake.rs @@ -23,14 +23,8 @@ fn main() { .run(); // Verify NoTT version does NOT have enzyme_type attributes - llvm_filecheck() - .patterns("nott.check") - .stdin_buf(rfs::read("nott.ll")) - .run(); - + llvm_filecheck().patterns("nott.check").stdin_buf(rfs::read("nott.ll")).run(); + // Verify TypeTree version DOES have enzyme_type attributes - llvm_filecheck() - .patterns("with_tt.check") - .stdin_buf(rfs::read("with_tt.ll")) - .run(); -} \ No newline at end of file + llvm_filecheck().patterns("with_tt.check").stdin_buf(rfs::read("with_tt.ll")).run(); +} diff --git a/tests/run-make/autodiff/type-trees/nott-flag/test.rs b/tests/run-make/autodiff/type-trees/nott-flag/test.rs index 5c634eea035d9..de3549c37c679 100644 --- a/tests/run-make/autodiff/type-trees/nott-flag/test.rs +++ b/tests/run-make/autodiff/type-trees/nott-flag/test.rs @@ -12,4 +12,4 @@ fn main() { let x = 2.0; let mut dx = 0.0; let _result = d_square(&x, &mut dx, 1.0); -} \ No newline at end of file +} diff --git a/tests/ui/autodiff/flag_nott.rs b/tests/ui/autodiff/flag_nott.rs index 7a97d892cd881..faa9949fe8167 100644 --- a/tests/ui/autodiff/flag_nott.rs +++ b/tests/ui/autodiff/flag_nott.rs @@ -16,4 +16,4 @@ fn main() { let x = 2.0; let mut dx = 0.0; let result = d_square(&x, &mut dx, 1.0); -} \ No newline at end of file +} From b8bb2e8bfec298aecec54f67e8ea3deda471eafa Mon Sep 17 00:00:00 2001 From: Karan Janthe Date: Wed, 27 Aug 2025 05:58:42 +0000 Subject: [PATCH 07/22] typo: allow EnzymeTypeTreeShiftIndiciesEq --- typos.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/typos.toml b/typos.toml index b0ff48f8fa28b..c7d5b0000b9d7 100644 --- a/typos.toml +++ b/typos.toml @@ -53,6 +53,7 @@ ERROR_DS_NOT_AUTHORITIVE_FOR_DST_NC = "ERROR_DS_NOT_AUTHORITIVE_FOR_DST_NC" ERROR_MCA_OCCURED = "ERROR_MCA_OCCURED" ERRNO_ACCES = "ERRNO_ACCES" tolen = "tolen" +EnzymeTypeTreeShiftIndiciesEq = "EnzymeTypeTreeShiftIndiciesEq" [default] extend-ignore-words-re = [ From d89ee858ab80f2c922c117c95b20264be8b84e3c Mon Sep 17 00:00:00 2001 From: Karan Janthe Date: Wed, 27 Aug 2025 18:19:55 +0000 Subject: [PATCH 08/22] enzyme submodule updated --- src/llvm-project | 2 +- src/tools/cargo | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llvm-project b/src/llvm-project index 333793696b0a0..2a22c8014390b 160000 --- a/src/llvm-project +++ b/src/llvm-project @@ -1 +1 @@ -Subproject commit 333793696b0a08002acac6af983adc7a5233fc56 +Subproject commit 2a22c8014390bdf387d8ed26005fee3187fb98bd diff --git a/src/tools/cargo b/src/tools/cargo index 966f94733bbc9..a4bd03c92d3b9 160000 --- a/src/tools/cargo +++ b/src/tools/cargo @@ -1 +1 @@ -Subproject commit 966f94733bbc94ca51ff9f1e4c49ad250ebbdc50 +Subproject commit a4bd03c92d3b977909953d06cc45e263bd7ca7d8 From 54f9376660707d4ca9fce51fd423658f75128ac4 Mon Sep 17 00:00:00 2001 From: Karan Janthe Date: Wed, 27 Aug 2025 18:23:54 +0000 Subject: [PATCH 09/22] autodiff: f128 support added for typetree --- compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs | 2 +- compiler/rustc_codegen_llvm/src/typetree.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs index 9dd84ad9a4dc3..1596dc4837978 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs @@ -25,8 +25,8 @@ pub(crate) enum CConcreteType { DT_Half = 3, DT_Float = 4, DT_Double = 5, - // FIXME(KMJ-007): handle f128 using long double here(https://github.com/EnzymeAD/Enzyme/issues/1600) DT_Unknown = 6, + DT_FP128 = 9, } pub(crate) struct TypeTree { diff --git a/compiler/rustc_codegen_llvm/src/typetree.rs b/compiler/rustc_codegen_llvm/src/typetree.rs index 8e6272a186b88..8c0d255bba84e 100644 --- a/compiler/rustc_codegen_llvm/src/typetree.rs +++ b/compiler/rustc_codegen_llvm/src/typetree.rs @@ -31,7 +31,7 @@ fn to_enzyme_typetree( rustc_ast::expand::typetree::Kind::Half => llvm::CConcreteType::DT_Half, rustc_ast::expand::typetree::Kind::Float => llvm::CConcreteType::DT_Float, rustc_ast::expand::typetree::Kind::Double => llvm::CConcreteType::DT_Double, - rustc_ast::expand::typetree::Kind::F128 => llvm::CConcreteType::DT_Unknown, + rustc_ast::expand::typetree::Kind::F128 => llvm::CConcreteType::DT_FP128, rustc_ast::expand::typetree::Kind::Unknown => llvm::CConcreteType::DT_Unknown, }; From 31541feb6f7e46c23141bdeb3e35ecd305bf8762 Mon Sep 17 00:00:00 2001 From: Karan Janthe Date: Mon, 1 Sep 2025 05:05:49 +0000 Subject: [PATCH 10/22] autodiff: add TypeTree support for arrays --- compiler/rustc_middle/src/ty/mod.rs | 42 ++++++++++++++++++- .../type-trees/array-typetree/array.check | 4 ++ .../type-trees/array-typetree/rmake.rs | 9 ++++ .../type-trees/array-typetree/test.rs | 15 +++++++ 4 files changed, 69 insertions(+), 1 deletion(-) create mode 100644 tests/run-make/autodiff/type-trees/array-typetree/array.check create mode 100644 tests/run-make/autodiff/type-trees/array-typetree/rmake.rs create mode 100644 tests/run-make/autodiff/type-trees/array-typetree/test.rs diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index 741b5d7fd4e06..02a4e4e2b1592 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -2286,6 +2286,46 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { }]); } - // FIXME(KMJ-007): Handle arrays, slices, structs, and other complex types + if ty.is_array() { + if let ty::Array(element_ty, len_const) = ty.kind() { + let len = len_const.try_to_target_usize(tcx).unwrap_or(0); + if len == 0 { + return TypeTree::new(); + } + + let element_tree = typetree_from_ty(tcx, *element_ty); + + let element_layout = tcx + .layout_of(ty::TypingEnv::fully_monomorphized().as_query_input(*element_ty)) + .ok() + .map(|layout| layout.size.bytes_usize()) + .unwrap_or(0); + + if element_layout == 0 { + return TypeTree::new(); + } + + let mut types = Vec::new(); + for i in 0..len { + let base_offset = (i as usize * element_layout) as isize; + + for elem_type in &element_tree.0 { + types.push(Type { + offset: if elem_type.offset == -1 { + base_offset + } else { + base_offset + elem_type.offset + }, + size: elem_type.size, + kind: elem_type.kind, + child: elem_type.child.clone(), + }); + } + } + + return TypeTree(types); + } + } + TypeTree::new() } diff --git a/tests/run-make/autodiff/type-trees/array-typetree/array.check b/tests/run-make/autodiff/type-trees/array-typetree/array.check new file mode 100644 index 0000000000000..7513458b8ab81 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/array-typetree/array.check @@ -0,0 +1,4 @@ +; Check that array TypeTree metadata is correctly generated +; Should show Float@double at each array element offset (0, 8, 16, 24, 32 bytes) + +CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_array{{.*}}"enzyme_type"="{[]:Pointer}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/array-typetree/rmake.rs b/tests/run-make/autodiff/type-trees/array-typetree/rmake.rs new file mode 100644 index 0000000000000..20b6a06690625 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/array-typetree/rmake.rs @@ -0,0 +1,9 @@ +//@ needs-enzyme +//@ ignore-cross-compile + +use run_make_support::{llvm_filecheck, rfs, rustc}; + +fn main() { + rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run(); + llvm_filecheck().patterns("array.check").stdin_buf(rfs::read("test.ll")).run(); +} diff --git a/tests/run-make/autodiff/type-trees/array-typetree/test.rs b/tests/run-make/autodiff/type-trees/array-typetree/test.rs new file mode 100644 index 0000000000000..f54ebf5a4c7b5 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/array-typetree/test.rs @@ -0,0 +1,15 @@ +#![feature(autodiff)] + +use std::autodiff::autodiff_reverse; + +#[autodiff_reverse(d_test, Duplicated, Active)] +#[no_mangle] +fn test_array(arr: &[f64; 5]) -> f64 { + arr[0] + arr[1] + arr[2] + arr[3] + arr[4] +} + +fn main() { + let arr = [1.0, 2.0, 3.0, 4.0, 5.0]; + let mut d_arr = [0.0; 5]; + let _result = d_test(&arr, &mut d_arr, 1.0); +} From be3617b04031c7d4f227ce54c4ce6ceeae00981c Mon Sep 17 00:00:00 2001 From: Karan Janthe Date: Mon, 1 Sep 2025 06:17:34 +0000 Subject: [PATCH 11/22] autodiff: slice support in typetree --- compiler/rustc_middle/src/ty/mod.rs | 7 +++++++ .../autodiff/type-trees/slice-typetree/rmake.rs | 9 +++++++++ .../type-trees/slice-typetree/slice.check | 4 ++++ .../autodiff/type-trees/slice-typetree/test.rs | 16 ++++++++++++++++ 4 files changed, 36 insertions(+) create mode 100644 tests/run-make/autodiff/type-trees/slice-typetree/rmake.rs create mode 100644 tests/run-make/autodiff/type-trees/slice-typetree/slice.check create mode 100644 tests/run-make/autodiff/type-trees/slice-typetree/test.rs diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index 02a4e4e2b1592..82a41f403f8b4 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -2327,5 +2327,12 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { } } + if ty.is_slice() { + if let ty::Slice(element_ty) = ty.kind() { + let element_tree = typetree_from_ty(tcx, *element_ty); + return element_tree; + } + } + TypeTree::new() } diff --git a/tests/run-make/autodiff/type-trees/slice-typetree/rmake.rs b/tests/run-make/autodiff/type-trees/slice-typetree/rmake.rs new file mode 100644 index 0000000000000..b81fb50bf1a75 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/slice-typetree/rmake.rs @@ -0,0 +1,9 @@ +//@ needs-enzyme +//@ ignore-cross-compile + +use run_make_support::{llvm_filecheck, rfs, rustc}; + +fn main() { + rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run(); + llvm_filecheck().patterns("slice.check").stdin_buf(rfs::read("test.ll")).run(); +} diff --git a/tests/run-make/autodiff/type-trees/slice-typetree/slice.check b/tests/run-make/autodiff/type-trees/slice-typetree/slice.check new file mode 100644 index 0000000000000..8fc9e9c4f1b17 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/slice-typetree/slice.check @@ -0,0 +1,4 @@ +; Check that slice TypeTree metadata is correctly generated +; Should show Float@double for slice elements + +CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_slice{{.*}}"enzyme_type"="{[]:Pointer}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/slice-typetree/test.rs b/tests/run-make/autodiff/type-trees/slice-typetree/test.rs new file mode 100644 index 0000000000000..7117fa3844f59 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/slice-typetree/test.rs @@ -0,0 +1,16 @@ +#![feature(autodiff)] + +use std::autodiff::autodiff_reverse; + +#[autodiff_reverse(d_test, Duplicated, Active)] +#[no_mangle] +fn test_slice(slice: &[f64]) -> f64 { + slice.iter().sum() +} + +fn main() { + let arr = [1.0, 2.0, 3.0, 4.0, 5.0]; + let slice = &arr[..]; + let mut d_slice = [0.0; 5]; + let _result = d_test(slice, &mut d_slice[..], 1.0); +} From 7c5fbfbdbbb389462e0ffb936ba9b16cffbce6ed Mon Sep 17 00:00:00 2001 From: Karan Janthe Date: Mon, 1 Sep 2025 16:14:31 +0000 Subject: [PATCH 12/22] autodiff: tuple support in typetree --- compiler/rustc_middle/src/ty/mod.rs | 36 +++++++++++++++++++ .../type-trees/tuple-typetree/rmake.rs | 9 +++++ .../type-trees/tuple-typetree/test.rs | 15 ++++++++ .../type-trees/tuple-typetree/tuple.check | 4 +++ 4 files changed, 64 insertions(+) create mode 100644 tests/run-make/autodiff/type-trees/tuple-typetree/rmake.rs create mode 100644 tests/run-make/autodiff/type-trees/tuple-typetree/test.rs create mode 100644 tests/run-make/autodiff/type-trees/tuple-typetree/tuple.check diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index 82a41f403f8b4..c0a091551c9eb 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -2334,5 +2334,41 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { } } + if let ty::Tuple(tuple_types) = ty.kind() { + if tuple_types.is_empty() { + return TypeTree::new(); + } + + let mut types = Vec::new(); + let mut current_offset = 0; + + for tuple_ty in tuple_types.iter() { + let element_tree = typetree_from_ty(tcx, tuple_ty); + + let element_layout = tcx + .layout_of(ty::TypingEnv::fully_monomorphized().as_query_input(tuple_ty)) + .ok() + .map(|layout| layout.size.bytes_usize()) + .unwrap_or(0); + + for elem_type in &element_tree.0 { + types.push(Type { + offset: if elem_type.offset == -1 { + current_offset as isize + } else { + current_offset as isize + elem_type.offset + }, + size: elem_type.size, + kind: elem_type.kind, + child: elem_type.child.clone(), + }); + } + + current_offset += element_layout; + } + + return TypeTree(types); + } + TypeTree::new() } diff --git a/tests/run-make/autodiff/type-trees/tuple-typetree/rmake.rs b/tests/run-make/autodiff/type-trees/tuple-typetree/rmake.rs new file mode 100644 index 0000000000000..76913828901c6 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/tuple-typetree/rmake.rs @@ -0,0 +1,9 @@ +//@ needs-enzyme +//@ ignore-cross-compile + +use run_make_support::{llvm_filecheck, rfs, rustc}; + +fn main() { + rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run(); + llvm_filecheck().patterns("tuple.check").stdin_buf(rfs::read("test.ll")).run(); +} diff --git a/tests/run-make/autodiff/type-trees/tuple-typetree/test.rs b/tests/run-make/autodiff/type-trees/tuple-typetree/test.rs new file mode 100644 index 0000000000000..32187b587a383 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/tuple-typetree/test.rs @@ -0,0 +1,15 @@ +#![feature(autodiff)] + +use std::autodiff::autodiff_reverse; + +#[autodiff_reverse(d_test, Duplicated, Active)] +#[no_mangle] +fn test_tuple(tuple: &(f64, f64, f64)) -> f64 { + tuple.0 + tuple.1 * 2.0 + tuple.2 * 3.0 +} + +fn main() { + let tuple = (1.0, 2.0, 3.0); + let mut d_tuple = (0.0, 0.0, 0.0); + let _result = d_test(&tuple, &mut d_tuple, 1.0); +} diff --git a/tests/run-make/autodiff/type-trees/tuple-typetree/tuple.check b/tests/run-make/autodiff/type-trees/tuple-typetree/tuple.check new file mode 100644 index 0000000000000..50aa25e96aedd --- /dev/null +++ b/tests/run-make/autodiff/type-trees/tuple-typetree/tuple.check @@ -0,0 +1,4 @@ +; Check that tuple TypeTree metadata is correctly generated +; Should show Float@double at offsets 0, 8, 16 for (f64, f64, f64) + +CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_tuple{{.*}}"enzyme_type"="{[]:Pointer}" \ No newline at end of file From 574f0b97d6f30cd6cedb165fde13cdec176611b8 Mon Sep 17 00:00:00 2001 From: Karan Janthe Date: Mon, 1 Sep 2025 16:28:14 +0000 Subject: [PATCH 13/22] autodiff: struct support in typetree --- compiler/rustc_middle/src/ty/mod.rs | 32 +++++++++++++++++++ .../type-trees/struct-typetree/rmake.rs | 9 ++++++ .../type-trees/struct-typetree/struct.check | 4 +++ .../type-trees/struct-typetree/test.rs | 22 +++++++++++++ 4 files changed, 67 insertions(+) create mode 100644 tests/run-make/autodiff/type-trees/struct-typetree/rmake.rs create mode 100644 tests/run-make/autodiff/type-trees/struct-typetree/struct.check create mode 100644 tests/run-make/autodiff/type-trees/struct-typetree/test.rs diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index c0a091551c9eb..703d417f96db9 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -2370,5 +2370,37 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { return TypeTree(types); } + if let ty::Adt(adt_def, args) = ty.kind() { + if adt_def.is_struct() { + let struct_layout = + tcx.layout_of(ty::TypingEnv::fully_monomorphized().as_query_input(ty)); + if let Ok(layout) = struct_layout { + let mut types = Vec::new(); + + for (field_idx, field_def) in adt_def.all_fields().enumerate() { + let field_ty = field_def.ty(tcx, args); + let field_tree = typetree_from_ty(tcx, field_ty); + + let field_offset = layout.fields.offset(field_idx).bytes_usize(); + + for elem_type in &field_tree.0 { + types.push(Type { + offset: if elem_type.offset == -1 { + field_offset as isize + } else { + field_offset as isize + elem_type.offset + }, + size: elem_type.size, + kind: elem_type.kind, + child: elem_type.child.clone(), + }); + } + } + + return TypeTree(types); + } + } + } + TypeTree::new() } diff --git a/tests/run-make/autodiff/type-trees/struct-typetree/rmake.rs b/tests/run-make/autodiff/type-trees/struct-typetree/rmake.rs new file mode 100644 index 0000000000000..0af1b65ee181b --- /dev/null +++ b/tests/run-make/autodiff/type-trees/struct-typetree/rmake.rs @@ -0,0 +1,9 @@ +//@ needs-enzyme +//@ ignore-cross-compile + +use run_make_support::{llvm_filecheck, rfs, rustc}; + +fn main() { + rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run(); + llvm_filecheck().patterns("struct.check").stdin_buf(rfs::read("test.ll")).run(); +} diff --git a/tests/run-make/autodiff/type-trees/struct-typetree/struct.check b/tests/run-make/autodiff/type-trees/struct-typetree/struct.check new file mode 100644 index 0000000000000..2f763f18c1cee --- /dev/null +++ b/tests/run-make/autodiff/type-trees/struct-typetree/struct.check @@ -0,0 +1,4 @@ +; Check that struct TypeTree metadata is correctly generated +; Should show Float@double at offsets 0, 8, 16 for Point struct fields + +CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_struct{{.*}}"enzyme_type"="{[]:Pointer}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/struct-typetree/test.rs b/tests/run-make/autodiff/type-trees/struct-typetree/test.rs new file mode 100644 index 0000000000000..cbe7b10e40974 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/struct-typetree/test.rs @@ -0,0 +1,22 @@ +#![feature(autodiff)] + +use std::autodiff::autodiff_reverse; + +#[repr(C)] +struct Point { + x: f64, + y: f64, + z: f64, +} + +#[autodiff_reverse(d_test, Duplicated, Active)] +#[no_mangle] +fn test_struct(point: &Point) -> f64 { + point.x + point.y * 2.0 + point.z * 3.0 +} + +fn main() { + let point = Point { x: 1.0, y: 2.0, z: 3.0 }; + let mut d_point = Point { x: 0.0, y: 0.0, z: 0.0 }; + let _result = d_test(&point, &mut d_point, 1.0); +} From 4f3f0f48e7b1e61818b2bcbe4451f89bb4f47049 Mon Sep 17 00:00:00 2001 From: Karan Janthe Date: Thu, 4 Sep 2025 11:17:34 +0000 Subject: [PATCH 14/22] autodiff: fixed test to be more precise for type tree checking --- .../rustc_codegen_llvm/src/llvm/enzyme_ffi.rs | 23 ++++++ compiler/rustc_codegen_llvm/src/typetree.rs | 70 ++++++++----------- compiler/rustc_middle/src/ty/mod.rs | 36 +++------- .../type-trees/array-typetree/array.check | 2 +- .../memcpy-typetree/memcpy-ir.check | 2 +- .../mixed-struct-typetree/mixed.check | 2 + .../type-trees/mixed-struct-typetree/rmake.rs | 16 +++++ .../type-trees/mixed-struct-typetree/test.rs | 23 ++++++ .../type-trees/nott-flag/with_tt.check | 2 +- .../scalar-types/f128-typetree/f128.check | 4 +- .../scalar-types/f16-typetree/f16.check | 4 +- .../scalar-types/f32-typetree/f32.check | 2 +- .../scalar-types/f64-typetree/f64.check | 2 +- .../scalar-types/i32-typetree/i32.check | 4 +- .../scalar-types/i32-typetree/test.rs | 7 +- .../type-trees/slice-typetree/slice.check | 2 +- .../type-trees/struct-typetree/struct.check | 2 +- .../type-trees/tuple-typetree/tuple.check | 2 +- .../type-trees/type-analysis/vec/vec.check | 2 +- 19 files changed, 120 insertions(+), 87 deletions(-) create mode 100644 tests/run-make/autodiff/type-trees/mixed-struct-typetree/mixed.check create mode 100644 tests/run-make/autodiff/type-trees/mixed-struct-typetree/rmake.rs create mode 100644 tests/run-make/autodiff/type-trees/mixed-struct-typetree/test.rs diff --git a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs index 1596dc4837978..e63043b21227f 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs @@ -118,6 +118,13 @@ pub(crate) mod Enzyme_AD { max_size: i64, add_offset: u64, ); + pub(crate) fn EnzymeTypeTreeInsertEq( + CTT: CTypeTreeRef, + indices: *const i64, + len: usize, + ct: CConcreteType, + ctx: &Context, + ); pub(crate) fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char; pub(crate) fn EnzymeTypeTreeToStringFree(arg1: *const c_char); } @@ -234,6 +241,16 @@ pub(crate) mod Fallback_AD { unimplemented!() } + pub(crate) unsafe fn EnzymeTypeTreeInsertEq( + CTT: CTypeTreeRef, + indices: *const i64, + len: usize, + ct: CConcreteType, + ctx: &Context, + ) { + unimplemented!() + } + pub(crate) unsafe fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char { unimplemented!() } @@ -312,6 +329,12 @@ impl TypeTree { self } + + pub(crate) fn insert(&mut self, indices: &[i64], ct: CConcreteType, ctx: &Context) { + unsafe { + EnzymeTypeTreeInsertEq(self.inner, indices.as_ptr(), indices.len(), ct, ctx); + } + } } impl Clone for TypeTree { diff --git a/compiler/rustc_codegen_llvm/src/typetree.rs b/compiler/rustc_codegen_llvm/src/typetree.rs index 8c0d255bba84e..ae6a2da62b5dd 100644 --- a/compiler/rustc_codegen_llvm/src/typetree.rs +++ b/compiler/rustc_codegen_llvm/src/typetree.rs @@ -8,22 +8,24 @@ use { use crate::llvm::{self, Value}; -/// Converts a Rust TypeTree to Enzyme's internal TypeTree format -/// -/// This function takes a Rust-side TypeTree (from rustc_ast::expand::typetree) -/// and converts it to Enzyme's internal C++ TypeTree representation that -/// Enzyme can understand during differentiation analysis. #[cfg(llvm_enzyme)] fn to_enzyme_typetree( rust_typetree: RustTypeTree, - data_layout: &str, + _data_layout: &str, llcx: &llvm::Context, ) -> llvm::TypeTree { - // Start with an empty TypeTree let mut enzyme_tt = llvm::TypeTree::new(); - - // Convert each Type in the Rust TypeTree to Enzyme format - for rust_type in rust_typetree.0 { + process_typetree_recursive(&mut enzyme_tt, &rust_typetree, &[], llcx); + enzyme_tt +} +#[cfg(llvm_enzyme)] +fn process_typetree_recursive( + enzyme_tt: &mut llvm::TypeTree, + rust_typetree: &RustTypeTree, + parent_indices: &[i64], + llcx: &llvm::Context, +) { + for rust_type in &rust_typetree.0 { let concrete_type = match rust_type.kind { rustc_ast::expand::typetree::Kind::Anything => llvm::CConcreteType::DT_Anything, rustc_ast::expand::typetree::Kind::Integer => llvm::CConcreteType::DT_Integer, @@ -35,25 +37,27 @@ fn to_enzyme_typetree( rustc_ast::expand::typetree::Kind::Unknown => llvm::CConcreteType::DT_Unknown, }; - // Create a TypeTree for this specific type - let type_tt = llvm::TypeTree::from_type(concrete_type, llcx); - - // Apply offset if specified - let type_tt = if rust_type.offset == -1 { - type_tt // -1 means everywhere/no specific offset + let mut indices = parent_indices.to_vec(); + if !parent_indices.is_empty() { + if rust_type.offset == -1 { + indices.push(-1); + } else { + indices.push(rust_type.offset as i64); + } + } else if rust_type.offset == -1 { + indices.push(-1); } else { - // Apply specific offset positioning - type_tt.shift(data_layout, rust_type.offset, rust_type.size as isize, 0) - }; + indices.push(rust_type.offset as i64); + } - // Merge this type into the main TypeTree - enzyme_tt = enzyme_tt.merge(type_tt); - } + enzyme_tt.insert(&indices, concrete_type, llcx); - enzyme_tt + if rust_type.kind == rustc_ast::expand::typetree::Kind::Pointer && !rust_type.child.0.is_empty() { + process_typetree_recursive(enzyme_tt, &rust_type.child, &indices, llcx); + } + } } -// Attaches TypeTree information to LLVM function as enzyme_type attributes. #[cfg(llvm_enzyme)] pub(crate) fn add_tt<'ll>( llmod: &'ll llvm::Module, @@ -64,28 +68,20 @@ pub(crate) fn add_tt<'ll>( let inputs = tt.args; let ret_tt: RustTypeTree = tt.ret; - // Get LLVM data layout string for TypeTree conversion let llvm_data_layout: *const c_char = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) }; let llvm_data_layout = std::str::from_utf8(unsafe { std::ffi::CStr::from_ptr(llvm_data_layout) }.to_bytes()) .expect("got a non-UTF8 data-layout from LLVM"); - // Attribute name that Enzyme recognizes for TypeTree information let attr_name = "enzyme_type"; let c_attr_name = CString::new(attr_name).unwrap(); - // Attach TypeTree attributes to each input parameter - // Enzyme uses these to understand parameter memory layouts during differentiation for (i, input) in inputs.iter().enumerate() { unsafe { - // Convert Rust TypeTree to Enzyme's internal format let enzyme_tt = to_enzyme_typetree(input.clone(), llvm_data_layout, llcx); - - // Serialize TypeTree to string format that Enzyme can parse let c_str = llvm::EnzymeTypeTreeToString(enzyme_tt.inner); let c_str = std::ffi::CStr::from_ptr(c_str); - // Create LLVM string attribute with TypeTree information let attr = llvm::LLVMCreateStringAttribute( llcx, c_attr_name.as_ptr(), @@ -94,17 +90,11 @@ pub(crate) fn add_tt<'ll>( c_str.to_bytes().len() as c_uint, ); - // Attach attribute to the specific function parameter - // Note: ArgumentPlace uses 0-based indexing, but LLVM uses 1-based for arguments attributes::apply_to_llfn(fn_def, llvm::AttributePlace::Argument(i as u32), &[attr]); - - // Free the C string to prevent memory leaks llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr()); } } - // Attach TypeTree attribute to the return type - // Enzyme needs this to understand how to handle return value derivatives unsafe { let enzyme_tt = to_enzyme_typetree(ret_tt, llvm_data_layout, llcx); let c_str = llvm::EnzymeTypeTreeToString(enzyme_tt.inner); @@ -118,15 +108,11 @@ pub(crate) fn add_tt<'ll>( c_str.to_bytes().len() as c_uint, ); - // Attach to function return type attributes::apply_to_llfn(fn_def, llvm::AttributePlace::ReturnValue, &[ret_attr]); - - // Free the C string llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr()); } } -// Fallback implementation when Enzyme is not available #[cfg(not(llvm_enzyme))] pub(crate) fn add_tt<'ll>( _llmod: &'ll llvm::Module, diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index 703d417f96db9..581f20e8492d9 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -2261,10 +2261,10 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { x if x == tcx.types.f32 => (Kind::Float, 4), x if x == tcx.types.f64 => (Kind::Double, 8), x if x == tcx.types.f128 => (Kind::F128, 16), - _ => return TypeTree::new(), + _ => (Kind::Integer, 0), } } else { - return TypeTree::new(); + (Kind::Integer, 0) }; return TypeTree(vec![Type { offset: -1, size, kind, child: TypeTree::new() }]); @@ -2295,32 +2295,14 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { let element_tree = typetree_from_ty(tcx, *element_ty); - let element_layout = tcx - .layout_of(ty::TypingEnv::fully_monomorphized().as_query_input(*element_ty)) - .ok() - .map(|layout| layout.size.bytes_usize()) - .unwrap_or(0); - - if element_layout == 0 { - return TypeTree::new(); - } - let mut types = Vec::new(); - for i in 0..len { - let base_offset = (i as usize * element_layout) as isize; - - for elem_type in &element_tree.0 { - types.push(Type { - offset: if elem_type.offset == -1 { - base_offset - } else { - base_offset + elem_type.offset - }, - size: elem_type.size, - kind: elem_type.kind, - child: elem_type.child.clone(), - }); - } + for elem_type in &element_tree.0 { + types.push(Type { + offset: -1, + size: elem_type.size, + kind: elem_type.kind, + child: elem_type.child.clone(), + }); } return TypeTree(types); diff --git a/tests/run-make/autodiff/type-trees/array-typetree/array.check b/tests/run-make/autodiff/type-trees/array-typetree/array.check index 7513458b8ab81..0d38bdec17e84 100644 --- a/tests/run-make/autodiff/type-trees/array-typetree/array.check +++ b/tests/run-make/autodiff/type-trees/array-typetree/array.check @@ -1,4 +1,4 @@ ; Check that array TypeTree metadata is correctly generated ; Should show Float@double at each array element offset (0, 8, 16, 24, 32 bytes) -CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_array{{.*}}"enzyme_type"="{[]:Pointer}" \ No newline at end of file +CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_array{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy-ir.check b/tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy-ir.check index 3a59f06dda1e4..0e6351ac4d39d 100644 --- a/tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy-ir.check +++ b/tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy-ir.check @@ -1,7 +1,7 @@ ; Check that enzyme_type attributes are present in the LLVM IR function definition ; This verifies our TypeTree system correctly attaches metadata for Enzyme -CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_memcpy({{.*}}"enzyme_type"="{[]:Pointer}" +CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_memcpy({{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}" ; Check that llvm.memcpy exists (either call or declare) CHECK: {{(call|declare).*}}@llvm.memcpy diff --git a/tests/run-make/autodiff/type-trees/mixed-struct-typetree/mixed.check b/tests/run-make/autodiff/type-trees/mixed-struct-typetree/mixed.check new file mode 100644 index 0000000000000..584f584084358 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/mixed-struct-typetree/mixed.check @@ -0,0 +1,2 @@ +; Check that mixed struct with large array generates correct detailed type tree +CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@float}"{{.*}}@test_mixed_struct{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Integer, [-1,8]:Float@float}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/mixed-struct-typetree/rmake.rs b/tests/run-make/autodiff/type-trees/mixed-struct-typetree/rmake.rs new file mode 100644 index 0000000000000..1c19963bc3612 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/mixed-struct-typetree/rmake.rs @@ -0,0 +1,16 @@ +//@ needs-enzyme +//@ ignore-cross-compile + +use run_make_support::{llvm_filecheck, rfs, rustc}; + +fn main() { + rustc() + .input("test.rs") + .arg("-Zautodiff=Enable") + .arg("-Zautodiff=NoPostopt") + .opt_level("0") + .emit("llvm-ir") + .run(); + + llvm_filecheck().patterns("mixed.check").stdin_buf(rfs::read("test.ll")).run(); +} diff --git a/tests/run-make/autodiff/type-trees/mixed-struct-typetree/test.rs b/tests/run-make/autodiff/type-trees/mixed-struct-typetree/test.rs new file mode 100644 index 0000000000000..7a734980e617c --- /dev/null +++ b/tests/run-make/autodiff/type-trees/mixed-struct-typetree/test.rs @@ -0,0 +1,23 @@ +#![feature(autodiff)] + +use std::autodiff::autodiff_reverse; + +#[repr(C)] +struct Container { + header: i64, + data: [f32; 1000], +} + +#[autodiff_reverse(d_test, Duplicated, Active)] +#[no_mangle] +#[inline(never)] +fn test_mixed_struct(container: &Container) -> f32 { + container.data[0] + container.data[999] +} + +fn main() { + let container = Container { header: 42, data: [1.0; 1000] }; + let mut d_container = Container { header: 0, data: [0.0; 1000] }; + let result = d_test(&container, &mut d_container, 1.0); + std::hint::black_box(result); +} diff --git a/tests/run-make/autodiff/type-trees/nott-flag/with_tt.check b/tests/run-make/autodiff/type-trees/nott-flag/with_tt.check index 53eafc10a65da..3c02003c88287 100644 --- a/tests/run-make/autodiff/type-trees/nott-flag/with_tt.check +++ b/tests/run-make/autodiff/type-trees/nott-flag/with_tt.check @@ -1,4 +1,4 @@ // Check that enzyme_type attributes are present when TypeTree is enabled // This verifies our TypeTree metadata attachment is working -CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@square{{.*}}"enzyme_type"="{[]:Pointer}" \ No newline at end of file +CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@square{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/scalar-types/f128-typetree/f128.check b/tests/run-make/autodiff/type-trees/scalar-types/f128-typetree/f128.check index 31cef113420db..733e46aa45a57 100644 --- a/tests/run-make/autodiff/type-trees/scalar-types/f128-typetree/f128.check +++ b/tests/run-make/autodiff/type-trees/scalar-types/f128-typetree/f128.check @@ -1,4 +1,4 @@ ; Check that f128 TypeTree metadata is correctly generated -; f128 maps to Unknown in our current implementation since CConcreteType doesn't have DT_F128 +; Should show Float@fp128 for f128 values and Pointer for references -CHECK: define{{.*}}"enzyme_type"={{.*}}@test_f128{{.*}}"enzyme_type"="{[]:Pointer}" \ No newline at end of file +CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@fp128}"{{.*}}@test_f128{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@fp128}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/scalar-types/f16-typetree/f16.check b/tests/run-make/autodiff/type-trees/scalar-types/f16-typetree/f16.check index 97a3964a569b1..9caca26b5cf51 100644 --- a/tests/run-make/autodiff/type-trees/scalar-types/f16-typetree/f16.check +++ b/tests/run-make/autodiff/type-trees/scalar-types/f16-typetree/f16.check @@ -1,4 +1,4 @@ ; Check that f16 TypeTree metadata is correctly generated -; Should show Half for f16 values and Pointer for references +; Should show Float@half for f16 values and Pointer for references -CHECK: define{{.*}}"enzyme_type"={{.*}}@test_f16{{.*}}"enzyme_type"="{[]:Pointer}" \ No newline at end of file +CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@half}"{{.*}}@test_f16{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@half}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/scalar-types/f32-typetree/f32.check b/tests/run-make/autodiff/type-trees/scalar-types/f32-typetree/f32.check index b32d5086d07d1..ec12ba6b2348d 100644 --- a/tests/run-make/autodiff/type-trees/scalar-types/f32-typetree/f32.check +++ b/tests/run-make/autodiff/type-trees/scalar-types/f32-typetree/f32.check @@ -1,4 +1,4 @@ ; Check that f32 TypeTree metadata is correctly generated ; Should show Float@float for f32 values and Pointer for references -CHECK: define{{.*}}"enzyme_type"="{[]:Float@float}"{{.*}}@test_f32{{.*}}"enzyme_type"="{[]:Pointer}" \ No newline at end of file +CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@float}"{{.*}}@test_f32{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@float}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/scalar-types/f64-typetree/f64.check b/tests/run-make/autodiff/type-trees/scalar-types/f64-typetree/f64.check index 0e9d9c03a0f86..f1af270824d57 100644 --- a/tests/run-make/autodiff/type-trees/scalar-types/f64-typetree/f64.check +++ b/tests/run-make/autodiff/type-trees/scalar-types/f64-typetree/f64.check @@ -1,4 +1,4 @@ ; Check that f64 TypeTree metadata is correctly generated ; Should show Float@double for f64 values and Pointer for references -CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_f64{{.*}}"enzyme_type"="{[]:Pointer}" \ No newline at end of file +CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_f64{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/i32.check b/tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/i32.check index aa8e5223e8ed5..fd9a94be8100b 100644 --- a/tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/i32.check +++ b/tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/i32.check @@ -1,4 +1,4 @@ ; Check that i32 TypeTree metadata is correctly generated -; Should show Integer for i32 values (integers are typically Const in autodiff) +; Should show Integer for i32 values and Pointer for references -CHECK: define{{.*}}"enzyme_type"="{[]:Integer}"{{.*}}@test_i32{{.*}}"enzyme_type"="{[]:Integer}" \ No newline at end of file +CHECK: define{{.*}}"enzyme_type"="{[-1]:Integer}"{{.*}}@test_i32{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Integer}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/test.rs b/tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/test.rs index 530c137f62b9c..249803c5d9f7a 100644 --- a/tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/test.rs +++ b/tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/test.rs @@ -2,13 +2,14 @@ use std::autodiff::autodiff_reverse; -#[autodiff_reverse(d_test, Const, Active)] +#[autodiff_reverse(d_test, Duplicated, Active)] #[no_mangle] -fn test_i32(x: i32) -> i32 { +fn test_i32(x: &i32) -> i32 { x * x } fn main() { let x = 5_i32; - let _result = d_test(x, 1); + let mut dx = 0_i32; + let _result = d_test(&x, &mut dx, 1); } diff --git a/tests/run-make/autodiff/type-trees/slice-typetree/slice.check b/tests/run-make/autodiff/type-trees/slice-typetree/slice.check index 8fc9e9c4f1b17..6543b61611538 100644 --- a/tests/run-make/autodiff/type-trees/slice-typetree/slice.check +++ b/tests/run-make/autodiff/type-trees/slice-typetree/slice.check @@ -1,4 +1,4 @@ ; Check that slice TypeTree metadata is correctly generated ; Should show Float@double for slice elements -CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_slice{{.*}}"enzyme_type"="{[]:Pointer}" \ No newline at end of file +CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_slice{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/struct-typetree/struct.check b/tests/run-make/autodiff/type-trees/struct-typetree/struct.check index 2f763f18c1cee..54956317e1e95 100644 --- a/tests/run-make/autodiff/type-trees/struct-typetree/struct.check +++ b/tests/run-make/autodiff/type-trees/struct-typetree/struct.check @@ -1,4 +1,4 @@ ; Check that struct TypeTree metadata is correctly generated ; Should show Float@double at offsets 0, 8, 16 for Point struct fields -CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_struct{{.*}}"enzyme_type"="{[]:Pointer}" \ No newline at end of file +CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_struct{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@double, [-1,8]:Float@double, [-1,16]:Float@double}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/tuple-typetree/tuple.check b/tests/run-make/autodiff/type-trees/tuple-typetree/tuple.check index 50aa25e96aedd..47647e78cc35c 100644 --- a/tests/run-make/autodiff/type-trees/tuple-typetree/tuple.check +++ b/tests/run-make/autodiff/type-trees/tuple-typetree/tuple.check @@ -1,4 +1,4 @@ ; Check that tuple TypeTree metadata is correctly generated ; Should show Float@double at offsets 0, 8, 16 for (f64, f64, f64) -CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_tuple{{.*}}"enzyme_type"="{[]:Pointer}" \ No newline at end of file +CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_tuple{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@double, [-1,8]:Float@double, [-1,16]:Float@double}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/type-analysis/vec/vec.check b/tests/run-make/autodiff/type-trees/type-analysis/vec/vec.check index dcf9508b69d6f..cdb70eb83fc19 100644 --- a/tests/run-make/autodiff/type-trees/type-analysis/vec/vec.check +++ b/tests/run-make/autodiff/type-trees/type-analysis/vec/vec.check @@ -1,7 +1,7 @@ // CHECK: callee - {[-1]:Float@float} |{[-1]:Pointer}:{} // CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer} // CHECK-DAG: %{{[0-9]+}} = getelementptr inbounds nuw i8, ptr %{{[0-9]+}}, i64 8, !dbg !{{[0-9]+}}: {[-1]:Pointer} -// CHECK-DAG: %{{[0-9]+}} = load ptr, ptr %{{[0-9]+}}, align 8, !dbg !{{[0-9]+}}, !nonnull !102, !noundef !{{[0-9]+}}: {} +// CHECK-DAG: %{{[0-9]+}} = load ptr, ptr %{{[0-9]+}}, align 8, !dbg !{{[0-9]+}}, !nonnull !{{[0-9]+}}, !noundef !{{[0-9]+}}: {} // CHECK-DAG: %{{[0-9]+}} = getelementptr inbounds nuw i8, ptr %{{[0-9]+}}, i64 16, !dbg !{{[0-9]+}}: {[-1]:Pointer} // CHECK-DAG: %{{[0-9]+}} = load i64, ptr %{{[0-9]+}}, align 8, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {} // CHECK-DAG: %{{[0-9]+}} = icmp eq i64 %{{[0-9]+}}, 0, !dbg !{{[0-9]+}}: {[-1]:Integer} From 4520926bb527bd43edbf0de84c2b0c6a9c5fc5ce Mon Sep 17 00:00:00 2001 From: Karan Janthe Date: Thu, 11 Sep 2025 07:30:35 +0000 Subject: [PATCH 15/22] autodiff: recurion added for typetree --- .../rustc_codegen_llvm/src/llvm/enzyme_ffi.rs | 1 + compiler/rustc_codegen_llvm/src/typetree.rs | 10 +- compiler/rustc_middle/src/ty/mod.rs | 76 +++++++++++-- .../type-trees/nott-flag/with_tt.check | 2 +- .../recursion-typetree/recursion.check | 3 + .../type-trees/recursion-typetree/rmake.rs | 9 ++ .../type-trees/recursion-typetree/test.rs | 100 ++++++++++++++++++ .../scalar-types/f128-typetree/f128.check | 2 +- .../scalar-types/f16-typetree/f16.check | 2 +- .../scalar-types/f32-typetree/f32.check | 2 +- .../scalar-types/f64-typetree/f64.check | 2 +- .../scalar-types/i32-typetree/i32.check | 2 +- 12 files changed, 191 insertions(+), 20 deletions(-) create mode 100644 tests/run-make/autodiff/type-trees/recursion-typetree/recursion.check create mode 100644 tests/run-make/autodiff/type-trees/recursion-typetree/rmake.rs create mode 100644 tests/run-make/autodiff/type-trees/recursion-typetree/test.rs diff --git a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs index e63043b21227f..b604f5139c8c1 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs @@ -127,6 +127,7 @@ pub(crate) mod Enzyme_AD { ); pub(crate) fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char; pub(crate) fn EnzymeTypeTreeToStringFree(arg1: *const c_char); + pub(crate) fn EnzymeGetMaxTypeDepth() -> ::std::os::raw::c_uint; } unsafe extern "C" { diff --git a/compiler/rustc_codegen_llvm/src/typetree.rs b/compiler/rustc_codegen_llvm/src/typetree.rs index ae6a2da62b5dd..1a54884f6c5a4 100644 --- a/compiler/rustc_codegen_llvm/src/typetree.rs +++ b/compiler/rustc_codegen_llvm/src/typetree.rs @@ -39,11 +39,7 @@ fn process_typetree_recursive( let mut indices = parent_indices.to_vec(); if !parent_indices.is_empty() { - if rust_type.offset == -1 { - indices.push(-1); - } else { - indices.push(rust_type.offset as i64); - } + indices.push(rust_type.offset as i64); } else if rust_type.offset == -1 { indices.push(-1); } else { @@ -52,7 +48,9 @@ fn process_typetree_recursive( enzyme_tt.insert(&indices, concrete_type, llcx); - if rust_type.kind == rustc_ast::expand::typetree::Kind::Pointer && !rust_type.child.0.is_empty() { + if rust_type.kind == rustc_ast::expand::typetree::Kind::Pointer + && !rust_type.child.0.is_empty() + { process_typetree_recursive(enzyme_tt, &rust_type.child, &indices, llcx); } } diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index 581f20e8492d9..7ca2355947ad4 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -2252,6 +2252,61 @@ pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>) -> FncTree { /// Generate TypeTree for a specific type. /// This function analyzes a Rust type and creates appropriate TypeTree metadata. pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { + let mut visited = Vec::new(); + typetree_from_ty_inner(tcx, ty, 0, &mut visited) +} + +/// Internal recursive function for TypeTree generation with cycle detection and depth limiting. +fn typetree_from_ty_inner<'tcx>( + tcx: TyCtxt<'tcx>, + ty: Ty<'tcx>, + depth: usize, + visited: &mut Vec>, +) -> TypeTree { + #[cfg(llvm_enzyme)] + { + unsafe extern "C" { + fn EnzymeGetMaxTypeDepth() -> ::std::os::raw::c_uint; + } + let max_depth = unsafe { EnzymeGetMaxTypeDepth() } as usize; + if depth > max_depth { + return TypeTree::new(); + } + } + + #[cfg(not(llvm_enzyme))] + if depth > 6 { + return TypeTree::new(); + } + + if visited.contains(&ty) { + return TypeTree::new(); + } + + visited.push(ty); + let result = typetree_from_ty_impl(tcx, ty, depth, visited); + visited.pop(); + result +} + +/// Implementation of TypeTree generation logic. +fn typetree_from_ty_impl<'tcx>( + tcx: TyCtxt<'tcx>, + ty: Ty<'tcx>, + depth: usize, + visited: &mut Vec>, +) -> TypeTree { + typetree_from_ty_impl_inner(tcx, ty, depth, visited, false) +} + +/// Internal implementation with context about whether this is for a reference target. +fn typetree_from_ty_impl_inner<'tcx>( + tcx: TyCtxt<'tcx>, + ty: Ty<'tcx>, + depth: usize, + visited: &mut Vec>, + is_reference_target: bool, +) -> TypeTree { if ty.is_scalar() { let (kind, size) = if ty.is_integral() || ty.is_char() || ty.is_bool() { (Kind::Integer, ty.primitive_size(tcx).bytes_usize()) @@ -2267,7 +2322,10 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { (Kind::Integer, 0) }; - return TypeTree(vec![Type { offset: -1, size, kind, child: TypeTree::new() }]); + // Use offset 0 for scalars that are direct targets of references (like &f64) + // Use offset -1 for scalars used directly (like function return types) + let offset = if is_reference_target && !ty.is_array() { 0 } else { -1 }; + return TypeTree(vec![Type { offset, size, kind, child: TypeTree::new() }]); } if ty.is_ref() || ty.is_raw_ptr() || ty.is_box() { @@ -2277,7 +2335,7 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { return TypeTree::new(); }; - let child = typetree_from_ty(tcx, inner_ty); + let child = typetree_from_ty_impl_inner(tcx, inner_ty, depth + 1, visited, true); return TypeTree(vec![Type { offset: -1, size: tcx.data_layout.pointer_size().bytes_usize(), @@ -2292,9 +2350,8 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { if len == 0 { return TypeTree::new(); } - - let element_tree = typetree_from_ty(tcx, *element_ty); - + let element_tree = + typetree_from_ty_impl_inner(tcx, *element_ty, depth + 1, visited, false); let mut types = Vec::new(); for elem_type in &element_tree.0 { types.push(Type { @@ -2311,7 +2368,8 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { if ty.is_slice() { if let ty::Slice(element_ty) = ty.kind() { - let element_tree = typetree_from_ty(tcx, *element_ty); + let element_tree = + typetree_from_ty_impl_inner(tcx, *element_ty, depth + 1, visited, false); return element_tree; } } @@ -2325,7 +2383,8 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { let mut current_offset = 0; for tuple_ty in tuple_types.iter() { - let element_tree = typetree_from_ty(tcx, tuple_ty); + let element_tree = + typetree_from_ty_impl_inner(tcx, tuple_ty, depth + 1, visited, false); let element_layout = tcx .layout_of(ty::TypingEnv::fully_monomorphized().as_query_input(tuple_ty)) @@ -2361,7 +2420,8 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { for (field_idx, field_def) in adt_def.all_fields().enumerate() { let field_ty = field_def.ty(tcx, args); - let field_tree = typetree_from_ty(tcx, field_ty); + let field_tree = + typetree_from_ty_impl_inner(tcx, field_ty, depth + 1, visited, false); let field_offset = layout.fields.offset(field_idx).bytes_usize(); diff --git a/tests/run-make/autodiff/type-trees/nott-flag/with_tt.check b/tests/run-make/autodiff/type-trees/nott-flag/with_tt.check index 3c02003c88287..0b4c91191798b 100644 --- a/tests/run-make/autodiff/type-trees/nott-flag/with_tt.check +++ b/tests/run-make/autodiff/type-trees/nott-flag/with_tt.check @@ -1,4 +1,4 @@ // Check that enzyme_type attributes are present when TypeTree is enabled // This verifies our TypeTree metadata attachment is working -CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@square{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}" \ No newline at end of file +CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@square{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@double}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/recursion-typetree/recursion.check b/tests/run-make/autodiff/type-trees/recursion-typetree/recursion.check new file mode 100644 index 0000000000000..1960e7b816c5f --- /dev/null +++ b/tests/run-make/autodiff/type-trees/recursion-typetree/recursion.check @@ -0,0 +1,3 @@ +CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_deep{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@double}" +CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_graph{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Integer, [-1,8]:Integer, [-1,16]:Integer, [-1,24]:Float@double}" +CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_node{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@double}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/recursion-typetree/rmake.rs b/tests/run-make/autodiff/type-trees/recursion-typetree/rmake.rs new file mode 100644 index 0000000000000..78718f3a21595 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/recursion-typetree/rmake.rs @@ -0,0 +1,9 @@ +//@ needs-enzyme +//@ ignore-cross-compile + +use run_make_support::{llvm_filecheck, rfs, rustc}; + +fn main() { + rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run(); + llvm_filecheck().patterns("recursion.check").stdin_buf(rfs::read("test.ll")).run(); +} diff --git a/tests/run-make/autodiff/type-trees/recursion-typetree/test.rs b/tests/run-make/autodiff/type-trees/recursion-typetree/test.rs new file mode 100644 index 0000000000000..9d40bec1bf1d1 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/recursion-typetree/test.rs @@ -0,0 +1,100 @@ +#![feature(autodiff)] + +use std::autodiff::autodiff_reverse; + +// Self-referential struct to test recursion detection +#[derive(Clone)] +struct Node { + value: f64, + next: Option>, +} + +// Mutually recursive structs to test cycle detection +#[derive(Clone)] +struct GraphNodeA { + value: f64, + connections: Vec, +} + +#[derive(Clone)] +struct GraphNodeB { + weight: f64, + target: Option>, +} + +#[autodiff_reverse(d_test_node, Duplicated, Active)] +#[no_mangle] +fn test_node(node: &Node) -> f64 { + node.value * 2.0 +} + +#[autodiff_reverse(d_test_graph, Duplicated, Active)] +#[no_mangle] +fn test_graph(a: &GraphNodeA) -> f64 { + a.value * 3.0 +} + +// Simple depth test - deeply nested but not circular +#[derive(Clone)] +struct Level1 { + val: f64, + next: Option>, +} +#[derive(Clone)] +struct Level2 { + val: f64, + next: Option>, +} +#[derive(Clone)] +struct Level3 { + val: f64, + next: Option>, +} +#[derive(Clone)] +struct Level4 { + val: f64, + next: Option>, +} +#[derive(Clone)] +struct Level5 { + val: f64, + next: Option>, +} +#[derive(Clone)] +struct Level6 { + val: f64, + next: Option>, +} +#[derive(Clone)] +struct Level7 { + val: f64, + next: Option>, +} +#[derive(Clone)] +struct Level8 { + val: f64, +} + +#[autodiff_reverse(d_test_deep, Duplicated, Active)] +#[no_mangle] +fn test_deep(deep: &Level1) -> f64 { + deep.val * 4.0 +} + +fn main() { + let node = Node { value: 1.0, next: None }; + + let graph = GraphNodeA { value: 2.0, connections: vec![] }; + + let deep = Level1 { val: 5.0, next: None }; + + let mut d_node = Node { value: 0.0, next: None }; + + let mut d_graph = GraphNodeA { value: 0.0, connections: vec![] }; + + let mut d_deep = Level1 { val: 0.0, next: None }; + + let _result1 = d_test_node(&node, &mut d_node, 1.0); + let _result2 = d_test_graph(&graph, &mut d_graph, 1.0); + let _result3 = d_test_deep(&deep, &mut d_deep, 1.0); +} diff --git a/tests/run-make/autodiff/type-trees/scalar-types/f128-typetree/f128.check b/tests/run-make/autodiff/type-trees/scalar-types/f128-typetree/f128.check index 733e46aa45a57..23db64eea52aa 100644 --- a/tests/run-make/autodiff/type-trees/scalar-types/f128-typetree/f128.check +++ b/tests/run-make/autodiff/type-trees/scalar-types/f128-typetree/f128.check @@ -1,4 +1,4 @@ ; Check that f128 TypeTree metadata is correctly generated ; Should show Float@fp128 for f128 values and Pointer for references -CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@fp128}"{{.*}}@test_f128{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@fp128}" \ No newline at end of file +CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@fp128}"{{.*}}@test_f128{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@fp128}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/scalar-types/f16-typetree/f16.check b/tests/run-make/autodiff/type-trees/scalar-types/f16-typetree/f16.check index 9caca26b5cf51..9adff68d36f34 100644 --- a/tests/run-make/autodiff/type-trees/scalar-types/f16-typetree/f16.check +++ b/tests/run-make/autodiff/type-trees/scalar-types/f16-typetree/f16.check @@ -1,4 +1,4 @@ ; Check that f16 TypeTree metadata is correctly generated ; Should show Float@half for f16 values and Pointer for references -CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@half}"{{.*}}@test_f16{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@half}" \ No newline at end of file +CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@half}"{{.*}}@test_f16{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@half}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/scalar-types/f32-typetree/f32.check b/tests/run-make/autodiff/type-trees/scalar-types/f32-typetree/f32.check index ec12ba6b2348d..176630f57e8f7 100644 --- a/tests/run-make/autodiff/type-trees/scalar-types/f32-typetree/f32.check +++ b/tests/run-make/autodiff/type-trees/scalar-types/f32-typetree/f32.check @@ -1,4 +1,4 @@ ; Check that f32 TypeTree metadata is correctly generated ; Should show Float@float for f32 values and Pointer for references -CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@float}"{{.*}}@test_f32{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@float}" \ No newline at end of file +CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@float}"{{.*}}@test_f32{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@float}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/scalar-types/f64-typetree/f64.check b/tests/run-make/autodiff/type-trees/scalar-types/f64-typetree/f64.check index f1af270824d57..929cd379694ac 100644 --- a/tests/run-make/autodiff/type-trees/scalar-types/f64-typetree/f64.check +++ b/tests/run-make/autodiff/type-trees/scalar-types/f64-typetree/f64.check @@ -1,4 +1,4 @@ ; Check that f64 TypeTree metadata is correctly generated ; Should show Float@double for f64 values and Pointer for references -CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_f64{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}" \ No newline at end of file +CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_f64{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@double}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/i32.check b/tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/i32.check index fd9a94be8100b..dee4aa5bbb6b2 100644 --- a/tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/i32.check +++ b/tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/i32.check @@ -1,4 +1,4 @@ ; Check that i32 TypeTree metadata is correctly generated ; Should show Integer for i32 values and Pointer for references -CHECK: define{{.*}}"enzyme_type"="{[-1]:Integer}"{{.*}}@test_i32{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Integer}" \ No newline at end of file +CHECK: define{{.*}}"enzyme_type"="{[-1]:Integer}"{{.*}}@test_i32{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Integer}" \ No newline at end of file From 3ba5f19182bf7144c54cbbd0b7af3d4fe76b5317 Mon Sep 17 00:00:00 2001 From: Karan Janthe Date: Fri, 12 Sep 2025 06:11:18 +0000 Subject: [PATCH 16/22] autodiff: typetree recursive depth query from enzyme with fallback Signed-off-by: Karan Janthe --- .../rustc_codegen_llvm/src/llvm/enzyme_ffi.rs | 1 - compiler/rustc_codegen_llvm/src/typetree.rs | 10 ++++----- .../rustc_llvm/llvm-wrapper/RustWrapper.cpp | 12 +++++++++++ compiler/rustc_middle/src/ty/mod.rs | 21 +++++++------------ src/llvm-project | 2 +- src/tools/cargo | 2 +- 6 files changed, 26 insertions(+), 22 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs index b604f5139c8c1..e63043b21227f 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs @@ -127,7 +127,6 @@ pub(crate) mod Enzyme_AD { ); pub(crate) fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char; pub(crate) fn EnzymeTypeTreeToStringFree(arg1: *const c_char); - pub(crate) fn EnzymeGetMaxTypeDepth() -> ::std::os::raw::c_uint; } unsafe extern "C" { diff --git a/compiler/rustc_codegen_llvm/src/typetree.rs b/compiler/rustc_codegen_llvm/src/typetree.rs index 1a54884f6c5a4..7e2635037008e 100644 --- a/compiler/rustc_codegen_llvm/src/typetree.rs +++ b/compiler/rustc_codegen_llvm/src/typetree.rs @@ -1,5 +1,5 @@ use rustc_ast::expand::typetree::FncTree; -#[cfg(llvm_enzyme)] +#[cfg(feature = "llvm_enzyme")] use { crate::attributes, rustc_ast::expand::typetree::TypeTree as RustTypeTree, @@ -8,7 +8,7 @@ use { use crate::llvm::{self, Value}; -#[cfg(llvm_enzyme)] +#[cfg(feature = "llvm_enzyme")] fn to_enzyme_typetree( rust_typetree: RustTypeTree, _data_layout: &str, @@ -18,7 +18,7 @@ fn to_enzyme_typetree( process_typetree_recursive(&mut enzyme_tt, &rust_typetree, &[], llcx); enzyme_tt } -#[cfg(llvm_enzyme)] +#[cfg(feature = "llvm_enzyme")] fn process_typetree_recursive( enzyme_tt: &mut llvm::TypeTree, rust_typetree: &RustTypeTree, @@ -56,7 +56,7 @@ fn process_typetree_recursive( } } -#[cfg(llvm_enzyme)] +#[cfg(feature = "llvm_enzyme")] pub(crate) fn add_tt<'ll>( llmod: &'ll llvm::Module, llcx: &'ll llvm::Context, @@ -111,7 +111,7 @@ pub(crate) fn add_tt<'ll>( } } -#[cfg(not(llvm_enzyme))] +#[cfg(not(feature = "llvm_enzyme"))] pub(crate) fn add_tt<'ll>( _llmod: &'ll llvm::Module, _llcx: &'ll llvm::Context, diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp index 64151962321fd..c1a924a87e414 100644 --- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp +++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp @@ -1847,3 +1847,15 @@ extern "C" void LLVMRustSetNoSanitizeHWAddress(LLVMValueRef Global) { MD.NoHWAddress = true; GV.setSanitizerMetadata(MD); } + +#ifdef ENZYME +extern "C" { +extern llvm::cl::opt EnzymeMaxTypeDepth; +} + +extern "C" size_t LLVMRustEnzymeGetMaxTypeDepth() { return EnzymeMaxTypeDepth; } +#else +extern "C" size_t LLVMRustEnzymeGetMaxTypeDepth() { + return 6; // Default fallback depth +} +#endif diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index 7ca2355947ad4..ce4de6b95e0bb 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -63,7 +63,7 @@ pub use rustc_type_ir::solve::SizedTraitKind; pub use rustc_type_ir::*; #[allow(hidden_glob_reexports, unused_imports)] use rustc_type_ir::{InferCtxtLike, Interner}; -use tracing::{debug, instrument}; +use tracing::{debug, instrument, trace}; pub use vtable::*; use {rustc_ast as ast, rustc_hir as hir}; @@ -2256,6 +2256,10 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { typetree_from_ty_inner(tcx, ty, 0, &mut visited) } +/// Maximum recursion depth for TypeTree generation to prevent stack overflow +/// from pathological deeply nested types. Combined with cycle detection. +const MAX_TYPETREE_DEPTH: usize = 6; + /// Internal recursive function for TypeTree generation with cycle detection and depth limiting. fn typetree_from_ty_inner<'tcx>( tcx: TyCtxt<'tcx>, @@ -2263,19 +2267,8 @@ fn typetree_from_ty_inner<'tcx>( depth: usize, visited: &mut Vec>, ) -> TypeTree { - #[cfg(llvm_enzyme)] - { - unsafe extern "C" { - fn EnzymeGetMaxTypeDepth() -> ::std::os::raw::c_uint; - } - let max_depth = unsafe { EnzymeGetMaxTypeDepth() } as usize; - if depth > max_depth { - return TypeTree::new(); - } - } - - #[cfg(not(llvm_enzyme))] - if depth > 6 { + if depth >= MAX_TYPETREE_DEPTH { + trace!("typetree depth limit {} reached for type: {}", MAX_TYPETREE_DEPTH, ty); return TypeTree::new(); } diff --git a/src/llvm-project b/src/llvm-project index 2a22c8014390b..333793696b0a0 160000 --- a/src/llvm-project +++ b/src/llvm-project @@ -1 +1 @@ -Subproject commit 2a22c8014390bdf387d8ed26005fee3187fb98bd +Subproject commit 333793696b0a08002acac6af983adc7a5233fc56 diff --git a/src/tools/cargo b/src/tools/cargo index a4bd03c92d3b9..966f94733bbc9 160000 --- a/src/tools/cargo +++ b/src/tools/cargo @@ -1 +1 @@ -Subproject commit a4bd03c92d3b977909953d06cc45e263bd7ca7d8 +Subproject commit 966f94733bbc94ca51ff9f1e4c49ad250ebbdc50 From 97b2292b065f825d482899149001a74b6a5450c4 Mon Sep 17 00:00:00 2001 From: Jules Bertholet Date: Wed, 17 Sep 2025 10:26:52 -0400 Subject: [PATCH 17/22] Allow shared access to `Exclusive` when `T: Sync` --- library/core/src/sync/exclusive.rs | 114 ++++++++++++++++++++++++++--- 1 file changed, 104 insertions(+), 10 deletions(-) diff --git a/library/core/src/sync/exclusive.rs b/library/core/src/sync/exclusive.rs index cf086bf4f5080..f181c5514f256 100644 --- a/library/core/src/sync/exclusive.rs +++ b/library/core/src/sync/exclusive.rs @@ -1,28 +1,32 @@ //! Defines [`Exclusive`]. +use core::cmp::Ordering; use core::fmt; use core::future::Future; -use core::marker::Tuple; +use core::hash::{Hash, Hasher}; +use core::marker::{StructuralPartialEq, Tuple}; use core::ops::{Coroutine, CoroutineState}; use core::pin::Pin; use core::task::{Context, Poll}; -/// `Exclusive` provides only _mutable_ access, also referred to as _exclusive_ -/// access to the underlying value. It provides no _immutable_, or _shared_ -/// access to the underlying value. +/// `Exclusive` provides _mutable_ access, also referred to as _exclusive_ +/// access to the underlying value. However, it only permits _immutable_, or _shared_ +/// access to the underlying value when that value is [`Sync`]. /// /// While this may seem not very useful, it allows `Exclusive` to _unconditionally_ -/// implement [`Sync`]. Indeed, the safety requirements of `Sync` state that for `Exclusive` +/// implement `Sync`. Indeed, the safety requirements of `Sync` state that for `Exclusive` /// to be `Sync`, it must be sound to _share_ across threads, that is, it must be sound -/// for `&Exclusive` to cross thread boundaries. By design, a `&Exclusive` has no API -/// whatsoever, making it useless, thus harmless, thus memory safe. +/// for `&Exclusive` to cross thread boundaries. By design, a `&Exclusive` for non-`Sync` T +/// has no API whatsoever, making it useless, thus harmless, thus memory safe. /// /// Certain constructs like [`Future`]s can only be used with _exclusive_ access, /// and are often `Send` but not `Sync`, so `Exclusive` can be used as hint to the /// Rust compiler that something is `Sync` in practice. /// /// ## Examples -/// Using a non-`Sync` future prevents the wrapping struct from being `Sync` +/// +/// Using a non-`Sync` future prevents the wrapping struct from being `Sync`: +/// /// ```compile_fail /// use core::cell::Cell; /// @@ -43,7 +47,8 @@ use core::task::{Context, Poll}; /// ``` /// /// `Exclusive` ensures the struct is `Sync` without stripping the future of its -/// functionality. +/// functionality: +/// /// ``` /// #![feature(exclusive_wrapper)] /// use core::cell::Cell; @@ -66,6 +71,7 @@ use core::task::{Context, Poll}; /// ``` /// /// ## Parallels with a mutex +/// /// In some sense, `Exclusive` can be thought of as a _compile-time_ version of /// a mutex, as the borrow-checker guarantees that only one `&mut` can exist /// for any value. This is a parallel with the fact that @@ -75,7 +81,7 @@ use core::task::{Context, Poll}; #[doc(alias = "SyncWrapper")] #[doc(alias = "SyncCell")] #[doc(alias = "Unique")] -// `Exclusive` can't have `PartialOrd`, `Clone`, etc. impls as they would +// `Exclusive` can't have derived `PartialOrd`, `Clone`, etc. impls as they would // use `&` access to the inner value, violating the `Sync` impl's safety // requirements. #[derive(Default)] @@ -195,6 +201,17 @@ where } } +#[unstable(feature = "exclusive_wrapper", issue = "98407")] +impl Fn for Exclusive +where + F: Sync + Fn, + Args: Tuple, +{ + extern "rust-call" fn call(&self, args: Args) -> Self::Output { + self.as_ref().call(args) + } +} + #[unstable(feature = "exclusive_wrapper", issue = "98407")] impl Future for Exclusive where @@ -221,3 +238,80 @@ where G::resume(self.get_pin_mut(), arg) } } + +#[unstable(feature = "exclusive_wrapper", issue = "98407")] +impl AsRef for Exclusive +where + T: Sync + ?Sized, +{ + #[inline] + fn as_ref(&self) -> &T { + &self.inner + } +} + +#[unstable(feature = "exclusive_wrapper", issue = "98407")] +impl Clone for Exclusive +where + T: Sync + Clone, +{ + #[inline] + fn clone(&self) -> Self { + Self { inner: self.inner.clone() } + } +} + +#[unstable(feature = "exclusive_wrapper", issue = "98407")] +impl Copy for Exclusive where T: Sync + Copy {} + +#[unstable(feature = "exclusive_wrapper", issue = "98407")] +impl PartialEq> for Exclusive +where + T: Sync + PartialEq + ?Sized, + U: Sync + ?Sized, +{ + #[inline] + fn eq(&self, other: &Exclusive) -> bool { + self.inner == other.inner + } +} + +#[unstable(feature = "exclusive_wrapper", issue = "98407")] +impl StructuralPartialEq for Exclusive where T: Sync + StructuralPartialEq + ?Sized {} + +#[unstable(feature = "exclusive_wrapper", issue = "98407")] +impl Eq for Exclusive where T: Sync + Eq + ?Sized {} + +#[unstable(feature = "exclusive_wrapper", issue = "98407")] +impl Hash for Exclusive +where + T: Sync + Hash + ?Sized, +{ + #[inline] + fn hash(&self, state: &mut H) { + Hash::hash(&self.inner, state) + } +} + +#[unstable(feature = "exclusive_wrapper", issue = "98407")] +impl PartialOrd> for Exclusive +where + T: Sync + PartialOrd + ?Sized, + U: Sync + ?Sized, +{ + #[inline] + fn partial_cmp(&self, other: &Exclusive) -> Option { + self.inner.partial_cmp(&other.inner) + } +} + +#[unstable(feature = "exclusive_wrapper", issue = "98407")] +impl Ord for Exclusive +where + T: Sync + Ord + ?Sized, +{ + #[inline] + fn cmp(&self, other: &Self) -> Ordering { + self.inner.cmp(&other.inner) + } +} From f1d6e000b500bd23a9f75716a161bc015ab49273 Mon Sep 17 00:00:00 2001 From: Jules Bertholet Date: Sun, 21 Sep 2025 17:26:39 -0400 Subject: [PATCH 18/22] Bless UI tests --- tests/ui/explicit-tail-calls/callee_is_weird.stderr | 2 +- tests/ui/impl-trait/where-allowed.stderr | 2 ++ tests/ui/traits/next-solver/well-formed-in-relate.stderr | 2 ++ 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/ui/explicit-tail-calls/callee_is_weird.stderr b/tests/ui/explicit-tail-calls/callee_is_weird.stderr index a4e5a38ce3320..9a5da28b559e3 100644 --- a/tests/ui/explicit-tail-calls/callee_is_weird.stderr +++ b/tests/ui/explicit-tail-calls/callee_is_weird.stderr @@ -12,7 +12,7 @@ error: tail calls can only be performed with function definitions or pointers LL | become (&mut &std::sync::Exclusive::new(f))() | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | - = note: callee has type `Exclusive` + = note: callee has type `&Exclusive` error: tail calls can only be performed with function definitions or pointers --> $DIR/callee_is_weird.rs:22:12 diff --git a/tests/ui/impl-trait/where-allowed.stderr b/tests/ui/impl-trait/where-allowed.stderr index 08caff326c470..4d8f23bf7ca64 100644 --- a/tests/ui/impl-trait/where-allowed.stderr +++ b/tests/ui/impl-trait/where-allowed.stderr @@ -387,6 +387,8 @@ LL | fn in_impl_Fn_return_in_return() -> &'static impl Fn() -> impl Debug { pani where A: Tuple, F: Fn, F: ?Sized; - impl Fn for Box where Args: Tuple, F: Fn, A: Allocator, F: ?Sized; + - impl Fn for Exclusive + where F: Sync, F: Fn, Args: Tuple; error[E0118]: no nominal type found for inherent implementation --> $DIR/where-allowed.rs:240:1 diff --git a/tests/ui/traits/next-solver/well-formed-in-relate.stderr b/tests/ui/traits/next-solver/well-formed-in-relate.stderr index 5294a072d3123..d79e465b3e387 100644 --- a/tests/ui/traits/next-solver/well-formed-in-relate.stderr +++ b/tests/ui/traits/next-solver/well-formed-in-relate.stderr @@ -12,6 +12,8 @@ LL | x = unconstrained_map(); where A: Tuple, F: Fn, F: ?Sized; - impl Fn for Box where Args: Tuple, F: Fn, A: Allocator, F: ?Sized; + - impl Fn for Exclusive + where F: Sync, F: Fn, Args: Tuple; note: required by a bound in `unconstrained_map` --> $DIR/well-formed-in-relate.rs:21:25 | From c0de794949f652b368fa107c3b52d0515f1f3859 Mon Sep 17 00:00:00 2001 From: David Carlier Date: Tue, 29 Apr 2025 19:01:08 +0100 Subject: [PATCH 19/22] std::net: update tcp deferaccept delay type to Duration. --- library/std/src/os/net/linux_ext/tcp.rs | 23 +++++++++++-------- library/std/src/os/net/linux_ext/tests.rs | 11 +++++---- .../std/src/sys/net/connection/socket/unix.rs | 9 ++++---- 3 files changed, 26 insertions(+), 17 deletions(-) diff --git a/library/std/src/os/net/linux_ext/tcp.rs b/library/std/src/os/net/linux_ext/tcp.rs index fde53ec4257bd..dbefc91a979a5 100644 --- a/library/std/src/os/net/linux_ext/tcp.rs +++ b/library/std/src/os/net/linux_ext/tcp.rs @@ -4,6 +4,7 @@ use crate::sealed::Sealed; use crate::sys_common::AsInner; +use crate::time::Duration; use crate::{io, net}; /// Os-specific extensions for [`TcpStream`] @@ -59,11 +60,13 @@ pub trait TcpStreamExt: Sealed { /// A socket listener will be awakened solely when data arrives. /// - /// The `accept` argument set the delay in seconds until the + /// The `accept` argument set the maximum delay until the /// data is available to read, reducing the number of short lived /// connections without data to process. /// Contrary to other platforms `SO_ACCEPTFILTER` feature equivalent, there is /// no necessity to set it after the `listen` call. + /// Note that the delay is expressed as Duration from user's perspective + /// the call rounds it down to the nearest second expressible as a `c_int`. /// /// See [`man 7 tcp`](https://man7.org/linux/man-pages/man7/tcp.7.html) /// @@ -73,16 +76,17 @@ pub trait TcpStreamExt: Sealed { /// #![feature(tcp_deferaccept)] /// use std::net::TcpStream; /// use std::os::linux::net::TcpStreamExt; + /// use std::time::Duration; /// /// let stream = TcpStream::connect("127.0.0.1:8080") /// .expect("Couldn't connect to the server..."); - /// stream.set_deferaccept(1).expect("set_deferaccept call failed"); + /// stream.set_deferaccept(Duration::from_secs(1u64)).expect("set_deferaccept call failed"); /// ``` #[unstable(feature = "tcp_deferaccept", issue = "119639")] #[cfg(target_os = "linux")] - fn set_deferaccept(&self, accept: u32) -> io::Result<()>; + fn set_deferaccept(&self, accept: Duration) -> io::Result<()>; - /// Gets the accept delay value (in seconds) of the `TCP_DEFER_ACCEPT` option. + /// Gets the accept delay value of the `TCP_DEFER_ACCEPT` option. /// /// For more information about this option, see [`TcpStreamExt::set_deferaccept`]. /// @@ -92,15 +96,16 @@ pub trait TcpStreamExt: Sealed { /// #![feature(tcp_deferaccept)] /// use std::net::TcpStream; /// use std::os::linux::net::TcpStreamExt; + /// use std::time::Duration; /// /// let stream = TcpStream::connect("127.0.0.1:8080") /// .expect("Couldn't connect to the server..."); - /// stream.set_deferaccept(1).expect("set_deferaccept call failed"); - /// assert_eq!(stream.deferaccept().unwrap_or(0), 1); + /// stream.set_deferaccept(Duration::from_secs(1u64)).expect("set_deferaccept call failed"); + /// assert_eq!(stream.deferaccept().unwrap(), Duration::from_secs(1u64)); /// ``` #[unstable(feature = "tcp_deferaccept", issue = "119639")] #[cfg(target_os = "linux")] - fn deferaccept(&self) -> io::Result; + fn deferaccept(&self) -> io::Result; } #[stable(feature = "tcp_quickack", since = "1.89.0")] @@ -117,12 +122,12 @@ impl TcpStreamExt for net::TcpStream { } #[cfg(target_os = "linux")] - fn set_deferaccept(&self, accept: u32) -> io::Result<()> { + fn set_deferaccept(&self, accept: Duration) -> io::Result<()> { self.as_inner().as_inner().set_deferaccept(accept) } #[cfg(target_os = "linux")] - fn deferaccept(&self) -> io::Result { + fn deferaccept(&self) -> io::Result { self.as_inner().as_inner().deferaccept() } } diff --git a/library/std/src/os/net/linux_ext/tests.rs b/library/std/src/os/net/linux_ext/tests.rs index 12f35696abc5c..0758b426ccc59 100644 --- a/library/std/src/os/net/linux_ext/tests.rs +++ b/library/std/src/os/net/linux_ext/tests.rs @@ -32,6 +32,7 @@ fn deferaccept() { use crate::net::test::next_test_ip4; use crate::net::{TcpListener, TcpStream}; use crate::os::net::linux_ext::tcp::TcpStreamExt; + use crate::time::Duration; macro_rules! t { ($e:expr) => { @@ -43,10 +44,12 @@ fn deferaccept() { } let addr = next_test_ip4(); + let one = Duration::from_secs(1u64); + let zero = Duration::from_secs(0u64); let _listener = t!(TcpListener::bind(&addr)); let stream = t!(TcpStream::connect(&("localhost", addr.port()))); - stream.set_deferaccept(1).expect("set_deferaccept failed"); - assert_eq!(stream.deferaccept().unwrap(), 1); - stream.set_deferaccept(0).expect("set_deferaccept failed"); - assert_eq!(stream.deferaccept().unwrap(), 0); + stream.set_deferaccept(one).expect("set_deferaccept failed"); + assert_eq!(stream.deferaccept().unwrap(), one); + stream.set_deferaccept(zero).expect("set_deferaccept failed"); + assert_eq!(stream.deferaccept().unwrap(), zero); } diff --git a/library/std/src/sys/net/connection/socket/unix.rs b/library/std/src/sys/net/connection/socket/unix.rs index 8b5970d1494e5..86b966242bf38 100644 --- a/library/std/src/sys/net/connection/socket/unix.rs +++ b/library/std/src/sys/net/connection/socket/unix.rs @@ -485,14 +485,15 @@ impl Socket { // bionic libc makes no use of this flag #[cfg(target_os = "linux")] - pub fn set_deferaccept(&self, accept: u32) -> io::Result<()> { - setsockopt(self, libc::IPPROTO_TCP, libc::TCP_DEFER_ACCEPT, accept as c_int) + pub fn set_deferaccept(&self, accept: Duration) -> io::Result<()> { + let val = cmp::min(accept.as_secs(), c_int::MAX as u64) as c_int; + setsockopt(self, libc::IPPROTO_TCP, libc::TCP_DEFER_ACCEPT, val) } #[cfg(target_os = "linux")] - pub fn deferaccept(&self) -> io::Result { + pub fn deferaccept(&self) -> io::Result { let raw: c_int = getsockopt(self, libc::IPPROTO_TCP, libc::TCP_DEFER_ACCEPT)?; - Ok(raw as u32) + Ok(Duration::from_secs(raw as _)) } #[cfg(any(target_os = "freebsd", target_os = "netbsd"))] From 19d0e728496afa1ba6a5f9817201b7d57337c14c Mon Sep 17 00:00:00 2001 From: David Carlier Date: Sat, 27 Sep 2025 20:38:45 +0100 Subject: [PATCH 20/22] fix build for android --- library/std/src/os/net/linux_ext/tcp.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/library/std/src/os/net/linux_ext/tcp.rs b/library/std/src/os/net/linux_ext/tcp.rs index dbefc91a979a5..3f9b2bd3f4b43 100644 --- a/library/std/src/os/net/linux_ext/tcp.rs +++ b/library/std/src/os/net/linux_ext/tcp.rs @@ -4,6 +4,7 @@ use crate::sealed::Sealed; use crate::sys_common::AsInner; +#[cfg(target_os = "linux")] use crate::time::Duration; use crate::{io, net}; From c1259aa26fead9a9d365c4436d5ceb00cad88bbe Mon Sep 17 00:00:00 2001 From: WANG Rui Date: Thu, 28 Aug 2025 23:31:46 +0800 Subject: [PATCH 21/22] Add LSX accelerated implementation for source file analysis This patch introduces an LSX-optimized version of `analyze_source_file` for the `loongarch64` target. Similar to existing SSE2 implementation for x86, this version: - Processes 16-byte chunks at a time using LSX vector intrinsics. - Quickly identifies newlines in ASCII-only chunks. - Falls back to the generic implementation when multi-byte UTF-8 characters are detected or in the tail portion. --- .../rustc_span/src/analyze_source_file.rs | 109 +++++++++++++++++- compiler/rustc_span/src/lib.rs | 1 + 2 files changed, 107 insertions(+), 3 deletions(-) diff --git a/compiler/rustc_span/src/analyze_source_file.rs b/compiler/rustc_span/src/analyze_source_file.rs index c32593a6d95ab..bb2cda77dffff 100644 --- a/compiler/rustc_span/src/analyze_source_file.rs +++ b/compiler/rustc_span/src/analyze_source_file.rs @@ -81,8 +81,8 @@ cfg_select! { // use `loadu`, which supports unaligned loading. let chunk = unsafe { _mm_loadu_si128(chunk.as_ptr() as *const __m128i) }; - // For character in the chunk, see if its byte value is < 0, which - // indicates that it's part of a UTF-8 char. + // For each character in the chunk, see if its byte value is < 0, + // which indicates that it's part of a UTF-8 char. let multibyte_test = _mm_cmplt_epi8(chunk, _mm_set1_epi8(0)); // Create a bit mask from the comparison results. let multibyte_mask = _mm_movemask_epi8(multibyte_test); @@ -132,8 +132,111 @@ cfg_select! { } } } + target_arch = "loongarch64" => { + fn analyze_source_file_dispatch( + src: &str, + lines: &mut Vec, + multi_byte_chars: &mut Vec, + ) { + use std::arch::is_loongarch_feature_detected; + + if is_loongarch_feature_detected!("lsx") { + unsafe { + analyze_source_file_lsx(src, lines, multi_byte_chars); + } + } else { + analyze_source_file_generic( + src, + src.len(), + RelativeBytePos::from_u32(0), + lines, + multi_byte_chars, + ); + } + } + + /// Checks 16 byte chunks of text at a time. If the chunk contains + /// something other than printable ASCII characters and newlines, the + /// function falls back to the generic implementation. Otherwise it uses + /// LSX intrinsics to quickly find all newlines. + #[target_feature(enable = "lsx")] + unsafe fn analyze_source_file_lsx( + src: &str, + lines: &mut Vec, + multi_byte_chars: &mut Vec, + ) { + use std::arch::loongarch64::*; + + const CHUNK_SIZE: usize = 16; + + let (chunks, tail) = src.as_bytes().as_chunks::(); + + // This variable keeps track of where we should start decoding a + // chunk. If a multi-byte character spans across chunk boundaries, + // we need to skip that part in the next chunk because we already + // handled it. + let mut intra_chunk_offset = 0; + + for (chunk_index, chunk) in chunks.iter().enumerate() { + // All LSX memory instructions support unaligned access, so using + // vld is fine. + let chunk = unsafe { lsx_vld::<0>(chunk.as_ptr() as *const i8) }; + + // For each character in the chunk, see if its byte value is < 0, + // which indicates that it's part of a UTF-8 char. + let multibyte_mask = lsx_vmskltz_b(chunk); + // Create a bit mask from the comparison results. + let multibyte_mask = lsx_vpickve2gr_w::<0>(multibyte_mask); + + // If the bit mask is all zero, we only have ASCII chars here: + if multibyte_mask == 0 { + assert!(intra_chunk_offset == 0); + + // Check for newlines in the chunk + let newlines_test = lsx_vseqi_b::<{b'\n' as i32}>(chunk); + let newlines_mask = lsx_vmskltz_b(newlines_test); + let mut newlines_mask = lsx_vpickve2gr_w::<0>(newlines_mask); + + let output_offset = RelativeBytePos::from_usize(chunk_index * CHUNK_SIZE + 1); + + while newlines_mask != 0 { + let index = newlines_mask.trailing_zeros(); + + lines.push(RelativeBytePos(index) + output_offset); + + // Clear the bit, so we can find the next one. + newlines_mask &= newlines_mask - 1; + } + } else { + // The slow path. + // There are multibyte chars in here, fallback to generic decoding. + let scan_start = chunk_index * CHUNK_SIZE + intra_chunk_offset; + intra_chunk_offset = analyze_source_file_generic( + &src[scan_start..], + CHUNK_SIZE - intra_chunk_offset, + RelativeBytePos::from_usize(scan_start), + lines, + multi_byte_chars, + ); + } + } + + // There might still be a tail left to analyze + let tail_start = src.len() - tail.len() + intra_chunk_offset; + if tail_start < src.len() { + analyze_source_file_generic( + &src[tail_start..], + src.len() - tail_start, + RelativeBytePos::from_usize(tail_start), + lines, + multi_byte_chars, + ); + } + } + } _ => { - // The target (or compiler version) does not support SSE2 ... + // The target (or compiler version) does not support vector instructions + // our specialized implementations need (x86 SSE2, loongarch64 LSX)... fn analyze_source_file_dispatch( src: &str, lines: &mut Vec, diff --git a/compiler/rustc_span/src/lib.rs b/compiler/rustc_span/src/lib.rs index 35dbbe58db9a4..ededbea57e966 100644 --- a/compiler/rustc_span/src/lib.rs +++ b/compiler/rustc_span/src/lib.rs @@ -17,6 +17,7 @@ // tidy-alphabetical-start #![allow(internal_features)] +#![cfg_attr(target_arch = "loongarch64", feature(stdarch_loongarch))] #![doc(html_root_url = "https://doc.rust-lang.org/nightly/nightly-rustc/")] #![doc(rust_logo)] #![feature(array_windows)] From 5e9cab39213b892fdb5eb24d2d27f1ab0918ce46 Mon Sep 17 00:00:00 2001 From: Shunpoco Date: Sun, 28 Sep 2025 09:46:20 +0100 Subject: [PATCH 22/22] modify ensure_version_or_cargo_install to check existing binary Current implementation uses bin_name to check if it exists, but it should use tool_root_dir/tool_bin_dir/bin_name instead. Otherwise the check fails every time, hence the function falls back to install the binary. --- src/tools/tidy/src/lib.rs | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/tools/tidy/src/lib.rs b/src/tools/tidy/src/lib.rs index 0bfee93796be1..874a758bd9b38 100644 --- a/src/tools/tidy/src/lib.rs +++ b/src/tools/tidy/src/lib.rs @@ -167,12 +167,16 @@ pub fn ensure_version_or_cargo_install( bin_name: &str, version: &str, ) -> io::Result { + let tool_root_dir = build_dir.join("misc-tools"); + let tool_bin_dir = tool_root_dir.join("bin"); + let bin_path = tool_bin_dir.join(bin_name).with_extension(env::consts::EXE_EXTENSION); + // ignore the process exit code here and instead just let the version number check fail. // we also importantly don't return if the program wasn't installed, // instead we want to continue to the fallback. 'ck: { // FIXME: rewrite as if-let chain once this crate is 2024 edition. - let Ok(output) = Command::new(bin_name).arg("--version").output() else { + let Ok(output) = Command::new(&bin_path).arg("--version").output() else { break 'ck; }; let Ok(s) = str::from_utf8(&output.stdout) else { @@ -182,12 +186,10 @@ pub fn ensure_version_or_cargo_install( break 'ck; }; if v == version { - return Ok(PathBuf::from(bin_name)); + return Ok(bin_path); } } - let tool_root_dir = build_dir.join("misc-tools"); - let tool_bin_dir = tool_root_dir.join("bin"); eprintln!("building external tool {bin_name} from package {pkg_name}@{version}"); // use --force to ensure that if the required version is bumped, we update it. // use --target-dir to ensure we have a build cache so repeated invocations aren't slow. @@ -213,7 +215,6 @@ pub fn ensure_version_or_cargo_install( if !cargo_exit_code.success() { return Err(io::Error::other("cargo install failed")); } - let bin_path = tool_bin_dir.join(bin_name).with_extension(env::consts::EXE_EXTENSION); assert!( matches!(bin_path.try_exists(), Ok(true)), "cargo install did not produce the expected binary"