Skip to content

Commit 527229a

Browse files
committed
add batch arg lowering to enzyme, update test
1 parent c97d7f9 commit 527229a

File tree

2 files changed

+48
-18
lines changed

2 files changed

+48
-18
lines changed

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -208,16 +208,21 @@ fn generate_enzyme_call<'ll>(
208208
// FIXME(ZuseZ4): We will upstream a safety check later which asserts that
209209
// int2 >= int1, which means the shadow vector is large enough to store the gradient.
210210
assert!(llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Integer);
211-
let next_outer_arg2 = outer_args[outer_pos + 2];
212-
let next_outer_ty2 = cx.val_ty(next_outer_arg2);
213-
assert!(llvm::LLVMRustGetTypeKind(next_outer_ty2) == llvm::TypeKind::Pointer);
214-
let next_outer_arg3 = outer_args[outer_pos + 3];
215-
let next_outer_ty3 = cx.val_ty(next_outer_arg3);
216-
assert!(llvm::LLVMRustGetTypeKind(next_outer_ty3) == llvm::TypeKind::Integer);
217-
args.push(next_outer_arg2);
211+
212+
for _ in 0..attrs.width {
213+
let next_outer_arg2 = outer_args[outer_pos + 2];
214+
let next_outer_ty2 = cx.val_ty(next_outer_arg2);
215+
assert!(llvm::LLVMRustGetTypeKind(next_outer_ty2) == llvm::TypeKind::Pointer);
216+
let next_outer_arg3 = outer_args[outer_pos + 3];
217+
let next_outer_ty3 = cx.val_ty(next_outer_arg3);
218+
assert!(llvm::LLVMRustGetTypeKind(next_outer_ty3) == llvm::TypeKind::Integer);
219+
args.push(next_outer_arg2);
220+
}
221+
222+
218223
args.push(cx.get_metadata_value(enzyme_const));
219224
args.push(next_outer_arg);
220-
outer_pos += 4;
225+
outer_pos += 2 + 2 * attrs.width as usize;
221226
activity_pos += 2;
222227
} else {
223228
// A duplicated pointer will have the following two outer_fn arguments:
@@ -235,6 +240,14 @@ fn generate_enzyme_call<'ll>(
235240
args.push(next_outer_arg);
236241
outer_pos += 2;
237242
activity_pos += 1;
243+
244+
// Now, if width > 1, we need to account for that
245+
for _ in 1..attrs.width {
246+
let next_outer_arg = outer_args[outer_pos];
247+
args.push(next_outer_arg);
248+
outer_pos += 1;
249+
}
250+
238251
}
239252
} else {
240253
// We do not differentiate with resprect to this argument.

tests/codegen/autodiffv.rs

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,40 @@ fn square(x: &f64) -> f64 {
1111
x * x
1212
}
1313

14-
// CHECK:define internal fastcc double @diffesquare(double %x.0.val, ptr nocapture align 8 %"x'"
14+
// CHECK:define internal fastcc void @diffe4square([4 x ptr] %"x'"
1515
// CHECK-NEXT:invertstart:
16-
// CHECK-NEXT: %_0 = fmul double %x.0.val, %x.0.val
17-
// CHECK-NEXT: %0 = fadd fast double %x.0.val, %x.0.val
18-
// CHECK-NEXT: %1 = load double, ptr %"x'", align 8
19-
// CHECK-NEXT: %2 = fadd fast double %1, %0
20-
// CHECK-NEXT: store double %2, ptr %"x'", align 8
21-
// CHECK-NEXT: ret double %_0
16+
// CHECK-NEXT: %0 = extractvalue [4 x ptr] %"x'", 0
17+
// CHECK-NEXT: %1 = load double, ptr %0, align 8, !alias.scope !15950, !noalias !15953
18+
// CHECK-NEXT: %2 = fadd fast double %1, 6.000000e+00
19+
// CHECK-NEXT: store double %2, ptr %0, align 8, !alias.scope !15950, !noalias !15953
20+
// CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 1
21+
// CHECK-NEXT: %4 = load double, ptr %3, align 8, !alias.scope !15958, !noalias !15959
22+
// CHECK-NEXT: %5 = fadd fast double %4, 6.000000e+00
23+
// CHECK-NEXT: store double %5, ptr %3, align 8, !alias.scope !15958, !noalias !15959
24+
// CHECK-NEXT: %6 = extractvalue [4 x ptr] %"x'", 2
25+
// CHECK-NEXT: %7 = load double, ptr %6, align 8, !alias.scope !15960, !noalias !15961
26+
// CHECK-NEXT: %8 = fadd fast double %7, 6.000000e+00
27+
// CHECK-NEXT: store double %8, ptr %6, align 8, !alias.scope !15960, !noalias !15961
28+
// CHECK-NEXT: %9 = extractvalue [4 x ptr] %"x'", 3
29+
// CHECK-NEXT: %10 = load double, ptr %9, align 8, !alias.scope !15962, !noalias !15963
30+
// CHECK-NEXT: %11 = fadd fast double %10, 6.000000e+00
31+
// CHECK-NEXT: store double %11, ptr %9, align 8, !alias.scope !15962, !noalias !15963
32+
// CHECK-NEXT: ret void
2233
// CHECK-NEXT:}
2334

2435
fn main() {
2536
let x = 3.0;
2637
let output = square(&x);
2738
assert_eq!(9.0, output);
2839

29-
let mut df_dx = 0.0;
30-
let output_ = d_square(&x, &mut df_dx, 1.0);
40+
let mut df_dx1 = 0.0;
41+
let mut df_dx2 = 0.0;
42+
let mut df_dx3 = 0.0;
43+
let mut df_dx4 = 0.0;
44+
let output_ = d_square(&x, &mut df_dx1, &mut df_dx2, &mut df_dx3, &mut df_dx4, 1.0);
3145
assert_eq!(output, output_);
32-
assert_eq!(6.0, df_dx);
46+
assert_eq!(6.0, df_dx1);
47+
assert_eq!(6.0, df_dx2);
48+
assert_eq!(6.0, df_dx3);
49+
assert_eq!(6.0, df_dx4);
3350
}

0 commit comments

Comments
 (0)