Skip to content

Commit 48eb14e

Browse files
authored
Allow compiled functions to share memory (#125)
1 parent 8e3b8fd commit 48eb14e

File tree

4 files changed

+84
-22
lines changed

4 files changed

+84
-22
lines changed

crates/wasm/src/lib.rs

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use std::{
77
};
88
use wasm_encoder::{
99
BlockType, CodeSection, EntityType, ExportSection, Function, FunctionSection, ImportSection,
10-
Instruction, MemArg, MemorySection, MemoryType, Module, TypeSection, ValType,
10+
Instruction, MemArg, MemoryType, Module, TypeSection, ValType,
1111
};
1212

1313
/// Resolve `ty` via `generics` and `types`, then return its ID in `typemap`, inserting if need be.
@@ -976,18 +976,23 @@ impl<'a, 'b, O: Eq + Hash, T: Refs<'a, Opaque = O>> Codegen<'a, 'b, O, T> {
976976

977977
/// A WebAssembly module for a graph of functions.
978978
///
979-
/// The module exports its memory with name `"m"` and its entrypoint function with name `"f"`. The
980-
/// function takes one parameter in addition to its original parameters, which must be an
981-
/// 8-byte-aligned pointer to the start of the memory region it can use for allocation. The memory
982-
/// is the exact number of pages necessary to accommodate the function's own memory allocation as
983-
/// well as memory allocation for all of its parameters, with each node in each parameter's memory
984-
/// allocation tree being 8-byte aligned. That is, the function's last argument should be just large
985-
/// enough to accommodate those allocations for all the parameters; in that case, no memory will be
979+
/// The module exports its entrypoint function with name `"f"`. The function takes one parameter in
980+
/// addition to its original parameters, which must be an 8-byte-aligned pointer to the start of the
981+
/// memory region it can use for allocation.
982+
///
983+
/// Under module name `"m"`, the module imports a memory whose minimum number of pages is the exact
984+
/// number of pages necessary to accommodate the function's own memory allocation as well as memory
985+
/// allocation for all of its parameters, with each node in each parameter's memory allocation tree
986+
/// being 8-byte aligned. That is, the function's last argument should be just large enough to
987+
/// accommodate those allocations for all the parameters; in that case, no memory will be
986988
/// incorrectly overwritten and no out-of-bounds memory accesses will occur.
987989
pub struct Wasm<O> {
988990
/// The bytes of the WebAssembly module binary.
989991
pub bytes: Vec<u8>,
990992

993+
/// The minimum number of pages required by the imported memory.
994+
pub pages: u64,
995+
991996
/// All the opaque functions that the WebAssembly module must import, in order.
992997
///
993998
/// The module name for each import is the empty string, and the field name is the base-ten
@@ -1390,7 +1395,6 @@ pub fn compile<'a, O: Eq + Hash, T: Refs<'a, Opaque = O>>(f: Node<'a, O, T>) ->
13901395
type_section.function(params.into_vec(), results.into_vec());
13911396
}
13921397

1393-
let mut memory_section = MemorySection::new();
13941398
let page_size = 65536; // https://webassembly.github.io/spec/core/exec/runtime.html#page-size
13951399
let cost = funcs.last().map_or(0, |((def, _), (_, def_types, _))| {
13961400
def.params
@@ -1400,12 +1404,16 @@ pub fn compile<'a, O: Eq + Hash, T: Refs<'a, Opaque = O>>(f: Node<'a, O, T>) ->
14001404
.sum()
14011405
}) + costs.last().unwrap_or(&0);
14021406
let pages = ((cost + page_size - 1) / page_size).into(); // round up to a whole number of pages
1403-
memory_section.memory(MemoryType {
1404-
minimum: pages,
1405-
maximum: Some(pages),
1406-
memory64: false,
1407-
shared: false,
1408-
});
1407+
import_section.import(
1408+
"m",
1409+
"",
1410+
MemoryType {
1411+
minimum: pages,
1412+
maximum: None,
1413+
memory64: false,
1414+
shared: false,
1415+
},
1416+
);
14091417

14101418
let mut export_section = ExportSection::new();
14111419
export_section.export(
@@ -1419,11 +1427,11 @@ pub fn compile<'a, O: Eq + Hash, T: Refs<'a, Opaque = O>>(f: Node<'a, O, T>) ->
14191427
module.section(&type_section);
14201428
module.section(&import_section);
14211429
module.section(&function_section);
1422-
module.section(&memory_section);
14231430
module.section(&export_section);
14241431
module.section(&code_section);
14251432
Wasm {
14261433
bytes: module.finish(),
1434+
pages,
14271435
imports,
14281436
}
14291437
}

crates/web/src/lib.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,9 +330,14 @@ impl Func {
330330

331331
/// Compile the call graph subtended by this function to WebAssembly.
332332
pub fn compile(&self) -> Wasm {
333-
let rose_wasm::Wasm { bytes, imports } = rose_wasm::compile(self.node());
333+
let rose_wasm::Wasm {
334+
bytes,
335+
pages,
336+
imports,
337+
} = rose_wasm::compile(self.node());
334338
Wasm {
335339
bytes: Some(bytes),
340+
pages,
336341
imports: Some(
337342
imports
338343
.into_keys()
@@ -488,6 +493,7 @@ impl Func {
488493
#[wasm_bindgen]
489494
pub struct Wasm {
490495
bytes: Option<Vec<u8>>,
496+
pub pages: u64,
491497
imports: Option<Vec<js_sys::Function>>,
492498
}
493499

packages/core/src/impl.ts

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -814,19 +814,39 @@ const getMeta = (
814814
} else return undefined;
815815
};
816816

817-
/** Concretize the abstract function `f` using the compiler. */
817+
interface CompileOptions {
818+
memory?: WebAssembly.Memory;
819+
}
820+
821+
/**
822+
* Concretize the abstract function `f` using the compiler.
823+
*
824+
* Creates a new memory if `opts.memory` is not provided, otherwise attempts to
825+
* grow the provided memory to be large enough.
826+
*/
818827
export const compile = async <const A extends readonly any[], const R>(
819828
f: Fn & ((...args: A) => R),
829+
opts?: CompileOptions,
820830
): Promise<(...args: JsArgs<A>) => ToJs<R>> => {
821831
const func = f[inner];
822832
const res = func.compile();
823833
const bytes = res.bytes()!;
834+
const pages = Number(res.pages);
824835
const imports = res.imports()!;
825836
res.free();
826-
const instance = await WebAssembly.instantiate(
827-
await WebAssembly.compile(bytes),
828-
{ "": Object.fromEntries(imports.map((g, i) => [i.toString(), g])) },
829-
);
837+
let memory = opts?.memory;
838+
if (memory === undefined) memory = new WebAssembly.Memory({ initial: pages });
839+
else {
840+
// https://webassembly.github.io/spec/core/exec/runtime.html#page-size
841+
const pageSize = 65536;
842+
const delta = pages - memory.buffer.byteLength / pageSize;
843+
if (delta > 0) memory.grow(delta);
844+
}
845+
const mod = await WebAssembly.compile(bytes);
846+
const instance = await WebAssembly.instantiate(mod, {
847+
m: { "": memory },
848+
"": Object.fromEntries(imports.map((g, i) => [i.toString(), g])),
849+
});
830850
const { f: g, m } = instance.exports;
831851
const metas: (Meta | undefined)[] = [];
832852
const n = func.numTypes();

packages/core/src/index.test.ts

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -636,6 +636,34 @@ describe("valid", () => {
636636
expect(g(2, 3)).toBeCloseTo(-0.7785390719815313);
637637
});
638638

639+
test("compile with shared memory", async () => {
640+
// https://webassembly.github.io/spec/core/exec/runtime.html#page-size
641+
const pageSize = 65536;
642+
643+
const memory = new WebAssembly.Memory({ initial: 0 });
644+
expect(memory.buffer.byteLength).toBe(0);
645+
646+
const f = fn([Vec(2, Real)], Real, ([x, y]) => mul(x, y));
647+
const fCompiled = await compile(f, { memory });
648+
expect(memory.buffer.byteLength).toBe(pageSize);
649+
expect(fCompiled([2, 3])).toBe(6);
650+
651+
const n = 10000;
652+
const g = fn([Vec(n, Real), Vec(n, Real)], Vec(n, Real), (a, b) =>
653+
vec(n, Real, (i) => mul(a[i], b[i])),
654+
);
655+
const gCompiled = await compile(g, { memory });
656+
expect(memory.buffer.byteLength).toBeGreaterThan(pageSize);
657+
const a = [];
658+
const b = [];
659+
for (let i = 1; i <= n; ++i) {
660+
a.push(i);
661+
b.push(1 / i);
662+
}
663+
const c = gCompiled(a, b);
664+
for (let i = 0; i < n; ++i) expect(c[i]).toBeCloseTo(1);
665+
});
666+
639667
test("compile opaque function", async () => {
640668
const f = opaque([Real], Real, Math.sin);
641669
const g = await compile(f);

0 commit comments

Comments
 (0)