Skip to content

Commit 017cd75

Browse files
committed
add non-working enzyme_buffer mode, with incorrect logic
1 parent b405a8c commit 017cd75

File tree

6 files changed

+118
-25
lines changed

6 files changed

+118
-25
lines changed

compiler/rustc_ast/src/expand/autodiff_attrs.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,9 @@ pub enum DiffActivity {
7373
FakeActivitySize(Option<u32>),
7474
/// Batching mode A
7575
Vector,
76-
/// Batching mode B, missing implementation (only available as part of autodiff through dupv)
77-
// Leaf,
78-
/// Batching mode C, scalar.
76+
/// Batching mode B, equivalent to *v modes above
77+
Buffer,
78+
/// "Batching" mode C, scalar. Not batched.
7979
Scalar,
8080
}
8181

@@ -167,6 +167,7 @@ pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
167167
// We just compute derivatives wrt. the inputs, so we can ignore the return value.
168168
activity == DiffActivity::Const
169169
|| activity == DiffActivity::Vector
170+
|| activity == DiffActivity::Buffer
170171
|| activity == DiffActivity::Scalar
171172
}
172173
}
@@ -205,7 +206,7 @@ pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool {
205206
DiffMode::Batch => {
206207
// Batching is a special case, since we don't compute derivatives wrt. the return value.
207208
// We just compute derivatives wrt. the inputs, so we can ignore the return value.
208-
matches!(activity, Const | Vector)
209+
matches!(activity, Const | Vector | Buffer)
209210
}
210211
};
211212
}
@@ -225,6 +226,7 @@ impl Display for DiffActivity {
225226
DiffActivity::DuplicatedOnly => write!(f, "DuplicatedOnly"),
226227
DiffActivity::FakeActivitySize(s) => write!(f, "FakeActivitySize({:?})", s),
227228
DiffActivity::Vector => write!(f, "Vector"),
229+
DiffActivity::Buffer => write!(f, "Buffer"),
228230
DiffActivity::Scalar => write!(f, "Scalar"),
229231
}
230232
}
@@ -261,6 +263,7 @@ impl FromStr for DiffActivity {
261263
"DuplicatedOnly" => Ok(DiffActivity::DuplicatedOnly),
262264
"Scalar" => Ok(DiffActivity::Scalar),
263265
"Vector" => Ok(DiffActivity::Vector),
266+
"Buffer" => Ok(DiffActivity::Buffer),
264267
_ => Err(()),
265268
}
266269
}

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -884,7 +884,7 @@ mod llvm_enzyme {
884884
DiffActivity::None | DiffActivity::FakeActivitySize(_) => {
885885
panic!("Should not happen");
886886
}
887-
DiffActivity::Vector | DiffActivity::Scalar => todo!()
887+
DiffActivity::Vector | DiffActivity::Scalar | DiffActivity::Buffer => todo!()
888888
}
889889
if let PatKind::Ident(_, ident, _) = arg.pat.kind {
890890
idents.push(ident.clone());

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

Lines changed: 69 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -76,15 +76,18 @@ fn match_args_from_caller_to_enzyme<'ll>(
7676
outer_pos = 1;
7777
}
7878

79+
// Autodiff activities
7980
let enzyme_const = cx.create_metadata("enzyme_const".to_string()).unwrap();
8081
let enzyme_out = cx.create_metadata("enzyme_out".to_string()).unwrap();
8182
let enzyme_dup = cx.create_metadata("enzyme_dup".to_string()).unwrap();
8283
let enzyme_dupv = cx.create_metadata("enzyme_dupv".to_string()).unwrap();
8384
let enzyme_dupnoneed = cx.create_metadata("enzyme_dupnoneed".to_string()).unwrap();
8485
let enzyme_dupnoneedv = cx.create_metadata("enzyme_dupnoneedv".to_string()).unwrap();
8586

87+
// Batching activities
8688
let enzyme_scalar = cx.create_metadata("enzyme_scalar".to_string()).unwrap();
8789
let enzyme_vector = cx.create_metadata("enzyme_vector".to_string()).unwrap();
90+
let enzyme_buffer = cx.create_metadata("enzyme_buffer".to_string()).unwrap();
8891

