Skip to content

Commit ff9ea61

Browse files
committed
Fix dead code elimination for Expr::Slice
1 parent 16fe5e6 commit ff9ea61

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

crates/wasm/src/lib.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,10 @@ impl<'a, O: Eq + Hash, T: Refs<'a, Opaque = O>> Topsort<'a, O, T> {
139139
vars.follow(index, instr.var);
140140
}
141141
&Expr::Member { tuple, .. } => vars.follow(tuple, instr.var),
142-
&Expr::Slice { .. } => vars.live(instr.var),
142+
&Expr::Slice { index, .. } => {
143+
vars.live(instr.var);
144+
vars.follow(index, instr.var);
145+
}
143146
&Expr::Field { .. } => vars.live(instr.var),
144147
&Expr::Unary { arg, .. } => vars.follow(arg, instr.var),
145148
&Expr::Binary { left, right, .. } => {

packages/core/src/index.test.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -934,4 +934,12 @@ describe("valid", () => {
934934
const h = await compile(g);
935935
expect(h()).toEqual([[0]]);
936936
});
937+
938+
test("compile gradient with dynamic index", async () => {
939+
const T = struct({ v: Vec(1, Real), i: 1 });
940+
const f = fn([T], Real, ({ v, i }) => v[i]);
941+
const g = fn([T], T, (x) => vjp(f)(x).grad(1));
942+
const h = await compile(g);
943+
expect(h({ v: [2], i: 0 })).toEqual({ v: [1], i: 0 });
944+
});
937945
});

0 commit comments

Comments
 (0)