Skip to content

Commit 731a98a

Browse files
committed
autodiff: struct support in typetree
1 parent b62d3d7 commit 731a98a

File tree

4 files changed

+67
-0
lines changed

4 files changed

+67
-0
lines changed

compiler/rustc_middle/src/ty/mod.rs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2376,5 +2376,37 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
23762376
return TypeTree(types);
23772377
}
23782378

2379+
if let ty::Adt(adt_def, args) = ty.kind() {
2380+
if adt_def.is_struct() {
2381+
let struct_layout =
2382+
tcx.layout_of(ty::TypingEnv::fully_monomorphized().as_query_input(ty));
2383+
if let Ok(layout) = struct_layout {
2384+
let mut types = Vec::new();
2385+
2386+
for (field_idx, field_def) in adt_def.all_fields().enumerate() {
2387+
let field_ty = field_def.ty(tcx, args);
2388+
let field_tree = typetree_from_ty(tcx, field_ty);
2389+
2390+
let field_offset = layout.fields.offset(field_idx).bytes_usize();
2391+
2392+
for elem_type in &field_tree.0 {
2393+
types.push(Type {
2394+
offset: if elem_type.offset == -1 {
2395+
field_offset as isize
2396+
} else {
2397+
field_offset as isize + elem_type.offset
2398+
},
2399+
size: elem_type.size,
2400+
kind: elem_type.kind,
2401+
child: elem_type.child.clone(),
2402+
});
2403+
}
2404+
}
2405+
2406+
return TypeTree(types);
2407+
}
2408+
}
2409+
}
2410+
23792411
TypeTree::new()
23802412
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
//@ needs-enzyme
2+
//@ ignore-cross-compile
3+
4+
use run_make_support::{llvm_filecheck, rfs, rustc};
5+
6+
fn main() {
7+
rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run();
8+
llvm_filecheck().patterns("struct.check").stdin_buf(rfs::read("test.ll")).run();
9+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
; Check that struct TypeTree metadata is correctly generated
2+
; Should show Float@double at offsets 0, 8, 16 for Point struct fields
3+
4+
CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_struct{{.*}}"enzyme_type"="{[]:Pointer}"
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#![feature(autodiff)]
2+
3+
use std::autodiff::autodiff_reverse;
4+
5+
#[repr(C)]
6+
struct Point {
7+
x: f64,
8+
y: f64,
9+
z: f64,
10+
}
11+
12+
#[autodiff_reverse(d_test, Duplicated, Active)]
13+
#[no_mangle]
14+
fn test_struct(point: &Point) -> f64 {
15+
point.x + point.y * 2.0 + point.z * 3.0
16+
}
17+
18+
fn main() {
19+
let point = Point { x: 1.0, y: 2.0, z: 3.0 };
20+
let mut d_point = Point { x: 0.0, y: 0.0, z: 0.0 };
21+
let _result = d_test(&point, &mut d_point, 1.0);
22+
}

0 commit comments

Comments
 (0)