Skip to content

Commit 83aa06f

Browse files
committed
fixing typetree metadata issue
Signed-off-by: Karan Janthe <[email protected]>
1 parent 461f7c2 commit 83aa06f

File tree

4 files changed

+40
-14
lines changed

4 files changed

+40
-14
lines changed

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,9 @@ pub(crate) fn differentiate<'ll>(
504504
},
505505
));
506506
};
507+
// Attach TypeTree metadata to the source function before calling Enzyme
508+
let fnc_tree = FncTree { args: item.inputs.clone(), ret: item.output.clone() };
509+
add_tt(cx.llmod, cx.llcx, fn_def, fnc_tree);
507510

508511
generate_enzyme_call(&cx, fn_def, fn_target, item.attrs.clone());
509512
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
; Check that enzyme_type attributes are present in the LLVM IR function definition
2+
; This verifies our TypeTree system correctly attaches metadata for Enzyme
3+
4+
CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_memcpy({{.*}}"enzyme_type"="{[]:Pointer}"
5+
6+
; Check that the differentiated function also has proper enzyme_type attributes
7+
8+
CHECK: @diffetest_memcpy({{.*}}"enzyme_type"="{[]:Pointer}"{{.*}}"enzyme_type"="{[]:Pointer}"

tests/run-make/autodiff/type-trees/type-analysis/memcpy/memcpy.rs

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,15 @@ fn test_memcpy(input: &[f64; 8]) -> f64 {
1212
ptr::copy_nonoverlapping(input.as_ptr(), local_data.as_mut_ptr(), 8);
1313
}
1414

15-
let mut result = 0.0;
16-
for i in 0..8 {
17-
result += local_data[i] * local_data[i];
18-
}
19-
20-
result
15+
// type tree does not support loops
16+
local_data[0] * local_data[0]
17+
+ local_data[1] * local_data[1]
18+
+ local_data[2] * local_data[2]
19+
+ local_data[3] * local_data[3]
20+
+ local_data[4] * local_data[4]
21+
+ local_data[5] * local_data[5]
22+
+ local_data[6] * local_data[6]
23+
+ local_data[7] * local_data[7]
2124
}
2225

2326
fn main() {
Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,22 @@
11
//@ needs-enzyme
22
//@ ignore-cross-compile
33

4-
use std::fs;
5-
64
use run_make_support::{llvm_filecheck, rfs, rustc};
75

86
fn main() {
9-
// Compile the Rust file with the required flags, capturing both stdout and stderr
7+
// First, compile to LLVM IR to check for enzyme_type attributes
8+
let _ir_output = rustc()
9+
.input("memcpy.rs")
10+
.arg("-Zautodiff=Enable")
11+
.arg("-Zautodiff=NoPostopt")
12+
.opt_level("3")
13+
.arg("-Clto=fat")
14+
.arg("--emit=llvm-ir")
15+
.arg("-o")
16+
.arg("main.ll")
17+
.run();
18+
19+
// Then compile with TypeTree analysis output for the existing checks
1020
let output = rustc()
1121
.input("memcpy.rs")
1222
.arg("-Zautodiff=Enable,PrintTAFn=test_memcpy")
@@ -18,11 +28,13 @@ fn main() {
1828

1929
let stdout = output.stdout_utf8();
2030
let stderr = output.stderr_utf8();
31+
let ir_content = rfs::read_to_string("main.ll");
32+
33+
rfs::write("memcpy.stdout", &stdout);
34+
rfs::write("memcpy.stderr", &stderr);
35+
rfs::write("main.ir", &ir_content);
2136

22-
// Write the outputs to files
23-
rfs::write("memcpy.stdout", stdout);
24-
rfs::write("memcpy.stderr", stderr);
37+
llvm_filecheck().patterns("memcpy.check").stdin_buf(stdout).run();
2538

26-
// Run FileCheck on the stdout using the check file
27-
llvm_filecheck().patterns("memcpy.check").stdin_buf(rfs::read("memcpy.stdout")).run();
39+
llvm_filecheck().patterns("memcpy-ir.check").stdin_buf(ir_content).run();
2840
}

0 commit comments

Comments
 (0)