Skip to content

Commit 375e14e

Browse files
committed
Add TypeTree metadata attachment for autodiff
- Add F128 support to TypeTree Kind enum - Implement TypeTree FFI bindings and conversion functions - Add typetree.rs module for metadata attachment to LLVM functions - Integrate TypeTree generation with autodiff intrinsic pipeline - Support scalar types: f32, f64, integers, f16, f128 - Attach enzyme_type attributes as LLVM string metadata for Enzyme Signed-off-by: Karan Janthe <[email protected]>
1 parent e1258e7 commit 375e14e

File tree

7 files changed

+343
-14
lines changed

7 files changed

+343
-14
lines changed

compiler/rustc_ast/src/expand/typetree.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ pub enum Kind {
3131
Half,
3232
Float,
3333
Double,
34+
F128,
3435
Unknown,
3536
}
3637

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use std::ptr;
22

33
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode};
4+
use rustc_ast::expand::typetree::FncTree;
45
use rustc_codegen_ssa::common::TypeKind;
56
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
67
use rustc_middle::ty::{Instance, PseudoCanonicalInput, TyCtxt, TypingEnv};
@@ -294,6 +295,7 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
294295
fn_args: &[&'ll Value],
295296
attrs: AutoDiffAttrs,
296297
dest: PlaceRef<'tcx, &'ll Value>,
298+
fnc_tree: FncTree,
297299
) {
298300
// We have to pick the name depending on whether we want forward or reverse mode autodiff.
299301
let mut ad_name: String = match attrs.mode {
@@ -370,6 +372,10 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
370372
fn_args,
371373
);
372374

375+
if !fnc_tree.args.is_empty() || !fnc_tree.ret.0.is_empty() {
376+
crate::typetree::add_tt(cx.llmod, cx.llcx, fn_to_diff, fnc_tree);
377+
}
378+
373379
let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None);
374380

375381
builder.store_to_place(call, dest.val);

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1213,6 +1213,9 @@ fn codegen_autodiff<'ll, 'tcx>(
12131213
&mut diff_attrs.input_activity,
12141214
);
12151215

1216+
let fnc_tree =
1217+
rustc_middle::ty::fnc_typetrees(tcx, fn_source.ty(tcx, TypingEnv::fully_monomorphized()));
1218+
12161219
// Build body
12171220
generate_enzyme_call(
12181221
bx,
@@ -1223,6 +1226,7 @@ fn codegen_autodiff<'ll, 'tcx>(
12231226
&val_arr,
12241227
diff_attrs.clone(),
12251228
result,
1229+
fnc_tree,
12261230
);
12271231
}
12281232

compiler/rustc_codegen_llvm/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ mod llvm_util;
6767
mod mono_item;
6868
mod type_;
6969
mod type_of;
70+
mod typetree;
7071
mod va_arg;
7172
mod value;
7273

compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs

Lines changed: 181 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,35 @@
33
use libc::{c_char, c_uint};
44

55
use super::MetadataKindId;
6-
use super::ffi::{AttributeKind, BasicBlock, Metadata, Module, Type, Value};
6+
use super::ffi::{AttributeKind, BasicBlock, Context, Metadata, Module, Type, Value};
77
use crate::llvm::{Bool, Builder};
88

