Skip to content

Commit 0d1dcd6

Browse files
committed
Parse math expressions
1 parent 737fe75 commit 0d1dcd6

File tree

9 files changed

+462
-66
lines changed

9 files changed

+462
-66
lines changed

Makefile

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ prettier: npm
5151
packages: core site wasm
5252

5353
# run JavaScript tests
54-
test-js: test-core
54+
test-js: test-core test-site
5555

5656
## `packages/core`
5757

@@ -68,9 +68,14 @@ test-core: npm wasm
6868

6969
site-deps: npm core
7070

71+
# build
7172
site: site-deps
7273
npm run --workspace=@rose-lang/site build
7374

75+
# test
76+
test-site: site-deps
77+
npm run --workspace=@rose-lang/site test -- run --no-threads
78+
7479
## `packages/wasm`
7580

7681
# build

packages/site/index.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
></a>
3131
</div>
3232
<div class="example">
33-
<input class="textbox" value="x^y" readonly />
33+
<input class="textbox" value="x^y" id="textbox" />
3434
<canvas width="300" height="300" id="canvas"></canvas>
3535
</div>
3636
<div></div>

packages/site/package.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
"scripts": {
1111
"build": "vite build",
1212
"dev": "vite",
13-
"preview": "vite preview"
13+
"preview": "vite preview",
14+
"test": "vitest"
1415
}
1516
}

packages/site/src/func.ts

Lines changed: 0 additions & 58 deletions
This file was deleted.

packages/site/src/main.ts

Lines changed: 77 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,57 @@
1-
import all, { Info } from "./func.js";
1+
import { Real, Vec, compile, fn, jvp, vec, vjp } from "rose";
2+
import { Expr, parse } from "./parse.js";
23

34
type Vec2 = [number, number];
45

