Skip to content

Commit 31541fe

Browse files
committed
autodiff: add TypeTree support for arrays
1 parent 54f9376 commit 31541fe

File tree

4 files changed

+69
-1
lines changed

4 files changed

+69
-1
lines changed

compiler/rustc_middle/src/ty/mod.rs

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2286,6 +2286,46 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
22862286
}]);
22872287
}
22882288

2289-
// FIXME(KMJ-007): Handle arrays, slices, structs, and other complex types
2289+
if ty.is_array() {
2290+
if let ty::Array(element_ty, len_const) = ty.kind() {
2291+
let len = len_const.try_to_target_usize(tcx).unwrap_or(0);
2292+
if len == 0 {
2293+
return TypeTree::new();
2294+
}
2295+
2296+
let element_tree = typetree_from_ty(tcx, *element_ty);
2297+
2298+
let element_layout = tcx
2299+
.layout_of(ty::TypingEnv::fully_monomorphized().as_query_input(*element_ty))
2300+
.ok()
2301+
.map(|layout| layout.size.bytes_usize())
2302+
.unwrap_or(0);
2303+
2304+
if element_layout == 0 {
2305+
return TypeTree::new();
2306+
}
2307+
2308+
let mut types = Vec::new();
2309+
for i in 0..len {
2310+
let base_offset = (i as usize * element_layout) as isize;
2311+
2312+
for elem_type in &element_tree.0 {
2313+
types.push(Type {
2314+
offset: if elem_type.offset == -1 {
2315+
base_offset
2316+
} else {
2317+
base_offset + elem_type.offset
2318+
},
2319+
size: elem_type.size,
2320+
kind: elem_type.kind,
2321+
child: elem_type.child.clone(),
2322+
});
2323+
}
2324+
}
2325+
2326+
return TypeTree(types);
2327+
}
2328+
}
2329+
22902330
TypeTree::new()
22912331
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
; Check that array TypeTree metadata is correctly generated
2+
; Should show Float@double at each array element offset (0, 8, 16, 24, 32 bytes)
3+
4+
CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_array{{.*}}"enzyme_type"="{[]:Pointer}"
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
//@ needs-enzyme
2+
//@ ignore-cross-compile
3+
4+
use run_make_support::{llvm_filecheck, rfs, rustc};
5+
6+
fn main() {
7+
rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run();
8+
llvm_filecheck().patterns("array.check").stdin_buf(rfs::read("test.ll")).run();
9+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#![feature(autodiff)]
2+
3+
use std::autodiff::autodiff_reverse;
4+
5+
#[autodiff_reverse(d_test, Duplicated, Active)]
6+
#[no_mangle]
7+
fn test_array(arr: &[f64; 5]) -> f64 {
8+
arr[0] + arr[1] + arr[2] + arr[3] + arr[4]
9+
}
10+
11+
fn main() {
12+
let arr = [1.0, 2.0, 3.0, 4.0, 5.0];
13+
let mut d_arr = [0.0; 5];
14+
let _result = d_test(&arr, &mut d_arr, 1.0);
15+
}

0 commit comments

Comments
 (0)