Skip to content

Commit e1258e7

Browse files
committed
autodiff: Add basic TypeTree with NoTT flag
Signed-off-by: Karan Janthe <[email protected]>
1 parent 2f4dfc7 commit e1258e7

File tree

13 files changed

+212
-5
lines changed

13 files changed

+212
-5
lines changed

compiler/rustc_ast/src/expand/autodiff_attrs.rs

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
use std::fmt::{self, Display, Formatter};
77
use std::str::FromStr;
88

9+
use crate::expand::typetree::TypeTree;
910
use crate::expand::{Decodable, Encodable, HashStable_Generic};
1011
use crate::{Ty, TyKind};
1112

@@ -84,6 +85,8 @@ pub struct AutoDiffItem {
8485
/// The name of the function being generated
8586
pub target: String,
8687
pub attrs: AutoDiffAttrs,
88+
pub inputs: Vec<TypeTree>,
89+
pub output: TypeTree,
8790
}
8891

8992
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
@@ -275,14 +278,22 @@ impl AutoDiffAttrs {
275278
!matches!(self.mode, DiffMode::Error | DiffMode::Source)
276279
}
277280

278-
pub fn into_item(self, source: String, target: String) -> AutoDiffItem {
279-
AutoDiffItem { source, target, attrs: self }
281+
pub fn into_item(
282+
self,
283+
source: String,
284+
target: String,
285+
inputs: Vec<TypeTree>,
286+
output: TypeTree,
287+
) -> AutoDiffItem {
288+
AutoDiffItem { source, target, inputs, output, attrs: self }
280289
}
281290
}
282291

283292
impl fmt::Display for AutoDiffItem {
284293
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
285294
write!(f, "Differentiating {} -> {}", self.source, self.target)?;
286-
write!(f, " with attributes: {:?}", self.attrs)
295+
write!(f, " with attributes: {:?}", self.attrs)?;
296+
write!(f, " with inputs: {:?}", self.inputs)?;
297+
write!(f, " with output: {:?}", self.output)
287298
}
288299
}

compiler/rustc_codegen_llvm/src/back/lto.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,8 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
563563
config::AutoDiff::Enable => {}
564564
// We handle this below
565565
config::AutoDiff::NoPostopt => {}
566+
// Disables TypeTree generation
567+
config::AutoDiff::NoTT => {}
566568
}
567569
}
568570
// This helps with handling enums for now.

