Skip to content

Commit 27a8b79

Browse files
authored
Support index arithmetic (#124)
1 parent cf95b07 commit 27a8b79

File tree

10 files changed

+324
-3
lines changed

10 files changed

+324
-3
lines changed

crates/autodiff/src/lib.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,17 @@ impl Autodiff<'_> {
240240
},
241241
&Expr::Binary { op, left, right } => match op {
242242
// boring cases
243-
Binop::And | Binop::Or | Binop::Iff | Binop::Xor => self.code.push(Instr {
243+
Binop::And
244+
| Binop::Or
245+
| Binop::Iff
246+
| Binop::Xor
247+
| Binop::INeq
248+
| Binop::ILt
249+
| Binop::ILeq
250+
| Binop::IEq
251+
| Binop::IGt
252+
| Binop::IGeq
253+
| Binop::IAdd => self.code.push(Instr {
244254
var,
245255
expr: Expr::Binary { op, left, right },
246256
}),

crates/core/src/lib.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,17 @@ pub enum Binop {
210210
Iff,
211211
Xor,
212212

213+
// `Fin` -> `Fin` -> `Bool`
214+
INeq,
215+
ILt,
216+
ILeq,
217+
IEq,
218+
IGt,
219+
IGeq,
220+
221+
// `Fin` -> `Fin` -> `Fin`
222+
IAdd,
223+
213224
// `F64` -> `F64` -> `Bool`
214225
Neq,
215226
Lt,

crates/interp/src/lib.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,15 @@ impl<'a, 'b, O: Opaque, T: Refs<'a, Opaque = O>> Interpreter<'a, 'b, O, T> {
236236
Binop::Iff => Val::Bool(x.bool() == y.bool()),
237237
Binop::Xor => Val::Bool(x.bool() != y.bool()),
238238

239+
Binop::INeq => Val::Bool(x.fin() != y.fin()),
240+
Binop::ILt => Val::Bool(x.fin() < y.fin()),
241+
Binop::ILeq => Val::Bool(x.fin() <= y.fin()),
242+
Binop::IEq => Val::Bool(x.fin() == y.fin()),
243+
Binop::IGt => Val::Bool(x.fin() > y.fin()),
244+
Binop::IGeq => Val::Bool(x.fin() >= y.fin()),
245+
246+
Binop::IAdd => Val::Fin(x.fin() + y.fin()),
247+
239248
Binop::Neq => Val::Bool(x.f64() != y.f64()),
240249
Binop::Lt => Val::Bool(x.f64() < y.f64()),
241250
Binop::Leq => Val::Bool(x.f64() <= y.f64()),

crates/transpose/src/lib.rs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -630,6 +630,13 @@ impl<'a> Transpose<'a> {
630630
| Binop::Or
631631
| Binop::Iff
632632
| Binop::Xor
633+
| Binop::INeq
634+
| Binop::ILt
635+
| Binop::ILeq
636+
| Binop::IEq
637+
| Binop::IGt
638+
| Binop::IGeq
639+
| Binop::IAdd
633640
| Binop::Neq
634641
| Binop::Lt
635642
| Binop::Leq
@@ -704,7 +711,17 @@ impl<'a> Transpose<'a> {
704711
}
705712
_ => {
706713
let (a, b) = match op {
707-
Binop::And | Binop::Or | Binop::Iff | Binop::Xor => (left, right),
714+
Binop::And
715+
| Binop::Or
716+
| Binop::Iff
717+
| Binop::Xor
718+
| Binop::INeq
719+
| Binop::ILt
720+
| Binop::ILeq
721+
| Binop::IEq
722+
| Binop::IGt
723+
| Binop::IGeq
724+
| Binop::IAdd => (left, right),
708725
Binop::Neq
709726
| Binop::Lt
710727
| Binop::Leq

crates/wasm/src/lib.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -782,6 +782,13 @@ impl<'a, 'b, O: Eq + Hash, T: Refs<'a, Opaque = O>> Codegen<'a, 'b, O, T> {
782782
Binop::Or => self.wasm.instruction(&Instruction::I32Or),
783783
Binop::Iff => self.wasm.instruction(&Instruction::I32Eq),
784784
Binop::Xor => self.wasm.instruction(&Instruction::I32Xor),
785+
Binop::INeq => self.wasm.instruction(&Instruction::I32Ne),
786+
Binop::ILt => self.wasm.instruction(&Instruction::I32LtU),
787+
Binop::ILeq => self.wasm.instruction(&Instruction::I32LeU),
788+
Binop::IEq => self.wasm.instruction(&Instruction::I32Eq),
789+
Binop::IGt => self.wasm.instruction(&Instruction::I32GtU),
790+
Binop::IGeq => self.wasm.instruction(&Instruction::I32GeU),
791+
Binop::IAdd => self.wasm.instruction(&Instruction::I32Add),
785792
Binop::Neq => self.wasm.instruction(&Instruction::F64Ne),
786793
Binop::Lt => self.wasm.instruction(&Instruction::F64Lt),
787794
Binop::Leq => self.wasm.instruction(&Instruction::F64Le),

crates/web/src/lib.rs

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1318,6 +1318,98 @@ impl Block {
13181318
self.instr(f, t, expr)
13191319
}
13201320

1321+
/// Return the variable ID for a new "index not equal" instruction on `left` and `right`.
1322+
///
1323+
/// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type.
1324+
pub fn ineq(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize {
1325+
let t = id::ty(f.ty_bool());
1326+
let expr = rose::Expr::Binary {
1327+
op: rose::Binop::INeq,
1328+
left: id::var(left),
1329+
right: id::var(right),
1330+
};
1331+
self.instr(f, t, expr)
1332+
}
1333+
1334+
/// Return the variable ID for a new "index less than" instruction on `left` and `right`.
1335+
///
1336+
/// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type.
1337+
pub fn ilt(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize {
1338+
let t = id::ty(f.ty_bool());
1339+
let expr = rose::Expr::Binary {
1340+
op: rose::Binop::ILt,
1341+
left: id::var(left),
1342+
right: id::var(right),
1343+
};
1344+
self.instr(f, t, expr)
1345+
}
1346+
1347+
/// Return the variable ID for a new "index less than or equal" instruction on `left` and
1348+
/// `right`.
1349+
///
1350+
/// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type.
1351+
pub fn ileq(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize {
1352+
let t = id::ty(f.ty_bool());
1353+
let expr = rose::Expr::Binary {
1354+
op: rose::Binop::ILeq,
1355+
left: id::var(left),
1356+
right: id::var(right),
1357+
};
1358+
self.instr(f, t, expr)
1359+
}
1360+
1361+
/// Return the variable ID for a new "index equal" instruction on `left` and `right`.
1362+
///
1363+
/// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type.
1364+
pub fn ieq(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize {
1365+
let t = id::ty(f.ty_bool());
1366+
let expr = rose::Expr::Binary {
1367+
op: rose::Binop::IEq,
1368+
left: id::var(left),
1369+
right: id::var(right),
1370+
};
1371+
self.instr(f, t, expr)
1372+
}
1373+
1374+
/// Return the variable ID for a new "index greater than" instruction on `left` and `right`.
1375+
///
1376+
/// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type.
1377+
pub fn igt(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize {
1378+
let t = id::ty(f.ty_bool());
1379+
let expr = rose::Expr::Binary {
1380+
op: rose::Binop::IGt,
1381+
left: id::var(left),
1382+
right: id::var(right),
1383+
};
1384+
self.instr(f, t, expr)
1385+
}
1386+
1387+
/// Return the variable ID for a new "index greater than or equal" instruction on `left` and
1388+
/// `right`.
1389+
///
1390+
/// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type.
1391+
pub fn igeq(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize {
1392+
let t = id::ty(f.ty_bool());
1393+
let expr = rose::Expr::Binary {
1394+
op: rose::Binop::IGeq,
1395+
left: id::var(left),
1396+
right: id::var(right),
1397+
};
1398+
self.instr(f, t, expr)
1399+
}
1400+
1401+
/// Return the variable ID for a new "index add" instruction on `left` and `right`.
1402+
///
1403+
/// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type.
1404+
pub fn iadd(&mut self, f: &mut FuncBuilder, t: usize, left: usize, right: usize) -> usize {
1405+
let expr = rose::Expr::Binary {
1406+
op: rose::Binop::IAdd,
1407+
left: id::var(left),
1408+
right: id::var(right),
1409+
};
1410+
self.instr(f, id::ty(t), expr)
1411+
}
1412+
13211413
/// Return the variable ID for a new "not equal" instruction on `left` and `right`.
13221414
///
13231415
/// Assumes `left` and `right` are defined, in scope, and have 64-bit floating point type.

crates/web/src/pprint.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,13 @@ impl<'a, O: Eq + Hash, T: Refs<'a, Opaque = O>> Function<'a, '_, O, T> {
171171
Binop::Or => writeln!(f, "x{} or x{}", left.var(), right.var())?,
172172
Binop::Iff => writeln!(f, "x{} iff x{}", left.var(), right.var())?,
173173
Binop::Xor => writeln!(f, "x{} xor x{}", left.var(), right.var())?,
174+
Binop::INeq => writeln!(f, "x{} != x{}", left.var(), right.var())?,
175+
Binop::ILt => writeln!(f, "x{} < x{}", left.var(), right.var())?,
176+
Binop::ILeq => writeln!(f, "x{} <= x{}", left.var(), right.var())?,
177+
Binop::IEq => writeln!(f, "x{} == x{}", left.var(), right.var())?,
178+
Binop::IGt => writeln!(f, "x{} > x{}", left.var(), right.var())?,
179+
Binop::IGeq => writeln!(f, "x{} >= x{}", left.var(), right.var())?,
180+
Binop::IAdd => writeln!(f, "x{} + x{}", left.var(), right.var())?,
174181
Binop::Neq => writeln!(f, "x{} != x{}", left.var(), right.var())?,
175182
Binop::Lt => writeln!(f, "x{} < x{}", left.var(), right.var())?,
176183
Binop::Leq => writeln!(f, "x{} <= x{}", left.var(), right.var())?,

packages/core/src/impl.ts

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ type Zero = typeof zeroSymbol;
9090
export type Tan = Zero | Var;
9191

9292
/** An abstract natural number, which can be used to index into a vector. */
93-
type Nat = number | symbol;
93+
export type Nat = number | symbol;
9494

9595
/** The portion of an abstract vector that can be directly indexed. */
9696
interface VecIndex<T> {
@@ -954,6 +954,56 @@ export const xor = (p: Bool, q: Bool): Bool => {
954954
return newVar(ctx.block.xor(ctx.func, boolId(ctx, p), boolId(ctx, q)));
955955
};
956956

957+
/** Return an abstract boolean for if `i` is not equal to `j`. */
958+
export const ineq = (ty: Nats, i: Nat, j: Nat): Bool => {
959+
const ctx = getCtx();
960+
const t = tyId(ctx, ty);
961+
return newVar(ctx.block.ineq(ctx.func, valId(ctx, t, i), valId(ctx, t, j)));
962+
};
963+
964+
/** Return an abstract boolean for if `i` is less than `j`. */
965+
export const ilt = (ty: Nats, i: Nat, j: Nat): Bool => {
966+
const ctx = getCtx();
967+
const t = tyId(ctx, ty);
968+
return newVar(ctx.block.ilt(ctx.func, valId(ctx, t, i), valId(ctx, t, j)));
969+
};
970+
971+
/** Return an abstract boolean for if `i` is less than or equal to `j`. */
972+
export const ileq = (ty: Nats, i: Nat, j: Nat): Bool => {
973+
const ctx = getCtx();
974+
const t = tyId(ctx, ty);
975+
return newVar(ctx.block.ileq(ctx.func, valId(ctx, t, i), valId(ctx, t, j)));
976+
};
977+
978+
/** Return an abstract boolean for if `i` is equal to `j`. */
979+
export const ieq = (ty: Nats, i: Nat, j: Nat): Bool => {
980+
const ctx = getCtx();
981+
const t = tyId(ctx, ty);
982+
return newVar(ctx.block.ieq(ctx.func, valId(ctx, t, i), valId(ctx, t, j)));
983+
};
984+
985+
/** Return an abstract boolean for if `i` is greater than `j`. */
986+
export const igt = (ty: Nats, i: Nat, j: Nat): Bool => {
987+
const ctx = getCtx();
988+
const t = tyId(ctx, ty);
989+
return newVar(ctx.block.igt(ctx.func, valId(ctx, t, i), valId(ctx, t, j)));
990+
};
991+
992+
/** Return an abstract boolean for if `i` is greater than or equal to `j`. */
993+
export const igeq = (ty: Nats, i: Nat, j: Nat): Bool => {
994+
const ctx = getCtx();
995+
const t = tyId(ctx, ty);
996+
return newVar(ctx.block.igeq(ctx.func, valId(ctx, t, i), valId(ctx, t, j)));
997+
};
998+
999+
/** Return the abstract index `i` plus the abstract index `y`. */
1000+
export const iadd = (ty: Nats, i: Nat, j: Nat): Nat => {
1001+
const ctx = getCtx();
1002+
const t = tyId(ctx, ty);
1003+
const k = ctx.block.iadd(ctx.func, t, valId(ctx, t, i), valId(ctx, t, j));
1004+
return idVal(ctx, t, k) as Nat;
1005+
};
1006+
9571007
/** Return an abstract value selecting between `then` and `els` via `cond`. */
9581008
export const select = <const T>(
9591009
cond: Bool,

0 commit comments

Comments
 (0)