Skip to content

Commit b405a8c

Browse files
committed
working batching1 tests using &f32 and Vector, failing test batching2 using Simd and Batching
1 parent d2c44cd commit b405a8c

File tree

8 files changed

+182
-49
lines changed

8 files changed

+182
-49
lines changed

compiler/rustc_ast/src/expand/autodiff_attrs.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ pub enum DiffMode {
2929
Forward,
3030
/// The target function, to be created using reverse mode AD.
3131
Reverse,
32+
/// The target function, to be created using batching.
33+
Batch,
3234
}
3335

3436
/// Dual and Duplicated (and their Only variants) are getting lowered to the same Enzyme Activity.
@@ -69,6 +71,12 @@ pub enum DiffActivity {
6971
/// length of a slice/vec. This is used for safety checks on slices.
7072
/// The integer (if given) specifies the size of the slice element in bytes.
7173
FakeActivitySize(Option<u32>),
74+
/// Batching mode A
75+
Vector,
76+
/// Batching mode B, missing implementation (only available as part of autodiff through dupv)
77+
// Leaf,
78+
/// Batching mode C, scalar.
79+
Scalar,
7280
}
7381

7482
impl DiffActivity {
@@ -130,6 +138,7 @@ impl Display for DiffMode {
130138
DiffMode::Source => write!(f, "Source"),
131139
DiffMode::Forward => write!(f, "Forward"),
132140
DiffMode::Reverse => write!(f, "Reverse"),
141+
DiffMode::Batch => write!(f, "Batch"),
133142
}
134143
}
135144
}
@@ -153,6 +162,13 @@ pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
153162
|| activity == DiffActivity::Active
154163
|| activity == DiffActivity::ActiveOnly
155164
}
165+
DiffMode::Batch => {
166+
// Batching is a special case, since we don't compute derivatives wrt. the return value.
167+
// We just compute derivatives wrt. the inputs, so we can ignore the return value.
168+
activity == DiffActivity::Const
169+
|| activity == DiffActivity::Vector
170+
|| activity == DiffActivity::Scalar
171+
}
156172
}
157173
}
158174

@@ -186,6 +202,11 @@ pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool {
186202
DiffMode::Reverse => {
187203
matches!(activity, Active | ActiveOnly | Duplicated | DuplicatedOnly | Const)
188204
}
205+
DiffMode::Batch => {
206+
// Batching is a special case, since we don't compute derivatives wrt. the return value.
207+
// We just compute derivatives wrt. the inputs, so we can ignore the return value.
208+
matches!(activity, Const | Vector)
209+
}
189210
};
190211
}
191212