8992
while activity_pos < inputs.len() {
9093
let diff_activity = inputs[activity_pos as usize];
@@ -103,16 +106,17 @@ fn match_args_from_caller_to_enzyme<'ll>(
103106
DiffActivity::DuplicatedOnly => (enzyme_dupnoneed, true),
104107
DiffActivity::FakeActivitySize(_) => (enzyme_const, false),
105108
DiffActivity::Vector => (enzyme_vector, true),
109+
DiffActivity::Buffer => (enzyme_buffer, false),
106110
DiffActivity::Scalar => (enzyme_scalar, true),
107111
};
108-
let no_autodiff_only_batching = matches!(diff_activity, DiffActivity::Scalar | DiffActivity::Vector);
112+
let no_autodiff_only_batching = matches!(diff_activity, DiffActivity::Scalar | DiffActivity::Vector | DiffActivity::Buffer);
109113
let outer_arg = outer_args[outer_pos];
110114
args.push(cx.get_metadata_value(activity));
111115
if matches!(diff_activity, DiffActivity::Dualv) {
112116
let next_outer_arg = outer_args[outer_pos + 1];
113117
let elem_bytes_size: u64 = match inputs[activity_pos + 1] {
114118
DiffActivity::FakeActivitySize(Some(s)) => s.into(),
115-
_ => bug!("incorrect Dualv handling recognized."),
119+
_ => bug!("incorrect Dualv/Batching handling recognized."),
116120
};
117121
// stride: sizeof(T) * n_elems.
118122
// n_elems is the next integer.
@@ -127,7 +131,53 @@ fn match_args_from_caller_to_enzyme<'ll>(
127131
};
128132
args.push(mul);
129133
}
134+
if matches!(diff_activity, DiffActivity::Buffer) {
135+
// There are various cases.
136+
// A) We look at a scalar float.
137+
// B) We look at a Vector/Array of floats (byVal). Not sure if this is valid.
138+
// C) We look at a ptr as part of a slice.
139+
// D) We look at a ptr as part of a raw pointer or reference.
140+
141+
let mut elem_offset = cx.get_const_i64(width.into());
142+
let outer_ty = cx.val_ty(outer_arg);
143+
dbg!(&outer_ty);
144+
let bit_width = if cx.is_float_type(outer_ty) {
145+
cx.float_width(outer_ty)
146+
} else if cx.is_vec_or_array_type(outer_ty) {
147+
let elem_ty = cx.element_type(outer_ty);
148+
assert!(cx.is_float_type(elem_ty));
149+
let num_vec_elements = cx.vector_length(outer_ty);
150+
assert!(num_vec_elements == width as usize);
151+
dbg!(&num_vec_elements);
152+
cx.float_width(elem_ty)
153+
} else if cx.is_ptr_type(outer_ty) {
154+
if is_slice(activity_pos, inputs) {
155+
elem_offset = outer_args[outer_pos + 1];
156+
let elem_bytes_size: u64 = match inputs[activity_pos + 1] {
157+
DiffActivity::FakeActivitySize(Some(s)) => s.into(),
158+
_ => bug!("incorrect Dualv/Buffer handling recognized."),
159+
};
160+
elem_bytes_size as usize * 8
161+
} else {
162+
// raw pointer or ref, hence `num_elem` = 1
163+
unimplemented!()
164+
}
165+
} else {
166+
bug!("expected float or vector type, found {:?}", outer_ty);
167+
};
168+
let elem_bytes_size = bit_width as u64 / 8;
169+
let mul = unsafe {
170+
llvm::LLVMBuildMul(
171+
builder.llbuilder,
172+
cx.get_const_i64(elem_bytes_size),
173+
elem_offset,
174+
UNNAMED,
175+
)
176+
};
177+
args.push(mul);
178+
}
130179
args.push(outer_arg);
180+
dbg!(&args);
131181
if duplicated {
132182
// We know that duplicated args by construction have a following argument,
133183
// so this can not be out of bounds.
@@ -136,17 +186,7 @@ fn match_args_from_caller_to_enzyme<'ll>(
136186
// FIXME(ZuseZ4): We should add support for Vec here too, but it's less urgent since
137187
// vectors behind references (&Vec<T>) are already supported. Users can not pass a
138188
// Vec by value for reverse mode, so this would only help forward mode autodiff.
139-
let slice = {
140-
if activity_pos + 1 >= inputs.len() {
141-
// If there is no arg following our ptr, it also can't be a slice,
142-
// since that would lead to a ptr, int pair.
143-
false
144-
} else {
145-
let next_activity = inputs[activity_pos + 1];
146-
// We analyze the MIR types and add this dummy activity if we visit a slice.
147-
matches!(next_activity, DiffActivity::FakeActivitySize(_))
148-
}
149-
};
189+
let slice = is_slice(activity_pos, &inputs);
150190
if slice {
151191
// A duplicated slice will have the following two outer_fn arguments:
152192
// (..., ptr1, int1, ptr2, int2, ...). We add the following llvm-ir to our __enzyme call:
@@ -209,6 +249,19 @@ fn match_args_from_caller_to_enzyme<'ll>(
209249
activity_pos += 1;
210250
}
211251
}
252+
dbg!("ending");
253+
}
254+
255+
fn is_slice(activity_pos: usize, inputs: &[DiffActivity]) -> bool {
256+
if activity_pos + 1 >= inputs.len() {
257+
// If there is no arg following our ptr, it also can't be a slice,
258+
// since that would lead to a ptr, int pair.
259+
false
260+
} else {
261+
let next_activity = inputs[activity_pos + 1];
262+
// We analyze the MIR types and add this dummy activity if we visit a slice.
263+
matches!(next_activity, DiffActivity::FakeActivitySize(_))
264+
}
212265
}
213266

214267
// On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input
@@ -426,6 +479,7 @@ fn generate_enzyme_call<'ll>(
426479

427480
let call = builder.call(enzyme_ty, ad_fn, &args, None);
428481

