Skip to content

Commit 7c5fbfb

Browse files
committed
autodiff: tuple support in typetree
1 parent be3617b commit 7c5fbfb

File tree

4 files changed

+64
-0
lines changed

4 files changed

+64
-0
lines changed

compiler/rustc_middle/src/ty/mod.rs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2334,5 +2334,41 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
23342334
}
23352335
}
23362336

2337+
if let ty::Tuple(tuple_types) = ty.kind() {
2338+
if tuple_types.is_empty() {
2339+
return TypeTree::new();
2340+
}
2341+
2342+
let mut types = Vec::new();
2343+
let mut current_offset = 0;
2344+
2345+
for tuple_ty in tuple_types.iter() {
2346+
let element_tree = typetree_from_ty(tcx, tuple_ty);
2347+
2348+
let element_layout = tcx
2349+
.layout_of(ty::TypingEnv::fully_monomorphized().as_query_input(tuple_ty))
2350+
.ok()
2351+
.map(|layout| layout.size.bytes_usize())
2352+
.unwrap_or(0);
2353+
2354+
for elem_type in &element_tree.0 {
2355+
types.push(Type {
2356+
offset: if elem_type.offset == -1 {
2357+
current_offset as isize
2358+
} else {
2359+
current_offset as isize + elem_type.offset
2360+
},
2361+
size: elem_type.size,
2362+
kind: elem_type.kind,
2363+
child: elem_type.child.clone(),
2364+
});
2365+
}
2366+
2367+
current_offset += element_layout;
2368+
}
2369+
2370+
return TypeTree(types);
2371+
}
2372+
23372373
TypeTree::new()
23382374
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
//@ needs-enzyme
2+
//@ ignore-cross-compile
3+
4+
use run_make_support::{llvm_filecheck, rfs, rustc};
5+
6+
fn main() {
7+
rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run();
8+
llvm_filecheck().patterns("tuple.check").stdin_buf(rfs::read("test.ll")).run();
9+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#![feature(autodiff)]
2+
3+
use std::autodiff::autodiff_reverse;
4+
5+
#[autodiff_reverse(d_test, Duplicated, Active)]
6+
#[no_mangle]
7+
fn test_tuple(tuple: &(f64, f64, f64)) -> f64 {
8+
tuple.0 + tuple.1 * 2.0 + tuple.2 * 3.0
9+
}
10+
11+
fn main() {
12+
let tuple = (1.0, 2.0, 3.0);
13+
let mut d_tuple = (0.0, 0.0, 0.0);
14+
let _result = d_test(&tuple, &mut d_tuple, 1.0);
15+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
; Check that tuple TypeTree metadata is correctly generated
2+
; Should show Float@double at offsets 0, 8, 16 for (f64, f64, f64)
3+
4+
CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_tuple{{.*}}"enzyme_type"="{[]:Pointer}"

0 commit comments

Comments
 (0)