compiler/rustc_interface/src/tests.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -766,6 +766,7 @@ fn test_unstable_options_tracking_hash() {
766766
tracked!(always_encode_mir, true);
767767
tracked!(assume_incomplete_release, true);
768768
tracked!(autodiff, vec![AutoDiff::Enable]);
769+
tracked!(autodiff, vec![AutoDiff::Enable, AutoDiff::NoTT]);
769770
tracked!(binary_dep_depinfo, true);
770771
tracked!(box_noalias, false);
771772
tracked!(

compiler/rustc_middle/src/error.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ pub(crate) struct OpaqueHiddenTypeMismatch<'tcx> {
3737
pub sub: TypeMismatchReason,
3838
}
3939

40-
// FIXME(autodiff): I should get used somewhere
4140
#[derive(Diagnostic)]
4241
#[diag(middle_unsupported_union)]
4342
pub struct UnsupportedUnion {

compiler/rustc_middle/src/ty/mod.rs

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ pub use generic_args::{GenericArgKind, TermKind, *};
2525
pub use generics::*;
2626
pub use intrinsic::IntrinsicDef;
2727
use rustc_abi::{Align, FieldIdx, Integer, IntegerType, ReprFlags, ReprOptions, VariantIdx};
28+
use rustc_ast::expand::typetree::{FncTree, Kind, Type, TypeTree};
2829
use rustc_ast::node_id::NodeMap;
2930
pub use rustc_ast_ir::{Movability, Mutability, try_visit};
3031
use rustc_data_structures::fx::{FxHashMap, FxHashSet, FxIndexMap, FxIndexSet};
@@ -2216,3 +2217,82 @@ pub struct DestructuredConst<'tcx> {
22162217
pub variant: Option<VariantIdx>,
22172218
pub fields: &'tcx [ty::Const<'tcx>],
22182219
}
2220+
2221+
/// Generate TypeTree information for autodiff.
2222+
/// This function creates TypeTree metadata that describes the memory layout
2223+
/// of function parameters and return types for Enzyme autodiff.
2224+
pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>) -> FncTree {
2225+
// Check if TypeTrees are disabled via NoTT flag
2226+
if tcx.sess.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::NoTT) {
2227+
return FncTree { args: vec![], ret: TypeTree::new() };
2228+
}
2229+
2230+
// Check if this is actually a function type
2231+
if !fn_ty.is_fn() {
2232+
return FncTree { args: vec![], ret: TypeTree::new() };
2233+
}
2234+
2235+
// Get the function signature
2236+
let fn_sig = fn_ty.fn_sig(tcx);
2237+
let sig = tcx.instantiate_bound_regions_with_erased(fn_sig);
2238+
2239+
// Create TypeTrees for each input parameter
2240+
let mut args = vec![];
2241+
for ty in sig.inputs().iter() {
2242+
let type_tree = typetree_from_ty(tcx, *ty);
2243+
args.push(type_tree);
2244+
}
2245+
2246+
// Create TypeTree for return type
2247+
let ret = typetree_from_ty(tcx, sig.output());
2248+
2249+
FncTree { args, ret }
2250+
}
2251+
2252+
/// Generate TypeTree for a specific type.
2253+
/// This function analyzes a Rust type and creates appropriate TypeTree metadata.
2254+
fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
2255+
// Handle basic scalar types
2256+
if ty.is_scalar() {
2257+
let (kind, size) = if ty.is_integral() || ty.is_char() || ty.is_bool() {
2258+
(Kind::Integer, ty.primitive_size(tcx).bytes_usize())
2259+
} else if ty.is_floating_point() {
2260+
match ty {
2261+
x if x == tcx.types.f32 => (Kind::Float, 4),
2262+
x if x == tcx.types.f64 => (Kind::Double, 8),
2263+
_ => return TypeTree::new(), // Unknown float type
2264+
}
2265+
} else {
2266+
// TODO(KMJ-007): Handle other scalar types if needed
2267+
return TypeTree::new();
2268+
};
2269+
2270+
return TypeTree(vec![Type {
2271+
offset: -1,
2272+
size,
2273+
kind,
2274+
child: TypeTree::new()
2275+
}]);
2276+
}
2277+
2278+
// Handle references and pointers
2279+
if ty.is_ref() || ty.is_raw_ptr() || ty.is_box() {
2280+
let inner_ty = if let Some(inner) = ty.builtin_deref(true) {
2281+
inner
2282+
} else {
2283+
// TODO(KMJ-007): Handle complex pointer types
2284+
return TypeTree::new();
2285+
};
2286+
2287+
let child = typetree_from_ty(tcx, inner_ty);
2288+
return TypeTree(vec![Type {
2289+
offset: -1,
2290+
size: 8, // TODO(KMJ-007): Get actual pointer size from target
2291+
kind: Kind::Pointer,
2292+
child,
2293+
}]);
2294+
}
2295+
2296+
// TODO(KMJ-007): Handle arrays, slices, structs, and other complex types
2297+
TypeTree::new()
2298+
}

compiler/rustc_session/src/config.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,8 @@ pub enum AutoDiff {
257257
LooseTypes,
258258
/// Runs Enzyme's aggressive inlining
259259
Inline,
260+
/// Disable Type Tree
261+
NoTT,
260262
}
261263

262264
/// Settings for `-Z instrument-xray` flag.

