Skip to content

Commit eb518c9

Browse files
committed
wip, fixing ci
1 parent b045bbd commit eb518c9

File tree

6 files changed

+187
-16
lines changed

6 files changed

+187
-16
lines changed

compiler/rustc_codegen_gcc/src/builder.rs

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1384,27 +1384,19 @@ impl<'a, 'gcc, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'gcc, 'tcx> {
13841384
_src_align: Align,
13851385
size: RValue<'gcc>,
13861386
flags: MemFlags,
1387-
tt: Option<FncTree>,
1387+
_tt: Option<FncTree>,
13881388
) {
13891389
assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported");
13901390
let size = self.intcast(size, self.type_size_t(), false);
13911391
let _is_volatile = flags.contains(MemFlags::VOLATILE);
13921392
let dst = self.pointercast(dst, self.type_i8p());
13931393
let src = self.pointercast(src, self.type_ptr_to(self.type_void()));
13941394
let memcpy = self.context.get_builtin_function("memcpy");
1395-
1396-
// Create the memcpy call
1397-
let call = self.context.new_call(self.location, memcpy, &[dst, src, size]);
1398-
1399-
// TypeTree metadata for memcpy: when Enzyme encounters a memcpy during autodiff,
1400-
if let Some(_tt) = tt {
1401-
// TODO(KMJ-007): implement TypeTree support for gcc backend
1402-
// For now, we just ignore the TypeTree since gcc backend doesn't support autodiff yet
1403-
// When autodiff support is added to gcc backend, this should attach TypeTree information
1404-
// as function attributes similar to how LLVM backend does it.
1405-
}
14061395
// TODO(antoyo): handle aligns and is_volatile.
1407-
self.block.add_eval(self.location, call);
1396+
self.block.add_eval(
1397+
self.location,
1398+
self.context.new_call(self.location, memcpy, &[dst, src, size]),
1399+
);
14081400
}
14091401