482+
dbg!(&call);
429483
// This part is a bit iffy. LLVM requires that a call to an inlineable function has some
430484
// metadata attached to it, but we just created this code oota. Given that the
431485
// differentiated function already has partly confusing metadata, and given that this
@@ -472,6 +526,7 @@ fn generate_enzyme_call<'ll>(
472526
} else {
473527
builder.ret(call);
474528
}
529+
dbg!("Still alive");
475530

476531
// Let's crash in case that we messed something up above and generated invalid IR.
477532
llvm::LLVMRustVerifyFunction(
@@ -531,6 +586,7 @@ pub(crate) fn differentiate<'ll>(
531586

532587
generate_enzyme_call(&cx, fn_def, fn_target, item.attrs.clone());
533588
}
589+
dbg!("lowered all");
534590

535591
// FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts
536592

compiler/rustc_codegen_llvm/src/type_.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,23 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
154154
)
155155
}
156156
}
157+
158+
pub(crate) fn is_float_type(&self, ty: &'ll Type) -> bool {
159+
matches!(
160+
self.type_kind(ty),
161+
TypeKind::Half | TypeKind::Float | TypeKind::Double | TypeKind::X86_FP80
162+
| TypeKind::FP128 | TypeKind::PPC_FP128
163+
)
164+
}
165+
166+
pub(crate) fn is_vec_or_array_type(&self, ty: &'ll Type) -> bool {
167+
matches!(self.type_kind(ty),
168+
TypeKind::Array | TypeKind::Vector | TypeKind::ScalableVector)
169+
}
170+
171+
pub(crate) fn is_ptr_type(&self, ty: &'ll Type) -> bool {
172+
matches!(self.type_kind(ty), TypeKind::Pointer)
173+
}
157174
}
158175

159176
impl<'ll, CX: Borrow<SCx<'ll>>> BaseTypeCodegenMethods for GenericCx<'ll, CX> {

compiler/rustc_monomorphize/src/partitioning/autodiff.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec
6060
| DiffActivity::Duplicated => {
6161
DiffActivity::FakeActivitySize(Some(elem_size))
6262
}
63+
DiffActivity::Buffer => {
64+
DiffActivity::FakeActivitySize(Some(elem_size))
65+
}
6366
DiffActivity::Const => DiffActivity::Const,
6467
_ => bug!("unexpected activity for ptr/ref"),
6568
};
Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,62 @@
1-
//@ compile-flags: -Zautodiff=Enable -Zautodiff=NoPostopt -C opt-level=3 -Clto=fat
1+
//@ compile-flags: -Zautodiff=Enable,PrintModAfter -Zautodiff=NoPostopt -C opt-level=3 -Clto=fat
22
//@ no-prefer-dynamic
33
//@ needs-enzyme
44
//
55
#![feature(rustc_attrs)]
66
#![feature(prelude_import)]
77
#![feature(panic_internals)]
8+
#![feature(portable_simd)]
89
#![no_std]
910
//@ needs-enzyme
1011
#![feature(autodiff)]
1112
#[prelude_import]
1213
use ::std::prelude::rust_2015::*;
1314
#[macro_use]
1415
extern crate std;
16+
use std::simd::Simd;
1517
//@ pretty-mode:expanded
1618
//@ pretty-compare-only
1719
//@ pp-exact:batching.pp
1820

21+
// CHECK: __enzyme
1922

2023
// Test that forward mode ad macros are expanded correctly.
2124
use std::arch::asm;
2225
use std::autodiff::autodiff;
2326

2427
// Generated from:
2528
//// ```
26-
/// #[batching(d_square2, 4, Batching, Batching)]
29+
/// #[batching(d_square2, 4, Buffer, Buffer)]
2730
/// fn square(x: f32) -> f32 {
2831
/// x * x
2932
/// }
3033
3134
#[no_mangle]
3235
#[rustc_autodiff]
3336
#[inline(never)]
34-
fn square(x: f32) -> f32 {
35-
x * x
37+
fn square(x: &[f32]) -> f32 {
38+
x[0] * x[0]
3639
}
37-
#[rustc_autodiff(Batch, 4, Batching, Batching)]
40+
#[rustc_autodiff(Batch, 4, Buffer, Buffer)]
3841
#[no_mangle]
3942
#[inline(never)]
40-
fn d_square2(x: Simd<f32,4>) -> Simd<f32,4> {
43+
fn d_square2(x: &[f32]) -> Simd<f32,4> {
4144
unsafe {
4245
asm!("NOP", options(nomem));
4346
};
44-
::core::hint::black_box(x);
47+
::core::hint::black_box(square(x));
4548
::core::hint::black_box(());
4649
::core::hint::black_box(Default::default())
4750
}
51+
//#[rustc_autodiff(Batch, 4, Batching, Batching)]
52+
//#[no_mangle]
53+
//#[inline(never)]
54+
//fn d_square2(x: Simd<f32,4>) -> Simd<f32,4> {
55+
// unsafe {
56+
// asm!("NOP", options(nomem));
57+
// };
58+
// ::core::hint::black_box(square(x[0]));
59+
// ::core::hint::black_box(());
60+
// ::core::hint::black_box(Default::default())
61+
//}
4862
fn main() {}

0 commit comments

Comments
 (0)