diff --git a/package-lock.json b/package-lock.json index c13b052..a47e8da 100644 --- a/package-lock.json +++ b/package-lock.json @@ -12,7 +12,8 @@ "@sourceacademy/conductor": "^0.2.3", "@types/estree": "^1.0.0", "fast-levenshtein": "^3.0.0", - "mathjs": "^14.4.0" + "mathjs": "^14.4.0", + "wabt": "^1.0.37" }, "devDependencies": { "@rollup/plugin-commonjs": "^28.0.3", @@ -38,6 +39,11 @@ "typescript": "^5.5.3" } }, + "../wasm-util": { + "version": "1.0.0", + "extraneous": true, + "license": "ISC" + }, "node_modules/@ampproject/remapping": { "version": "2.3.0", "resolved": "https://registry.npmjs.org/@ampproject/remapping/-/remapping-2.3.0.tgz", @@ -5813,6 +5819,23 @@ "node": ">=10.12.0" } }, + "node_modules/wabt": { + "version": "1.0.37", + "resolved": "https://registry.npmjs.org/wabt/-/wabt-1.0.37.tgz", + "integrity": "sha512-2B/TH4ppwtlkUosLtuIimKsTVnqM8aoXxYHnu/WOxiSqa+CGoZXmG+pQyfDQjEKIAc7GqFlJsuCKuK8rIPL1sg==", + "license": "Apache-2.0", + "bin": { + "wasm-decompile": "bin/wasm-decompile", + "wasm-interp": "bin/wasm-interp", + "wasm-objdump": "bin/wasm-objdump", + "wasm-stats": "bin/wasm-stats", + "wasm-strip": "bin/wasm-strip", + "wasm-validate": "bin/wasm-validate", + "wasm2c": "bin/wasm2c", + "wasm2wat": "bin/wasm2wat", + "wat2wasm": "bin/wat2wasm" + } + }, "node_modules/walker": { "version": "1.0.8", "resolved": "https://registry.npmjs.org/walker/-/walker-1.0.8.tgz", diff --git a/package.json b/package.json index 0c3eb6a..3d5c2e1 100644 --- a/package.json +++ b/package.json @@ -10,7 +10,8 @@ "build": "rollup -c --bundleConfigAsCjs", "start": "npm run build && node dist/index.js", "jsdoc": "./scripts/jsdoc.sh", - "test": "jest" + "test": "jest", + "wasm": "ts-node src/wasm-compiler/index.ts" }, "keywords": [ "Python", @@ -49,6 +50,7 @@ "@sourceacademy/conductor": "^0.2.3", "@types/estree": "^1.0.0", "fast-levenshtein": "^3.0.0", - "mathjs": "^14.4.0" + "mathjs": "^14.4.0", + "wabt": "^1.0.37" } } diff --git a/src/conductor/PyWasmEvaluator.ts b/src/conductor/PyWasmEvaluator.ts new file mode 100644 index 0000000..b1e3414 --- /dev/null +++ b/src/conductor/PyWasmEvaluator.ts @@ -0,0 +1,23 @@ +// This file is adapted from: +// https://github.com/source-academy/conductor +// Original author(s): Source Academy Team + +import { BasicEvaluator, IRunnerPlugin } from "@sourceacademy/conductor/runner"; +import { compileToWasmAndRun } from "../wasm-compiler/compile"; + +export default class PyEvaluator extends BasicEvaluator { + constructor(conductor: IRunnerPlugin) { + super(conductor); + } + + async evaluateChunk(chunk: string): Promise { + try { + const result = await compileToWasmAndRun(chunk); + this.conductor.sendOutput(result); + } catch (error) { + this.conductor.sendOutput( + `Error: ${error instanceof Error ? error.message : error}` + ); + } + } +} diff --git a/src/tests/wasm-compiler.spec.ts b/src/tests/wasm-compiler.spec.ts new file mode 100644 index 0000000..de61971 --- /dev/null +++ b/src/tests/wasm-compiler.spec.ts @@ -0,0 +1,153 @@ +import { compileToWasmAndRun } from "../wasm-compiler"; + +describe("Environment tests", () => { + it("captures outer variable by reference (mutation after definition visible)", async () => { + const pythonCode = ` +def outer(): + x = 1 + def inner(): + return x + x = 2 + return inner() +outer() +`; + const result = await compileToWasmAndRun(pythonCode); + expect(result).toEqual([0, BigInt(2)]); + }); + + it("inner variable shadows outer variable", async () => { + const pythonCode = ` +def outer(): + x = 5 + def inner(): + x = 7 + return x + return inner() +outer() +`; + const result = await compileToWasmAndRun(pythonCode); + expect(result).toEqual([0, BigInt(7)]); + }); + + it("inner function reads variable from outer scope", async () => { + const pythonCode = ` +def outer(): + x = 5 + def inner(): + return x + return inner() +outer() +`; + const result = await compileToWasmAndRun(pythonCode); + expect(result).toEqual([0, BigInt(5)]); + }); + + it("reassignment in outer scope after defining inner is visible to inner", async () => { + const pythonCode = ` +def outer(): + x = 1 + def inner(): + return x + x = 9 + return inner() +outer() +`; + const result = await compileToWasmAndRun(pythonCode); + expect(result).toEqual([0, BigInt(9)]); + }); + + it("nested closure can access variable from grandparent scope", async () => { + const pythonCode = ` +def grandparent(): + a = 10 + def parent(): + b = 5 + def child(): + return a + b + return child + return parent + +f = grandparent() +g = f() +g() +`; + const result = await compileToWasmAndRun(pythonCode); + expect(result).toEqual([0, BigInt(15)]); + }); + + it("each call to outer creates a new environment", async () => { + const pythonCode = ` +def make_number(n): + def get(): + return n + return get + +a = make_number(3) +b = make_number(10) +a() + b() +`; + const result = await compileToWasmAndRun(pythonCode); + expect(result).toEqual([0, BigInt(13)]); + }); + + it("function returned from outer retains access to outer variable", async () => { + const pythonCode = ` +def outer(): + x = 7 + def inner(): + return x + return inner + +f = outer() +f() +`; + const result = await compileToWasmAndRun(pythonCode); + expect(result).toEqual([0, BigInt(7)]); + }); + + it("returned function reflects reassignment in outer before return", async () => { + const pythonCode = ` +def outer(): + x = 3 + def inner(): + return x + x = 8 + return inner + +f = outer() +f() +`; + const result = await compileToWasmAndRun(pythonCode); + expect(result).toEqual([0, BigInt(8)]); + }); + + it("different closures capture independent variables", async () => { + const pythonCode = ` +def make_adder(n): + def add(x): + return n + x + return add + +add1 = make_adder(1) +add2 = make_adder(5) +add1(3) + add2(3) +`; + const result = await compileToWasmAndRun(pythonCode); + expect(result).toEqual([0, BigInt(12)]); + }); + + it("reusing same closure multiple times uses same environment", async () => { + const pythonCode = ` +def outer(): + x = 4 + def inner(): + return x + return inner + +f = outer() +f() + f() +`; + const result = await compileToWasmAndRun(pythonCode); + expect(result).toEqual([0, BigInt(8)]); + }); +}); diff --git a/src/wasm-compiler/.prettierrc b/src/wasm-compiler/.prettierrc new file mode 100644 index 0000000..9df4afa --- /dev/null +++ b/src/wasm-compiler/.prettierrc @@ -0,0 +1,11 @@ +{ + "printWidth": 80, + "overrides": [ + { + "files": "constants.ts", + "options": { + "printWidth": 120 + } + } + ] +} diff --git a/src/wasm-compiler/builderGenerator.ts b/src/wasm-compiler/builderGenerator.ts new file mode 100644 index 0000000..d4d4952 --- /dev/null +++ b/src/wasm-compiler/builderGenerator.ts @@ -0,0 +1,620 @@ +import { ExprNS, StmtNS } from "../ast-types"; +import { TokenType } from "../tokens"; +import { + ALLOC_ENV_FX, + APPLY_FX_NAME, + applyFuncFactory, + ARITHMETIC_OP_FX, + ARITHMETIC_OP_TAG, + BOOL_NOT_FX, + BOOLISE_FX, + COMPARISON_OP_FX, + COMPARISON_OP_TAG, + CURR_ENV, + GET_LEX_ADDR_FX, + GET_PAIR_HEAD_FX, + GET_PAIR_TAIL_FX, + HEAP_PTR, + importedLogs, + LOG_FX, + MAKE_BOOL_FX, + MAKE_CLOSURE_FX, + MAKE_COMPLEX_FX, + MAKE_FLOAT_FX, + MAKE_INT_FX, + MAKE_NONE_FX, + MAKE_PAIR_FX, + MAKE_STRING_FX, + nativeFunctions, + NEG_FX, + PRE_APPLY_FX, + SET_LEX_ADDR_FX, + SET_PAIR_HEAD_FX, + SET_PAIR_TAIL_FX, + SET_PARAM_FX, + TYPE_TAG, +} from "./constants"; +import { f64, global, i32, i64, local, mut, wasm } from "./wasm-util/builder"; +import { WasmInstruction, WasmNumeric, WasmRaw } from "./wasm-util/types"; + +const builtInFunctions: { + name: string; + arity: number; + body: WasmInstruction | WasmInstruction[]; + isVoid: boolean; +}[] = [ + { + name: "print", + arity: 1, + body: wasm + .call(LOG_FX) + .args(wasm.call(GET_LEX_ADDR_FX).args(i32.const(0), i32.const(0))), + isVoid: true, + }, + { + name: "pair", + arity: 2, + body: wasm + .call(MAKE_PAIR_FX) + .args( + wasm.call(GET_LEX_ADDR_FX).args(i32.const(0), i32.const(0)), + wasm.call(GET_LEX_ADDR_FX).args(i32.const(0), i32.const(1)) + ), + isVoid: false, + }, + { + name: "head", + arity: 1, + body: wasm + .call(GET_PAIR_HEAD_FX) + .args(wasm.call(GET_LEX_ADDR_FX).args(i32.const(0), i32.const(0))), + isVoid: false, + }, + { + name: "tail", + arity: 1, + body: wasm + .call(GET_PAIR_TAIL_FX) + .args(wasm.call(GET_LEX_ADDR_FX).args(i32.const(0), i32.const(0))), + isVoid: false, + }, + { + name: "set_head", + arity: 2, + body: wasm + .call(SET_PAIR_HEAD_FX) + .args( + wasm.call(GET_LEX_ADDR_FX).args(i32.const(0), i32.const(0)), + wasm.call(GET_LEX_ADDR_FX).args(i32.const(0), i32.const(1)) + ), + isVoid: true, + }, + { + name: "set_tail", + arity: 2, + body: wasm + .call(SET_PAIR_TAIL_FX) + .args( + wasm.call(GET_LEX_ADDR_FX).args(i32.const(0), i32.const(0)), + wasm.call(GET_LEX_ADDR_FX).args(i32.const(0), i32.const(1)) + ), + isVoid: true, + }, + { + name: "bool", + arity: 1, + body: [ + i32.const(TYPE_TAG.BOOL), + wasm + .call(BOOLISE_FX) + .args(wasm.call(GET_LEX_ADDR_FX).args(i32.const(0), i32.const(0))), + ], + isVoid: false, + }, +]; + +type Binding = { name: string; tag: "local" | "nonlocal" }; + +interface BuilderVisitor extends StmtNS.Visitor, ExprNS.Visitor { + visit(stmt: StmtNS.Stmt): S; + visit(stmt: ExprNS.Expr): E; + visit(stmt: StmtNS.Stmt | ExprNS.Expr): S | E; +} + +export class BuilderGenerator + implements BuilderVisitor +{ + private strings: [string, number][] = []; + private heapPointer = 0; + + private environment: Binding[][] = [[]]; + private userFunctions: WasmInstruction[][] = []; + + private getLexAddress(name: string): [number, number] { + for (let i = this.environment.length - 1; i >= 0; i--) { + const curr = this.environment[i]; + const index = curr.findIndex((b) => b.name === name); + + if (index !== -1) { + // check if variable is used before nonlocal declaration + if (curr[index].tag === "nonlocal") { + throw new Error( + `Name ${curr[index].name} is used prior to nonlocal declaration` + ); + } + + return [this.environment.length - 1 - i, index]; + } + } + throw new Error(`Name ${name} not defined!`); + } + + private collectDeclarations( + statements: StmtNS.Stmt[], + parameters?: StmtNS.FunctionDef["parameters"] + ): Binding[] { + const bindings: Binding[] = statements + .filter( + (s) => s instanceof StmtNS.Assign || s instanceof StmtNS.FunctionDef + ) + .map((s) => ({ name: s.name.lexeme, tag: "local" })); + + statements + .filter((s) => s instanceof StmtNS.NonLocal) + .map((s) => s.name.lexeme) + .forEach((name) => { + // nonlocal declaration must exist in a nonlocal scope + if ( + !this.environment.find( + (frame, i) => + i !== 0 && frame.find((binding) => binding.name === name) + ) + ) + throw new Error(`No binding for nonlocal ${name} found!`); + + // cannot declare parameter name as nonlocal + if (parameters && parameters.map((p) => p.lexeme).includes(name)) { + throw new Error(`${name} is parameter and nonlocal`); + } + + for (let i = 0; i < bindings.length; i++) { + const binding = bindings[i]; + if (binding.name === name) { + // tag this binding as nonlocal so + // if it's accessed before its nonlocal statement, + // throw error + bindings[i].tag = "nonlocal"; + } + } + }); + + return [ + ...(parameters?.map((p) => ({ name: p.lexeme, tag: "local" as const })) ?? + []), + ...bindings, + ]; + } + + visit(stmt: StmtNS.Stmt): WasmInstruction; + visit(stmt: ExprNS.Expr): WasmNumeric; + visit(stmt: StmtNS.Stmt | ExprNS.Expr): WasmInstruction | WasmNumeric { + return stmt.accept(this); + } + + visitFileInputStmt(stmt: StmtNS.FileInput): WasmInstruction { + if (stmt.statements.length <= 0) { + console.log("No statements found"); + throw new Error("No statements found"); + } + + // declare built-in functions in the global environment before user code + const builtInFuncsDeclarations = builtInFunctions.map( + ({ name, arity, body, isVoid }, i) => { + this.environment[0].push({ name, tag: "local" }); + const tag = this.userFunctions.length; + const newBody = [ + ...(Array.isArray(body) ? body : [body]), + wasm.return( + ...(isVoid ? [wasm.call(MAKE_NONE_FX)] : []), + global.set(CURR_ENV, local.get("$return_env")) + ), + ]; + this.userFunctions.push(newBody); + + return wasm + .call(SET_LEX_ADDR_FX) + .args( + i32.const(0), + i32.const(i), + wasm + .call(MAKE_CLOSURE_FX) + .args( + i32.const(tag), + i32.const(arity), + i32.const(arity), + global.get(CURR_ENV) + ) + ); + } + ); + + this.environment[0].push(...this.collectDeclarations(stmt.statements)); + + const body = stmt.statements.map((s) => this.visit(s)); + + // this matches the format of drop in visitSimpleExpr + const lastInstr = body.at(-1); + const undroppedInstr = + lastInstr?.op === "drop" && + lastInstr.value?.op === "drop" && + lastInstr.value.value; + + // collect all strings, native functions used and user functions + const strings = this.strings.map(([str, add]) => + wasm.data(i32.const(add), str) + ); + + const applyFunction = applyFuncFactory(this.userFunctions); + + // because each variable has a tag and payload = 3 words + const globalEnvLength = this.environment[0].length; + + return wasm + .module() + .imports(wasm.import("js", "memory").memory(1), ...importedLogs) + .globals( + wasm.global(HEAP_PTR, mut.i32).init(i32.const(this.heapPointer)), + wasm.global(CURR_ENV, mut.i32).init(i32.const(0)) + ) + .datas(...strings) + .funcs( + ...nativeFunctions, + applyFunction, + + wasm + .func("$main") + .results(...(undroppedInstr ? [i32, i64] : [])) + .body( + global.set( + CURR_ENV, + wasm + .call(ALLOC_ENV_FX) + .args(i32.const(globalEnvLength), i32.const(0), i32.const(0)) + ), + + ...builtInFuncsDeclarations, + + ...(undroppedInstr ? [...body.slice(0, -1), undroppedInstr] : body) + ) + ) + .exports(wasm.export("main").func("$main")) + .build(); + } + + visitSimpleExprStmt(stmt: StmtNS.SimpleExpr): WasmInstruction { + const expr = this.visit(stmt.expression); + return wasm.drop(wasm.drop(expr)); + } + + visitGroupingExpr(expr: ExprNS.Grouping): WasmNumeric { + return this.visit(expr.expression); + } + + visitBinaryExpr(expr: ExprNS.Binary): WasmNumeric { + const left = this.visit(expr.left); + const right = this.visit(expr.right); + + const type = expr.operator.type; + let opTag: number; + if (type === TokenType.PLUS) opTag = ARITHMETIC_OP_TAG.ADD; + else if (type === TokenType.MINUS) opTag = ARITHMETIC_OP_TAG.SUB; + else if (type === TokenType.STAR) opTag = ARITHMETIC_OP_TAG.MUL; + else if (type === TokenType.SLASH) opTag = ARITHMETIC_OP_TAG.DIV; + else throw new Error(`Unsupported binary operator: ${type}`); + + return wasm.call(ARITHMETIC_OP_FX).args(left, right, i32.const(opTag)); + } + + visitCompareExpr(expr: ExprNS.Compare): WasmNumeric { + const left = this.visit(expr.left); + const right = this.visit(expr.right); + + const type = expr.operator.type; + let opTag: number; + if (type === TokenType.DOUBLEEQUAL) opTag = COMPARISON_OP_TAG.EQ; + else if (type === TokenType.NOTEQUAL) opTag = COMPARISON_OP_TAG.NEQ; + else if (type === TokenType.LESS) opTag = COMPARISON_OP_TAG.LT; + else if (type === TokenType.LESSEQUAL) opTag = COMPARISON_OP_TAG.LTE; + else if (type === TokenType.GREATER) opTag = COMPARISON_OP_TAG.GT; + else if (type === TokenType.GREATEREQUAL) opTag = COMPARISON_OP_TAG.GTE; + else throw new Error(`Unsupported comparison operator: ${type}`); + + return wasm.call(COMPARISON_OP_FX).args(left, right, i32.const(opTag)); + } + + visitUnaryExpr(expr: ExprNS.Unary): WasmNumeric { + const right = this.visit(expr.right); + + const type = expr.operator.type; + if (type === TokenType.MINUS) return wasm.call(NEG_FX).args(right); + else if (type === TokenType.NOT) return wasm.call(BOOL_NOT_FX).args(right); + else throw new Error(`Unsupported unary operator: ${type}`); + } + + visitBoolOpExpr(expr: ExprNS.BoolOp): WasmNumeric { + const left = this.visit(expr.left); + const right = this.visit(expr.right); + + const type = expr.operator.type; + + // not a wasm function as it needs to short-circuit + if (type === TokenType.AND) { + // if x is false, then x else y + return wasm + .if(i64.eqz(wasm.call(BOOLISE_FX).args(left))) + .results(i32, i64) + .then(left) + .else(right) as unknown as WasmNumeric; // these WILL return WasmNumeric + } else if (type === TokenType.OR) { + // if x is false, then y else x + return wasm + .if(i64.eqz(wasm.call(BOOLISE_FX).args(left))) + .results(i32, i64) + .then(right) + .else(left) as unknown as WasmNumeric; + } else throw new Error(`Unsupported boolean binary operator: ${type}`); + } + + visitTernaryExpr(expr: ExprNS.Ternary): WasmNumeric { + const consequent = this.visit(expr.consequent); + const alternative = this.visit(expr.alternative); + + const predicate = this.visit(expr.predicate); + + return wasm + .if(i32.wrap_i64(wasm.call(BOOLISE_FX).args(predicate))) + .results(i32, i64) + .then(consequent) + .else(alternative) as unknown as WasmNumeric; + } + + visitNoneExpr(expr: ExprNS.None): WasmNumeric { + return wasm.call(MAKE_NONE_FX); + } + + visitBigIntLiteralExpr(expr: ExprNS.BigIntLiteral): WasmNumeric { + const value = BigInt(expr.value); + const min = BigInt("-9223372036854775808"); // -(2^63) + const max = BigInt("9223372036854775807"); // (2^63) - 1 + if (value < min || value > max) { + throw new Error(`BigInt literal out of bounds: ${expr.value}`); + } + + return wasm.call(MAKE_INT_FX).args(i64.const(value)); + } + + visitLiteralExpr(expr: ExprNS.Literal): WasmNumeric { + if (typeof expr.value === "number") + return wasm.call(MAKE_FLOAT_FX).args(f64.const(expr.value)); + else if (typeof expr.value === "boolean") + return wasm.call(MAKE_BOOL_FX).args(i32.const(expr.value ? 1 : 0)); + else if (typeof expr.value === "string") { + const str = expr.value; + const len = str.length; + const toReturn = wasm + .call(MAKE_STRING_FX) + .args(i32.const(this.heapPointer), i32.const(len)); + + this.strings.push([str, this.heapPointer]); + this.heapPointer += len; + return toReturn; + } else { + throw new Error(`Unsupported literal type: ${typeof expr.value}`); + } + } + + visitComplexExpr(expr: ExprNS.Complex): WasmNumeric { + return wasm + .call(MAKE_COMPLEX_FX) + .args(f64.const(expr.value.real), f64.const(expr.value.imag)); + } + + visitAssignStmt(stmt: StmtNS.Assign): WasmInstruction { + const [depth, index] = this.getLexAddress(stmt.name.lexeme); + const expression = this.visit(stmt.value); + + return wasm + .call(SET_LEX_ADDR_FX) + .args(i32.const(depth), i32.const(index), expression); + } + + visitVariableExpr(expr: ExprNS.Variable): WasmNumeric { + const [depth, index] = this.getLexAddress(expr.name.lexeme); + return wasm.call(GET_LEX_ADDR_FX).args(i32.const(depth), i32.const(index)); + } + + visitFunctionDefStmt(stmt: StmtNS.FunctionDef): WasmInstruction { + const [depth, index] = this.getLexAddress(stmt.name.lexeme); + const arity = stmt.parameters.length; + const tag = this.userFunctions.length; + this.userFunctions.push([]); // placeholder + + const newFrame = this.collectDeclarations(stmt.body, stmt.parameters); + + if (tag >= 1 << 16) + throw new Error("Tag cannot be above 16-bit integer limit"); + if (arity >= 1 << 8) + throw new Error("Arity cannot be above 8-bit integer limit"); + if (newFrame.length > 1 << 8) + throw new Error("Environment length cannot be above 8-bit integer limit"); + + this.environment.push(newFrame); + const body = stmt.body.map((s) => this.visit(s)); + this.environment.pop(); + + this.userFunctions[tag] = body; + + return wasm + .call(SET_LEX_ADDR_FX) + .args( + i32.const(depth), + i32.const(index), + wasm + .call(MAKE_CLOSURE_FX) + .args( + i32.const(tag), + i32.const(arity), + i32.const(newFrame.length), + global.get(CURR_ENV) + ) + ); + } + + visitLambdaExpr(expr: ExprNS.Lambda): WasmNumeric { + const arity = expr.parameters.length; + const tag = this.userFunctions.length; + this.userFunctions.push([]); // placeholder + + // no statements allowed in lambdas, so there won't be any new local declarations + // other than parameters + const newFrame = this.collectDeclarations([], expr.parameters); + + if (tag >= 1 << 16) + throw new Error("Tag cannot be above 16-bit integer limit"); + if (arity >= 1 << 8) + throw new Error("Arity cannot be above 8-bit integer limit"); + if (newFrame.length > 1 << 8) + throw new Error("Environment length cannot be above 8-bit integer limit"); + + this.environment.push(newFrame); + const body = this.visit(expr.body); + this.environment.pop(); + + this.userFunctions[tag] = [wasm.return(body)]; + + return wasm + .call(MAKE_CLOSURE_FX) + .args( + i32.const(tag), + i32.const(arity), + i32.const(newFrame.length), + global.get(CURR_ENV) + ); + } + + visitCallExpr(expr: ExprNS.Call): WasmRaw { + const callee = this.visit(expr.callee); + const args = expr.args.map((arg) => this.visit(arg)); + + // PRE_APPLY returns (1, 2) callee tag and value, (3) pointer to new environment + // APPLY expects (1) pointer to return environment, (2, 3) callee tag and value + + // we call PRE_APPLY first, which verifies the callee is a closure and arity matches + // AND creates a new environment for the function call, but does not set CURR_ENV yet + // this is so that we can set the arguments in the new environment first + + // this means we can't use SET_LEX_ADDR_FX because it uses CURR_ENV internally + // so we manually set the arguments in the new environment using SET_PARAM_FX + + // the SET_PARAM function returns the env address after setting the parameter + // so we can chain the calls together + return wasm.raw` +${global.get(CURR_ENV)} +${wasm.call(PRE_APPLY_FX).args(callee, i32.const(args.length))} + +${args.map( + (arg, i) => + wasm.raw` +(i32.const ${i * 12}) (i32.add) ${arg} (call ${SET_PARAM_FX.name})` +)} + +(global.set ${CURR_ENV}) +(call ${APPLY_FX_NAME}) +`; + } + + visitReturnStmt(stmt: StmtNS.Return): WasmInstruction { + const value = stmt.value; + + return wasm.return( + value ? this.visit(value) : wasm.call(MAKE_NONE_FX), + global.set(CURR_ENV, local.get("$return_env")) + ); + } + + visitNonLocalStmt(stmt: StmtNS.NonLocal): WasmInstruction { + // because of this.collectDeclarations, this nonlocal declaration + // is guaranteed to have a nonlocal (and not global) binding. + // because of this.getLexAddress, it's also guaranteed to not have been + // used illegally before this statement. + // all that's left to do is remove the binding from the compile time environment + // from here onwards (from the local frame). + // if it doesn't exist in the local frame, do nothing as the statement has + // no effect + + const currFrame = this.environment.at(-1); + const bindingIndex = currFrame?.findIndex( + (binding) => binding.name === stmt.name.lexeme + ); + + if (bindingIndex != null) { + currFrame?.splice(bindingIndex, 1); + } + + return wasm.nop(); + } + + visitIfStmt(stmt: StmtNS.If): WasmInstruction { + const condition = this.visit(stmt.condition); + const body = stmt.body.map((b) => this.visit(b)); + const elseBody = stmt.elseBlock?.map((e) => this.visit(e)); + + return elseBody + ? wasm + .if(i32.wrap_i64(wasm.call(BOOLISE_FX).args(condition))) + .then(...body) + .else(...elseBody) + : wasm + .if(i32.wrap_i64(wasm.call(BOOLISE_FX).args(condition))) + .then(...body); + } + + visitPassStmt(stmt: StmtNS.Pass): WasmInstruction { + return wasm.nop(); + } + + // UNIMPLEMENTED PYTHON CONSTRUCTS + visitMultiLambdaExpr(expr: ExprNS.MultiLambda): WasmNumeric { + throw new Error("Method not implemented."); + } + visitIndentCreation(stmt: StmtNS.Indent): WasmInstruction { + throw new Error("Method not implemented."); + } + visitDedentCreation(stmt: StmtNS.Dedent): WasmInstruction { + throw new Error("Method not implemented."); + } + visitAnnAssignStmt(stmt: StmtNS.AnnAssign): WasmInstruction { + throw new Error("Method not implemented."); + } + visitBreakStmt(stmt: StmtNS.Break): WasmInstruction { + throw new Error("Method not implemented."); + } + visitContinueStmt(stmt: StmtNS.Continue): WasmInstruction { + throw new Error("Method not implemented."); + } + visitFromImportStmt(stmt: StmtNS.FromImport): WasmInstruction { + throw new Error("Method not implemented."); + } + visitGlobalStmt(stmt: StmtNS.Global): WasmInstruction { + throw new Error("Method not implemented."); + } + visitAssertStmt(stmt: StmtNS.Assert): WasmInstruction { + throw new Error("Method not implemented."); + } + visitWhileStmt(stmt: StmtNS.While): WasmInstruction { + throw new Error("Method not implemented."); + } + visitForStmt(stmt: StmtNS.For): WasmInstruction { + throw new Error("Method not implemented."); + } +} diff --git a/src/wasm-compiler/constants.ts b/src/wasm-compiler/constants.ts new file mode 100644 index 0000000..2922f43 --- /dev/null +++ b/src/wasm-compiler/constants.ts @@ -0,0 +1,871 @@ +import { f64, global, i32, i64, local, memory, wasm } from "./wasm-util/builder"; +import { WasmInstruction } from "./wasm-util/types"; + +// tags +export const TYPE_TAG = { + INT: 0, + FLOAT: 1, + COMPLEX: 2, + BOOL: 3, + STRING: 4, + CLOSURE: 5, + NONE: 6, + UNBOUND: 7, + PAIR: 8, +} as const; + +export const ERROR_MAP = { + NEG_NOT_SUPPORT: [0, "Unary minus operator used on unsupported operand."], + LOG_UNKNOWN_TYPE: [1, "Calling log on an unknown runtime type."], + ARITH_OP_UNKNOWN_TYPE: [2, "Calling an arithmetic operation on an unsupported runtime type."], + COMPLEX_COMPARISON: [3, "Using an unsupported comparison operator on complex type."], + COMPARE_OP_UNKNOWN_TYPE: [4, "Calling a comparison operation on unsupported operands."], + CALL_NOT_FX: [5, "Calling a non-function value."], + FUNC_WRONG_ARITY: [6, "Calling function with wrong number of arguments."], + UNBOUND: [7, "Accessing an unbound value."], + HEAD_NOT_PAIR: [8, "Accessing the head of a non-pair value."], + TAIL_NOT_PAIR: [9, "Accessing the tail of a non-pair value."], + BOOL_UNKNOWN_TYPE: [10, "Trying to convert an unknnown runtime type to a bool."], + BOOL_UNKNOWN_OP: [11, "Unknown boolean binary operator."], +} as const; + +export const HEAP_PTR = "$_heap_pointer"; +export const CURR_ENV = "$_current_env"; + +// boxing functions + +// store directly in payload +export const MAKE_INT_FX = wasm + .func("$_make_int") + .params({ $value: i64 }) + .results(i32, i64) + .body(i32.const(TYPE_TAG.INT), local.get("$value")); + +// reinterpret bits as int +export const MAKE_FLOAT_FX = wasm + .func("$_make_float") + .params({ $value: f64 }) + .results(i32, i64) + .body(i32.const(TYPE_TAG.FLOAT), i64.reinterpret_f64(local.get("$value"))); + +// upper 32: pointer to f64 real part; lower 32: pointer to f64 imaginary part +export const MAKE_COMPLEX_FX = wasm + .func("$_make_complex") + .params({ $real: f64, $img: f64 }) + .results(i32, i64) + .body( + f64.store(global.get(HEAP_PTR), local.get("$real")), + f64.store(i32.add(global.get(HEAP_PTR), i32.const(8)), local.get("$img")), + + i32.const(TYPE_TAG.COMPLEX), + i64.extend_i32_u(global.get(HEAP_PTR)), + + global.set(HEAP_PTR, i32.add(global.get(HEAP_PTR), i32.const(16))) + ); + +// store directly as i32 +export const MAKE_BOOL_FX = wasm + .func("$_make_bool") + .params({ $value: i32 }) + .results(i32, i64) + .body( + i32.const(TYPE_TAG.BOOL), + wasm + .if(i32.eqz(local.get("$value"))) + .results(i64) + .then(i64.const(0)) + .else(i64.const(1)) + ); + +// upper 32: pointer; lower 32: length +export const MAKE_STRING_FX = wasm + .func("$_make_string") + .params({ $ptr: i32, $len: i32 }) + .results(i32, i64) + .body( + i32.const(TYPE_TAG.STRING), + i64.or(i64.shl(i64.extend_i32_u(local.get("$ptr")), i64.const(32)), i64.extend_i32_u(local.get("$len"))) + ); + +// upper 16: tag; upperMid 8: arity; lowerMid 8: envSize; lower 32: parentEnv +export const MAKE_CLOSURE_FX = wasm + .func("$_make_closure") + .params({ $tag: i32, $arity: i32, $env_size: i32, $parent_env: i32 }) + .results(i32, i64) + .body( + i32.const(TYPE_TAG.CLOSURE), + + i64.or( + i64.or( + i64.or( + i64.shl(i64.extend_i32_u(local.get("$tag")), i64.const(48)), + i64.shl(i64.extend_i32_u(local.get("$arity")), i64.const(40)) + ), + i64.shl(i64.extend_i32_u(local.get("$env_size")), i64.const(32)) + ), + i64.extend_i32_u(local.get("$parent_env")) + ) + ); + +export const MAKE_NONE_FX = wasm.func("$_make_none").results(i32, i64).body(i32.const(TYPE_TAG.NONE), i64.const(0)); + +// pair-related functions + +// upper 32: pointer to head; lower 32: pointer to tail +export const MAKE_PAIR_FX = wasm + .func("$_make_pair") + .params({ $tag1: i32, $val1: i64, $tag2: i32, $val2: i64 }) + .results(i32, i64) + .body( + i32.store(global.get(HEAP_PTR), local.get("$tag1")), + i64.store(i32.add(global.get(HEAP_PTR), i32.const(4)), local.get("$val1")), + i32.store(i32.add(global.get(HEAP_PTR), i32.const(12)), local.get("$tag2")), + i64.store(i32.add(global.get(HEAP_PTR), i32.const(16)), local.get("$val2")), + + i32.const(TYPE_TAG.PAIR), + i64.extend_i32_u(global.get(HEAP_PTR)), + + global.set(HEAP_PTR, i32.add(global.get(HEAP_PTR), i32.const(24))) + ); + +export const GET_PAIR_HEAD_FX = wasm + .func("$_get_pair_head") + .params({ $tag: i32, $val: i64 }) + .results(i32, i64) + .body( + wasm + .if(i32.ne(local.get("$tag"), i32.const(TYPE_TAG.PAIR))) + .then(wasm.call("$_log_error").args(i32.const(ERROR_MAP.HEAD_NOT_PAIR[0])), wasm.unreachable()), + + i32.load(i32.wrap_i64(local.get("$val"))), + i64.load(i32.add(i32.wrap_i64(local.get("$val")), i32.const(4))) + ); + +export const GET_PAIR_TAIL_FX = wasm + .func("$_get_pair_tail") + .params({ $tag: i32, $val: i64 }) + .results(i32, i64) + .body( + wasm + .if(i32.ne(local.get("$tag"), i32.const(TYPE_TAG.PAIR))) + .then(wasm.call("$_log_error").args(i32.const(ERROR_MAP.TAIL_NOT_PAIR[0])), wasm.unreachable()), + + i32.load(i32.add(i32.wrap_i64(local.get("$val")), i32.const(12))), + i64.load(i32.add(i32.wrap_i64(local.get("$val")), i32.const(16))) + ); + +export const SET_PAIR_HEAD_FX = wasm + .func("$_set_pair_head") + .params({ $pair_tag: i32, $pair_val: i64, $tag: i32, $val: i64 }) + .body( + wasm + .if(i32.ne(local.get("$pair_tag"), i32.const(TYPE_TAG.PAIR))) + .then(wasm.call("$_log_error").args(i32.const(ERROR_MAP.HEAD_NOT_PAIR[0])), wasm.unreachable()), + + i32.store(i32.wrap_i64(local.get("$pair_val")), local.get("$tag")), + i64.store(i32.add(i32.wrap_i64(local.get("$pair_val")), i32.const(4)), local.get("$val")) + ); + +export const SET_PAIR_TAIL_FX = wasm + .func("$_set_pair_tail") + .params({ $pair_tag: i32, $pair_val: i64, $tag: i32, $val: i64 }) + .body( + wasm + .if(i32.ne(local.get("$pair_tag"), i32.const(TYPE_TAG.PAIR))) + .then(wasm.call("$_log_error").args(i32.const(ERROR_MAP.TAIL_NOT_PAIR[0])), wasm.unreachable()), + + i32.store(i32.add(i32.wrap_i64(local.get("$pair_val")), i32.const(12)), local.get("$tag")), + i64.store(i32.add(i32.wrap_i64(local.get("$pair_val")), i32.const(16)), local.get("$val")) + ); + +// logging functions +export const importedLogs = [ + wasm.import("console", "log").func("$_log_int").params(i64), + wasm.import("console", "log").func("$_log_float").params(f64), + wasm.import("console", "log_complex").func("$_log_complex").params(f64, f64), + wasm.import("console", "log_bool").func("$_log_bool").params(i64), + wasm.import("console", "log_string").func("$_log_string").params(i32, i32), + wasm.import("console", "log_closure").func("$_log_closure").params(i32, i32, i32, i32), + wasm.import("console", "log_none").func("$_log_none"), + wasm.import("console", "log_error").func("$_log_error").params(i32), +]; + +export const LOG_FX = wasm + .func("$_log") + .params({ $tag: i32, $value: i64 }) + .body( + wasm + .if(i32.eq(local.get("$tag"), i32.const(TYPE_TAG.INT))) + .then(wasm.call("$_log_int").args(local.get("$value")), wasm.return()), + wasm + .if(i32.eq(local.get("$tag"), i32.const(TYPE_TAG.FLOAT))) + .then(wasm.call("$_log_float").args(f64.reinterpret_i64(local.get("$value"))), wasm.return()), + wasm + .if(i32.eq(local.get("$tag"), i32.const(TYPE_TAG.COMPLEX))) + .then( + wasm + .call("$_log_complex") + .args( + f64.load(i32.wrap_i64(local.get("$value"))), + f64.load(i32.add(i32.wrap_i64(local.get("$value")), i32.const(8))) + ), + wasm.return() + ), + wasm + .if(i32.eq(local.get("$tag"), i32.const(TYPE_TAG.BOOL))) + .then(wasm.call("$_log_bool").args(local.get("$value")), wasm.return()), + wasm + .if(i32.eq(local.get("$tag"), i32.const(TYPE_TAG.STRING))) + .then( + wasm + .call("$_log_string") + .args(i32.wrap_i64(i64.shr_u(local.get("$value"), i64.const(32))), i32.wrap_i64(local.get("$value"))), + wasm.return() + ), + wasm + .if(i32.eq(local.get("$tag"), i32.const(TYPE_TAG.CLOSURE))) + .then( + wasm + .call("$_log_closure") + .args( + i32.and(i32.wrap_i64(i64.shr_u(local.get("$value"), i64.const(48))), i32.const(65535)), + i32.and(i32.wrap_i64(i64.shr_u(local.get("$value"), i64.const(40))), i32.const(255)), + i32.and(i32.wrap_i64(i64.shr_u(local.get("$value"), i64.const(32))), i32.const(255)), + i32.wrap_i64(local.get("$value")) + ), + wasm.return() + ), + wasm.if(i32.eq(local.get("$tag"), i32.const(TYPE_TAG.NONE))).then(wasm.call("$_log_none"), wasm.return()), + wasm + .if(i32.eq(local.get("$tag"), i32.const(TYPE_TAG.PAIR))) + .then( + wasm.call("$_log").args(wasm.call(GET_PAIR_HEAD_FX).args(local.get("$tag"), local.get("$value"))), + wasm.call("$_log").args(wasm.call(GET_PAIR_TAIL_FX).args(local.get("$tag"), local.get("$value"))), + wasm.return() + ), + + wasm.call("$_log_error").args(i32.const(ERROR_MAP.LOG_UNKNOWN_TYPE[0])), + wasm.unreachable() + ); + +// unary operation functions +export const NEG_FX = wasm + .func("$_py_neg") + .params({ $x_tag: i32, $x_val: i64 }) + .results(i32, i64) + .body( + wasm + .if(i32.eq(local.get("$x_tag"), i32.const(TYPE_TAG.INT))) + .then( + wasm.return(wasm.call(MAKE_INT_FX).args(i64.add(i64.xor(local.get("$x_val"), i64.const(-1)), i64.const(1)))) + ), + + wasm + .if(i32.eq(local.get("$x_tag"), i32.const(TYPE_TAG.FLOAT))) + .then(wasm.return(wasm.call(MAKE_FLOAT_FX).args(f64.neg(f64.reinterpret_i64(local.get("$x_val")))))), + + wasm + .if(i32.eq(local.get("$x_tag"), i32.const(TYPE_TAG.COMPLEX))) + .then( + wasm.return( + wasm + .call(MAKE_COMPLEX_FX) + .args( + f64.neg(f64.load(i32.wrap_i64(local.get("$x_val")))), + f64.neg(f64.load(i32.add(i32.wrap_i64(local.get("$x_val")), i32.const(8)))) + ) + ) + ), + + wasm.call("$_log_error").args(i32.const(ERROR_MAP.NEG_NOT_SUPPORT[0])), + wasm.unreachable() + ); + +export const ARITHMETIC_OP_TAG = { ADD: 0, SUB: 1, MUL: 2, DIV: 3 } as const; +// binary operation function +export const ARITHMETIC_OP_FX = wasm + .func("$_py_arith_op") + .params({ $x_tag: i32, $x_val: i64, $y_tag: i32, $y_val: i64, $op: i32 }) + .results(i32, i64) + .locals({ $a: f64, $b: f64, $c: f64, $d: f64, $denom: f64 }) + .body( + // if adding, check if both are strings + wasm + .if( + i32.and( + i32.eq(local.get("$op"), i32.const(ARITHMETIC_OP_TAG.ADD)), + i32.and( + i32.eq(local.get("$x_tag"), i32.const(TYPE_TAG.STRING)), + i32.eq(local.get("$y_tag"), i32.const(TYPE_TAG.STRING)) + ) + ) + ) + .then( + global.get(HEAP_PTR), // starting address of new string + + memory.copy( + global.get(HEAP_PTR), + i32.wrap_i64(i64.shr_u(local.get("$x_val"), i64.const(32))), + i32.wrap_i64(local.get("$x_val")) + ), + global.set(HEAP_PTR, i32.add(global.get(HEAP_PTR), i32.wrap_i64(local.get("$x_val")))), + memory.copy( + global.get(HEAP_PTR), + i32.wrap_i64(i64.shr_u(local.get("$y_val"), i64.const(32))), + i32.wrap_i64(local.get("$y_val")) + ), + global.set(HEAP_PTR, i32.add(global.get(HEAP_PTR), i32.wrap_i64(local.get("$y_val")))), + i32.add(i32.wrap_i64(local.get("$x_val")), i32.wrap_i64(local.get("$y_val"))), + + wasm.return(wasm.call(MAKE_STRING_FX).args()) + ), + + // if either's bool, convert to int + wasm.if(i32.eq(local.get("$x_tag"), i32.const(TYPE_TAG.BOOL))).then(local.set("$x_tag", i32.const(TYPE_TAG.INT))), + wasm.if(i32.eq(local.get("$y_tag"), i32.const(TYPE_TAG.BOOL))).then(local.set("$y_tag", i32.const(TYPE_TAG.INT))), + + // if both int, use int instr (except for division: use float) + wasm + .if( + i32.and( + i32.eq(local.get("$x_tag"), i32.const(TYPE_TAG.INT)), + i32.eq(local.get("$y_tag"), i32.const(TYPE_TAG.INT)) + ) + ) + .then( + ...wasm.buildBrTableBlocks( + wasm.br_table(local.get("$op"), "$add", "$sub", "$mul", "$div"), + wasm.return(wasm.call(MAKE_INT_FX).args(i64.add(local.get("$x_val"), local.get("$y_val")))), + wasm.return(wasm.call(MAKE_INT_FX).args(i64.sub(local.get("$x_val"), local.get("$y_val")))), + wasm.return(wasm.call(MAKE_INT_FX).args(i64.mul(local.get("$x_val"), local.get("$y_val")))), + wasm.return( + wasm + .call(MAKE_FLOAT_FX) + .args(f64.div(f64.convert_i64_s(local.get("$x_val")), f64.convert_i64_s(local.get("$y_val")))) + ) + ) + ), + + // else, if either's int, convert to float and set float locals + wasm + .if(i32.eq(local.get("$x_tag"), i32.const(TYPE_TAG.INT))) + .then(local.set("$a", f64.convert_i64_s(local.get("$x_val"))), local.set("$x_tag", i32.const(TYPE_TAG.FLOAT))) + .else(local.set("$a", f64.reinterpret_i64(local.get("$x_val")))), + + wasm + .if(i32.eq(local.get("$y_tag"), i32.const(TYPE_TAG.INT))) + .then(local.set("$c", f64.convert_i64_s(local.get("$y_val"))), local.set("$y_tag", i32.const(TYPE_TAG.FLOAT))) + .else(local.set("$c", f64.reinterpret_i64(local.get("$y_val")))), + + // if both float, use float instr + wasm + .if( + i32.and( + i32.eq(local.get("$x_tag"), i32.const(TYPE_TAG.FLOAT)), + i32.eq(local.get("$y_tag"), i32.const(TYPE_TAG.FLOAT)) + ) + ) + .then( + ...wasm.buildBrTableBlocks( + wasm.br_table(local.get("$op"), "$add", "$sub", "$mul", "$div"), + wasm.return(wasm.call(MAKE_FLOAT_FX).args(f64.add(local.get("$a"), local.get("$c")))), + wasm.return(wasm.call(MAKE_FLOAT_FX).args(f64.sub(local.get("$a"), local.get("$c")))), + wasm.return(wasm.call(MAKE_FLOAT_FX).args(f64.mul(local.get("$a"), local.get("$c")))), + wasm.return(wasm.call(MAKE_FLOAT_FX).args(f64.div(local.get("$a"), local.get("$c")))) + ) + ), + + // else, if either's complex, load from mem, set locals (default 0) + wasm + .if(i32.eq(local.get("$x_tag"), i32.const(TYPE_TAG.FLOAT))) + .then(local.set("$x_tag", i32.const(TYPE_TAG.COMPLEX))) + .else( + local.set("$a", f64.load(i32.wrap_i64(local.get("$x_val")))), + local.set("$b", f64.load(i32.add(i32.wrap_i64(local.get("$x_val")), i32.const(8)))) + ), + wasm + .if(i32.eq(local.get("$y_tag"), i32.const(TYPE_TAG.FLOAT))) + .then(local.set("$y_tag", i32.const(TYPE_TAG.COMPLEX))) + .else( + local.set("$c", f64.load(i32.wrap_i64(local.get("$y_val")))), + local.set("$d", f64.load(i32.add(i32.wrap_i64(local.get("$y_val")), i32.const(8)))) + ), + + // if both complex, perform complex operations + wasm + .if( + i32.and( + i32.eq(local.get("$x_tag"), i32.const(TYPE_TAG.COMPLEX)), + i32.eq(local.get("$y_tag"), i32.const(TYPE_TAG.COMPLEX)) + ) + ) + .then( + ...wasm.buildBrTableBlocks( + wasm.br_table(local.get("$op"), "$add", "$sub", "$mul", "$div"), + wasm.return( + wasm + .call(MAKE_COMPLEX_FX) + .args(f64.add(local.get("$a"), local.get("$c")), f64.add(local.get("$b"), local.get("$d"))) + ), + wasm.return( + wasm + .call(MAKE_COMPLEX_FX) + .args(f64.sub(local.get("$a"), local.get("$c")), f64.sub(local.get("$b"), local.get("$d"))) + ), + // (a+bi)*(c+di) = (ac-bd) + (ad+bc)i + wasm.return( + wasm + .call(MAKE_COMPLEX_FX) + .args( + f64.sub(f64.mul(local.get("$a"), local.get("$c")), f64.mul(local.get("$b"), local.get("$d"))), + f64.add(f64.mul(local.get("$b"), local.get("$c")), f64.mul(local.get("$a"), local.get("$d"))) + ) + ), + // (a+bi)/(c+di) = (ac+bd)/(c^2+d^2) + (bc-ad)/(c^2+d^2)i + wasm.return( + wasm + .call(MAKE_COMPLEX_FX) + .args( + local.tee( + "$denom", + f64.div( + f64.add(f64.mul(local.get("$a"), local.get("$c")), f64.mul(local.get("$b"), local.get("$d"))), + f64.add(f64.mul(local.get("$c"), local.get("$c")), f64.mul(local.get("$d"), local.get("$d"))) + ) + ), + f64.div( + f64.sub(f64.mul(local.get("$b"), local.get("$c")), f64.mul(local.get("$a"), local.get("$d"))), + local.get("$denom") + ) + ) + ) + ) + ), + + wasm.call("$_log_error").args(i32.const(ERROR_MAP.ARITH_OP_UNKNOWN_TYPE[0])), + wasm.unreachable() + ); + +export const COMPARISON_OP_TAG = { + EQ: 0, + NEQ: 1, + LT: 2, + LTE: 3, + GT: 4, + GTE: 5, +} as const; +// comparison function +export const STRING_COMPARE_FX = wasm + .func("$_py_string_cmp") + .params({ $x_ptr: i32, $x_len: i32, $y_ptr: i32, $y_len: i32 }) + .results(i32) + .locals({ $i: i32, $min_len: i32, $x_char: i32, $y_char: i32, $result: i32 }) + .body( + local.set( + "$min_len", + wasm.select(local.get("$x_len"), local.get("$y_len"), i32.lt_s(local.get("$x_len"), local.get("$y_len"))) + ), + + wasm.loop("$loop").body( + wasm.if(i32.lt_s(local.get("$i"), local.get("$min_len"))).then( + local.set("$x_char", i32.load8_u(i32.add(local.get("$x_ptr"), local.get("$i")))), + local.set("$y_char", i32.load8_u(i32.add(local.get("$y_ptr"), local.get("$i")))), + + wasm + .if(local.tee("$result", i32.sub(local.get("$x_char"), local.get("$y_char")))) + .then(wasm.return(local.get("$result"))), + + local.set("$i", i32.add(local.get("$i"), i32.const(1))), + + wasm.br("$loop") + ) + ), + + wasm.return(i32.sub(local.get("$y_len"), local.get("$x_len"))) + ); + +export const COMPARISON_OP_FX = wasm + .func("$_py_compare_op") + .params({ $x_tag: i32, $x_val: i64, $y_tag: i32, $y_val: i64, $op: i32 }) + .results(i32, i64) + .locals({ $a: f64, $b: f64, $c: f64, $d: f64 }) + .body( + // if both are strings + wasm + .if( + i32.and( + i32.eq(local.get("$x_tag"), i32.const(TYPE_TAG.STRING)), + i32.eq(local.get("$y_tag"), i32.const(TYPE_TAG.STRING)) + ) + ) + .then( + local.set( + "$x_tag", // reuse x_tag for comparison result + wasm + .call(STRING_COMPARE_FX) + .args( + i32.wrap_i64(i64.shr_u(local.get("$x_val"), i64.const(32))), + i32.wrap_i64(local.get("$x_val")), + i32.wrap_i64(i64.shr_u(local.get("$y_val"), i64.const(32))), + i32.wrap_i64(local.get("$y_val")) + ) + ), + + ...wasm.buildBrTableBlocks( + wasm.br_table(local.get("$op"), "$eq", "$neq", "$lt", "$lte", "$gt", "$gte"), + wasm.return(wasm.call(MAKE_BOOL_FX).args(i32.eqz(local.get("$x_tag")))), + wasm.return(wasm.call(MAKE_BOOL_FX).args(i32.ne(local.get("$x_tag"), i32.const(0)))), + wasm.return(wasm.call(MAKE_BOOL_FX).args(i32.lt_s(local.get("$x_tag"), i32.const(0)))), + wasm.return(wasm.call(MAKE_BOOL_FX).args(i32.le_s(local.get("$x_tag"), i32.const(0)))), + wasm.return(wasm.call(MAKE_BOOL_FX).args(i32.gt_s(local.get("$x_tag"), i32.const(0)))), + wasm.return(wasm.call(MAKE_BOOL_FX).args(i32.ge_s(local.get("$x_tag"), i32.const(0)))) + ) + ), + + // if either are bool, convert to int + wasm.if(i32.eq(local.get("$x_tag"), i32.const(TYPE_TAG.BOOL))).then(local.set("$x_tag", i32.const(TYPE_TAG.INT))), + wasm.if(i32.eq(local.get("$y_tag"), i32.const(TYPE_TAG.BOOL))).then(local.set("$y_tag", i32.const(TYPE_TAG.INT))), + + // if both int, use int comparison + wasm + .if( + i32.and( + i32.eq(local.get("$x_tag"), i32.const(TYPE_TAG.INT)), + i32.eq(local.get("$y_tag"), i32.const(TYPE_TAG.INT)) + ) + ) + .then( + ...wasm.buildBrTableBlocks( + wasm.br_table(local.get("$op"), "$eq", "$neq", "$lt", "$lte", "$gt", "$gte"), + wasm.return(wasm.call(MAKE_BOOL_FX).args(i64.eq(local.get("$x_val"), local.get("$y_val")))), + wasm.return(wasm.call(MAKE_BOOL_FX).args(i64.ne(local.get("$x_val"), local.get("$y_val")))), + wasm.return(wasm.call(MAKE_BOOL_FX).args(i64.lt_s(local.get("$x_val"), local.get("$y_val")))), + wasm.return(wasm.call(MAKE_BOOL_FX).args(i64.le_s(local.get("$x_val"), local.get("$y_val")))), + wasm.return(wasm.call(MAKE_BOOL_FX).args(i64.gt_s(local.get("$x_val"), local.get("$y_val")))), + wasm.return(wasm.call(MAKE_BOOL_FX).args(i64.ge_s(local.get("$x_val"), local.get("$y_val")))) + ) + ), + + // else, if either are int, convert to float and set float locals + wasm + .if(i32.eq(local.get("$x_tag"), i32.const(TYPE_TAG.INT))) + .then(local.set("$a", f64.convert_i64_s(local.get("$x_val"))), local.set("$x_tag", i32.const(TYPE_TAG.FLOAT))) + .else(local.set("$a", f64.reinterpret_i64(local.get("$x_val")))), + wasm + .if(i32.eq(local.get("$y_tag"), i32.const(TYPE_TAG.INT))) + .then(local.set("$c", f64.convert_i64_s(local.get("$y_val"))), local.set("$y_tag", i32.const(TYPE_TAG.FLOAT))) + .else(local.set("$c", f64.reinterpret_i64(local.get("$y_val")))), + + // if both float, use float comparison + wasm + .if( + i32.and( + i32.eq(local.get("$x_tag"), i32.const(TYPE_TAG.FLOAT)), + i32.eq(local.get("$y_tag"), i32.const(TYPE_TAG.FLOAT)) + ) + ) + .then( + ...wasm.buildBrTableBlocks( + wasm.br_table(local.get("$op"), "$eq", "$neq", "$lt", "$lte", "$gt", "$gte"), + wasm.return(wasm.call(MAKE_BOOL_FX).args(f64.eq(local.get("$a"), local.get("$c")))), + wasm.return(wasm.call(MAKE_BOOL_FX).args(f64.ne(local.get("$a"), local.get("$c")))), + wasm.return(wasm.call(MAKE_BOOL_FX).args(f64.lt(local.get("$a"), local.get("$c")))), + wasm.return(wasm.call(MAKE_BOOL_FX).args(f64.le(local.get("$a"), local.get("$c")))), + wasm.return(wasm.call(MAKE_BOOL_FX).args(f64.gt(local.get("$a"), local.get("$c")))), + wasm.return(wasm.call(MAKE_BOOL_FX).args(f64.ge(local.get("$a"), local.get("$c")))) + ) + ), + + // else, if either are complex, load complex from memory and set float locals (default 0) + wasm + .if(i32.eq(local.get("$x_tag"), i32.const(TYPE_TAG.FLOAT))) + .then(local.set("$x_tag", i32.const(TYPE_TAG.COMPLEX))) + .else( + local.set("$a", f64.load(i32.wrap_i64(local.get("$x_val")))), + local.set("$b", f64.load(i32.add(i32.wrap_i64(local.get("$x_val")), i32.const(8)))) + ), + wasm + .if(i32.eq(local.get("$y_tag"), i32.const(TYPE_TAG.FLOAT))) + .then(local.set("$y_tag", i32.const(TYPE_TAG.COMPLEX))) + .else( + local.set("$c", f64.load(i32.wrap_i64(local.get("$y_val")))), + local.set("$d", f64.load(i32.add(i32.wrap_i64(local.get("$y_val")), i32.const(8)))) + ), + + // if both complex, compare real and imaginary parts. only ==, != + wasm + .if( + i32.and( + i32.eq(local.get("$x_tag"), i32.const(TYPE_TAG.COMPLEX)), + i32.eq(local.get("$y_tag"), i32.const(TYPE_TAG.COMPLEX)) + ) + ) + .then( + wasm + .if(i32.eq(local.get("$op"), i32.const(COMPARISON_OP_TAG.EQ))) + .then( + wasm.return( + wasm + .call(MAKE_BOOL_FX) + .args(i32.and(f64.eq(local.get("$a"), local.get("$c")), f64.eq(local.get("$b"), local.get("$d")))) + ) + ) + .else( + wasm + .if(i32.eq(local.get("$op"), i32.const(COMPARISON_OP_TAG.NEQ))) + .then( + wasm.return( + wasm + .call(MAKE_BOOL_FX) + .args(i32.or(f64.ne(local.get("$a"), local.get("$c")), f64.ne(local.get("$b"), local.get("$d")))) + ) + ) + .else(wasm.call("$_log_error").args(i32.const(ERROR_MAP.COMPLEX_COMPARISON[0])), wasm.unreachable()) + ) + ), + + // else, default to not equal + wasm + .if(i32.eq(local.get("$op"), i32.const(COMPARISON_OP_TAG.EQ))) + .then(wasm.return(wasm.call(MAKE_BOOL_FX).args(i32.const(0)))) + .else( + wasm + .if(i32.eq(local.get("$op"), i32.const(COMPARISON_OP_TAG.NEQ))) + .then(wasm.return(wasm.call(MAKE_BOOL_FX).args(i32.const(1)))) + ), + + // other operators: unreachable + wasm.call("$_log_error").args(i32.const(ERROR_MAP.COMPARE_OP_UNKNOWN_TYPE[0])), + wasm.unreachable() + ); + +// bool related functions + +export const BOOLISE_FX = wasm + .func("$_boolise") + .params({ $tag: i32, $val: i64 }) + .results(i64) + .body( + // None => False + wasm + .if(i32.eq(local.get("$tag"), i32.const(TYPE_TAG.NONE))) + .then(wasm.return(wasm.call(MAKE_BOOL_FX).args(i32.const(0)))), + + // bool or int => return bool with value (False if 0) + wasm + .if( + i32.or(i32.eq(local.get("$tag"), i32.const(TYPE_TAG.INT)), i32.eq(local.get("$tag"), i32.const(TYPE_TAG.BOOL))) + ) + .then(wasm.return(wasm.call(MAKE_BOOL_FX).args(i32.wrap_i64(local.get("$val"))))), + + // float/complex => False if equivalent of 0 + wasm + .if(i32.eq(local.get("$tag"), i32.const(TYPE_TAG.FLOAT))) + .then(wasm.return(wasm.call(MAKE_BOOL_FX).args(f64.ne(f64.reinterpret_i64(local.get("$val")), f64.const(0))))), + wasm + .if(i32.eq(local.get("$tag"), i32.const(TYPE_TAG.COMPLEX))) + .then( + wasm.return( + wasm + .call(MAKE_BOOL_FX) + .args( + i32.or( + f64.ne(f64.load(i32.add(i32.wrap_i64(local.get("$val")), i32.const(8))), f64.const(0)), + f64.ne(f64.load(i32.wrap_i64(local.get("$val"))), f64.const(0)) + ) + ) + ) + ), + + // string => False if length is 0 + wasm + .if(i32.eq(local.get("$tag"), i32.const(TYPE_TAG.STRING))) + .then(wasm.return(wasm.call(MAKE_BOOL_FX).args(i32.wrap_i64(local.get("$val"))))), + + // closure/pair => True + wasm + .if( + i32.or( + i32.eq(local.get("$tag"), i32.const(TYPE_TAG.CLOSURE)), + i32.eq(local.get("$tag"), i32.const(TYPE_TAG.PAIR)) + ) + ) + .then(wasm.return(wasm.call(MAKE_BOOL_FX).args(i32.const(1)))), + + wasm.call("$_log_error").args(i32.const(ERROR_MAP.BOOL_UNKNOWN_TYPE[0])), + wasm.unreachable() + ); + +export const BOOL_NOT_FX = wasm + .func("$_bool_not") + .params({ $tag: i32, $val: i64 }) + .results(i32, i64) + .body( + i32.const(TYPE_TAG.BOOL), + i64.extend_i32_u(i64.eqz(wasm.call(BOOLISE_FX).args(local.get("$tag"), local.get("$val")))) + ); + +export const BOOL_BINARY_OP_TAG = { AND: 0, OR: 1 }; + +// *3*4 because each variable has a tag and payload = 3 words = 12 bytes; +4 because parentEnv is stored at start of env +// we initialise only local variables to UNBOUND, NOT parameters. +// this is because have already set parameters in the new environment before calling this function. +export const ALLOC_ENV_FX = wasm + .func("$_alloc_env") + .params({ $size: i32, $parent: i32, $arity: i32 }) + .results(i32) + .body( + global.get(HEAP_PTR), // return the start of the new env, set CURR_ENV AFTER + i32.store(global.get(HEAP_PTR), local.get("$parent")), + global.set(HEAP_PTR, i32.add(global.get(HEAP_PTR), i32.const(4))), + + memory.fill( + i32.add(global.get(HEAP_PTR), i32.mul(local.get("$arity"), i32.const(12))), + i32.const(TYPE_TAG.UNBOUND), + i32.mul(i32.sub(local.get("$size"), local.get("$arity")), i32.const(12)) + ), + global.set(HEAP_PTR, i32.add(global.get(HEAP_PTR), i32.mul(local.get("$size"), i32.const(12)))) + ); + +export const PRE_APPLY_FX = wasm + .func("$_pre_apply") + .params({ $tag: i32, $val: i64, $arity: i32 }) + .results(i32, i64, i32) + .body( + wasm + .if(i32.ne(local.get("$tag"), i32.const(TYPE_TAG.CLOSURE))) + .then(wasm.call("$_log_error").args(i32.const(ERROR_MAP.CALL_NOT_FX[0])), wasm.unreachable()), + + wasm + .if( + i32.ne(i32.and(i32.wrap_i64(i64.shr_u(local.get("$val"), i64.const(40))), i32.const(255)), local.get("$arity")) + ) + .then(wasm.call("$_log_error").args(i32.const(ERROR_MAP.FUNC_WRONG_ARITY[0])), wasm.unreachable()), + + local.get("$tag"), + local.get("$val"), + wasm + .call(ALLOC_ENV_FX) + .args( + i32.and(i32.wrap_i64(i64.shr_u(local.get("$val"), i64.const(32))), i32.const(255)), + i32.wrap_i64(local.get("$val")), + local.get("$arity") + ) + ); + +export const APPLY_FX_NAME = "$_apply"; +export const applyFuncFactory = (bodies: WasmInstruction[][]) => + wasm + .func(APPLY_FX_NAME) + .params({ $return_env: i32, $tag: i32, $val: i64 }) + .results(i32, i64) + .body( + ...wasm.buildBrTableBlocks( + wasm.br_table(i32.wrap_i64(i64.shr_u(local.get("$val"), i64.const(48))), ...Array(bodies.length).keys()), + ...bodies.map((body) => [ + ...body, + wasm.return(wasm.call(MAKE_NONE_FX), global.set(CURR_ENV, local.get("$return_env"))), + ]) + ) + ); + +export const GET_LEX_ADDR_FX = wasm + .func("$_get_lex_addr") + .params({ $depth: i32, $index: i32 }) + .results(i32, i64) + .locals({ $env: i32, $tag: i32 }) + .body( + local.set("$env", global.get(CURR_ENV)), + + wasm.loop("$loop").body( + wasm.if(i32.eqz(local.get("$depth"))).then( + local.set( + "$tag", + i32.load(i32.add(i32.add(local.get("$env"), i32.const(4)), i32.mul(local.get("$index"), i32.const(12)))) + ), + + wasm + .if(i32.eq(local.get("$tag"), i32.const(TYPE_TAG.UNBOUND))) + .then(wasm.call("$_log_error").args(i32.const(ERROR_MAP.UNBOUND[0])), wasm.unreachable()), + + wasm.return( + local.get("$tag"), + i64.load(i32.add(i32.add(local.get("$env"), i32.const(8)), i32.mul(local.get("$index"), i32.const(12)))) + ) + ), + + local.set("$env", i32.load(local.get("$env"))), + local.set("$depth", i32.sub(local.get("$depth"), i32.const(1))), + wasm.br("$loop") + ), + + wasm.unreachable() + ); + +export const SET_LEX_ADDR_FX = wasm + .func("$_set_lex_addr") + .params({ $depth: i32, $index: i32, $tag: i32, $value: i64 }) + .locals({ $env: i32 }) + .body( + local.set("$env", global.get(CURR_ENV)), + + wasm.loop("$loop").body( + wasm + .if(i32.eqz(local.get("$depth"))) + .then( + i32.store( + i32.add(i32.add(local.get("$env"), i32.const(4)), i32.mul(local.get("$index"), i32.const(12))), + local.get("$tag") + ), + i64.store( + i32.add(i32.add(local.get("$env"), i32.const(8)), i32.mul(local.get("$index"), i32.const(12))), + local.get("$value") + ), + wasm.return() + ), + + local.set("$env", i32.load(local.get("$env"))), + local.set("$depth", i32.sub(local.get("$depth"), i32.const(1))), + wasm.br("$loop") + ), + + wasm.unreachable() + ); + +export const SET_PARAM_FX = wasm + .func("$_set_param") + .params({ $addr: i32, $tag: i32, $value: i64 }) + .results(i32) + .body( + i32.store(i32.add(local.get("$addr"), i32.const(4)), local.get("$tag")), + i64.store(i32.add(local.get("$addr"), i32.const(8)), local.get("$value")), + + local.get("$addr") + ); + +export const nativeFunctions = [ + MAKE_INT_FX, + MAKE_FLOAT_FX, + MAKE_COMPLEX_FX, + MAKE_BOOL_FX, + MAKE_STRING_FX, + MAKE_CLOSURE_FX, + MAKE_NONE_FX, + MAKE_PAIR_FX, + GET_PAIR_HEAD_FX, + GET_PAIR_TAIL_FX, + SET_PAIR_HEAD_FX, + SET_PAIR_TAIL_FX, + LOG_FX, + NEG_FX, + ARITHMETIC_OP_FX, + STRING_COMPARE_FX, + COMPARISON_OP_FX, + BOOLISE_FX, + BOOL_NOT_FX, + ALLOC_ENV_FX, + PRE_APPLY_FX, + GET_LEX_ADDR_FX, + SET_LEX_ADDR_FX, + SET_PARAM_FX, +]; diff --git a/src/wasm-compiler/index.ts b/src/wasm-compiler/index.ts new file mode 100644 index 0000000..c6844d9 --- /dev/null +++ b/src/wasm-compiler/index.ts @@ -0,0 +1,60 @@ +import assert from "assert"; +import wabt from "wabt"; +import { Parser } from "../parser"; +import { Tokenizer } from "../tokenizer"; +import { BuilderGenerator } from "./builderGenerator"; +import { ERROR_MAP } from "./constants"; +import { WatGenerator } from "./wasm-util/watGenerator"; + +export async function compileToWasmAndRun(code: string) { + const script = code + "\n"; + const tokenizer = new Tokenizer(script); + const tokens = tokenizer.scanEverything(); + const pyParser = new Parser(script, tokens); + const ast = pyParser.parse(); + + const builderGenerator = new BuilderGenerator(); + const watIR = builderGenerator.visit(ast); + + const watGenerator = new WatGenerator(); + const wat = watGenerator.visit(watIR); + + const w = await wabt(); + const wasm = w.parseWat("a", wat).toBinary({}).buffer as BufferSource; + + const memory = new WebAssembly.Memory({ initial: 1 }); + + const result = await WebAssembly.instantiate(wasm, { + console: { + log: console.log, + log_complex: (real: number, imag: number) => + console.log(`${real} ${imag >= 0 ? "+" : "-"} ${Math.abs(imag)}j`), + log_bool: (value: bigint) => + console.log(value === BigInt(0) ? "False" : "True"), + log_string: (offset: number, length: number) => + console.log( + new TextDecoder("utf8").decode( + new Uint8Array(memory.buffer, offset, length) + ) + ), + log_closure: ( + tag: number, + arity: number, + envSize: number, + parentEnv: number + ) => + console.log( + `Closure (tag: ${tag}, arity: ${arity}, envSize: ${envSize}, parentEnv: ${parentEnv})` + ), + log_none: () => console.log("None"), + log_error: (tag: number) => + console.error(Object.values(ERROR_MAP).find(([i]) => i === tag)?.[1]), + log_pair: () => console.log(), + }, + js: { memory }, + }); + + // run the exported main function + assert(typeof result.instance.exports.main === "function"); + return result.instance.exports.main() as [number, number]; +} diff --git a/src/wasm-compiler/wasm-util/builder.ts b/src/wasm-compiler/wasm-util/builder.ts new file mode 100644 index 0000000..5453f5e --- /dev/null +++ b/src/wasm-compiler/wasm-util/builder.ts @@ -0,0 +1,742 @@ +// ------------------------ WASM Builder API ---------------------------- + +import { + f32ConversionOp, + f64ConversionOp, + floatBinaryOp, + floatComparisonOp, + floatConversionOp, + floatUnaryOp, + i32ConversionOp, + i64ConversionOp, + intBinaryOp, + intComparisonOp, + intConversionOp, + intTestOp, + WasmRaw, + type WasmBinaryOp, + type WasmBlock, + type WasmBlockType, + type WasmBr, + type WasmBrTable, + type WasmCall, + type WasmComparisonOp, + type WasmConst, + type WasmConversionOp, + type WasmData, + type WasmDrop, + type WasmExport, + type WasmFloatNumericType, + type WasmFunction, + type WasmFuncType, + type WasmGlobal, + type WasmGlobalFor, + type WasmGlobalGet, + type WasmGlobalSet, + type WasmIf, + type WasmImport, + type WasmInstruction, + type WasmIntNumericType, + type WasmIntTestOp, + type WasmLabel, + type WasmLoadOp, + type WasmLocalGet, + type WasmLocalSet, + type WasmLocalTee, + type WasmLoop, + type WasmMemoryCopy, + type WasmMemoryFill, + type WasmModule, + type WasmNop, + type WasmNumeric, + type WasmNumericFor, + type WasmNumericType, + type WasmReturn, + type WasmSelect, + type WasmStoreOp, + type WasmUnaryOp, + type WasmUnreachable, +} from "./types"; +import { typedFromEntries } from "./util"; + +type BuilderAsType = { + "~type": T; +}; + +const binaryOp = < + T extends WasmNumericType, + const Op extends (( + | WasmBinaryOp + | WasmComparisonOp + )["op"] extends `${T}.${infer S}` + ? S + : never)[] +>( + type: T, + ops: Op +) => + typedFromEntries( + ops.map((op) => { + const fn = (left: WasmNumericFor, right: WasmNumericFor) => ({ + op: `${type}.${op}`, + left, + right, + }); + return [op, fn]; + }) as { + [K in keyof Op]: [ + Op[K], + ( + ...args: Extract extends { + left: infer L; + right: infer R; + } + ? [left: L, right: R] + : never + ) => Extract + ]; + } + ); + +const unaryOp = < + T extends WasmNumericType, + const Op extends (( + | WasmConversionOp + | (T extends WasmIntNumericType + ? WasmIntTestOp | WasmLoadOp + : T extends WasmFloatNumericType + ? WasmUnaryOp + : never) + )["op"] extends `${T}.${infer S}` + ? S + : never)[] +>( + type: T, + ops: Op +) => + typedFromEntries( + ops.map((op) => { + const fn = (right: WasmNumericFor) => ({ + op: `${type}.${op}`, + right, + }); + return [op, fn]; + }) as { + [K in keyof Op]: [ + Op[K], + ( + ...args: Extract, { op: `${T}.${Op[K]}` }> extends { + right: infer R; + } + ? [right: R] + : never + ) => Extract, { op: `${T}.${Op[K]}` }> + ]; + } + ); + +type Builder = { + [K in WasmInstruction["op"] as K extends `${T}.${infer S}` ? S : never]: ( + ...args: never[] + ) => Extract; +} & (T extends WasmNumericType ? BuilderAsType : unknown); + +const loadHelper = + (op: Op) => + (address: WasmNumericFor<"i32">) => ({ op, address }); + +const i32 = { + const: (value: number | bigint): WasmConst<"i32"> => ({ + op: "i32.const", + value: BigInt(value), + }), + ...binaryOp("i32", [...intBinaryOp, ...intComparisonOp]), + ...unaryOp("i32", [...i32ConversionOp, ...intConversionOp, ...intTestOp]), + load: loadHelper("i32.load"), + load8_s: loadHelper("i32.load8_s"), + load8_u: loadHelper("i32.load8_u"), + load16_s: loadHelper("i32.load16_s"), + load16_u: loadHelper("i32.load16_u"), + store: ( + address: WasmNumericFor<"i32">, + value: WasmNumericFor<"i32"> + ): WasmStoreOp<"i32"> => ({ op: "i32.store", address, value }), + + "~type": "i32", +} satisfies Builder<"i32">; + +const i64 = { + const: (value: number | bigint): WasmConst<"i64"> => ({ + op: "i64.const", + value: BigInt(value), + }), + ...binaryOp("i64", [...intBinaryOp, ...intComparisonOp]), + ...unaryOp("i64", [...i64ConversionOp, ...intConversionOp, ...intTestOp]), + load: loadHelper("i64.load"), + load8_s: loadHelper("i64.load8_s"), + load8_u: loadHelper("i64.load8_u"), + load16_s: loadHelper("i64.load16_s"), + load16_u: loadHelper("i64.load16_u"), + load32_s: loadHelper("i64.load32_s"), + load32_u: loadHelper("i64.load32_u"), + store: ( + address: WasmNumericFor<"i32">, + value: WasmNumericFor<"i64"> + ): WasmStoreOp<"i64"> => ({ op: "i64.store", address, value }), + + "~type": "i64", +} satisfies Builder<"i64">; + +const f32 = { + const: (value: number): WasmConst<"f32"> => ({ + op: "f32.const", + value, + }), + ...binaryOp("f32", [...floatBinaryOp, ...floatComparisonOp]), + ...unaryOp("f32", [ + ...f32ConversionOp, + ...floatConversionOp, + ...floatUnaryOp, + ]), + load: (address: WasmNumericFor<"i32">): WasmLoadOp<"f32"> => ({ + op: "f32.load", + address, + }), + store: ( + address: WasmNumericFor<"i32">, + value: WasmNumericFor<"f32"> + ): WasmStoreOp<"f32"> => ({ op: "f32.store", address, value }), + + "~type": "f32", +} satisfies Builder<"f32">; + +const f64 = { + const: (value: number): WasmConst<"f64"> => ({ + op: "f64.const", + value, + }), + ...binaryOp("f64", [...floatBinaryOp, ...floatComparisonOp]), + ...unaryOp("f64", [ + ...f64ConversionOp, + ...floatConversionOp, + ...floatUnaryOp, + ]), + load: (address: WasmNumericFor<"i32">): WasmLoadOp<"f64"> => ({ + op: "f64.load", + address, + }), + store: ( + address: WasmNumericFor<"i32">, + value: WasmNumericFor<"f64"> + ): WasmStoreOp<"f64"> => ({ op: "f64.store", address, value }), + + "~type": "f64", +} satisfies Builder<"f64">; + +const local = { + get: (label: WasmLabel | number): WasmLocalGet => ({ + op: "local.get", + label, + }), + set: (label: WasmLabel | number, right: WasmNumeric): WasmLocalSet => ({ + op: "local.set", + label, + right, + }), + tee: (label: WasmLabel | number, right: WasmNumeric): WasmLocalTee => ({ + op: "local.tee", + label, + right, + }), +} satisfies Builder<"local">; + +const global = { + get: (label: WasmLabel): WasmGlobalGet => ({ op: "global.get", label }), + set: (label: WasmLabel, right: WasmNumeric): WasmGlobalSet => ({ + op: "global.set", + label, + right, + }), +} satisfies Builder<"global">; + +const memory = { + copy: ( + destination: WasmNumericFor<"i32">, + source: WasmNumericFor<"i32">, + size: WasmNumericFor<"i32"> + ): WasmMemoryCopy => ({ op: "memory.copy", destination, source, size }), + + fill: ( + address: WasmNumericFor<"i32">, + value: WasmNumericFor<"i32">, + numOfBytes: WasmNumericFor<"i32"> + ): WasmMemoryFill => ({ op: "memory.fill", address, value, numOfBytes }), +} satisfies Builder<"memory">; + +type WasmBlockTypeHelper = { + params(...params: BuilderAsType[]): WasmBlockTypeHelper; + results(...results: BuilderAsType[]): WasmBlockTypeHelper; + locals(...locals: BuilderAsType[]): WasmBlockTypeHelper; + + body(...instrs: WasmInstruction[]): T; +}; + +type WasmIfBlockTypeHelper = { + params(...params: BuilderAsType[]): WasmIfBlockTypeHelper; + results(...results: BuilderAsType[]): WasmIfBlockTypeHelper; + locals(...locals: BuilderAsType[]): WasmIfBlockTypeHelper; + + then(...thenInstrs: WasmInstruction[]): WasmIf & { + else(...elseInstrs: WasmInstruction[]): WasmIf; + }; +}; + +type WasmFuncTypeHelper = { + params(params: Record): WasmFuncTypeHelper; + locals(locals: Record): WasmFuncTypeHelper; + results(...results: BuilderAsType[]): WasmFuncTypeHelper; + + body(...instrs: WasmInstruction[]): WasmFunction; +}; + +type WasmModuleHelper = { + imports(...imports: WasmImport[]): WasmModuleHelper; + globals(...globals: WasmGlobal[]): WasmModuleHelper; + datas(...datas: WasmData[]): WasmModuleHelper; + funcs(...funcs: WasmFunction[]): WasmModuleHelper; + startFunc(startFunc: WasmLabel): Omit; + exports(...exports: WasmExport[]): WasmModuleHelper; + + build(): WasmModule; +}; + +const blockLoopHelper = + (type: T) => + ( + label?: WasmLabel + ): WasmBlockTypeHelper => { + const blockType: Required = { + paramTypes: [], + resultTypes: [], + localTypes: [], + }; + return { + params(...params) { + blockType.paramTypes.push(...params.map((p) => p["~type"])); + return this; + }, + locals(...locals) { + blockType.localTypes.push(...locals.map((l) => l["~type"])); + return this; + }, + results(...results) { + blockType.resultTypes.push(...results.map((r) => r["~type"])); + return this; + }, + + body(...instrs: WasmInstruction[]) { + return { op: type, label, blockType, body: instrs } as T extends "block" + ? WasmBlock + : WasmLoop; + }, + }; + }; + +type BuilderMutableType = { + "~type": `mut ${T}`; +}; +const mut: { [T in WasmNumericType]: BuilderMutableType } = { + i32: { "~type": "mut i32" }, + i64: { "~type": "mut i64" }, + f32: { "~type": "mut f32" }, + f64: { "~type": "mut f64" }, +}; + +const wasm = { + block: blockLoopHelper("block"), + loop: blockLoopHelper("loop"), + if: (predicate: WasmNumeric, label?: WasmLabel): WasmIfBlockTypeHelper => { + const blockType: Required = { + paramTypes: [], + resultTypes: [], + localTypes: [], + }; + return { + params(...params) { + blockType.paramTypes.push(...params.map((p) => p["~type"])); + return this; + }, + results(...results) { + blockType.resultTypes.push(...results.map((r) => r["~type"])); + return this; + }, + locals(...locals) { + blockType.localTypes.push(...locals.map((l) => l["~type"])); + return this; + }, + + then: (...thenInstrs) => ({ + op: "if", + predicate, + label, + blockType, + thenBody: thenInstrs, + + else(...elseInstrs) { + return { ...this, elseBody: elseInstrs }; + }, + }), + }; + }, + drop: (value?: WasmInstruction): WasmDrop => ({ op: "drop", value }), + unreachable: (): WasmUnreachable => ({ op: "unreachable" }), + nop: (): WasmNop => ({ op: "nop" }), + br: (label: WasmLabel): WasmBr => ({ op: "br", label }), + br_table: ( + value: WasmNumeric, + ...labels: (WasmLabel | number)[] + ): WasmBrTable => ({ op: "br_table", labels, value }), + + call: ( + functionName: WasmLabel | WasmFunction + ): WasmCall & { args(...args: WasmNumeric[]): WasmCall } => ({ + op: "call", + function: + typeof functionName === "string" ? functionName : functionName.name, + arguments: [], + args(...args: WasmNumeric[]): WasmCall { + return { ...this, arguments: args }; + }, + }), + + return: (...values: WasmInstruction[]): WasmReturn => ({ + op: "return", + values, + }), + select: ( + first: WasmNumeric, + second: WasmNumeric, + condition: WasmNumeric + ): WasmSelect => ({ op: "select", first, second, condition }), + + import: (moduleName: string, itemName: string) => ({ + memory: (initial: number, maximum?: number): WasmImport => ({ + op: "import", + moduleName, + itemName, + externType: { type: "memory", limits: { initial, maximum } }, + }), + + func(name: WasmLabel) { + const funcType: Required = { + paramTypes: [], + resultTypes: [], + localTypes: [], + }; + const importInstr: WasmImport = { + op: "import", + moduleName, + itemName, + externType: { type: "func", name, funcType }, + }; + + return { + ...importInstr, + params(...params: BuilderAsType[]) { + funcType.paramTypes.push(...params.map((p) => p["~type"])); + return this; + }, + locals(...locals: BuilderAsType[]) { + funcType.localTypes.push(...locals.map((l) => l["~type"])); + return this; + }, + results(...results: BuilderAsType[]) { + funcType.resultTypes.push(...results.map((r) => r["~type"])); + return this; + }, + }; + }, + }), + + global: ( + name: WasmLabel, + valueType: BuilderAsType | BuilderMutableType + ) => ({ + init: (initialValue: WasmNumericFor): WasmGlobalFor => ({ + op: "global", + name, + valueType: valueType["~type"], + initialValue, + }), + }), + + data: (offset: WasmNumericFor<"i32">, data: string): WasmData => ({ + op: "data", + offset, + data, + }), + + func(name: WasmLabel): WasmFuncTypeHelper { + const funcType: WasmFuncType = { + paramTypes: {}, + resultTypes: [], + localTypes: {}, + }; + return { + params(params) { + for (const [key, val] of Object.entries(params)) + funcType.paramTypes[key as WasmLabel] = val["~type"]; + + return this; + }, + locals(locals) { + for (const [key, val] of Object.entries(locals)) + funcType.localTypes[key as WasmLabel] = val["~type"]; + return this; + }, + results(...results) { + funcType.resultTypes.push(...results.map((r) => r["~type"])); + return this; + }, + + body(...instrs) { + return { op: "func", name, funcType, body: instrs }; + }, + }; + }, + + export: (name: string) => ({ + memory: (index: number): WasmExport => ({ + op: "export", + name, + externType: { type: "memory", index }, + }), + func: (identifier: WasmLabel): WasmExport => ({ + op: "export", + name, + externType: { type: "func", identifier }, + }), + }), + + module(): WasmModuleHelper { + const definitions: Omit = { + imports: [], + globals: [], + datas: [], + funcs: [], + startFunc: undefined, + exports: [], + }; + return { + imports(...imports) { + definitions.imports.push(...imports); + return this; + }, + globals(...globals) { + definitions.globals.push(...globals); + return this; + }, + datas(...datas) { + definitions.datas.push(...datas); + return this; + }, + funcs(...funcs) { + definitions.funcs.push(...funcs); + return this; + }, + startFunc(startFunc) { + definitions.startFunc = { op: "start", functionName: startFunc }; + return this; + }, + exports(...exports) { + definitions.exports.push(...exports); + return this; + }, + + build() { + return { op: "module", ...definitions }; + }, + }; + }, + + // not a WASM instruction, but a helper to build br_table with blocks + buildBrTableBlocks: ( + { labels, value }: WasmBrTable, + ...bodies: (WasmInstruction | WasmInstruction[])[] + ) => { + if (labels.length !== bodies.length) { + throw new Error( + `Number of labels in br_table (${labels.length}) does not match number of blocks (${bodies.length})` + ); + } + + const buildBlock = (index: number): [WasmBlock, ...WasmInstruction[]] => { + const body = bodies[index]; + if (!body) { + throw new Error( + `No body found for block at index ${index} in br_table` + ); + } + + return [ + wasm + .block(typeof labels[index] === "string" ? labels[index] : undefined) + .body( + ...(index === 0 + ? [wasm.br_table(value, ...labels)] + : buildBlock(index - 1)) + ), + ...(Array.isArray(body) ? body : [body]), + ]; + }; + + return buildBlock(bodies.length - 1); + }, + + raw: ( + codeFragments: TemplateStringsArray, + ...interpolations: (number | string | WasmInstruction | WasmInstruction[])[] + ): WasmRaw => ({ op: "raw", codeFragments, interpolations }), +}; + +// This maps all WASM instructions to a visitor method name that will +// be used in the interface for the watGenerator + +const instrToMethodMap = { + // numerics + "i64.const": "visitConstOp", + "f32.const": "visitConstOp", + "i32.const": "visitConstOp", + "f64.const": "visitConstOp", + + // numerics: binary ops and comparisons + ...typedFromEntries( + [...intBinaryOp, ...intComparisonOp].map( + (op) => [`i32.${op}`, "visitBinaryOp"] as const + ) + ), + ...typedFromEntries( + [...intBinaryOp, ...intComparisonOp].map( + (op) => [`i64.${op}`, "visitBinaryOp"] as const + ) + ), + ...typedFromEntries( + [...floatBinaryOp, ...floatComparisonOp].map( + (op) => [`f32.${op}`, "visitBinaryOp"] as const + ) + ), + ...typedFromEntries( + [...floatBinaryOp, ...floatComparisonOp].map( + (op) => [`f64.${op}`, "visitBinaryOp"] as const + ) + ), + + ...typedFromEntries( + [...i32ConversionOp, ...intConversionOp, ...intTestOp].map( + (op) => [`i32.${op}`, "visitUnaryOp"] as const + ) + ), + ...typedFromEntries( + [...i64ConversionOp, ...intConversionOp, ...intTestOp].map( + (op) => [`i64.${op}`, "visitUnaryOp"] as const + ) + ), + ...typedFromEntries( + [...f32ConversionOp, ...floatConversionOp, ...floatUnaryOp].map( + (op) => [`f32.${op}`, "visitUnaryOp"] as const + ) + ), + ...typedFromEntries( + [...f64ConversionOp, ...floatConversionOp, ...floatUnaryOp].map( + (op) => [`f64.${op}`, "visitUnaryOp"] as const + ) + ), + + // memory + "i32.load": "visitLoadOp", + "i64.load": "visitLoadOp", + "f32.load": "visitLoadOp", + "f64.load": "visitLoadOp", + "i32.load8_s": "visitLoadOp", + "i32.load8_u": "visitLoadOp", + "i32.load16_s": "visitLoadOp", + "i32.load16_u": "visitLoadOp", + "i64.load8_s": "visitLoadOp", + "i64.load8_u": "visitLoadOp", + "i64.load16_s": "visitLoadOp", + "i64.load16_u": "visitLoadOp", + "i64.load32_s": "visitLoadOp", + "i64.load32_u": "visitLoadOp", + + "i32.store": "visitStoreOp", + "i64.store": "visitStoreOp", + "f32.store": "visitStoreOp", + "f64.store": "visitStoreOp", + + "memory.copy": "visitMemoryCopyOp", + "memory.fill": "visitMemoryFillOp", + + // control + unreachable: "visitUnreachableOp", + drop: "visitDropOp", + nop: "visitNopOp", + block: "visitBlockOp", + loop: "visitLoopOp", + if: "visitIfOp", + br: "visitBrOp", + br_table: "visitBrTableOp", + call: "visitCallOp", + return: "visitReturnOp", + select: "visitSelectOp", + + // variables + "local.get": "visitVariableGetOp", + "global.get": "visitVariableGetOp", + "local.set": "visitVariableSetOp", + "global.set": "visitVariableSetOp", + "local.tee": "visitVariableSetOp", + + // module + import: "visitImportOp", + global: "visitGlobalOp", + data: "visitDataOp", + func: "visitFuncOp", + export: "visitExportOp", + start: "visitStartOp", + module: "visitModuleOp", + + raw: "visitRaw", +} as const satisfies Record; + +// ------------------------ WASM Visitor Interface ---------------------------- + +// This collects all the visitor method names (the values in the above object) +// and maps it to an actual method which takes as argument the specific +// WasmInstruction type corresponding to the instructino string, and returns +// the WAT string. +// Expection: For WasmNumeric unary and binary operations, since there are +// so many specific WasmInstruction types, we generalise them. + +type WatVisitor = { + [K in keyof typeof instrToMethodMap as (typeof instrToMethodMap)[K]]: ( + instruction: (typeof instrToMethodMap)[K] extends "visitUnaryOp" + ? { op: string; right: WasmInstruction } + : (typeof instrToMethodMap)[K] extends "visitBinaryOp" + ? { op: string; left: WasmInstruction; right: WasmInstruction } + : Extract + ) => string; +}; + +export { + f32, + f64, + global, + i32, + i64, + instrToMethodMap, + local, + memory, + mut, + wasm, + type WatVisitor, +}; diff --git a/src/wasm-compiler/wasm-util/types.ts b/src/wasm-compiler/wasm-util/types.ts new file mode 100644 index 0000000..0b89515 --- /dev/null +++ b/src/wasm-compiler/wasm-util/types.ts @@ -0,0 +1,420 @@ +// ------------------------ WASM Numeric Types & Constants ---------------------------- + +export type WasmLabel = `$${string}`; + +export type WasmIntNumericType = "i32" | "i64"; +export type WasmFloatNumericType = "f32" | "f64"; +export type WasmNumericType = WasmIntNumericType | WasmFloatNumericType; + +export const floatUnaryOp = [ + "neg", + "abs", + "sqrt", + "ceil", + "floor", + "trunc", + "nearest", +] as const; +export const intBinaryOp = [ + "add", + "sub", + "mul", + "div_s", + "div_u", + "and", + "or", + "xor", + "shl", + "shr_s", + "shr_u", +] as const; +export const floatBinaryOp = ["add", "sub", "mul", "div"] as const; +export const intTestOp = ["eqz"] as const; +export const intComparisonOp = [ + "eq", + "ne", + "lt_s", + "lt_u", + "gt_s", + "gt_u", + "le_s", + "le_u", + "ge_s", + "ge_u", +] as const; +export const floatComparisonOp = ["eq", "ne", "lt", "gt", "le", "ge"] as const; +export const intConversionOp = [ + "trunc_f32_s", + "trunc_f32_u", + "trunc_f64_s", + "trunc_f64_u", +] as const; +export const i32ConversionOp = ["wrap_i64", "reinterpret_f32"] as const; +export const i64ConversionOp = [ + "extend_i32_s", + "extend_i32_u", + "reinterpret_f64", +] as const; +export const floatConversionOp = [ + "convert_i32_s", + "convert_i32_u", + "convert_i64_s", + "convert_i64_u", +] as const; +export const f32ConversionOp = ["demote_f64", "reinterpret_i32"] as const; +export const f64ConversionOp = ["promote_f32", "reinterpret_i64"] as const; +export const intLoadNarrowOp = ["8_s", "8_u", "16_s", "16_u"] as const; +export const i64LoadNarrowOp = ["32_s", "32_u"] as const; + +export type FloatUnaryOp = (typeof floatUnaryOp)[number]; +export type IntBinaryOp = (typeof intBinaryOp)[number]; +export type FloatBinaryOp = (typeof floatBinaryOp)[number]; +export type IntTestOp = (typeof intTestOp)[number]; +export type IntComparisonOp = (typeof intComparisonOp)[number]; +export type FloatComparisonOp = (typeof floatComparisonOp)[number]; +export type IntConversionOp = (typeof intConversionOp)[number]; +export type I32ConversionOp = (typeof i32ConversionOp)[number]; +export type I64ConversionOp = (typeof i64ConversionOp)[number]; +export type FloatConversionOp = (typeof floatConversionOp)[number]; +export type F32ConversionOp = (typeof f32ConversionOp)[number]; +export type F64ConversionOp = (typeof f64ConversionOp)[number]; +export type IntLoadNarrowOp = (typeof intLoadNarrowOp)[number]; +export type I64LoadNarrowOp = (typeof i64LoadNarrowOp)[number]; + +// ------------------------ WASM Numeric Instructions ---------------------------- + +export type WasmConst = { + op: `${T}.const`; + value: T extends WasmIntNumericType ? bigint : number; +}; + +export type WasmUnaryOp = { + [Op in FloatUnaryOp]: { op: `${T}.${Op}`; right: WasmNumericFor }; +}[FloatUnaryOp]; + +export type WasmBinaryOp = + T extends WasmIntNumericType + ? { + [Op in IntBinaryOp]: { + op: `${T}.${Op}`; + left: WasmNumericFor; + right: WasmNumericFor; + }; + }[IntBinaryOp] + : T extends WasmFloatNumericType + ? { + [Op in FloatBinaryOp]: { + op: `${T}.${Op}`; + left: WasmNumericFor; + right: WasmNumericFor; + }; + }[FloatBinaryOp] + : never; + +export type WasmIntTestOp = { + [Op in IntTestOp]: { op: `${T}.${Op}`; right: WasmNumericFor }; +}[IntTestOp]; + +export type WasmComparisonOp = + T extends WasmIntNumericType + ? { + [Op in IntComparisonOp]: { + op: `${T}.${Op}`; + left: WasmNumericFor; + right: WasmNumericFor; + }; + }[IntComparisonOp] + : T extends WasmFloatNumericType + ? { + [Op in FloatComparisonOp]: { + op: `${T}.${Op}`; + left: WasmNumericFor; + right: WasmNumericFor; + }; + }[FloatComparisonOp] + : never; + +type ExtractConversion = I extends `${string}_${infer T}` + ? T extends WasmNumericType + ? T + : T extends `${infer U}_${string}` + ? U + : never + : never; + +type WasmConversionOpHelper = I extends + | `i32.${I32ConversionOp | IntConversionOp}` + | `i64.${I64ConversionOp | IntConversionOp}` + | `f32.${F32ConversionOp | FloatConversionOp}` + | `f64.${F64ConversionOp | FloatConversionOp}` + ? { op: I; right: WasmNumericFor> } + : never; + +export type WasmConversionOp = + WasmConversionOpHelper<`${T}.${T extends "i32" + ? I32ConversionOp | IntConversionOp + : T extends "i64" + ? I64ConversionOp | IntConversionOp + : T extends "f32" + ? F32ConversionOp | FloatConversionOp + : F64ConversionOp | FloatConversionOp}`>; + +export type WasmLoadOp = { + op: `${T}.load`; + address: WasmNumericFor<"i32">; +}; +export type WasmLoadNarrowOp = T extends "i32" + ? { + [Op in IntLoadNarrowOp]: { + op: `${T}.load${Op}`; + address: WasmNumericFor<"i32">; + }; + }[IntLoadNarrowOp] + : { + [Op in IntLoadNarrowOp | I64LoadNarrowOp]: { + op: `${T}.load${Op}`; + address: WasmNumericFor<"i32">; + }; + }[IntLoadNarrowOp | I64LoadNarrowOp]; +export type WasmLoad = + | WasmLoadOp<"i32"> + | WasmLoadOp<"i64"> + | WasmLoadOp<"f32"> + | WasmLoadOp<"f64"> + | WasmLoadNarrowOp<"i32"> + | WasmLoadNarrowOp<"i64">; + +export type WasmStoreOp = { + op: `${T}.store`; + address: WasmNumericFor<"i32">; + value: WasmNumericFor; +}; +export type WasmStore = + | WasmStoreOp<"i32"> + | WasmStoreOp<"i64"> + | WasmStoreOp<"f32"> + | WasmStoreOp<"f64">; + +export type WasmNumericFor = + | WasmConst + | (T extends WasmFloatNumericType ? WasmUnaryOp : never) + | WasmBinaryOp + | (T extends "i32" + ? WasmIntTestOp | WasmComparisonOp + : never) + | WasmConversionOp + | WasmRaw + + // below are not numeric instructions, but the results of these are numeric + | WasmLoad + | WasmLocalGet + | WasmGlobalGet + | WasmLocalTee + | WasmCall // call generates numeric[], but for type simplicity we assume just 1 + | WasmSelect; + +export type WasmNumeric = + | WasmNumericFor<"i32"> + | WasmNumericFor<"i64"> + | WasmNumericFor<"f32"> + | WasmNumericFor<"f64">; + +// ------------------------ WASM Variable Instructions ---------------------------- + +export type WasmLocalSet = { + op: "local.set"; + label: WasmLabel | number; + right: WasmNumeric; +}; +export type WasmLocalGet = { op: "local.get"; label: WasmLabel | number }; +export type WasmLocalTee = { + op: "local.tee"; + label: WasmLabel | number; + right: WasmNumeric; +}; +export type WasmGlobalSet = { + op: "global.set"; + label: WasmLabel; + right: WasmNumeric; +}; +export type WasmGlobalGet = { op: "global.get"; label: WasmLabel }; + +type WasmVariable = + | WasmLocalSet + | WasmLocalGet + | WasmLocalTee + | WasmGlobalSet + | WasmGlobalGet + | WasmRaw; + +// ------------------------ WASM Memory Instructions ---------------------------- +// Technically WasmStoreOp and WasmLoadOp are memory instructions, but they are defined +// together with numerics for typing. + +export type WasmMemoryCopy = { + op: "memory.copy"; + destination: WasmNumericFor<"i32">; + source: WasmNumericFor<"i32">; + size: WasmNumericFor<"i32">; +}; + +export type WasmMemoryFill = { + op: "memory.fill"; + address: WasmNumericFor<"i32">; + value: WasmNumericFor<"i32">; + numOfBytes: WasmNumericFor<"i32">; +}; + +type WasmMemory = + | WasmMemoryCopy + | WasmMemoryFill + | WasmLoad + | WasmStore + | WasmRaw; + +// ------------------------ WASM Control Instructions ---------------------------- + +export type WasmUnreachable = { op: "unreachable" }; +export type WasmDrop = { op: "drop"; value: WasmInstruction | undefined }; +export type WasmNop = { op: "nop" }; + +export type WasmBlockType = { + paramTypes: WasmNumericType[]; + resultTypes: WasmNumericType[]; + localTypes?: WasmNumericType[]; +}; + +export type WasmBlockBase = { + label: WasmLabel | undefined; + blockType: WasmBlockType; +}; +export type WasmBlock = WasmBlockBase & { + op: "block"; + body: WasmInstruction[]; +}; +export type WasmLoop = WasmBlockBase & { op: "loop"; body: WasmInstruction[] }; +export type WasmIf = WasmBlockBase & { + op: "if"; + predicate: WasmNumeric; + thenBody: WasmInstruction[]; + elseBody?: WasmInstruction[]; +}; +export type WasmBr = { op: "br"; label: WasmLabel }; +export type WasmBrTable = { + op: "br_table"; + labels: (WasmLabel | number)[]; + value: WasmNumeric; +}; +export type WasmCall = { + op: "call"; + function: WasmLabel; + arguments: WasmNumeric[]; +}; +export type WasmReturn = { op: "return"; values: WasmInstruction[] }; +export type WasmSelect = { + op: "select"; + first: WasmNumeric; + second: WasmNumeric; + condition: WasmNumeric; +}; + +type WasmControl = + | WasmUnreachable + | WasmDrop + | WasmNop + | WasmBlock + | WasmLoop + | WasmIf + | WasmBr + | WasmBrTable + | WasmCall + | WasmReturn + | WasmSelect + | WasmRaw; + +// ------------------------ WASM Module Instructions ---------------------------- + +export type WasmLocals = Record; +export type WasmFuncType = { + paramTypes: WasmLocals; + resultTypes: WasmNumericType[]; + localTypes: WasmLocals; +}; + +export type WasmExternType = + | { type: "memory"; limits: { initial: number; maximum: number | undefined } } + | { type: "func"; name: WasmLabel; funcType: WasmBlockType }; +export type WasmImport = { + op: "import"; + moduleName: string; + itemName: string; + externType: WasmExternType; +}; + +export type WasmGlobalFor = { + op: "global"; + name: WasmLabel; + valueType: T | `mut ${T}`; + initialValue: WasmNumericFor; +}; +export type WasmGlobal = + | WasmGlobalFor<"i32"> + | WasmGlobalFor<"i64"> + | WasmGlobalFor<"f32"> + | WasmGlobalFor<"f64">; + +export type WasmData = { + op: "data"; + offset: WasmNumericFor<"i32">; + data: string; +}; +export type WasmFunction = { + op: "func"; + name: WasmLabel; + funcType: WasmFuncType; + body: WasmInstruction[]; +}; +export type WasmStart = { op: "start"; functionName: WasmLabel }; + +export type WasmExternIdx = + | { type: "func"; identifier: WasmLabel } + | { type: "memory"; index: number }; +export type WasmExport = { + op: "export"; + name: string; + externType: WasmExternIdx; +}; + +export type WasmModule = { + op: "module"; + imports: WasmImport[]; + globals: WasmGlobal[]; + datas: WasmData[]; + funcs: WasmFunction[]; + startFunc: WasmStart | undefined; + exports: WasmExport[]; +}; + +export type WasmModuleInstruction = + | WasmImport + | WasmGlobal + | WasmData + | WasmFunction + | WasmExport + | WasmStart + | WasmModule + | WasmRaw; + +// meant to be used with wasm.raw (tagged template) +export type WasmRaw = { + op: "raw"; + codeFragments: TemplateStringsArray; + interpolations: (number | string | WasmInstruction | WasmInstruction[])[]; +}; + +export type WasmInstruction = + | WasmNumeric + | WasmMemory + | WasmControl + | WasmVariable + | WasmModuleInstruction; diff --git a/src/wasm-compiler/wasm-util/util.ts b/src/wasm-compiler/wasm-util/util.ts new file mode 100644 index 0000000..91cda3f --- /dev/null +++ b/src/wasm-compiler/wasm-util/util.ts @@ -0,0 +1,5 @@ +export const typedFromEntries = < + const T extends readonly [PropertyKey, unknown][] +>( + entries: T +) => Object.fromEntries(entries) as { [K in T[number] as K[0]]: K[1] }; diff --git a/src/wasm-compiler/wasm-util/watGenerator.ts b/src/wasm-compiler/wasm-util/watGenerator.ts new file mode 100644 index 0000000..7035084 --- /dev/null +++ b/src/wasm-compiler/wasm-util/watGenerator.ts @@ -0,0 +1,276 @@ +import { instrToMethodMap, type WatVisitor } from "./builder"; +import type { + WasmBlock, + WasmBlockType, + WasmBr, + WasmBrTable, + WasmCall, + WasmConst, + WasmData, + WasmDrop, + WasmExport, + WasmFunction, + WasmGlobal, + WasmGlobalGet, + WasmGlobalSet, + WasmIf, + WasmImport, + WasmInstruction, + WasmLoad, + WasmLocalGet, + WasmLocalSet, + WasmLocalTee, + WasmLoop, + WasmMemoryCopy, + WasmMemoryFill, + WasmModule, + WasmNop, + WasmNumericType, + WasmRaw, + WasmReturn, + WasmSelect, + WasmStart, + WasmStore, + WasmUnreachable, +} from "./types"; + +export class WatGenerator implements WatVisitor { + // Dispatch method + visit(instruction: WasmInstruction): string { + return this[instrToMethodMap[instruction.op]](instruction as any); + } + // Numeric visitor methods + visitConstOp(instruction: WasmConst): string { + return `(${instruction.op} ${instruction.value})`; + } + visitUnaryOp(instruction: { op: string; right: WasmInstruction }): string { + const right = this.visit(instruction.right); + return `(${instruction.op} ${right})`; + } + visitBinaryOp(instruction: { + op: string; + left: WasmInstruction; + right: WasmInstruction; + }): string { + const left = this.visit(instruction.left); + const right = this.visit(instruction.right); + return `(${instruction.op} ${left} ${right})`; + } + + // Memory visitor methods + visitStoreOp(instruction: WasmStore): string { + const address = this.visit(instruction.address); + const value = this.visit(instruction.value); + return `(${instruction.op} ${address} ${value})`; + } + visitLoadOp(instruction: WasmLoad): string { + const address = this.visit(instruction.address); + return `(${instruction.op} ${address})`; + } + visitMemoryCopyOp(instruction: WasmMemoryCopy): string { + const dest = this.visit(instruction.destination); + const src = this.visit(instruction.source); + const size = this.visit(instruction.size); + return `(${instruction.op} ${dest} ${src} ${size})`; + } + visitMemoryFillOp(instruction: WasmMemoryFill): string { + const addr = this.visit(instruction.address); + const value = this.visit(instruction.value); + const size = this.visit(instruction.numOfBytes); + return `(${instruction.op} ${addr} ${value} ${size})`; + } + + // Control visitor methods + private visitBlockType(type: WasmBlockType): string { + const params = type.paramTypes.map((param) => `(param ${param})`).join(" "); + const results = type.resultTypes + .map((result) => `(result ${result})`) + .join(" "); + const locals = + type.localTypes?.map((local) => `(local ${local})`).join(" ") ?? ""; + return `${params} ${results} ${locals}`; + } + + visitBlockOp(instruction: WasmBlock): string { + const label = instruction.label ?? ""; + const typeStr = this.visitBlockType(instruction.blockType); + const body = instruction.body.map((instr) => this.visit(instr)).join(" "); + return `(${instruction.op} ${label} ${typeStr} ${body})`; + } + visitLoopOp(instruction: WasmLoop): string { + const label = instruction.label ?? ""; + const typeStr = this.visitBlockType(instruction.blockType); + const body = instruction.body.map((instr) => this.visit(instr)).join(" "); + return `(${instruction.op} ${label} ${typeStr} ${body})`; + } + visitIfOp(instruction: WasmIf): string { + const label = instruction.label ?? ""; + const typeStr = this.visitBlockType(instruction.blockType); + const condition = this.visit(instruction.predicate); + const thenBody = instruction.thenBody + .map((instr) => this.visit(instr)) + .join(" "); + const elseBody = instruction.elseBody + ?.map((instr) => this.visit(instr)) + .join(" "); + + if (elseBody) { + return `(if ${label} ${typeStr} ${condition} (then ${thenBody}) (else ${elseBody}))`; + } else { + return `(if ${label} ${typeStr} ${condition} (then ${thenBody}))`; + } + } + visitUnreachableOp(instruction: WasmUnreachable): string { + return `(${instruction.op})`; + } + visitDropOp(instruction: WasmDrop): string { + const value = instruction.value ? this.visit(instruction.value) : ""; + return `(${instruction.op} ${value})`; + } + visitNopOp(instruction: WasmNop): string { + return `(${instruction.op})`; + } + visitBrOp(instruction: WasmBr): string { + return `(${instruction.op} ${instruction.label})`; + } + visitBrTableOp(instruction: WasmBrTable): string { + const value = this.visit(instruction.value); + const labels = instruction.labels.join(" "); + return `(${instruction.op} ${labels} ${value})`; + } + visitCallOp(instruction: WasmCall): string { + const args = instruction.arguments.map((arg) => this.visit(arg)).join(" "); + return `(${instruction.op} ${instruction.function} ${args})`; + } + visitReturnOp(instruction: WasmReturn): string { + const values = instruction.values + .map((value) => this.visit(value)) + .join(" "); + return `(${instruction.op} ${values})`; + } + visitSelectOp(instruction: WasmSelect): string { + const first = this.visit(instruction.first); + const second = this.visit(instruction.second); + const condition = this.visit(instruction.condition); + return `(${instruction.op} ${first} ${second} ${condition})`; + } + + // Variable visitor methods + visitVariableGetOp(instruction: WasmLocalGet | WasmGlobalGet): string { + return `(${instruction.op} ${instruction.label})`; + } + visitVariableSetOp( + instruction: WasmLocalSet | WasmLocalTee | WasmGlobalSet + ): string { + const right = this.visit(instruction.right); + return `(${instruction.op} ${instruction.label} ${right})`; + } + + // Module visitor methods + visitImportOp(instruction: WasmImport): string { + let externTypeStr: string; + + if (instruction.externType.type === "func") { + const params = instruction.externType.funcType.paramTypes + .map((type) => `(param ${type})`) + .join(" "); + + const results = instruction.externType.funcType.resultTypes + .map((type) => `(result ${type})`) + .join(" "); + + externTypeStr = `(func ${instruction.externType.name} ${params} ${results})`; + } else if (instruction.externType.type === "memory") { + const min = instruction.externType.limits.initial; + const max = instruction.externType.limits.maximum ?? ""; + externTypeStr = `(memory ${min} ${max})`; + } else { + const _exhaustiveCheck: never = instruction.externType; + throw new Error(`Unsupported import type: ${_exhaustiveCheck}`); + } + + return `(${instruction.op} "${instruction.moduleName}" "${instruction.itemName}" ${externTypeStr})`; + } + visitGlobalOp(instruction: WasmGlobal): string { + const init = this.visit(instruction.initialValue); + return `(${instruction.op} ${instruction.name} (${instruction.valueType}) ${init})`; + } + visitDataOp(instruction: WasmData): string { + const offset = this.visit(instruction.offset); + return `(${instruction.op} ${offset} "${instruction.data}")`; + } + visitFuncOp(instruction: WasmFunction): string { + const params = Object.entries(instruction.funcType.paramTypes) + .map(([name, type]) => `(param ${name} ${type})`) + .join(" "); + + const results = instruction.funcType.resultTypes.length + ? `(result ${instruction.funcType.resultTypes.join(" ")})` + : ""; + + const locals = Object.entries(instruction.funcType.localTypes) + .map(([name, type]) => `(local ${name} ${type})`) + .join(" "); + + const body = instruction.body.map((instr) => this.visit(instr)).join(" "); + + return `(${instruction.op} ${instruction.name} ${params} ${results} ${locals} ${body})`; + } + visitExportOp(instruction: WasmExport): string { + let externTypeStr: string; + + if (instruction.externType.type === "func") { + externTypeStr = `(func ${instruction.externType.identifier})`; + } else if (instruction.externType.type === "memory") { + externTypeStr = `(memory ${instruction.externType.index})`; + } else { + const _exhaustiveCheck: never = instruction.externType; + throw new Error(`Unsupported export type: ${_exhaustiveCheck}`); + } + return `(${instruction.op} "${instruction.name}" ${externTypeStr})`; + } + visitStartOp(instruction: WasmStart): string { + return `(${instruction.op} ${instruction.functionName})`; + } + + visitModuleOp(instruction: WasmModule): string { + const imports = instruction.imports + .map((i) => ` ${this.visit(i)}\n`) + .join(""); + + const globals = instruction.globals + .map((g) => ` ${this.visit(g)}\n`) + .join(""); + + const datas = instruction.datas.map((d) => ` ${this.visit(d)}\n`).join(""); + + const funcs = instruction.funcs.map((f) => ` ${this.visit(f)}\n`).join(""); + + const startFunc = instruction.startFunc + ? this.visit(instruction.startFunc) + : ""; + + const exports = instruction.exports + .map((e) => ` ${this.visit(e)}\n`) + .join(""); + + return `(${instruction.op}\n${imports}\n${globals}\n${datas}\n${funcs}\n${startFunc}\n${exports})`; + } + + visitRaw(instruction: WasmRaw): string { + let code = ""; + for (let i = 0; i < instruction.interpolations.length; i++) { + code += instruction.codeFragments[i]; + const interp = instruction.interpolations[i]; + if (typeof interp === "string" || typeof interp === "number") { + code += interp.toString(); + } else if (Array.isArray(interp)) { + code += interp.map((instr) => this.visit(instr)).join(" "); + } else { + code += this.visit(interp); + } + } + code += instruction.codeFragments[instruction.interpolations.length]; + return code; + } +} diff --git a/test.py b/test.py deleted file mode 100644 index dc233a3..0000000 --- a/test.py +++ /dev/null @@ -1 +0,0 @@ -add_one = lambda : None