14101402
fn memmove(

compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs

Lines changed: 97 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,32 @@
11
#![expect(dead_code)]
22

3-
use libc::{c_char, c_uint};
3+
use libc::{c_char, c_uint, size_t};
44

55
use super::MetadataKindId;
6-
use super::ffi::{AttributeKind, BasicBlock, Metadata, Module, Type, Value};
6+
use super::ffi::{AttributeKind, BasicBlock, Context, Metadata, Module, Type, Value};
77
use crate::llvm::{Bool, Builder};
88

9+
// TypeTree types
10+
pub(crate) type CTypeTreeRef = *mut EnzymeTypeTree;
11+
12+
#[repr(C)]
13+
#[derive(Debug, Copy, Clone)]
14+
pub(crate) struct EnzymeTypeTree {
15+
_unused: [u8; 0],
16+
}
17+
18+
#[repr(u32)]
19+
#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
20+
pub(crate) enum CConcreteType {
21+
DT_Anything = 0,
22+
DT_Integer = 1,
23+
DT_Pointer = 2,
24+
DT_Half = 3,
25+
DT_Float = 4,
26+
DT_Double = 5,
27+
DT_Unknown = 6,
28+
}
29+
930
#[link(name = "llvm-wrapper", kind = "static")]
1031
unsafe extern "C" {
1132
// Enzyme
@@ -66,12 +87,35 @@ pub(crate) use self::Enzyme_AD::*;
6687
pub(crate) mod Enzyme_AD {
6788
use std::ffi::{CString, c_char};
6889

69-
use libc::c_void;
90+
use libc::{c_void, size_t};
91+
92+
use super::{CConcreteType, CTypeTreeRef, Context};
7093

7194
unsafe extern "C" {
7295
pub(crate) fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8);
7396
pub(crate) fn EnzymeSetCLString(arg1: *mut ::std::os::raw::c_void, arg2: *const c_char);
7497
}
98+
99+
// TypeTree functions
100+
unsafe extern "C" {
101+
pub(crate) fn EnzymeNewTypeTree() -> CTypeTreeRef;
102+
pub(crate) fn EnzymeNewTypeTreeCT(arg1: CConcreteType, ctx: &Context) -> CTypeTreeRef;
103+
pub(crate) fn EnzymeNewTypeTreeTR(arg1: CTypeTreeRef) -> CTypeTreeRef;
104+
pub(crate) fn EnzymeFreeTypeTree(CTT: CTypeTreeRef);
105+
pub(crate) fn EnzymeMergeTypeTree(arg1: CTypeTreeRef, arg2: CTypeTreeRef) -> bool;
106+
pub(crate) fn EnzymeTypeTreeOnlyEq(arg1: CTypeTreeRef, pos: i64);
107+
pub(crate) fn EnzymeTypeTreeData0Eq(arg1: CTypeTreeRef);
108+
pub(crate) fn EnzymeTypeTreeShiftIndiciesEq(
109+
arg1: CTypeTreeRef,
110+
data_layout: *const c_char,
111+
offset: i64,
112+
max_size: i64,
113+
add_offset: u64,
114+
);
115+
pub(crate) fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char;
116+
pub(crate) fn EnzymeTypeTreeToStringFree(arg1: *const c_char);
117+
}
118+
75119
unsafe extern "C" {
76120
static mut EnzymePrintPerf: c_void;
77121
static mut EnzymePrintActivity: c_void;
@@ -140,6 +184,56 @@ pub(crate) use self::Fallback_AD::*;
140184
#[cfg(not(llvm_enzyme))]
141185
pub(crate) mod Fallback_AD {
142186
#![allow(unused_variables)]
187+
188+
use super::{CConcreteType, CTypeTreeRef, Context};
189+
use libc::c_char;
190+
191+
// TypeTree function fallbacks
192+
pub(crate) fn EnzymeNewTypeTree() -> CTypeTreeRef {
193+
unimplemented!()
194+
}
195+
196+
pub(crate) fn EnzymeNewTypeTreeCT(arg1: CConcreteType, ctx: &Context) -> CTypeTreeRef {
197+
unimplemented!()
198+
}
199+
200+
pub(crate) fn EnzymeNewTypeTreeTR(arg1: CTypeTreeRef) -> CTypeTreeRef {
201+
unimplemented!()
202+
}
203+
204+
pub(crate) fn EnzymeFreeTypeTree(CTT: CTypeTreeRef) {
205+
unimplemented!()
206+
}
207+
208+
pub(crate) fn EnzymeMergeTypeTree(arg1: CTypeTreeRef, arg2: CTypeTreeRef) -> bool {
209+
unimplemented!()
210+
}
211+
212+
pub(crate) fn EnzymeTypeTreeOnlyEq(arg1: CTypeTreeRef, pos: i64) {
213+
unimplemented!()
214+
}
215+
216+
pub(crate) fn EnzymeTypeTreeData0Eq(arg1: CTypeTreeRef) {
217+
unimplemented!()
218+
}
219+
220+
pub(crate) fn EnzymeTypeTreeShiftIndiciesEq(
221+
arg1: CTypeTreeRef,
222+
data_layout: *const c_char,
223+
offset: i64,
224+
max_size: i64,
225+
add_offset: u64,
226+
) {
227+
unimplemented!()
228+
}
229+
230+
pub(crate) fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char {
231+
unimplemented!()
232+
}
233+
234+
pub(crate) fn EnzymeTypeTreeToStringFree(arg1: *const c_char) {
235+
unimplemented!()
236+
}
143237

144238
pub(crate) fn set_inline(val: bool) {
145239
unimplemented!()

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2680,27 +2680,36 @@ unsafe extern "C" {
26802680

26812681
// ========== ENZYME AUTODIFF FFI FUNCTIONS ==========
26822682

2683+
#[cfg(llvm_enzyme)]
26832684
// Enzyme Type Tree Functions (minimal set for TypeTree support)
26842685
pub(crate) fn EnzymeNewTypeTree() -> CTypeTreeRef;
2686+
#[cfg(llvm_enzyme)]
26852687
pub(crate) fn EnzymeFreeTypeTree(CTT: CTypeTreeRef);
2688+
#[cfg(llvm_enzyme)]
26862689
pub(crate) fn EnzymeNewTypeTreeCT(arg1: CConcreteType, ctx: &Context) -> CTypeTreeRef;
2690+
#[cfg(llvm_enzyme)]
26872691
pub(crate) fn EnzymeNewTypeTreeTR(arg1: CTypeTreeRef) -> CTypeTreeRef;
2692+
#[cfg(llvm_enzyme)]
26882693
pub(crate) fn EnzymeMergeTypeTree(arg1: CTypeTreeRef, arg2: CTypeTreeRef);
2694+
#[cfg(llvm_enzyme)]
26892695
pub(crate) fn EnzymeTypeTreeShiftIndiciesEq(
26902696
arg1: CTypeTreeRef,
26912697
data_layout: *const c_char,
26922698
offset: i64,
26932699
max_size: i64,
26942700
add_offset: u64,
26952701
);
2702+
#[cfg(llvm_enzyme)]
26962703
pub(crate) fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char;
2704+
#[cfg(llvm_enzyme)]
26972705
pub(crate) fn EnzymeTypeTreeToStringFree(arg1: *const c_char);
26982706
}
26992707

27002708
// ========== ENZYME TYPES AND ENUMS ==========
27012709

27022710
// Type Tree Support for Autodiff
27032711

2712+
#[cfg(llvm_enzyme)]
27042713
#[repr(u32)]
27052714
#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
27062715
pub(crate) enum CConcreteType {
@@ -2713,19 +2722,23 @@ pub(crate) enum CConcreteType {
27132722
DT_Unknown = 6,
27142723
}
27152724

2725+
#[cfg(llvm_enzyme)]
27162726
pub(crate) type CTypeTreeRef = *mut EnzymeTypeTree;
27172727

2728+
#[cfg(llvm_enzyme)]
27182729
#[repr(C)]
27192730
#[derive(Debug, Copy, Clone)]
27202731
pub(crate) struct EnzymeTypeTree {
27212732
_unused: [u8; 0],
27222733
}
27232734

27242735
// TypeTree wrapper for Rust-side type safety and memory management
2736+
#[cfg(llvm_enzyme)]
27252737
pub(crate) struct TypeTree {
27262738
pub(crate) inner: CTypeTreeRef,
27272739
}
27282740

2741+
#[cfg(llvm_enzyme)]
27292742
impl TypeTree {
27302743
pub(crate) fn new() -> TypeTree {
27312744
let inner = unsafe { EnzymeNewTypeTree() };
@@ -2769,13 +2782,15 @@ impl TypeTree {
27692782
}
27702783
}
27712784

2785+
#[cfg(llvm_enzyme)]
27722786
impl Clone for TypeTree {
27732787
fn clone(&self) -> Self {
27742788
let inner = unsafe { EnzymeNewTypeTreeTR(self.inner) };
27752789
TypeTree { inner }
27762790
}
27772791
}
27782792

2793+
#[cfg(llvm_enzyme)]
27792794
impl std::fmt::Display for TypeTree {
27802795
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27812796
let ptr = unsafe { EnzymeTypeTreeToString(self.inner) };
@@ -2792,12 +2807,14 @@ impl std::fmt::Display for TypeTree {
27922807
}
27932808
}
27942809

2810+
#[cfg(llvm_enzyme)]
27952811
impl std::fmt::Debug for TypeTree {
27962812
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27972813
<Self as std::fmt::Display>::fmt(self, f)
27982814
}
27992815
}
28002816

2817+
#[cfg(llvm_enzyme)]
28012818
impl Drop for TypeTree {
28022819
fn drop(&mut self) {
28032820
unsafe { EnzymeFreeTypeTree(self.inner) }
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
CHECK: test_memcpy - {[-1]:Float@double} |{[-1]:Pointer, [-1,0]:Float@double, [-1,8]:Float@double, [-1,16]:Float@double, [-1,24]:Float@double, [-1,32]:Float@double, [-1,40]:Float@double, [-1,48]:Float@double, [-1,56]:Float@double}:{}
2+
3+
CHECK-DAG: call void @llvm.memcpy{{.*}}!enzyme_type
4+
5+
CHECK-DAG: load double{{.*}}: {[-1]:Float@double}
6+
7+
CHECK-DAG: fmul double{{.*}}: {[-1]:Float@double}
8+
9+
CHECK-DAG: fadd double{{.*}}: {[-1]:Float@double}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#![feature(autodiff)]
2+
3+
use std::autodiff::autodiff_reverse;
4+
use std::ptr;
5+
6+
#[autodiff_reverse(d_test_memcpy, Duplicated, Active)]
7+
#[no_mangle]
8+
fn test_memcpy(input: &[f64; 8]) -> f64 {
9+
let mut local_data = [0.0f64; 8];
10+
11+
unsafe {
12+
ptr::copy_nonoverlapping(input.as_ptr(), local_data.as_mut_ptr(), 8);
13+
}
14+
15+
let mut result = 0.0;
16+
for i in 0..8 {
17+
result += local_data[i] * local_data[i];
18+
}
19+
20+
result
21+
}
22+
23+
fn main() {
24+
let input = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
25+
let mut d_input = [0.0; 8];
26+
let result = test_memcpy(&input);
27+
let result_d = d_test_memcpy(&input, &mut d_input, 1.0);
28+
29+
assert_eq!(result, result_d);
30+
println!("Memcpy test passed: result = {}", result);
31+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
//@ needs-enzyme
2+
//@ ignore-cross-compile
3+
4+
use std::fs;
5+
6+
use run_make_support::{llvm_filecheck, rfs, rustc};
7+
8+
fn main() {
9+
// Compile the Rust file with the required flags, capturing both stdout and stderr
10+
let output = rustc()
11+
.input("memcpy.rs")
12+
.arg("-Zautodiff=Enable,PrintTAFn=test_memcpy")
13+
.arg("-Zautodiff=NoPostopt")
14+
.opt_level("3")
15+
.arg("-Clto=fat")
16+
.arg("-g")
17+
.run();
18+
19+
let stdout = output.stdout_utf8();
20+
let stderr = output.stderr_utf8();
21+
22+
// Write the outputs to files
23+
rfs::write("memcpy.stdout", stdout);
24+
rfs::write("memcpy.stderr", stderr);
25+
26+
// Run FileCheck on the stdout using the check file
27+
llvm_filecheck().patterns("memcpy.check").stdin_buf(rfs::read("memcpy.stdout")).run();
28+
}

0 commit comments

Comments
 (0)