Skip to content

Commit 8777369

Browse files
committed
wip, fixing ci
1 parent cf93528 commit 8777369

File tree

6 files changed

+187
-16
lines changed

6 files changed

+187
-16
lines changed

compiler/rustc_codegen_gcc/src/builder.rs

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1384,27 +1384,19 @@ impl<'a, 'gcc, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'gcc, 'tcx> {
13841384
_src_align: Align,
13851385
size: RValue<'gcc>,
13861386
flags: MemFlags,
1387-
tt: Option<FncTree>,
1387+
_tt: Option<FncTree>,
13881388
) {
13891389
assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported");
13901390
let size = self.intcast(size, self.type_size_t(), false);
13911391
let _is_volatile = flags.contains(MemFlags::VOLATILE);
13921392
let dst = self.pointercast(dst, self.type_i8p());
13931393
let src = self.pointercast(src, self.type_ptr_to(self.type_void()));
13941394
let memcpy = self.context.get_builtin_function("memcpy");
1395-
1396-
// Create the memcpy call
1397-
let call = self.context.new_call(self.location, memcpy, &[dst, src, size]);
1398-
1399-
// TypeTree metadata for memcpy: when Enzyme encounters a memcpy during autodiff,
1400-
if let Some(_tt) = tt {
1401-
// TODO(KMJ-007): implement TypeTree support for gcc backend
1402-
// For now, we just ignore the TypeTree since gcc backend doesn't support autodiff yet
1403-
// When autodiff support is added to gcc backend, this should attach TypeTree information
1404-
// as function attributes similar to how LLVM backend does it.
1405-
}
14061395
// TODO(antoyo): handle aligns and is_volatile.
1407-
self.block.add_eval(self.location, call);
1396+
self.block.add_eval(
1397+
self.location,
1398+
self.context.new_call(self.location, memcpy, &[dst, src, size]),
1399+
);
14081400
}
14091401

