Skip to content

Commit 0eff92a

Browse files
authored
Fix compile when memory grows afterward (#126)
1 parent 5c2c0b0 commit 0eff92a

File tree

2 files changed

+54
-40
lines changed

2 files changed

+54
-40
lines changed

packages/core/src/impl.ts

Lines changed: 50 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -636,13 +636,13 @@ interface Layout {
636636
const aligned = ({ size, align }: Layout): number =>
637637
(size + align - 1) & ~(align - 1);
638638

639-
/** An aligned `ArrayBuffer` view, or `undefined` for zero-sized types. */
640-
type View = undefined | Uint8Array | Uint16Array | Uint32Array | Float64Array;
639+
/** An aligned `ArrayBuffer` view. */
640+
type View = Uint8Array | Uint16Array | Uint32Array | Float64Array;
641641

642642
const getView = (buffer: ArrayBuffer, layout: Layout, offset: number): View => {
643643
// this code assumes that the layout is uniquely determined by its `size`
644644
const { size } = layout;
645-
if (size === 0) return undefined;
645+
if (size === 0) throw Error("zero-sized type");
646646
else if (size === 1) return new Uint8Array(buffer, offset);
647647
else if (size === 2) return new Uint16Array(buffer, offset);
648648
else if (size === 4) return new Uint32Array(buffer, offset);
@@ -660,13 +660,13 @@ interface Meta {
660660
*
661661
* The given byte offset is only used for pointer types.
662662
*/
663-
encode: (x: unknown, offset: number) => number;
663+
encode: (x: unknown, pointer: number, buffer: ArrayBuffer) => number;
664664

665665
/** Total memory cost of an object of this type, including sub-allocations. */
666666
cost: number;
667667

668668
/** Return a JS value represented by the Wasm value `x`. */
669-
decode: (x: number) => unknown;
669+
decode: (x: number, buffer: ArrayBuffer) => unknown;
670670
}
671671

672672
/**
@@ -680,7 +680,6 @@ interface Meta {
680680
*/
681681
const getMeta = (
682682
f: Fn,
683-
buffer: ArrayBuffer,
684683
metas: (Meta | undefined)[],
685684
t: number,
686685
): Meta | undefined => {
@@ -724,41 +723,48 @@ const getMeta = (
724723
const n = func.size(func.index(t));
725724
const elem = aligned(layout);
726725
const total = aligned({ size: n * elem, align: 8 });
727-
const view = getView(buffer, layout, 0);
728726
return {
729727
layout: { size: 4, align: 4 },
730728
encode:
731-
view === undefined
732-
? (x, offset) => offset
733-
: (x, offset) => {
734-
let child = offset + total;
729+
layout.size === 0
730+
? (x, pointer) => pointer
731+
: (x, pointer, buffer) => {
732+
const view = getView(buffer, layout, 0);
733+
let child = pointer + total;
735734
for (let i = 0; i < n; ++i) {
736-
view[offset / elem + i] = encode((x as unknown[])[i], child);
735+
view[pointer / elem + i] = encode(
736+
(x as unknown[])[i],
737+
child,
738+
buffer,
739+
);
737740
child += cost;
738741
}
739-
return offset;
742+
return pointer;
740743
},
741744
cost: total + n * cost,
742745
decode:
743-
view === undefined
744-
? () => {
746+
layout.size === 0
747+
? (x, buffer) => {
745748
const arr: unknown[] = [];
746749
// this code assumes that all values of all zero-sized types can
747750
// be represented by zero
748-
for (let i = 0; i < n; ++i) arr.push(decode(0));
751+
for (let i = 0; i < n; ++i) arr.push(decode(0, buffer));
749752
return arr;
750753
}
751-
: (x) => {
754+
: (x, buffer) => {
755+
const view = getView(buffer, layout, 0);
752756
const arr: unknown[] = [];
753-
for (let i = 0; i < n; ++i) arr.push(decode(view[x / elem + i]));
757+
for (let i = 0; i < n; ++i)
758+
arr.push(decode(view[x / elem + i], buffer));
754759
return arr;
755760
},
756761
};
757762
} else if (func.isStruct(t)) {
758763
const keys = func.keys(t);
759764
const members = func.mems(t);
760765
const n = keys.length;
761-
const mems: { key: string; meta: Meta; view?: View; child?: number }[] = [];
766+
const mems: { key: string; meta: Meta; offset?: number; child?: number }[] =
767+
[];
762768
for (let i = 0; i < n; ++i) {
763769
const meta = metas[members[i]];
764770
if (meta === undefined) return undefined;
@@ -774,38 +780,41 @@ const getMeta = (
774780
const { layout } = meta;
775781
const { size, align } = layout;
776782
offset = aligned({ size: offset, align });
777-
mem.view = getView(buffer, layout, offset);
783+
mem.offset = offset;
778784
offset += size;
779785
}
780786
const total = aligned({ size: offset, align: 8 });
781787
return {
782788
layout: { size: 4, align: 4 },
783-
encode: (x, offset) => {
784-
for (const { key, meta, view, child } of mems) {
789+
encode: (x, pointer, buffer) => {
790+
for (const { key, meta, offset, child } of mems) {
785791
// instead of mutating each element of `mems` above to add more data
786792
// and then still having an `if` statement in here, it would be nicer
787793
// to just map over `mems` above to produce an array of closures that
788794
// can be called directly, with the condition on `view === undefined`
789795
// being handled once rather than in every call to `encode` here
790-
if (view !== undefined) {
791-
view[offset / aligned(meta.layout)] = meta.encode(
796+
if (meta.layout.size > 0) {
797+
const view = getView(buffer, meta.layout, offset!);
798+
view[pointer / aligned(meta.layout)] = meta.encode(
792799
(x as any)[key],
793-
offset + total + child!,
800+
pointer + total + child!,
801+
buffer,
794802
);
795803
}
796804
}
797-
return offset;
805+
return pointer;
798806
},
799807
cost: total + cost,
800-
decode: (x) => {
808+
decode: (x, buffer) => {
801809
const obj: any = {};
802-
for (const { key, meta, view } of mems) {
803-
if (view === undefined) {
810+
for (const { key, meta, offset } of mems) {
811+
if (meta.layout.size === 0) {
804812
// this code assumes that all values of all zero-sized types can be
805813
// represented by zero
806-
obj[key] = meta.decode(0);
814+
obj[key] = meta.decode(0, buffer);
807815
} else {
808-
obj[key] = meta.decode(view[x / aligned(meta.layout)]);
816+
const view = getView(buffer, meta.layout, offset!);
817+
obj[key] = meta.decode(view[x / aligned(meta.layout)], buffer);
809818
}
810819
}
811820
return obj;
@@ -834,9 +843,11 @@ export const compile = async <const A extends readonly any[], const R>(
834843
const pages = Number(res.pages);
835844
const imports = res.imports()!;
836845
res.free();
837-
let memory = opts?.memory;
838-
if (memory === undefined) memory = new WebAssembly.Memory({ initial: pages });
846+
let memory: WebAssembly.Memory;
847+
const given = opts?.memory;
848+
if (given === undefined) memory = new WebAssembly.Memory({ initial: pages });
839849
else {
850+
memory = given;
840851
// https://webassembly.github.io/spec/core/exec/runtime.html#page-size
841852
const pageSize = 65536;
842853
const delta = pages - memory.buffer.byteLength / pageSize;
@@ -847,11 +858,10 @@ export const compile = async <const A extends readonly any[], const R>(
847858
m: { "": memory },
848859
"": Object.fromEntries(imports.map((g, i) => [i.toString(), g])),
849860
});
850-
const { f: g, m } = instance.exports;
861+
const { f: g } = instance.exports;
851862
const metas: (Meta | undefined)[] = [];
852863
const n = func.numTypes();
853-
for (let t = 0; t < n; ++t)
854-
metas.push(getMeta(f, (m as WebAssembly.Memory).buffer, metas, t));
864+
for (let t = 0; t < n; ++t) metas.push(getMeta(f, metas, t));
855865
let total = 0;
856866
const params = Array.from(func.paramTypes()).map((t) => {
857867
const { encode, cost } = metas[t]!;
@@ -861,8 +871,10 @@ export const compile = async <const A extends readonly any[], const R>(
861871
});
862872
const { decode } = metas[func.retType()]!;
863873
return (...args): any => {
864-
const vals = params.map(({ encode, offset }, i) => encode(args[i], offset));
865-
return decode((g as any)(...vals, total));
874+
const vals = params.map(({ encode, offset }, i) =>
875+
encode(args[i], offset, memory.buffer),
876+
);
877+
return decode((g as any)(...vals, total), memory.buffer);
866878
};
867879
};
868880

packages/core/src/index.test.ts

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -643,17 +643,19 @@ describe("valid", () => {
643643
const memory = new WebAssembly.Memory({ initial: 0 });
644644
expect(memory.buffer.byteLength).toBe(0);
645645

646-
const f = fn([Vec(2, Real)], Real, ([x, y]) => mul(x, y));
646+
const f = fn([Vec(2, Real)], Vec(2, Real), ([x, y]) => [y, x]);
647647
const fCompiled = await compile(f, { memory });
648648
expect(memory.buffer.byteLength).toBe(pageSize);
649-
expect(fCompiled([2, 3])).toBe(6);
650649

651650
const n = 10000;
652651
const g = fn([Vec(n, Real), Vec(n, Real)], Vec(n, Real), (a, b) =>
653652
vec(n, Real, (i) => mul(a[i], b[i])),
654653
);
655654
const gCompiled = await compile(g, { memory });
656655
expect(memory.buffer.byteLength).toBeGreaterThan(pageSize);
656+
657+
expect(fCompiled([2, 3])).toEqual([3, 2]);
658+
657659
const a = [];
658660
const b = [];
659661
for (let i = 1; i <= n; ++i) {

0 commit comments

Comments
 (0)