6+
interface Info {
7+
val: number;
8+
grad: Vec2;
9+
hess: [Vec2, Vec2];
10+
}
11+
12+
type Func = (x: number, y: number) => Info;
13+
14+
const autodiff = async (root: Expr): Promise<Func> => {
15+
const Vec2 = Vec(2, Real);
16+
const f = fn([Vec2], Real, (v) => {
17+
const emit = (e: Expr): Real => {
18+
switch (e.kind) {
19+
case "const":
20+
return e.val;
21+
case "var":
22+
return v[e.idx];
23+
case "unary":
24+
return e.f(emit(e.arg));
25+
case "binary":
26+
return e.f(emit(e.lhs), emit(e.rhs));
27+
}
28+
};
29+
return emit(root);
30+
});
31+
32+
const Mat2 = Vec(2, Vec2);
33+
const g = fn([Vec2], Vec2, (v) => vjp(f)(v).grad(1));
34+
const h = fn([Vec2], Mat2, ([x, y]) => {
35+
const d = jvp(g);
36+
const a = d([
37+
{ re: x, du: 1 },
38+
{ re: y, du: 0 },
39+
]);
40+
const b = d([
41+
{ re: x, du: 0 },
42+
{ re: y, du: 1 },
43+
]);
44+
return [vec(2, Real, (i) => a[i].du), vec(2, Real, (i) => b[i].du)];
45+
});
46+
47+
return (await compile(
48+
fn([Real, Real], { val: Real, grad: Vec2, hess: Mat2 }, (x, y) => {
49+
const v = [x, y];
50+
return { val: f(v), grad: g(v), hess: h(v) };
51+
}),
52+
)) as unknown as Func;
53+
};
54+
555
interface Parabola {
656
/** coefficient of square term */
757
a: number;
@@ -75,7 +125,13 @@ const bezier = (
75125
): [Vec2, Vec2, Vec2] => {
76126
const l1 = pointSlope(parabola, x1);
77127
const l2 = pointSlope(parabola, x2);
78-
const [x3, y3] = intersectPointSlope(l1, l2);
128+
let [x3, y3] = intersectPointSlope(l1, l2);
129+
if (!(Number.isFinite(x3) && Number.isFinite(y3))) {
130+
const [x1, y1] = l1.point;
131+
const [x2, y2] = l2.point;
132+
x3 = (x1 + x2) / 2;
133+
y3 = (y1 + y2) / 2;
134+
}
79135
return [l1.point, [x3, y3], l2.point];
80136
};
81137

@@ -168,15 +224,31 @@ const toWorld = ([x, y]: Vec2): Vec3 => {
168224
return matVecMul(world, [x, y, z]);
169225
};
170226

171-
let point: Vec2;
227+
let func: Func;
228+
let point: Vec2 = [0.5, 0.5];
172229
let info: Info;
173230

174231
const setPoint = (newPoint: Vec2) => {
175232
point = newPoint;
176-
info = all(...point);
233+
info = func(...point);
177234
};
178235

179-
setPoint([0.5, 0.5]);
236+
const textbox = document.getElementById("textbox") as HTMLInputElement;
237+
const setFunc = async () => {
238+
let root: Expr = { kind: "const", val: NaN };
239+
try {
240+
root = parse(textbox.value);
241+
textbox.classList.remove("error");
242+
} catch (e) {
243+
textbox.classList.add("error");
244+
}
245+
func = await autodiff(root);
246+
setPoint(point);
247+
};
248+
await setFunc();
249+
textbox.addEventListener("input", async () => {
250+
await setFunc();
251+
});
180252

181253
const roseColor = "#C33358";
182254

packages/site/src/math.ts

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
import { Dual, Real, add, div, fn, mul, neg, opaque, sqrt, sub } from "rose";
2+
3+
export const acos = opaque([Real], Real, Math.acos);
4+
export const acosh = opaque([Real], Real, Math.acosh);
5+
export const asin = opaque([Real], Real, Math.asin);
6+
export const asinh = opaque([Real], Real, Math.asinh);
7+
export const atan = opaque([Real], Real, Math.atan);
8+
export const atanh = opaque([Real], Real, Math.atanh);
9+
export const cbrt = opaque([Real], Real, Math.cbrt);
10+
export const cos = opaque([Real], Real, Math.cos);
11+
export const cosh = opaque([Real], Real, Math.cosh);
12+
export const exp = opaque([Real], Real, Math.exp);
13+
export const expm1 = opaque([Real], Real, Math.expm1);
14+
export const log = opaque([Real], Real, Math.log);
15+
export const log10 = opaque([Real], Real, Math.log10);
16+
export const log1p = opaque([Real], Real, Math.log1p);
17+
export const log2 = opaque([Real], Real, Math.log2);
18+
export const pow = opaque([Real, Real], Real, Math.pow);
19+
export const sin = opaque([Real], Real, Math.sin);
20+
export const sinh = opaque([Real], Real, Math.sinh);
21+
export const tan = opaque([Real], Real, Math.tan);
22+
export const tanh = opaque([Real], Real, Math.tanh);
23+
24+
acos.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
25+
const y = acos(x);
26+
const dy = div(dx, neg(sqrt(sub(1, mul(x, x)))));
27+
return { re: y, du: dy };
28+
});
29+
30+
acosh.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
31+
const y = acosh(x);
32+
const dy = div(dx, mul(sqrt(sub(x, 1)), sqrt(add(x, 1))));
33+
return { re: y, du: dy };
34+
});
35+
36+
asin.jvp = fn([Dual, Dual], Dual, ({ re: x, du: dx }) => {
37+
const y = asin(x);
38+
const dy = div(dx, sqrt(sub(1, mul(x, x))));
39+
return { re: y, du: dy };
40+
});
41+
42+
asinh.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
43+
const y = asinh(x);
44+
const dy = div(dx, sqrt(add(1, mul(x, x))));
45+
return { re: y, du: dy };
46+
});
47+
48+
atan.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
49+
const y = atan(x);
50+
const dy = div(dx, add(1, mul(x, x)));
51+
return { re: y, du: dy };
52+
});
53+
54+
atanh.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
55+
const y = atanh(x);
56+
const dy = div(dx, sub(1, mul(x, x)));
57+
return { re: y, du: dy };
58+
});
59+
60+
cbrt.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
61+
const y = cbrt(x);
62+
const dy = mul(dx, div(1 / 3, mul(y, y)));
63+
return { re: y, du: dy };
64+
});
65+
66+
cos.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
67+
const y = cos(x);
68+
const dy = mul(dx, neg(sin(x)));
69+
return { re: y, du: dy };
70+
});
71+
72+
cosh.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
73+
const y = cosh(x);
74+
const dy = mul(dx, sinh(x));
75+
return { re: y, du: dy };
76+
});
77+
78+
exp.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
79+
const y = exp(x);
80+
const dy = mul(dx, y);
81+
return { re: y, du: dy };
82+
});
83+
84+
expm1.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
85+
const y = expm1(x);
86+
const dy = mul(dx, add(y, 1));
87+
return { re: y, du: dy };
88+
});
89+
90+
log.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
91+
const y = log(x);
92+
const dy = div(dx, x);
93+
return { re: y, du: dy };
94+
});
95+
96+
log10.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
97+
const y = log10(x);
98+
const dy = mul(dx, div(Math.LOG10E, x));
99+
return { re: y, du: dy };
100+
});
101+
102+
log1p.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
103+
const y = log1p(x);
104+
const dy = div(dx, add(1, x));
105+
return { re: y, du: dy };
106+
});
107+
108+
log2.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
109+
const y = log2(x);
110+
const dy = mul(dx, div(Math.LOG2E, x));
111+
return { re: y, du: dy };
112+
});
113+
114+
pow.jvp = fn([Dual, Dual], Dual, ({ re: x, du: dx }, { re: y, du: dy }) => {
115+
const z = pow(x, y);
116+
const dz = mul(add(mul(dx, div(y, x)), mul(dy, log(x))), z);
117+
return { re: z, du: dz };
118+
});
119+
120+
sin.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
121+
const y = sin(x);
122+
const dy = mul(dx, cos(x));
123+
return { re: y, du: dy };
124+
});
125+
126+
sinh.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
127+
const y = sinh(x);
128+
const dy = mul(dx, cosh(x));
129+
return { re: y, du: dy };
130+
});
131+
132+
tan.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
133+
const y = tan(x);
134+
const dy = mul(dx, add(1, mul(y, y)));
135+
return { re: y, du: dy };
136+
});
137+
138+
tanh.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
139+
const y = tanh(x);
140+
const dy = mul(dx, sub(1, mul(y, y)));
141+
return { re: y, du: dy };
142+
});

packages/site/src/parse.test.ts

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import { add } from "rose";
2+
import { expect, test } from "vitest";
3+
import { Expr, parse } from "./parse.js";
4+
5+
test("add", () => {
6+
const expected: Expr = {
7+
kind: "binary",
8+
f: add,
9+
lhs: { kind: "var", idx: 0 },
10+
rhs: { kind: "var", idx: 1 },
11+
};
12+
expect(parse("x+y")).toEqual(expected);
13+
});

0 commit comments

Comments
 (0)