Skip to content

Commit 574f0b9

Browse files
committed
autodiff: struct support in typetree
1 parent 7c5fbfb commit 574f0b9

File tree

4 files changed

+67
-0
lines changed

4 files changed

+67
-0
lines changed

compiler/rustc_middle/src/ty/mod.rs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2370,5 +2370,37 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
23702370
return TypeTree(types);
23712371
}
23722372

2373+
if let ty::Adt(adt_def, args) = ty.kind() {
2374+
if adt_def.is_struct() {
2375+
let struct_layout =
2376+
tcx.layout_of(ty::TypingEnv::fully_monomorphized().as_query_input(ty));
2377+
if let Ok(layout) = struct_layout {
2378+
let mut types = Vec::new();
2379+
2380+
for (field_idx, field_def) in adt_def.all_fields().enumerate() {
2381+
let field_ty = field_def.ty(tcx, args);
2382+
let field_tree = typetree_from_ty(tcx, field_ty);
2383+
2384+
let field_offset = layout.fields.offset(field_idx).bytes_usize();
2385+
2386+
for elem_type in &field_tree.0 {
2387+
types.push(Type {
2388+
offset: if elem_type.offset == -1 {
2389+
field_offset as isize
2390+
} else {
2391+
field_offset as isize + elem_type.offset
2392+
},
2393+
size: elem_type.size,
2394+
kind: elem_type.kind,
2395+
child: elem_type.child.clone(),
2396+
});
2397+
}
2398+
}
2399+
2400+
return TypeTree(types);
2401+
}
2402+
}
2403+
}
2404+
23732405
TypeTree::new()
23742406
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
//@ needs-enzyme
2+
//@ ignore-cross-compile
3+
4+
use run_make_support::{llvm_filecheck, rfs, rustc};
5+
6+
fn main() {
7+
rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run();
8+
llvm_filecheck().patterns("struct.check").stdin_buf(rfs::read("test.ll")).run();
9+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
; Check that struct TypeTree metadata is correctly generated
2+
; Should show Float@double at offsets 0, 8, 16 for Point struct fields
3+
4+
CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_struct{{.*}}"enzyme_type"="{[]:Pointer}"
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#![feature(autodiff)]
2+
3+
use std::autodiff::autodiff_reverse;
4+
5+
#[repr(C)]
6+
struct Point {
7+
x: f64,
8+
y: f64,
9+
z: f64,
10+
}
11+
12+
#[autodiff_reverse(d_test, Duplicated, Active)]
13+
#[no_mangle]
14+
fn test_struct(point: &Point) -> f64 {
15+
point.x + point.y * 2.0 + point.z * 3.0
16+
}
17+
18+
fn main() {
19+
let point = Point { x: 1.0, y: 2.0, z: 3.0 };
20+
let mut d_point = Point { x: 0.0, y: 0.0, z: 0.0 };
21+
let _result = d_test(&point, &mut d_point, 1.0);
22+
}

0 commit comments

Comments
 (0)