compiler/rustc_session/src/options.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -792,7 +792,7 @@ mod desc {
792792
pub(crate) const parse_list: &str = "a space-separated list of strings";
793793
pub(crate) const parse_list_with_polarity: &str =
794794
"a comma-separated list of strings, with elements beginning with + or -";
795-
pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintTAFn`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `PrintPasses`, `NoPostopt`, `LooseTypes`, `Inline`";
795+
pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintTAFn`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `PrintPasses`, `NoPostopt`, `LooseTypes`, `Inline`, `NoTT`";
796796
pub(crate) const parse_offload: &str = "a comma separated list of settings: `Enable`";
797797
pub(crate) const parse_comma_list: &str = "a comma-separated list of strings";
798798
pub(crate) const parse_opt_comma_list: &str = parse_comma_list;
@@ -1479,6 +1479,7 @@ pub mod parse {
14791479
"PrintPasses" => AutoDiff::PrintPasses,
14801480
"LooseTypes" => AutoDiff::LooseTypes,
14811481
"Inline" => AutoDiff::Inline,
1482+
"NoTT" => AutoDiff::NoTT,
14821483
_ => {
14831484
// FIXME(ZuseZ4): print an error saying which value is not recognized
14841485
return false;
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat
2+
//@ no-prefer-dynamic
3+
//@ needs-enzyme
4+
5+
// Test that basic autodiff still works with our TypeTree infrastructure
6+
#![feature(autodiff)]
7+
8+
use std::autodiff::autodiff_reverse;
9+
10+
#[autodiff_reverse(d_simple, Duplicated, Active)]
11+
#[no_mangle]
12+
#[inline(never)]
13+
fn simple(x: &f64) -> f64 {
14+
2.0 * x
15+
}
16+
17+
// CHECK-LABEL: @simple
18+
// CHECK: fmul double
19+
20+
// The derivative function should be generated normally
21+
// CHECK-LABEL: diffesimple
22+
// CHECK: fadd fast double
23+
24+
fn main() {
25+
let x = std::hint::black_box(3.0);
26+
let output = simple(&x);
27+
assert_eq!(6.0, output);
28+
29+
let mut df_dx = 0.0;
30+
let output_ = d_simple(&x, &mut df_dx, 1.0);
31+
assert_eq!(output, output_);
32+
assert_eq!(2.0, df_dx);
33+
}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
// TODO(KMJ-007): Update this test when TypeTree integration is complete
2+
// CHECK: square - {[-1]:Float@double} |{[-1]:Pointer}:{}
3+
// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@double}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
//@ needs-enzyme
2+
//@ ignore-cross-compile
3+
4+
use run_make_support::{llvm_filecheck, rfs, rustc};
5+
6+
fn main() {
7+
// Test with NoTT flag - should not generate TypeTree metadata
8+
let output_nott = rustc()
9+
.input("test.rs")
10+
.arg("-Zautodiff=Enable,NoTT,PrintTAFn=square")
11+
.arg("-Zautodiff=NoPostopt")
12+
.opt_level("3")
13+
.arg("-Clto=fat")
14+
.arg("-g")
15+
.run();
16+
17+
// Write output for NoTT case
18+
rfs::write("nott.stdout", output_nott.stdout_utf8());
19+
20+
// Test without NoTT flag - should generate TypeTree metadata
21+
let output_with_tt = rustc()
22+
.input("test.rs")
23+
.arg("-Zautodiff=Enable,PrintTAFn=square")
24+
.arg("-Zautodiff=NoPostopt")
25+
.opt_level("3")
26+
.arg("-Clto=fat")
27+
.arg("-g")
28+
.run();
29+
30+
// Write output for TypeTree case
31+
rfs::write("with_tt.stdout", output_with_tt.stdout_utf8());
32+
33+
// Verify NoTT output has minimal TypeTree info
34+
llvm_filecheck().patterns("nott.check").stdin_buf(rfs::read("nott.stdout")).run();
35+
36+
// Verify normal output will have TypeTree info (once implemented)
37+
llvm_filecheck().patterns("with_tt.check").stdin_buf(rfs::read("with_tt.stdout")).run();
38+
}

0 commit comments

Comments
 (0)