9+
// TypeTree types
10+
pub(crate) type CTypeTreeRef = *mut EnzymeTypeTree;
11+
12+
#[repr(C)]
13+
#[derive(Debug, Copy, Clone)]
14+
pub(crate) struct EnzymeTypeTree {
15+
_unused: [u8; 0],
16+
}
17+
18+
#[repr(u32)]
19+
#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
20+
#[allow(non_camel_case_types)]
21+
pub(crate) enum CConcreteType {
22+
DT_Anything = 0,
23+
DT_Integer = 1,
24+
DT_Pointer = 2,
25+
DT_Half = 3,
26+
DT_Float = 4,
27+
DT_Double = 5,
28+
DT_Unknown = 6,
29+
}
30+
31+
pub(crate) struct TypeTree {
32+
pub(crate) inner: CTypeTreeRef,
33+
}
34+
935
#[link(name = "llvm-wrapper", kind = "static")]
1036
unsafe extern "C" {
1137
// Enzyme
@@ -68,10 +94,33 @@ pub(crate) mod Enzyme_AD {
6894

6995
use libc::c_void;
7096

97+
use super::{CConcreteType, CTypeTreeRef, Context};
98+
7199
unsafe extern "C" {
72100
pub(crate) fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8);
73101
pub(crate) fn EnzymeSetCLString(arg1: *mut ::std::os::raw::c_void, arg2: *const c_char);
74102
}
103+
104+
// TypeTree functions
105+
unsafe extern "C" {
106+
pub(crate) fn EnzymeNewTypeTree() -> CTypeTreeRef;
107+
pub(crate) fn EnzymeNewTypeTreeCT(arg1: CConcreteType, ctx: &Context) -> CTypeTreeRef;
108+
pub(crate) fn EnzymeNewTypeTreeTR(arg1: CTypeTreeRef) -> CTypeTreeRef;
109+
pub(crate) fn EnzymeFreeTypeTree(CTT: CTypeTreeRef);
110+
pub(crate) fn EnzymeMergeTypeTree(arg1: CTypeTreeRef, arg2: CTypeTreeRef) -> bool;
111+
pub(crate) fn EnzymeTypeTreeOnlyEq(arg1: CTypeTreeRef, pos: i64);
112+
pub(crate) fn EnzymeTypeTreeData0Eq(arg1: CTypeTreeRef);
113+
pub(crate) fn EnzymeTypeTreeShiftIndiciesEq(
114+
arg1: CTypeTreeRef,
115+
data_layout: *const c_char,
116+
offset: i64,
117+
max_size: i64,
118+
add_offset: u64,
119+
);
120+
pub(crate) fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char;
121+
pub(crate) fn EnzymeTypeTreeToStringFree(arg1: *const c_char);
122+
}
123+
75124
unsafe extern "C" {
76125
static mut EnzymePrintPerf: c_void;
77126
static mut EnzymePrintActivity: c_void;
@@ -141,6 +190,57 @@ pub(crate) use self::Fallback_AD::*;
141190
pub(crate) mod Fallback_AD {
142191
#![allow(unused_variables)]
143192

193+
use libc::c_char;
194+
195+
use super::{CConcreteType, CTypeTreeRef, Context};
196+
197+
// TypeTree function fallbacks
198+
pub(crate) unsafe fn EnzymeNewTypeTree() -> CTypeTreeRef {
199+
unimplemented!()
200+
}
201+
202+
pub(crate) unsafe fn EnzymeNewTypeTreeCT(arg1: CConcreteType, ctx: &Context) -> CTypeTreeRef {
203+
unimplemented!()
204+
}
205+
206+
pub(crate) unsafe fn EnzymeNewTypeTreeTR(arg1: CTypeTreeRef) -> CTypeTreeRef {
207+
unimplemented!()
208+
}
209+
210+
pub(crate) unsafe fn EnzymeFreeTypeTree(CTT: CTypeTreeRef) {
211+
unimplemented!()
212+
}
213+
214+
pub(crate) unsafe fn EnzymeMergeTypeTree(arg1: CTypeTreeRef, arg2: CTypeTreeRef) -> bool {
215+
unimplemented!()
216+
}
217+
218+
pub(crate) unsafe fn EnzymeTypeTreeOnlyEq(arg1: CTypeTreeRef, pos: i64) {
219+
unimplemented!()
220+
}
221+
222+
pub(crate) unsafe fn EnzymeTypeTreeData0Eq(arg1: CTypeTreeRef) {
223+
unimplemented!()
224+
}
225+
226+
pub(crate) unsafe fn EnzymeTypeTreeShiftIndiciesEq(
227+
arg1: CTypeTreeRef,
228+
data_layout: *const c_char,
229+
offset: i64,
230+
max_size: i64,
231+
add_offset: u64,
232+
) {
233+
unimplemented!()
234+
}
235+
236+
pub(crate) unsafe fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char {
237+
unimplemented!()
238+
}
239+
240+
pub(crate) unsafe fn EnzymeTypeTreeToStringFree(arg1: *const c_char) {
241+
unimplemented!()
242+
}
243+
144244
pub(crate) fn set_inline(val: bool) {
145245
unimplemented!()
146246
}
@@ -169,3 +269,83 @@ pub(crate) mod Fallback_AD {
169269
unimplemented!()
170270
}
171271
}
272+
273+
impl TypeTree {
274+
pub(crate) fn new() -> TypeTree {
275+
let inner = unsafe { EnzymeNewTypeTree() };
276+
TypeTree { inner }
277+
}
278+
279+
pub(crate) fn from_type(t: CConcreteType, ctx: &Context) -> TypeTree {
280+
let inner = unsafe { EnzymeNewTypeTreeCT(t, ctx) };
281+
TypeTree { inner }
282+
}
283+
284+
pub(crate) fn merge(self, other: Self) -> Self {
285+
unsafe {
286+
EnzymeMergeTypeTree(self.inner, other.inner);
287+
}
288+
drop(other);
289+
self
290+
}
291+
292+
#[must_use]
293+
pub(crate) fn shift(
294+
self,
295+
layout: &str,
296+
offset: isize,
297+
max_size: isize,
298+
add_offset: usize,
299+
) -> Self {
300+
let layout = std::ffi::CString::new(layout).unwrap();
301+
302+
unsafe {
303+
EnzymeTypeTreeShiftIndiciesEq(
304+
self.inner,
305+
layout.as_ptr(),
306+
offset as i64,
307+
max_size as i64,
308+
add_offset as u64,
309+
);
310+
}
311+
312+
self
313+
}
314+
}
315+
316+
impl Clone for TypeTree {
317+
fn clone(&self) -> Self {
318+
let inner = unsafe { EnzymeNewTypeTreeTR(self.inner) };
319+
TypeTree { inner }
320+
}
321+
}
322+
323+
impl std::fmt::Display for TypeTree {
324+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
325+
let ptr = unsafe { EnzymeTypeTreeToString(self.inner) };
326+
let cstr = unsafe { std::ffi::CStr::from_ptr(ptr) };
327+
match cstr.to_str() {
328+
Ok(x) => write!(f, "{}", x)?,
329+
Err(err) => write!(f, "could not parse: {}", err)?,
330+
}
331+
332+
// delete C string pointer
333+
unsafe {
334+
EnzymeTypeTreeToStringFree(ptr);
335+
}
336+
337+
Ok(())
338+
}
339+
}
340+
341+
impl std::fmt::Debug for TypeTree {
342+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
343+
<Self as std::fmt::Display>::fmt(self, f)
344+
}
345+
}
346+
347+
impl Drop for TypeTree {
348+
fn drop(&mut self) {
349+
unsafe { EnzymeFreeTypeTree(self.inner) }
350+
}
351+
}

0 commit comments

Comments
 (0)