@@ -203,6 +224,8 @@ impl Display for DiffActivity {
203224
DiffActivity::Duplicated => write!(f, "Duplicated"),
204225
DiffActivity::DuplicatedOnly => write!(f, "DuplicatedOnly"),
205226
DiffActivity::FakeActivitySize(s) => write!(f, "FakeActivitySize({:?})", s),
227+
DiffActivity::Vector => write!(f, "Vector"),
228+
DiffActivity::Scalar => write!(f, "Scalar"),
206229
}
207230
}
208231
}
@@ -216,6 +239,7 @@ impl FromStr for DiffMode {
216239
"Source" => Ok(DiffMode::Source),
217240
"Forward" => Ok(DiffMode::Forward),
218241
"Reverse" => Ok(DiffMode::Reverse),
242+
"Batch" => Ok(DiffMode::Batch),
219243
_ => Err(()),
220244
}
221245
}
@@ -235,6 +259,8 @@ impl FromStr for DiffActivity {
235259
"DualvOnly" => Ok(DiffActivity::DualvOnly),
236260
"Duplicated" => Ok(DiffActivity::Duplicated),
237261
"DuplicatedOnly" => Ok(DiffActivity::DuplicatedOnly),
262+
"Scalar" => Ok(DiffActivity::Scalar),
263+
"Vector" => Ok(DiffActivity::Vector),
238264
_ => Err(()),
239265
}
240266
}

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -884,6 +884,7 @@ mod llvm_enzyme {
884884
DiffActivity::None | DiffActivity::FakeActivitySize(_) => {
885885
panic!("Should not happen");
886886
}
887+
DiffActivity::Vector | DiffActivity::Scalar => todo!()
887888
}
888889
if let PatKind::Ident(_, ident, _) = arg.pat.kind {
889890
idents.push(ident.clone());

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ fn match_args_from_caller_to_enzyme<'ll>(
8383
let enzyme_dupnoneed = cx.create_metadata("enzyme_dupnoneed".to_string()).unwrap();
8484
let enzyme_dupnoneedv = cx.create_metadata("enzyme_dupnoneedv".to_string()).unwrap();
8585

86+
let enzyme_scalar = cx.create_metadata("enzyme_scalar".to_string()).unwrap();
87+
let enzyme_vector = cx.create_metadata("enzyme_vector".to_string()).unwrap();
88+
8689
while activity_pos < inputs.len() {
8790
let diff_activity = inputs[activity_pos as usize];
8891
// Duplicated arguments received a shadow argument, into which enzyme will write the
@@ -99,7 +102,10 @@ fn match_args_from_caller_to_enzyme<'ll>(
99102
DiffActivity::Duplicated => (enzyme_dup, true),
100103
DiffActivity::DuplicatedOnly => (enzyme_dupnoneed, true),
101104
DiffActivity::FakeActivitySize(_) => (enzyme_const, false),
105+
DiffActivity::Vector => (enzyme_vector, true),
106+
DiffActivity::Scalar => (enzyme_scalar, true),
102107
};
108+
let no_autodiff_only_batching = matches!(diff_activity, DiffActivity::Scalar | DiffActivity::Vector);
103109
let outer_arg = outer_args[outer_pos];
104110
args.push(cx.get_metadata_value(activity));
105111
if matches!(diff_activity, DiffActivity::Dualv) {
@@ -178,8 +184,19 @@ fn match_args_from_caller_to_enzyme<'ll>(
178184
outer_pos += 2;
179185
activity_pos += 1;
180186

187+
dbg!(&width);
188+
dbg!(&outer_pos);
189+
dbg!(&activity_pos);
190+
dbg!(&args);
191+
let limit = if no_autodiff_only_batching {
192+
// Usually we have one primal arg + width shadow args.
193+
// Here we have `width` primal args, so one less than normal.
194+
width as usize - 1
195+
} else {
196+
width as usize
197+
};
181198
// Now, if width > 1, we need to account for that
182-
for _ in 1..width {
199+
for _ in 1..limit {
183200
let next_outer_arg = outer_args[outer_pos];
184201
args.push(next_outer_arg);
185202
outer_pos += 1;
@@ -269,6 +286,12 @@ fn compute_enzyme_fn_ty<'ll>(
269286
DiffMode::Reverse => {
270287
todo!("Handle sret for reverse mode");
271288
}
289+
DiffMode::Batch => {
290+
let arr_ty = unsafe {
291+
llvm::LLVMArrayType2(inner_ret_ty, attrs.width as u64)
292+
};
293+
ret_ty = arr_ty;
294+
}
272295
_ => {
273296
bug!("unreachable");
274297
}
@@ -299,6 +322,7 @@ fn generate_enzyme_call<'ll>(
299322
let mut ad_name: String = match attrs.mode {
300323
DiffMode::Forward => "__enzyme_fwddiff",
301324
DiffMode::Reverse => "__enzyme_autodiff",
325+
DiffMode::Batch => "__enzyme_batch",
302326
_ => panic!("logic bug in autodiff, unrecognized mode"),
303327
}
304328
.to_string();

compiler/rustc_codegen_ssa/src/codegen_attrs.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -855,6 +855,7 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
855855
let mode = match mode.as_str() {
856856
"Forward" => DiffMode::Forward,
857857
"Reverse" => DiffMode::Reverse,
858+
"Batch" => DiffMode::Batch,
858859
_ => {
859860
span_bug!(mode.span, "rustc_autodiff attribute contains invalid mode");
860861
}

compiler/rustc_span/src/symbol.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ symbols! {
191191
BTreeEntry,
192192
BTreeMap,
193193
BTreeSet,
194+
Batching,
194195
BinaryHeap,
195196
Borrow,
196197
BorrowMut,

tests/codegen/autodiff/batching.rs

Lines changed: 0 additions & 48 deletions
This file was deleted.
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
//@ compile-flags: -Zautodiff=Enable -Zautodiff=NoPostopt -C opt-level=3 -Clto=fat
2+
//@ no-prefer-dynamic
3+
//@ needs-enzyme
4+
//
5+
#![feature(rustc_attrs)]
6+
#![feature(prelude_import)]
7+
#![feature(panic_internals)]
8+
#![no_std]
9+
//@ needs-enzyme
10+
#![feature(autodiff)]
11+
#[prelude_import]
12+
use ::std::prelude::rust_2015::*;
13+
#[macro_use]
14+
extern crate std;
15+
//@ pretty-mode:expanded
16+
//@ pretty-compare-only
17+
//@ pp-exact:batching.pp
18+
19+
20+
// Test that forward mode ad macros are expanded correctly.
21+
use std::arch::asm;
22+
use std::autodiff::autodiff;
23+
24+
// CHECK: ; Function Attrs: mustprogress nofree noinline norecurse nosync nounwind nonlazybind willreturn memory(argmem: read) uwtable
25+
// CHECK-NEXT: define dso_local noundef float @square(ptr noalias nocapture noundef readonly align 4 dereferenceable(4) %x) unnamed_addr #2 {
26+
// CHECK-NEXT: start:
27+
// CHECK-NEXT: %_2 = load float, ptr %x, align 4, !noundef !4
28+
// CHECK-NEXT: %_0 = fmul float %_2, %_2
29+
// CHECK-NEXT: ret float %_0
30+
// CHECK-NEXT: }
31+
32+
// CHECK: ; Function Attrs: alwaysinline nonlazybind uwtable
33+
// CHECK-NEXT: define dso_local void @d_square2(ptr dead_on_unwind noalias nocapture noundef writable writeonly sret([16 x i8]) align 4 dereferenceable(16) initializes((0, 16)) %_0, ptr noalias nocapture noundef readonly align 4 dereferenceable(4) %x, ptr noalias noundef readonly align 4 dereferenceable(4) %bx_0, ptr noalias noundef readonly align 4 dereferenceable(4) %bx_1, ptr noalias noundef readonly align 4 dereferenceable(4) %bx_2) unnamed_addr #3 personality ptr @rust_eh_personality {
34+
// CHECK-NEXT: start:
35+
// CHECK-NEXT: %0 = insertvalue [4 x ptr] undef, ptr %x, 0
36+
// CHECK-NEXT: %1 = insertvalue [4 x ptr] %0, ptr %bx_0, 1
37+
// CHECK-NEXT: %2 = insertvalue [4 x ptr] %1, ptr %bx_1, 2
38+
// CHECK-NEXT: %3 = insertvalue [4 x ptr] %2, ptr %bx_2, 3
39+
// CHECK-NEXT: %4 = call [4 x float] @batch_square([4 x ptr] %3)
40+
// CHECK-NEXT: %.elt = extractvalue [4 x float] %4, 0
41+
// CHECK-NEXT: store float %.elt, ptr %_0, align 4
42+
// CHECK-NEXT: %_0.repack1 = getelementptr inbounds nuw i8, ptr %_0, i64 4
43+
// CHECK-NEXT: %.elt2 = extractvalue [4 x float] %4, 1
44+
// CHECK-NEXT: store float %.elt2, ptr %_0.repack1, align 4
45+
// CHECK-NEXT: %_0.repack3 = getelementptr inbounds nuw i8, ptr %_0, i64 8
46+
// CHECK-NEXT: %.elt4 = extractvalue [4 x float] %4, 2
47+
// CHECK-NEXT: store float %.elt4, ptr %_0.repack3, align 4
48+
// CHECK-NEXT: %_0.repack5 = getelementptr inbounds nuw i8, ptr %_0, i64 12
49+
// CHECK-NEXT: %.elt6 = extractvalue [4 x float] %4, 3
50+
// CHECK-NEXT: store float %.elt6, ptr %_0.repack5, align 4
51+
// CHECK-NEXT: ret void
52+
// CHECK-NEXT: }
53+
54+
/// Generated from:
55+
/// ```
56+
/// #[batching(d_square2, 4, Vector, Vector)]
57+
/// fn square(x: &f32) -> f32 {
58+
/// x * x
59+
/// }
60+
61+
#[no_mangle]
62+
#[rustc_autodiff]
63+
#[inline(never)]
64+
fn square(x: &f32) -> f32 {
65+
x * x
66+
}
67+
#[rustc_autodiff(Batch, 4, Vector, Vector)]
68+
#[no_mangle]
69+
#[inline(never)]
70+
fn d_square2(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32) -> [f32; 4usize] {
71+
unsafe {
72+
asm!("NOP", options(nomem));
73+
};
74+
::core::hint::black_box(square(x));
75+
::core::hint::black_box((bx_0, bx_1, bx_2));
76+
::core::hint::black_box(<[f32; 4usize]>::default())
77+
}
78+
79+
80+
fn main() {}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
//@ compile-flags: -Zautodiff=Enable -Zautodiff=NoPostopt -C opt-level=3 -Clto=fat
2+
//@ no-prefer-dynamic
3+
//@ needs-enzyme
4+
//
5+
#![feature(rustc_attrs)]
6+
#![feature(prelude_import)]
7+
#![feature(panic_internals)]
8+
#![no_std]
9+
//@ needs-enzyme
10+
#![feature(autodiff)]
11+
#[prelude_import]
12+
use ::std::prelude::rust_2015::*;
13+
#[macro_use]
14+
extern crate std;
15+
//@ pretty-mode:expanded
16+
//@ pretty-compare-only
17+
//@ pp-exact:batching.pp
18+
19+
20+
// Test that forward mode ad macros are expanded correctly.
21+
use std::arch::asm;
22+
use std::autodiff::autodiff;
23+
24+
// Generated from:
25+
//// ```
26+
/// #[batching(d_square2, 4, Batching, Batching)]
27+
/// fn square(x: f32) -> f32 {
28+
/// x * x
29+
/// }
30+
31+
#[no_mangle]
32+
#[rustc_autodiff]
33+
#[inline(never)]
34+
fn square(x: f32) -> f32 {
35+
x * x
36+
}
37+
#[rustc_autodiff(Batch, 4, Batching, Batching)]
38+
#[no_mangle]
39+
#[inline(never)]
40+
fn d_square2(x: Simd<f32,4>) -> Simd<f32,4> {
41+
unsafe {
42+
asm!("NOP", options(nomem));
43+
};
44+
::core::hint::black_box(x);
45+
::core::hint::black_box(());
46+
::core::hint::black_box(Default::default())
47+
}
48+
fn main() {}

0 commit comments

Comments
 (0)