14101402
fn memmove(

compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs

Lines changed: 97 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,32 @@
11
#![expect(dead_code)]
22

3-
use libc::{c_char, c_uint};
3+
use libc::{c_char, c_uint, size_t};
44

55
use super::MetadataKindId;
6-
use super::ffi::{AttributeKind, BasicBlock, Metadata, Module, Type, Value};
6+
use super::ffi::{AttributeKind, BasicBlock, Context, Metadata, Module, Type, Value};
77
use crate::llvm::{Bool, Builder};
88

9+
// TypeTree types
10+
pub(crate) type CTypeTreeRef = *mut EnzymeTypeTree;
11+
12+
#[repr(C)]
13+
#[derive(Debug, Copy, Clone)]
14+
pub(crate) struct EnzymeTypeTree {
15+
_unused: [u8; 0],
16+
}
17+
18+
#[repr(u32)]
19+
#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
20+
pub(crate) enum CConcreteType {
21+
DT_Anything = 0,
22+
DT_Integer = 1,
23+
DT_Pointer = 2,
24+
DT_Half = 3,
25+
DT_Float = 4,
26+
DT_Double = 5,
27+
DT_Unknown = 6,
28+
}
29+
930
#[link(name = "llvm-wrapper", kind = "static")]
1031
unsafe extern "C" {
1132
// Enzyme
@@ -66,12 +87,35 @@ pub(crate) use self::Enzyme_AD::*;
6687
pub(crate) mod Enzyme_AD {
6788
use std::ffi::{CString, c_char};
6889

69-
use libc::c_void;
90+
use libc::{c_void, size_t};
91+
92+
use super::{CConcreteType, CTypeTreeRef, Context};
7093

7194
unsafe extern "C" {
7295
pub(crate) fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8);
7396
pub(crate) fn EnzymeSetCLString(arg1: *mut ::std::os::raw::c_void, arg2: *const c_char);
7497
}
98+
99+
// TypeTree functions
100+
unsafe extern "C" {
101+
pub(crate) fn EnzymeNewTypeTree() -> CTypeTreeRef;
102+
pub(crate) fn EnzymeNewTypeTreeCT(arg1: CConcreteType, ctx: &Context) -> CTypeTreeRef;
103+
pub(crate) fn EnzymeNewTypeTreeTR(arg1: CTypeTreeRef) -> CTypeTreeRef;
104+
pub(crate) fn EnzymeFreeTypeTree(CTT: CTypeTreeRef);
105+
pub(crate) fn EnzymeMergeTypeTree(arg1: CTypeTreeRef, arg2: CTypeTreeRef) -> bool;
106+
pub(crate) fn EnzymeTypeTreeOnlyEq(arg1: CTypeTreeRef, pos: i64);
107+
pub(crate) fn EnzymeTypeTreeData0Eq(arg1: CTypeTreeRef);
108+
pub(crate) fn EnzymeTypeTreeShiftIndiciesEq(
109+
arg1: CTypeTreeRef,
110+
data_layout: *const c_char,
111+
offset: i64,
112+
max_size: i64,
113+
add_offset: u64,
114+
);
115+
pub(crate) fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char;
116+
pub(crate) fn EnzymeTypeTreeToStringFree(arg1: *const c_char);
117+
}
118+
75119
unsafe extern "C" {
76120
static mut EnzymePrintPerf: c_void;
77121
static mut EnzymePrintActivity: c_void;
@@ -140,6 +184,56 @@ pub(crate) use self::Fallback_AD::*;
140184
#[cfg(not(llvm_enzyme))]
141185
pub(crate) mod Fallback_AD {
142186
#![allow(unused_variables)]
187+
188+
use super::{CConcreteType, CTypeTreeRef, Context};
189+
use libc::c_char;
190+
191+
// TypeTree function fallbacks
192+
pub(crate) fn EnzymeNewTypeTree() -> CTypeTreeRef {
193+
unimplemented!()
194+
}
195+
196+
pub(crate) fn EnzymeNewTypeTreeCT(arg1: CConcreteType, ctx: &Context) -> CTypeTreeRef {
197+
unimplemented!()
198+
}
199+
200+
pub(crate) fn EnzymeNewTypeTreeTR(arg1: CTypeTreeRef) -> CTypeTreeRef {
201+
unimplemented!()
202+
}
203+
204+
pub(crate) fn EnzymeFreeTypeTree(CTT: CTypeTreeRef) {
205+
unimplemented!()
206+
}
207+
208+
pub(crate) fn EnzymeMergeTypeTree(arg1: CTypeTreeRef, arg2: CTypeTreeRef) -> bool {
209+
unimplemented!()
210+
}
211+
212+
pub(crate) fn EnzymeTypeTreeOnlyEq(arg1: CTypeTreeRef, pos: i64) {
213+
unimplemented!()
214+
}
215+
216+
pub(crate) fn EnzymeTypeTreeData0Eq(arg1: CTypeTreeRef) {
217+
unimplemented!()
218+
}
219+
220+
pub(crate) fn EnzymeTypeTreeShiftIndiciesEq(
221+
arg1: CTypeTreeRef,
222+
data_layout: *const c_char,
223+
offset: i64,
224+
max_size: i64,
225+
add_offset: u64,
226+
) {
227+
unimplemented!()
228+
}
229+
230+
pub(crate) fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char {
231+
unimplemented!()
232+
}
233+
234+
pub(crate) fn EnzymeTypeTreeToStringFree(arg1: *const c_char) {
235+
unimplemented!()
236+
}
143237

144238
pub(crate) fn set_inline(val: bool) {
145239
unimplemented!()

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2676,27 +2676,36 @@ unsafe extern "C" {
26762676

26772677
// ========== ENZYME AUTODIFF FFI FUNCTIONS ==========
26782678

2679+
#[cfg(llvm_enzyme)]
26792680
// Enzyme Type Tree Functions (minimal set for TypeTree support)
26802681
pub(crate) fn EnzymeNewTypeTree() -> CTypeTreeRef;
2682+
#[cfg(llvm_enzyme)]
26812683
pub(crate) fn EnzymeFreeTypeTree(CTT: CTypeTreeRef);
2684+
#[cfg(llvm_enzyme)]
26822685
pub(crate) fn EnzymeNewTypeTreeCT(arg1: CConcreteType, ctx: &Context) -> CTypeTreeRef;
2686+
#[cfg(llvm_enzyme)]
26832687
pub(crate) fn EnzymeNewTypeTreeTR(arg1: CTypeTreeRef) -> CTypeTreeRef;
2688+
#[cfg(llvm_enzyme)]
26842689
pub(crate) fn EnzymeMergeTypeTree(arg1: CTypeTreeRef, arg2: CTypeTreeRef);
2690+
#[cfg(llvm_enzyme)]
26852691
pub(crate) fn EnzymeTypeTreeShiftIndiciesEq(
26862692
arg1: CTypeTreeRef,
26872693
data_layout: *const c_char,
26882694
offset: i64,
26892695
max_size: i64,
26902696
add_offset: u64,
26912697
);
2698+
#[cfg(llvm_enzyme)]
26922699
pub(crate) fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char;
2700+
#[cfg(llvm_enzyme)]
26932701
pub(crate) fn EnzymeTypeTreeToStringFree(arg1: *const c_char);
26942702
}
26952703

26962704
// ========== ENZYME TYPES AND ENUMS ==========
26972705

26982706
// Type Tree Support for Autodiff
26992707

2708+
#[cfg(llvm_enzyme)]
27002709
#[repr(u32)]
27012710
#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
27022711
pub(crate) enum CConcreteType {
@@ -2709,19 +2718,23 @@ pub(crate) enum CConcreteType {
27092718
DT_Unknown = 6,
27102719
}
27112720

2721+
#[cfg(llvm_enzyme)]
27122722
pub(crate) type CTypeTreeRef = *mut EnzymeTypeTree;
27132723

2724+
#[cfg(llvm_enzyme)]
27142725
#[repr(C)]
27152726
#[derive(Debug, Copy, Clone)]
27162727
pub(crate) struct EnzymeTypeTree {
27172728
_unused: [u8; 0],
27182729
}
27192730

27202731
// TypeTree wrapper for Rust-side type safety and memory management
2732+
#[cfg(llvm_enzyme)]
27212733
pub(crate) struct TypeTree {
27222734
pub(crate) inner: CTypeTreeRef,
27232735
}
27242736

2737+
#[cfg(llvm_enzyme)]
27252738
impl TypeTree {
27262739
pub(crate) fn new() -> TypeTree {
27272740
let inner = unsafe { EnzymeNewTypeTree() };
@@ -2765,13 +2778,15 @@ impl TypeTree {
27652778
}
27662779
}
27672780

2781+
#[cfg(llvm_enzyme)]
27682782
impl Clone for TypeTree {
27692783
fn clone(&self) -> Self {
27702784
let inner = unsafe { EnzymeNewTypeTreeTR(self.inner) };
27712785
TypeTree { inner }
27722786
}
27732787
}
27742788

2789+
#[cfg(llvm_enzyme)]
27752790
impl std::fmt::Display for TypeTree {
27762791
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27772792
let ptr = unsafe { EnzymeTypeTreeToString(self.inner) };
@@ -2788,12 +2803,14 @@ impl std::fmt::Display for TypeTree {
27882803
}
27892804
}
27902805

2806+
#[cfg(llvm_enzyme)]
27912807
impl std::fmt::Debug for TypeTree {
27922808
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27932809
<Self as std::fmt::Display>::fmt(self, f)
27942810
}
27952811
}
27962812

2813+
#[cfg(llvm_enzyme)]
27972814
impl Drop for TypeTree {
27982815
fn drop(&mut self) {
27992816
unsafe { EnzymeFreeTypeTree(self.inner) }
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
CHECK: test_memcpy - {[-1]:Float@double} |{[-1]:Pointer, [-1,0]:Float@double, [-1,8]:Float@double, [-1,16]:Float@double, [-1,24]:Float@double, [-1,32]:Float@double, [-1,40]:Float@double, [-1,48]:Float@double, [-1,56]:Float@double}:{}
2+
3+
CHECK-DAG: call void @llvm.memcpy{{.*}}!enzyme_type
4+
5+
CHECK-DAG: load double{{.*}}: {[-1]:Float@double}
6+
7+
CHECK-DAG: fmul double{{.*}}: {[-1]:Float@double}
8+
9+
CHECK-DAG: fadd double{{.*}}: {[-1]:Float@double}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#![feature(autodiff)]
2+
3+
use std::autodiff::autodiff_reverse;
4+
use std::ptr;
5+
6+
#[autodiff_reverse(d_test_memcpy, Duplicated, Active)]
7+
#[no_mangle]
8+
fn test_memcpy(input: &[f64; 8]) -> f64 {
9+
let mut local_data = [0.0f64; 8];
10+
11+
unsafe {
12+
ptr::copy_nonoverlapping(input.as_ptr(), local_data.as_mut_ptr(), 8);
13+
}
14+
15+
let mut result = 0.0;
16+
for i in 0..8 {
17+
result += local_data[i] * local_data[i];
18+
}
19+
20+
result
21+
}
22+
23+
fn main() {
24+
let input = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
25+
let mut d_input = [0.0; 8];
26+
let result = test_memcpy(&input);
27+
let result_d = d_test_memcpy(&input, &mut d_input, 1.0);
28+
29+
assert_eq!(result, result_d);
30+
println!("Memcpy test passed: result = {}", result);
31+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
//@ needs-enzyme
2+
//@ ignore-cross-compile
3+
4+
use std::fs;
5+
6+
use run_make_support::{llvm_filecheck, rfs, rustc};
7+
8+
fn main() {
9+
// Compile the Rust file with the required flags, capturing both stdout and stderr
10+
let output = rustc()
11+
.input("memcpy.rs")
12+
.arg("-Zautodiff=Enable,PrintTAFn=test_memcpy")
13+
.arg("-Zautodiff=NoPostopt")
14+
.opt_level("3")
15+
.arg("-Clto=fat")
16+
.arg("-g")
17+
.run();
18+
19+
let stdout = output.stdout_utf8();
20+
let stderr = output.stderr_utf8();
21+
22+
// Write the outputs to files
23+
rfs::write("memcpy.stdout", stdout);
24+
rfs::write("memcpy.stderr", stderr);
25+
26+
// Run FileCheck on the stdout using the check file
27+
llvm_filecheck().patterns("memcpy.check").stdin_buf(rfs::read("memcpy.stdout")).run();
28+
}

0 commit comments

Comments
 (0)