diff --git a/Project.toml b/Project.toml index 9b0a850..bca0359 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,7 @@ GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55" NodeCall = "84d67a5e-1aa2-4817-a501-81e0ebf70bff" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Unrolled = "9602ed7d-8fef-5bc8-8597-8f21381861e8" +WasmCompiler = "d6c8c267-1a52-47cb-afc2-bc072b58c422" libjlnode_jll = "e3ef64f0-f261-50ad-b884-29c8a26457f8" libnode_jll = "76d26698-d9ba-5ca1-ae24-4ac9393d02a0" @@ -22,7 +23,7 @@ CodeInfoTools = "0.3" CompTime = "0.1" Dictionaries = "0.3" GPUCompiler = "0.24" -NodeCall = "1.1" +NodeCall = "1.1" Reexport = "1.2" Unrolled = "0.1" julia = "1.7" diff --git a/src/WebAssemblyCompiler.jl b/src/WebAssemblyCompiler.jl index 5d841f2..6a8946e 100644 --- a/src/WebAssemblyCompiler.jl +++ b/src/WebAssemblyCompiler.jl @@ -1,6 +1,7 @@ module WebAssemblyCompiler using Binaryen_jll +using WasmCompiler include("../lib/LibBinaryen.jl") using .LibBinaryen diff --git a/src/_compile.jl b/src/_compile.jl index 1eebff7..a8edbcd 100644 --- a/src/_compile.jl +++ b/src/_compile.jl @@ -10,30 +10,26 @@ function _compile(ctx::CompilerContext, x::Core.Argument; kw...) end # If at the top level or if it's not a callable struct, # we don't include the fun as the first argument. - BinaryenLocalGet(ctx.mod, argmap(ctx, x.n) - 1, - gettype(ctx, type)) + WasmCompiler.InstOperands(WasmCompiler.local_get(argmap(ctx, x.n)), []) end function _compile(ctx::CompilerContext, x::Core.SSAValue; kw...) # These come after the function arguments. bt = basetype(ctx, x) if Base.issingletontype(bt) getglobal(ctx, _compile(ctx, nothing)) else - BinaryenLocalGet(ctx.mod, ctx.varmap[x.id], - gettype(ctx, ssatype(ctx, x.id))) + InstOperands(WC.local_get(ctx.varmap[x.id]), []) end - # localid = ctx.varmap[x.id] - # BinaryenLocalGet(ctx.mod, localid, ctx.locals[localid]) end -_compile(ctx::CompilerContext, x::Float64; kw...) = BinaryenConst(ctx.mod, BinaryenLiteralFloat64(x)) -_compile(ctx::CompilerContext, x::Float32; kw...) = BinaryenConst(ctx.mod, BinaryenLiteralFloat32(x)) -_compile(ctx::CompilerContext, x::Int64; kw...) = BinaryenConst(ctx.mod, BinaryenLiteralInt64(x)) -_compile(ctx::CompilerContext, x::Int32; kw...) = BinaryenConst(ctx.mod, BinaryenLiteralInt32(x)) -_compile(ctx::CompilerContext, x::UInt8; kw...) = BinaryenConst(ctx.mod, BinaryenLiteralInt32(x)) -_compile(ctx::CompilerContext, x::Int8; kw...) = BinaryenConst(ctx.mod, BinaryenLiteralInt32(x)) -_compile(ctx::CompilerContext, x::UInt64; kw...) = BinaryenConst(ctx.mod, BinaryenLiteralInt64(reinterpret(Int64, x))) -_compile(ctx::CompilerContext, x::UInt32; kw...) = BinaryenConst(ctx.mod, BinaryenLiteralInt32(reinterpret(Int32, x))) -_compile(ctx::CompilerContext, x::Bool; kw...) = BinaryenConst(ctx.mod, BinaryenLiteralInt32(x)) -_compile(ctx::CompilerContext, x::Ptr{BinaryenExpression}; kw...) = x +_compile(ctx::CompilerContext, x::Float64; kw...) = InstOperands(WC.f64_const(x), []) +_compile(ctx::CompilerContext, x::Float32; kw...) = InstOperands(WC.f32_const(x), []) +_compile(ctx::CompilerContext, x::Int64; kw...) = InstOperands(WC.i64_const(x), []) +_compile(ctx::CompilerContext, x::Int32; kw...) = InstOperands(WC.i32_const(x), []) +_compile(ctx::CompilerContext, x::UInt8; kw...) = InstOperands(WC.i32_const(x), []) +_compile(ctx::CompilerContext, x::Int8; kw...) = InstOperands(WC.i32_const(x), []) +_compile(ctx::CompilerContext, x::UInt64; kw...) = InstOperands(WC.i64_const(reinterpret(Int64, x)), []) +_compile(ctx::CompilerContext, x::UInt32; kw...) = InstOperands(WC.i32_const(reinterpret(Int32, x)), []) +_compile(ctx::CompilerContext, x::Bool; kw...) = InstOperands(WC.i32_const(x), []) +# _compile(ctx::CompilerContext, x::Ptr{BinaryenExpression}; kw...) = x _compile(ctx::CompilerContext, x::GlobalRef; kw...) = getglobal(ctx, x.mod, x.name) _compile(ctx::CompilerContext, x::QuoteNode; kw...) = _compile(ctx, x.value) # _compile(ctx::CompilerContext, x::String; globals = false, kw...) = globals ? diff --git a/src/compile_block.jl b/src/compile_block.jl index f6d3aa2..d7a5026 100644 --- a/src/compile_block.jl +++ b/src/compile_block.jl @@ -2,9 +2,9 @@ const CCCallInfo = Core.Compiler.CallInfo else const CCCallInfo = Any -end - -function update!(ctx::CompilerContext, x, localtype = nothing) +end + +function update!(ctx::CompilerContext, x, localtype=nothing) # TODO: check the type of x, compare that with the wasm type of localtype, and if they differ, # convert one to the other. Hopefully, this is mainly Int32 -> Int64. push!(ctx.body, x) @@ -13,26 +13,26 @@ function update!(ctx::CompilerContext, x, localtype = nothing) ctx.localidx += 1 end # BinaryenExpressionPrint(x) - s = _debug_binaryen_get(ctx, x) - debug(:offline) && _debug_binaryen(ctx, x) + # s = _debug_binaryen_get(ctx, x) + # debug(:offline) && _debug_binaryen(ctx, x) return nothing end -function setlocal!(ctx, idx, x; set = true, drop = false) +function setlocal!(ctx, idx, x; set=true, drop=false) T = ssatype(ctx, idx) if T != Union{} && set # && T != Nothing && set # && T != Any ctx.varmap[idx] = ctx.localidx - x = BinaryenLocalSet(ctx.mod, ctx.localidx, x) + x = WasmCompiler.InstOperands(WasmCompiler.local_set(ctx.localidx), [x]) update!(ctx, x, T) else if drop - x = BinaryenDrop(ctx.mod, x) + x = WasmCompiler.InstOperands(WasmCompiler.drop(), [x]) end debug(:offline) && _debug_binaryen(ctx, x) push!(ctx.body, x) end end -function binaryfun(ctx, idx, bfuns, a, b; adjustsizes = true) +function binaryfun(ctx, idx, bfuns, a, b; adjustsizes=true) ctx.varmap[idx] = ctx.localidx Ta = roottype(ctx, a) Tb = roottype(ctx, b) @@ -40,25 +40,21 @@ function binaryfun(ctx, idx, bfuns, a, b; adjustsizes = true) _b = _compile(ctx, b) if adjustsizes && sizeof(Ta) !== sizeof(Tb) # try to make the sizes the same, at least for Integers if sizeof(Ta) == 4 - _b = BinaryenUnary(ctx.mod, BinaryenWrapInt64(), _b) + _b = InstOperands(WC.i32_wrap_i64(), [_b]) elseif sizeof(Ta) == 8 - _b = BinaryenUnary(ctx.mod, Tb <: Signed ? BinaryenExtendSInt32() : BinaryenExtendUInt32(), _b) + _b = InstOperands(Tb <: Signed ? WC.i64_extend_i32_s() : WC.i64_extend_i32_u(), [_b]) end end - x = BinaryenBinary(ctx.mod, - sizeof(Ta) < 8 && length(bfuns) > 1 ? bfuns[2]() : bfuns[1](), - _a, - _b) + x = InstOperands(sizeof(Ta) < 8 && length(bfuns) > 1 ? bfuns[2]() : bfuns[1](), [_a, _b]) setlocal!(ctx, idx, x) end function unaryfun(ctx, idx, bfuns, a) - x = BinaryenUnary(ctx.mod, - sizeof(roottype(ctx, a)) < 8 && length(bfuns) > 1 ? bfuns[2]() : bfuns[1](), - _compile(ctx, a)) + x = WasmCompiler.InstOperands(sizeof(roottype(ctx, a)) < 8 && length(bfuns) > 1 ? bfuns[2]() : bfuns[1](), + [_compile(ctx, a)]) setlocal!(ctx, idx, x) end -function binaryenfun(ctx, idx, bfun, args...; passall = false) - x = bfun(ctx.mod, (passall ? a : _compile(ctx, a) for a in args)...) +function binaryenfun(ctx, idx, bfun, args...; passall=false) + x = bfun(ctx.mod, (passall ? a : _compile(ctx, a) for a in args)...) setlocal!(ctx, idx, x) end @@ -67,23 +63,29 @@ end function compile_block(ctx::CompilerContext, cfg::Core.Compiler.CFG, phis, idx) idxs = cfg.blocks[idx].stmts ci = ctx.ci - ctx.body = BinaryenExpressionRef[] + ctx.body = WC.InstOperands[] + + cond_val = nothing + for idx in idxs node = ci.code[idx] debug(:inline) && @show idx node ssatype(ctx, idx) debug(:offline) && _debug_line(ctx, idx, node) - if node isa Union{Core.GotoNode, Core.GotoIfNot, Core.PhiNode, Nothing} + if node isa Core.GotoIfNot + cond_val = _compile(ctx, node.cond) + + elseif node isa Union{Core.GotoNode,Core.PhiNode,Nothing} # do nothing elseif node == Core.ReturnNode() - update!(ctx, BinaryenUnreachable(ctx.mod)) + update!(ctx, InstOperands(WC.unreachable(), [])) elseif node isa Core.ReturnNode - val = node.val isa GlobalRef ? Core.eval(node.val.mod, node.val.name) : + val = node.val isa GlobalRef ? Core.eval(node.val.mod, node.val.name) : node.val isa Core.Const ? node.val.val : node.val - update!(ctx, BinaryenReturn(ctx.mod, roottype(ctx, val) == Nothing ? C_NULL : _compile(ctx, val))) + update!(ctx, InstOperands(WC.return_(), roottype(ctx, val) == Nothing ? [] : [_compile(ctx, val)])) elseif node isa Core.PiNode fromT = roottype(ctx, node.val) @@ -94,498 +96,491 @@ function compile_block(ctx::CompilerContext, cfg::Core.Compiler.CFG, phis, idx) end setlocal!(ctx, idx, x) - ## Intrinsics ## + ## Intrinsics ## elseif matchgr(node, :neg_int) do a - T = roottype(ctx, a) - binaryfun(ctx, idx, (BinaryenSubInt64, BinaryenSubInt32), T(0), a) - end + T = roottype(ctx, a) + binaryfun(ctx, idx, (WC.i64_sub, WC.i32_sub), T(0), a) + end elseif matchgr(node, :neg_float) do a - unaryfun(ctx, idx, (BinaryenNegFloat64, BinaryenNegFloat32), a) - end + unaryfun(ctx, idx, (WC.f64_neg, WC.f32_neg), a) + end elseif matchgr(node, :add_int) do a, b - binaryfun(ctx, idx, (BinaryenAddInt64, BinaryenAddInt32), a, b) - end + binaryfun(ctx, idx, (WC.i64_add, WC.i32_add), a, b) + end elseif matchgr(node, :sub_int) do a, b - binaryfun(ctx, idx, (BinaryenSubInt64, BinaryenSubInt32), a, b) - end + binaryfun(ctx, idx, (WC.i64_sub, WC.i32_sub), a, b) + end elseif matchgr(node, :mul_int) do a, b - binaryfun(ctx, idx, (BinaryenMulInt64, BinaryenMulInt32), a, b) - end + binaryfun(ctx, idx, (WC.i64_mul, WC.i32_mul), a, b) + end elseif matchgr(node, :sdiv_int) do a, b - binaryfun(ctx, idx, (BinaryenDivSInt64, BinaryenDivSInt32), a, b) - end + binaryfun(ctx, idx, (WC.i64_div_s, WC.i32_div_s), a, b) + end elseif matchgr(node, :checked_sdiv_int) do a, b - binaryfun(ctx, idx, (BinaryenDivSInt64, BinaryenDivSInt32), a, b) - end + binaryfun(ctx, idx, (WC.i64_div_s, WC.i32_div_s), a, b) + end elseif matchgr(node, :udiv_int) do a, b - binaryfun(ctx, idx, (BinaryenDivUInt64, BinaryenDivUInt32), a, b) - end + binaryfun(ctx, idx, (WC.i64_div_u, WC.i32_div_u), a, b) + end elseif matchgr(node, :checked_udiv_int) do a, b - binaryfun(ctx, idx, (BinaryenDivUInt64, BinaryenDivUInt32), a, b) - end + binaryfun(ctx, idx, (WC.i64_div_u, WC.i32_div_u), a, b) + end elseif matchgr(node, :srem_int) do a, b - binaryfun(ctx, idx, (BinaryenRemSInt64, BinaryenRemSInt32), a, b) - end + binaryfun(ctx, idx, (WC.i64_rem_s, WC.i32_rem_s), a, b) + end elseif matchgr(node, :checked_srem_int) do a, b # LIES - it isn't checked - binaryfun(ctx, idx, (BinaryenRemSInt64, BinaryenRemSInt32), a, b) - end + binaryfun(ctx, idx, (WC.i64_rem_s, WC.i32_rem_s), a, b) + end elseif matchgr(node, :urem_int) do a, b - binaryfun(ctx, idx, (BinaryenRemUInt64, BinaryenRemUInt32), a, b) - end + binaryfun(ctx, idx, (WC.i64_rem_u, WC.i32_rem_u), a, b) + end elseif matchgr(node, :checked_urem_int) do a, b - binaryfun(ctx, idx, (BinaryenRemUInt64, BinaryenRemUInt32), a, b) - end + binaryfun(ctx, idx, (WC.i64_rem_u, WC.i32_rem_u), a, b) + end elseif matchgr(node, :add_float) do a, b - binaryfun(ctx, idx, (BinaryenAddFloat64, BinaryenAddFloat32), a, b) - end + binaryfun(ctx, idx, (WC.f64_add, WC.f32_add), a, b) + end elseif matchgr(node, :add_float_fast) do a, b - binaryfun(ctx, idx, (BinaryenAddFloat64, BinaryenAddFloat32), a, b) - end + binaryfun(ctx, idx, (WC.f64_add, WC.f32_add), a, b) + end elseif matchgr(node, :sub_float) do a, b - binaryfun(ctx, idx, (BinaryenSubFloat64, BinaryenSubFloat32), a, b) - end + binaryfun(ctx, idx, (WC.f64_sub, WC.f32_sub), a, b) + end elseif matchgr(node, :mul_float) do a, b - binaryfun(ctx, idx, (BinaryenMulFloat64, BinaryenMulFloat32), a, b) - end + binaryfun(ctx, idx, (WC.f64_mul, WC.f32_mul), a, b) + end elseif matchgr(node, :mul_float_fast) do a, b - binaryfun(ctx, idx, (BinaryenMulFloat64, BinaryenMulFloat32), a, b) - end + binaryfun(ctx, idx, (WC.f64_mul, WC.f32_mul), a, b) + end elseif matchgr(node, :muladd_float) do a, b, c - ab = BinaryenBinary(ctx.mod, - sizeof(roottype(ctx, a)) < 8 ? BinaryenMulFloat32() : BinaryenMulFloat64(), - _compile(ctx, a), - _compile(ctx, b)) - binaryfun(ctx, idx, (BinaryenAddFloat64, BinaryenAddFloat32), c, Pass(ab), adjustsizes = false) - end + ab = InstOperands(sizeof(roottype(ctx, a)) < 8 ? WC.f32_mul() : WC.f64_mul(), + [_compile(ctx, a), _compile(ctx, b)]) + binaryfun(ctx, idx, (WC.f64_add, WC.f32_add), c, Pass(ab), adjustsizes=false) + end elseif matchgr(node, :fma_float) do a, b, c - ab = BinaryenBinary(ctx.mod, - sizeof(roottype(ctx, a)) < 8 ? BinaryenMulFloat32() : BinaryenMulFloat64(), - _compile(ctx, a), - _compile(ctx, b)) - binaryfun(ctx, idx, (BinaryenAddFloat64, BinaryenAddFloat32), Pass(ab), c) - end + ab = InstOperands(sizeof(roottype(ctx, a)) < 8 ? WC.f32_mul() : WC.f64_mul(), + [_compile(ctx, a), _compile(ctx, b)]) + binaryfun(ctx, idx, (WC.f64_add, WC.f32_add), Pass(ab), c) + end elseif matchgr(node, :div_float) do a, b - binaryfun(ctx, idx, (BinaryenDivFloat64, BinaryenDivFloat32), a, b) - end + binaryfun(ctx, idx, (WC.f64_div, WC.f32_div), a, b) + end elseif matchgr(node, :div_float_fast) do a, b - binaryfun(ctx, idx, (BinaryenDivFloat64, BinaryenDivFloat32), a, b) - end + binaryfun(ctx, idx, (WC.f64_div, WC.f32_div), a, b) + end elseif matchgr(node, :eq_int) do a, b - binaryfun(ctx, idx, (BinaryenEqInt64, BinaryenEqInt32), a, b) - end + binaryfun(ctx, idx, (WC.i64_eq, WC.i32_eq), a, b) + end elseif matchgr(node, :(===)) do a, b - if roottype(ctx, a) <: Integer - binaryfun(ctx, idx, (BinaryenEqInt64, BinaryenEqInt32), a, b) - elseif roottype(ctx, a) <: AbstractFloat - binaryfun(ctx, idx, (BinaryenEqFloat64, BinaryenEqFloat32), a, b) + if roottype(ctx, a) <: Integer + binaryfun(ctx, idx, (WC.i64_eq, WC.i32_eq), a, b) + elseif roottype(ctx, a) <: AbstractFloat + binaryfun(ctx, idx, (WC.f64_eq, WC.f32_eq), a, b) # elseif roottype(ctx, a) <: Union{String, Symbol} # x = BinaryenStringEq(ctx.mod, BinaryenStringEqEqual(), _compile(ctx, a), _compile(ctx, b)) # setlocal!(ctx, idx, x) - else - x = BinaryenRefEq(ctx.mod, _compile(ctx, a), _compile(ctx, b)) - setlocal!(ctx, idx, x) - end + else + x = BinaryenRefEq(ctx.mod, _compile(ctx, a), _compile(ctx, b)) + setlocal!(ctx, idx, x) end + end elseif matchgr(node, :ne_int) do a, b - binaryfun(ctx, idx, (BinaryenNeInt64, BinaryenNeInt32), a, b) - end + binaryfun(ctx, idx, (BinaryenNeInt64, BinaryenNeInt32), a, b) + end elseif matchgr(node, :slt_int) do a, b - binaryfun(ctx, idx, (BinaryenLtSInt64, BinaryenLtSInt32), a, b) - end + binaryfun(ctx, idx, (WC.i64_lt_s, WC.i32_lt_s), a, b) + end elseif matchgr(node, :ult_int) do a, b - binaryfun(ctx, idx, (BinaryenLtUInt64, BinaryenLtUInt32), a, b) - end + binaryfun(ctx, idx, (WC.i64_lt_u, WC.i32_lt_u), a, b) + end elseif matchgr(node, :sle_int) do a, b - binaryfun(ctx, idx, (BinaryenLeSInt64, BinaryenLeSInt32), a, b) - end + binaryfun(ctx, idx, (WC.i64_le_s, WC.i32_le_s), a, b) + end elseif matchgr(node, :ule_int) do a, b - binaryfun(ctx, idx, (BinaryenLeUInt64, BinaryenLeUInt32), a, b) - end + binaryfun(ctx, idx, (WC.i64_le_u, WC.i32_le_u), a, b) + end elseif matchgr(node, :eq_float) do a, b - binaryfun(ctx, idx, (BinaryenEqFloat64, BinaryenEqFloat32), a, b) - end + binaryfun(ctx, idx, (WC.f64_eq, WC.f32_eq), a, b) + end elseif matchgr(node, :fpiseq) do a, b - binaryfun(ctx, idx, (BinaryenEqFloat64, BinaryenEqFloat32), a, b) - end + binaryfun(ctx, idx, (WC.f64_eq, WC.f32_eq), a, b) + end elseif matchgr(node, :ne_float) do a, b - binaryfun(ctx, idx, (BinaryenNeFloat64, BinaryenNeFloat32), a, b) - end + binaryfun(ctx, idx, (WC.f64_ne, WC.f32_ne), a, b) + end elseif matchgr(node, :lt_float) do a, b - binaryfun(ctx, idx, (BinaryenLtFloat64, BinaryenLtFloat32), a, b) - end + binaryfun(ctx, idx, (WC.f64_lt, WC.f32_lt), a, b) + end elseif matchgr(node, :le_float) do a, b - binaryfun(ctx, idx, (BinaryenLeFloat64, BinaryenLeFloat32), a, b) - end + binaryfun(ctx, idx, (WC.f64_le, WC.f32_le), a, b) + end elseif matchgr(node, :and_int) do a, b - binaryfun(ctx, idx, (BinaryenAndInt64, BinaryenAndInt32), a, b) - end + binaryfun(ctx, idx, (WC.i64_and, WC.i32_and), a, b) + end elseif matchgr(node, :or_int) do a, b - binaryfun(ctx, idx, (BinaryenOrInt64, BinaryenOrInt32), a, b) - end + binaryfun(ctx, idx, (WC.i64_or, WC.i32_or), a, b) + end elseif matchgr(node, :xor_int) do a, b - binaryfun(ctx, idx, (BinaryenXorInt64, BinaryenXorInt32), a, b) - end + binaryfun(ctx, idx, (WC.i64_xor, WC.i32_xor), a, b) + end elseif matchgr(node, :shl_int) do a, b - binaryfun(ctx, idx, (BinaryenShlInt64, BinaryenShlInt32), a, b) - end + binaryfun(ctx, idx, (WC.i64_shl, WC.i32_shl), a, b) + end elseif matchgr(node, :lshr_int) do a, b - binaryfun(ctx, idx, (BinaryenShrUInt64, BinaryenShrUInt32), a, b) - end + binaryfun(ctx, idx, (WC.i64_shr_u, WC.i32_shr_u), a, b) + end elseif matchgr(node, :ashr_int) do a, b - binaryfun(ctx, idx, (BinaryenShrSInt64, BinaryenShrSInt32), a, b) - end - - elseif matchgr(node, :ctpop_int) do a, b - binaryfun(ctx, idx, (BinaryenPopcntInt64, BinaryenPopcntInt32), a, b) - end + binaryfun(ctx, idx, (WC.i64_shr_s, WC.i32_shr_s), a, b) + end elseif matchgr(node, :ctpop_int) do a, b - binaryfun(ctx, idx, (BinaryenPopcntInt64, BinaryenPopcntInt32), a, b) - end + binaryfun(ctx, idx, (WC.i64_popcnt, WC.i32_popcnt), a, b) + end elseif matchgr(node, :copysign_float) do a, b - binaryfun(ctx, idx, (BinaryenCopySignInt64, BinaryenCopySignInt32), a, b) - end + binaryfun(ctx, idx, (WC.f64_copysign, WC.f32_copysign), a, b) + end elseif matchgr(node, :not_int) do a - Ta = roottype(ctx, a) - if sizeof(Ta) == 8 - if ssatype(ctx, idx) <: Bool - x = BinaryenUnary(ctx.mod, BinaryenEqZInt64(), _compile(ctx, a)) - else - x = BinaryenBinary(ctx.mod, BinaryenXorInt64(), _compile(ctx, a), _compile(ctx, Int64(-1))) - end + Ta = roottype(ctx, a) + if sizeof(Ta) == 8 + if ssatype(ctx, idx) <: Bool + x = InstOperands(WC.i64_eqz(), [_compile(ctx, a)]) else - if ssatype(ctx, idx) <: Bool - x = BinaryenUnary(ctx.mod, BinaryenEqZInt32(), _compile(ctx, a)) - else - x = BinaryenBinary(ctx.mod, BinaryenXorInt32(), _compile(ctx, a), _compile(ctx, Int32(-1))) - end + x = InstOperands(WC.i64_xor(), [_compile(ctx, a), _compile(ctx, Int64(-1))]) + end + else + if ssatype(ctx, idx) <: Bool + x = InstOperands(WC.i32_eqz(), [_compile(ctx, a)]) + else + x = InstOperands(WC.i32_xor(), [_compile(ctx, a), _compile(ctx, Int32(-1))]) end - setlocal!(ctx, idx, x) end + setlocal!(ctx, idx, x) + end elseif matchgr(node, :ctlz_int) do a - unaryfun(ctx, idx, (BinaryenClzInt64, BinaryenClzInt32), a) + unaryfun(ctx, idx, (WC.i64_clz, WC.i32_clz), a) end elseif matchgr(node, :cttz_int) do a - unaryfun(ctx, idx, (BinaryenCtzInt64, BinaryenCtzInt32), a) + unaryfun(ctx, idx, (WC.i64_ctz, WC.i32_ctz), a) end - ## I'm not sure these are right + ## I'm not sure these are right elseif matchgr(node, :sext_int) do a, b - t = (roottype(ctx, a), roottype(ctx, b)) - sizeof(t[1]) == 8 && sizeof(t[2]) == 1 ? unaryfun(ctx, idx, (BinaryenExtendS8Int64,), b) : - sizeof(t[1]) == 8 && sizeof(t[2]) == 2 ? unaryfun(ctx, idx, (BinaryenExtendS16Int64,), b) : - sizeof(t[1]) == 8 && sizeof(t[2]) == 4 ? unaryfun(ctx, idx, (BinaryenExtendSInt32,), b) : - sizeof(t[1]) == 4 && sizeof(t[2]) == 1 ? unaryfun(ctx, idx, (BinaryenExtendS8Int32,), b) : - sizeof(t[1]) == 4 && sizeof(t[2]) == 2 ? unaryfun(ctx, idx, (BinaryenExtendS16Int32,), b) : - error("Unsupported `sext_int` types $t") - end + t = (roottype(ctx, a), roottype(ctx, b)) + sizeof(t[1]) == 8 && sizeof(t[2]) == 1 ? unaryfun(ctx, idx, (WasmCompiler.i64_extend8_s,), b) : + sizeof(t[1]) == 8 && sizeof(t[2]) == 2 ? unaryfun(ctx, idx, (WasmCompiler.i64_extend16_s,), b) : + sizeof(t[1]) == 8 && sizeof(t[2]) == 4 ? unaryfun(ctx, idx, (WasmCompiler.i64_extend_i32_s,), b) : + sizeof(t[1]) == 4 && sizeof(t[2]) == 1 ? unaryfun(ctx, idx, (WasmCompiler.i32_extend8_s,), b) : + sizeof(t[1]) == 4 && sizeof(t[2]) == 2 ? unaryfun(ctx, idx, (WasmCompiler.i32_extend16_s,), b) : + error("Unsupported `sext_int` types $t") + end elseif matchgr(node, :zext_int) do a, b - t = (roottype(ctx, a), roottype(ctx, b)) - sizeof(t[1]) == 4 && sizeof(t[2]) <= 4 ? b : - sizeof(t[1]) == 8 && sizeof(t[2]) <= 4 ? unaryfun(ctx, idx, (BinaryenExtendUInt32,), b) : - error("Unsupported `zext_int` types $t") - end + t = (roottype(ctx, a), roottype(ctx, b)) + sizeof(t[1]) == 4 && sizeof(t[2]) <= 4 ? b : + sizeof(t[1]) == 8 && sizeof(t[2]) <= 4 ? unaryfun(ctx, idx, (WC.i64_extend_i32_u,), b) : + error("Unsupported `zext_int` types $t") + end elseif matchgr(node, :trunc_int) do a, b - t = (roottype(ctx, a), roottype(ctx, b)) - t == (Int32, Int64) ? unaryfun(ctx, idx, (BinaryenWrapInt64,), b) : - t == (Int32, UInt64) ? unaryfun(ctx, idx, (BinaryenWrapInt64,), b) : - t == (UInt32, UInt64) ? unaryfun(ctx, idx, (BinaryenWrapInt64,), b) : - t == (UInt32, Int64) ? unaryfun(ctx, idx, (BinaryenWrapInt64,), b) : - t == (UInt8, UInt64) ? unaryfun(ctx, idx, (BinaryenWrapInt64,), b) : - t == (UInt8, Int64) ? unaryfun(ctx, idx, (BinaryenWrapInt64,), b) : - error("Unsupported `trunc_int` types $t") - end + t = (roottype(ctx, a), roottype(ctx, b)) + t == (Int32, Int64) ? unaryfun(ctx, idx, (WC.i32_wrap_i64,), b) : + t == (Int32, UInt64) ? unaryfun(ctx, idx, (WC.i32_wrap_i64,), b) : + t == (UInt32, UInt64) ? unaryfun(ctx, idx, (WC.i32_wrap_i64,), b) : + t == (UInt32, Int64) ? unaryfun(ctx, idx, (WC.i32_wrap_i64,), b) : + t == (UInt8, UInt64) ? unaryfun(ctx, idx, (WC.i32_wrap_i64,), b) : + t == (UInt8, Int64) ? unaryfun(ctx, idx, (WC.i32_wrap_i64,), b) : + error("Unsupported `trunc_int` types $t") + end elseif matchgr(node, :flipsign_int) do a, b - Ta = eltype(roottype(ctx, a)) - Tb = eltype(roottype(ctx, b)) - # check the sign of b - if sizeof(Ta) == 8 - isnegative = BinaryenUnary(ctx.mod, BinaryenWrapInt64(), BinaryenBinary(ctx.mod, BinaryenShrUInt64(), _compile(ctx, b), _compile(ctx, UInt64(63)))) - x = BinaryenIf(ctx.mod, isnegative, - BinaryenBinary(ctx.mod, BinaryenMulInt64(), _compile(ctx, a), _compile(ctx, Int64(-1))), - _compile(ctx, a)) - else - isnegative = BinaryenBinary(ctx.mod, BinaryenShrUInt32(), _compile(ctx, b), _compile(ctx, UInt32(31))) - x = BinaryenIf(ctx.mod, isnegative, - BinaryenBinary(ctx.mod, BinaryenMulInt32(), _compile(ctx, a), _compile(ctx, Int32(-1))), - _compile(ctx, a)) - end - setlocal!(ctx, idx, x) + Ta = eltype(roottype(ctx, a)) + Tb = eltype(roottype(ctx, b)) + # check the sign of b + if sizeof(Ta) == 8 + isnegative = InstOperands(WC.i32_wrap_i64(), InstOperands(WC.i64_shr_u(), [_compile(ctx, b), _compile(ctx, UInt64(63))])) + x = BinaryenIf(ctx.mod, isnegative, + BinaryenBinary(ctx.mod, BinaryenMulInt64(), _compile(ctx, a), _compile(ctx, Int64(-1))), + _compile(ctx, a)) + else + isnegative = BinaryenBinary(ctx.mod, BinaryenShrUInt32(), _compile(ctx, b), _compile(ctx, UInt32(31))) + x = BinaryenIf(ctx.mod, isnegative, + BinaryenBinary(ctx.mod, BinaryenMulInt32(), _compile(ctx, a), _compile(ctx, Int32(-1))), + _compile(ctx, a)) end + setlocal!(ctx, idx, x) + end elseif matchgr(node, :fptoui) do a, b - t = (roottype(ctx, a), roottype(ctx, b)) - t == (UInt64, Float32) ? unaryfun(ctx, idx, (BinaryenTruncSatUFloat32ToInt64,), b) : - t == (UInt64, Float64) ? unaryfun(ctx, idx, (BinaryenTruncSatUFloat64ToInt64,), b) : - t == (UInt32, Float32) ? unaryfun(ctx, idx, (BinaryenTruncSatUFloat32ToInt32,), b) : - t == (UInt32, Float64) ? unaryfun(ctx, idx, (BinaryenTruncSatUFloat64ToInt32,), b) : - error("Unsupported `fptoui` types") - end + t = (roottype(ctx, a), roottype(ctx, b)) + t == (UInt64, Float32) ? unaryfun(ctx, idx, (WC.i64_trunc_sat_f32_u,), b) : + t == (UInt64, Float64) ? unaryfun(ctx, idx, (WC.i64_trunc_sat_f64_u,), b) : + t == (UInt32, Float32) ? unaryfun(ctx, idx, (WC.i32_trunc_sat_f32_u,), b) : + t == (UInt32, Float64) ? unaryfun(ctx, idx, (WC.i32_trunc_sat_f64_u,), b) : + error("Unsupported `fptoui` types") + end elseif matchgr(node, :fptosi) do a, b - t = (roottype(ctx, a), roottype(ctx, b)) - t == (Int64, Float32) ? unaryfun(ctx, idx, (BinaryenTruncSatSFloat32ToInt64,), b) : - t == (Int64, Float64) ? unaryfun(ctx, idx, (BinaryenTruncSatSFloat64ToInt64,), b) : - t == (Int32, Float32) ? unaryfun(ctx, idx, (BinaryenTruncSatSFloat32ToInt32,), b) : - t == (Int32, Float64) ? unaryfun(ctx, idx, (BinaryenTruncSatSFloat64ToInt32,), b) : - error("Unsupported `fptosi` types") - end + t = (roottype(ctx, a), roottype(ctx, b)) + t == (Int64, Float32) ? unaryfun(ctx, idx, (WC.i64_trunc_sat_f32_s,), b) : + t == (Int64, Float64) ? unaryfun(ctx, idx, (WC.i64_trunc_sat_f64_s,), b) : + t == (Int32, Float32) ? unaryfun(ctx, idx, (WC.i32_trunc_sat_f32_s,), b) : + t == (Int32, Float64) ? unaryfun(ctx, idx, (WC.i32_trunc_sat_f64_s,), b) : + error("Unsupported `fptosi` types") + end elseif matchgr(node, :uitofp) do a, b - t = (roottype(ctx, a), roottype(ctx, b)) - t == (Float64, UInt32) ? unaryfun(ctx, idx, (BinaryenConvertUInt32ToFloat64,), b) : - t == (Float64, UInt64) ? unaryfun(ctx, idx, (BinaryenConvertUInt64ToFloat64,), b) : - t == (Float32, UInt32) ? unaryfun(ctx, idx, (BinaryenConvertUInt32ToFloat32,), b) : - t == (Float32, UInt64) ? unaryfun(ctx, idx, (BinaryenConvertUInt64ToFloat32,), b) : - error("Unsupported `uitofp` types") - end + t = (roottype(ctx, a), roottype(ctx, b)) + t == (Float64, UInt32) ? unaryfun(ctx, idx, (WC.f64_convert_i32_u,), b) : + t == (Float64, UInt64) ? unaryfun(ctx, idx, (WC.f64_convert_i64_u,), b) : + t == (Float32, UInt32) ? unaryfun(ctx, idx, (WC.f32_convert_i32_u,), b) : + t == (Float32, UInt64) ? unaryfun(ctx, idx, (WC.f32_convert_i64_u,), b) : + error("Unsupported `uitofp` types") + end elseif matchgr(node, :sitofp) do a, b - t = (roottype(ctx, a), roottype(ctx, b)) - t == (Float64, Int32) ? unaryfun(ctx, idx, (BinaryenConvertSInt32ToFloat64,), b) : - t == (Float64, Int64) ? unaryfun(ctx, idx, (BinaryenConvertSInt64ToFloat64,), b) : - t == (Float32, Int32) ? unaryfun(ctx, idx, (BinaryenConvertSInt32ToFloat32,), b) : - t == (Float32, Int64) ? unaryfun(ctx, idx, (BinaryenConvertSInt64ToFloat32,), b) : - error("Unsupported `sitofp` types") - end + t = (roottype(ctx, a), roottype(ctx, b)) + t == (Float64, Int32) ? unaryfun(ctx, idx, (WC.f64_convert_i32_s,), b) : + t == (Float64, Int64) ? unaryfun(ctx, idx, (WC.f64_convert_i64_s,), b) : + t == (Float32, Int32) ? unaryfun(ctx, idx, (WC.f32_convert_i32_s,), b) : + t == (Float32, Int64) ? unaryfun(ctx, idx, (WC.f32_convert_i64_s,), b) : + error("Unsupported `sitofp` types") + end elseif matchgr(node, :fptrunc) do a, b - t = (roottype(ctx, a), roottype(ctx, b)) - t == (Float32, Float64) ? unaryfun(ctx, idx, (BinaryenDemoteFloat64,), b) : - error("Unsupported `fptrunc` types") - end + t = (roottype(ctx, a), roottype(ctx, b)) + t == (Float32, Float64) ? unaryfun(ctx, idx, (WC.f32_demote_f64,), b) : + error("Unsupported `fptrunc` types") + end elseif matchgr(node, :fpext) do a, b - t = (roottype(ctx, a), roottype(ctx, b)) - t == (Float64, Float32) ? unaryfun(ctx, idx, (BinaryenPromoteFloat32,), b) : - error("Unsupported `fpext` types") - end + t = (roottype(ctx, a), roottype(ctx, b)) + t == (Float64, Float32) ? unaryfun(ctx, idx, (WC.f64_promote_f32,), b) : + error("Unsupported `fpext` types") + end elseif matchgr(node, :ifelse) do cond, a, b - binaryenfun(ctx, idx, BinaryenIf, cond, a, b) - end + x = InstOperands(WC.select(), [_compile(ctx, a), _compile(ctx, b), _compile(ctx, cond)]) + setlocal!(ctx, idx, x) + end elseif matchgr(node, :abs_float) do a - unaryfun(ctx, idx, (BinaryenAbsFloat64, BinaryenAbsFloat32), a) - end + unaryfun(ctx, idx, (WC.f64_abs, WC.f32_abs), a) + end elseif matchgr(node, :ceil_llvm) do a - unaryfun(ctx, idx, (BinaryenCeilFloat64, BinaryenCeilFloat32), a) - end + unaryfun(ctx, idx, (WC.f64_ceil, WC.f32_ceil), a) + end elseif matchgr(node, :floor_llvm) do a - unaryfun(ctx, idx, (BinaryenFloorFloat64, BinaryenFloorFloat32), a) - end + unaryfun(ctx, idx, (WC.f64_floor, WC.f32_floor), a) + end elseif matchgr(node, :trunc_llvm) do a - unaryfun(ctx, idx, (BinaryenTruncFloat64, BinaryenTruncFloat32), a) - end + unaryfun(ctx, idx, (WC.f64_trunc, WC.f32_trunc), a) + end elseif matchgr(node, :rint_llvm) do a - unaryfun(ctx, idx, (BinaryenNearestFloat64, BinaryenNearestFloat32), a) - end + unaryfun(ctx, idx, (WC.f64_nearest, WC.f32_nearest), a) + end elseif matchgr(node, :sqrt_llvm) do a - unaryfun(ctx, idx, (BinaryenSqrtFloat64, BinaryenSqrtFloat32), a) - end + unaryfun(ctx, idx, (WC.f64_sqrt, WC.f32_sqrt), a) + end elseif matchgr(node, :sqrt_llvm_fast) do a - unaryfun(ctx, idx, (BinaryenSqrtFloat64, BinaryenSqrtFloat32), a) - end + unaryfun(ctx, idx, (WC.f64_sqrt, WC.f32_sqrt), a) + end elseif matchgr(node, :have_fma) do a - setlocal!(ctx, idx, _compile(ctx, Int32(0))) - end + setlocal!(ctx, idx, _compile(ctx, Int32(0))) + end elseif matchgr(node, :bitcast) do t, val - T = roottype(ctx, t) - Tval = roottype(ctx, val) - T == Float64 && Tval <: Integer ? unaryfun(ctx, idx, (BinaryenReinterpretInt64,), val) : - T == Float32 && Tval <: Integer ? unaryfun(ctx, idx, (BinaryenReinterpretInt32,), val) : - T <: Integer && sizeof(T) == 8 && Tval <: AbstractFloat ? unaryfun(ctx, idx, (BinaryenReinterpretFloat64,), val) : - T <: Integer && sizeof(T) == 4 && Tval <: AbstractFloat ? unaryfun(ctx, idx, (BinaryenReinterpretFloat32,), val) : - T <: Integer && Tval <: Integer ? setlocal!(ctx, idx, _compile(ctx, val)) : - # T <: Integer && Tval :: Char ? unaryfun(ctx, idx, (BinaryenReinterpretInt32,), val) : - error("Unsupported `bitcast` types, ($T, $Tval)") - end + T = roottype(ctx, t) + Tval = roottype(ctx, val) + T == Float64 && Tval <: Integer ? unaryfun(ctx, idx, (BinaryenReinterpretInt64,), val) : + T == Float32 && Tval <: Integer ? unaryfun(ctx, idx, (BinaryenReinterpretInt32,), val) : + T <: Integer && sizeof(T) == 8 && Tval <: AbstractFloat ? unaryfun(ctx, idx, (BinaryenReinterpretFloat64,), val) : + T <: Integer && sizeof(T) == 4 && Tval <: AbstractFloat ? unaryfun(ctx, idx, (BinaryenReinterpretFloat32,), val) : + T <: Integer && Tval <: Integer ? setlocal!(ctx, idx, _compile(ctx, val)) : + # T <: Integer && Tval :: Char ? unaryfun(ctx, idx, (BinaryenReinterpretInt32,), val) : + error("Unsupported `bitcast` types, ($T, $Tval)") + end - ## TODO - # ADD_I(cglobal, 2) \ + ## TODO + # ADD_I(cglobal, 2) \ - ## Builtins / key functions ## + ## Builtins / key functions ## elseif matchforeigncall(node, :jl_string_to_array) do args - # This is just a pass-through because strings are already Vector{UInt8} - setlocal!(ctx, idx, _compile(ctx, args[5])) - end + # This is just a pass-through because strings are already Vector{UInt8} + setlocal!(ctx, idx, _compile(ctx, args[5])) + end elseif matchforeigncall(node, :jl_array_to_string) do args - # This is just a pass-through because strings are already Vector{UInt8} - setlocal!(ctx, idx, _compile(ctx, args[5])) - end + # This is just a pass-through because strings are already Vector{UInt8} + setlocal!(ctx, idx, _compile(ctx, args[5])) + end elseif matchforeigncall(node, :_jl_symbol_to_array) do args - # This is just a pass-through because Symbols are already Vector{UInt8} - setlocal!(ctx, idx, _compile(ctx, args[5])) - end + # This is just a pass-through because Symbols are already Vector{UInt8} + setlocal!(ctx, idx, _compile(ctx, args[5])) + end elseif matchgr(node, :arrayref) do bool, arraywrapper, i - buffer = getbuffer(ctx, arraywrapper) - # signed = eT <: Signed && sizeof(eT) < 4 - signed = false - ## subtract one from i for zero-based indexing in WASM - i = BinaryenBinary(ctx.mod, BinaryenAddInt32(), _compile(ctx, I32(i)), _compile(ctx, Int32(-1))) - binaryenfun(ctx, idx, BinaryenArrayGet, buffer, i, gettype(ctx, roottype(ctx, arraywrapper)), Pass(signed)) - end + buffer = getbuffer(ctx, arraywrapper) + # signed = eT <: Signed && sizeof(eT) < 4 + signed = false + ## subtract one from i for zero-based indexing in WASM + i = BinaryenBinary(ctx.mod, BinaryenAddInt32(), _compile(ctx, I32(i)), _compile(ctx, Int32(-1))) + binaryenfun(ctx, idx, BinaryenArrayGet, buffer, i, gettype(ctx, roottype(ctx, arraywrapper)), Pass(signed)) + end elseif matchgr(node, :arrayset) do bool, arraywrapper, val, i - buffer = getbuffer(ctx, arraywrapper) - i = BinaryenBinary(ctx.mod, BinaryenAddInt32(), _compile(ctx, I32(i)), _compile(ctx, Int32(-1))) - x = _compile(ctx, val) - aT = eltype(roottype(ctx, arraywrapper)) - if aT == Any # Box if needed - x = box(ctx, x, roottype(ctx, val)) - end - x = BinaryenArraySet(ctx.mod, buffer, i, x) - update!(ctx, x) - end + buffer = getbuffer(ctx, arraywrapper) + i = BinaryenBinary(ctx.mod, BinaryenAddInt32(), _compile(ctx, I32(i)), _compile(ctx, Int32(-1))) + x = _compile(ctx, val) + aT = eltype(roottype(ctx, arraywrapper)) + if aT == Any # Box if needed + x = box(ctx, x, roottype(ctx, val)) + end + x = BinaryenArraySet(ctx.mod, buffer, i, x) + update!(ctx, x) + end elseif matchgr(node, :arraylen) do arraywrapper - x = BinaryenStructGet(ctx.mod, 1, _compile(ctx, arraywrapper), C_NULL, false) - if sizeof(Int) == 8 # extend to Int64 - unaryfun(ctx, idx, (BinaryenExtendUInt32,), x) - else - setlocal!(ctx, idx, x) - end + x = BinaryenStructGet(ctx.mod, 1, _compile(ctx, arraywrapper), C_NULL, false) + if sizeof(Int) == 8 # extend to Int64 + unaryfun(ctx, idx, (BinaryenExtendUInt32,), x) + else + setlocal!(ctx, idx, x) end + end elseif matchgr(node, :arraysize) do arraywrapper, n - x = BinaryenStructGet(ctx.mod, 1, _compile(ctx, arraywrapper), C_NULL, false) - if sizeof(Int) == 8 # extend to Int64 - unaryfun(ctx, idx, (BinaryenExtendUInt32,), x) - else - setlocal!(ctx, idx, x) - end + x = BinaryenStructGet(ctx.mod, 1, _compile(ctx, arraywrapper), C_NULL, false) + if sizeof(Int) == 8 # extend to Int64 + unaryfun(ctx, idx, (BinaryenExtendUInt32,), x) + else + setlocal!(ctx, idx, x) end + end elseif matchforeigncall(node, :jl_alloc_array_1d) do args - elT = eltype(args[1]) - size = _compile(ctx, I32(args[6])) - arraytype = BinaryenTypeGetHeapType(gettype(ctx, Buffer{elT})) - buffer = BinaryenArrayNew(ctx.mod, arraytype, size, _compile(ctx, default(elT))) - wrappertype = BinaryenTypeGetHeapType(gettype(ctx, FakeArrayWrapper{elT})) - binaryenfun(ctx, idx, BinaryenStructNew, [buffer, size], UInt32(2), wrappertype; passall = true) - end + elT = eltype(args[1]) + size = _compile(ctx, I32(args[6])) + arraytype = BinaryenTypeGetHeapType(gettype(ctx, Buffer{elT})) + buffer = BinaryenArrayNew(ctx.mod, arraytype, size, _compile(ctx, default(elT))) + wrappertype = BinaryenTypeGetHeapType(gettype(ctx, FakeArrayWrapper{elT})) + binaryenfun(ctx, idx, BinaryenStructNew, [buffer, size], UInt32(2), wrappertype; passall=true) + end elseif matchforeigncall(node, :jl_array_grow_end) do args - arraywrapper = args[5] - elT = eltype(roottype(ctx, args[5])) - arraytype = gettype(ctx, Buffer{elT}) - arrayheaptype = BinaryenTypeGetHeapType(arraytype) - arraywrappertype = gettype(ctx, Vector{elT}) - _arraywrapper = _compile(ctx, arraywrapper) - buffer = getbuffer(ctx, args[5]) - bufferlen = BinaryenArrayLen(ctx.mod, buffer) - extralen = _compile(ctx, I32(args[6])) - arraylen = BinaryenStructGet(ctx.mod, 1, _arraywrapper, C_NULL, false) - newlen = BinaryenBinary(ctx.mod, BinaryenAddInt32(), arraylen, extralen) - newbufferlen = BinaryenBinary(ctx.mod, BinaryenMulInt32(), newlen, _compile(ctx, I32(2))) - neednewbuffer = BinaryenBinary(ctx.mod, BinaryenLeUInt32(), arraylen, newlen) - newbufferget = BinaryenLocalGet(ctx.mod, ctx.localidx, arraytype) - newbufferblock = [ - BinaryenLocalSet(ctx.mod, ctx.localidx, BinaryenArrayNew(ctx.mod, arrayheaptype, newbufferlen, _compile(ctx, default(elT)))), - BinaryenArrayCopy(ctx.mod, newbufferget, _compile(ctx, I32(0)), buffer, _compile(ctx, I32(0)), _compile(ctx, arraylen)), - BinaryenStructSet(ctx.mod, 0, _arraywrapper, newbufferget), - ] - push!(ctx.locals, arraytype) - ctx.localidx += 1 - x = BinaryenIf(ctx.mod, neednewbuffer, - BinaryenBlock(ctx.mod, "newbuff", newbufferblock, length(newbufferblock), BinaryenTypeAuto()), - C_NULL) - update!(ctx, x) - x = BinaryenStructSet(ctx.mod, 1, _arraywrapper, newlen) - update!(ctx, x) - end + arraywrapper = args[5] + elT = eltype(roottype(ctx, args[5])) + arraytype = gettype(ctx, Buffer{elT}) + arrayheaptype = BinaryenTypeGetHeapType(arraytype) + arraywrappertype = gettype(ctx, Vector{elT}) + _arraywrapper = _compile(ctx, arraywrapper) + buffer = getbuffer(ctx, args[5]) + bufferlen = BinaryenArrayLen(ctx.mod, buffer) + extralen = _compile(ctx, I32(args[6])) + arraylen = BinaryenStructGet(ctx.mod, 1, _arraywrapper, C_NULL, false) + newlen = BinaryenBinary(ctx.mod, BinaryenAddInt32(), arraylen, extralen) + newbufferlen = BinaryenBinary(ctx.mod, BinaryenMulInt32(), newlen, _compile(ctx, I32(2))) + neednewbuffer = BinaryenBinary(ctx.mod, BinaryenLeUInt32(), arraylen, newlen) + newbufferget = BinaryenLocalGet(ctx.mod, ctx.localidx, arraytype) + newbufferblock = [ + BinaryenLocalSet(ctx.mod, ctx.localidx, BinaryenArrayNew(ctx.mod, arrayheaptype, newbufferlen, _compile(ctx, default(elT)))), + BinaryenArrayCopy(ctx.mod, newbufferget, _compile(ctx, I32(0)), buffer, _compile(ctx, I32(0)), _compile(ctx, arraylen)), + BinaryenStructSet(ctx.mod, 0, _arraywrapper, newbufferget), + ] + push!(ctx.locals, arraytype) + ctx.localidx += 1 + x = BinaryenIf(ctx.mod, neednewbuffer, + BinaryenBlock(ctx.mod, "newbuff", newbufferblock, length(newbufferblock), BinaryenTypeAuto()), + C_NULL) + update!(ctx, x) + x = BinaryenStructSet(ctx.mod, 1, _arraywrapper, newlen) + update!(ctx, x) + end elseif matchforeigncall(node, :jl_array_del_end) do args - arraywrapper = _compile(ctx, args[5]) - i = _compile(ctx, I32(args[6])) - arraylen = BinaryenStructGet(ctx.mod, 1, arraywrapper, C_NULL, false) - newlen = BinaryenBinary(ctx.mod, BinaryenSubInt32(), arraylen, i) - x = BinaryenStructSet(ctx.mod, 1, arraywrapper, newlen) - update!(ctx, x) - end + arraywrapper = _compile(ctx, args[5]) + i = _compile(ctx, I32(args[6])) + arraylen = BinaryenStructGet(ctx.mod, 1, arraywrapper, C_NULL, false) + newlen = BinaryenBinary(ctx.mod, BinaryenSubInt32(), arraylen, i) + x = BinaryenStructSet(ctx.mod, 1, arraywrapper, newlen) + update!(ctx, x) + end elseif matchforeigncall(node, :_jl_array_copy) do args - srcbuffer = getbuffer(ctx, args[5]) - destbuffer = getbuffer(ctx, args[6]) - n = args[7] - x = BinaryenArrayCopy(ctx.mod, destbuffer, _compile(ctx, I32(0)), srcbuffer, _compile(ctx, I32(0)), _compile(ctx, n)) - update!(ctx, x) - end + srcbuffer = getbuffer(ctx, args[5]) + destbuffer = getbuffer(ctx, args[6]) + n = args[7] + x = BinaryenArrayCopy(ctx.mod, destbuffer, _compile(ctx, I32(0)), srcbuffer, _compile(ctx, I32(0)), _compile(ctx, n)) + update!(ctx, x) + end elseif matchforeigncall(node, :_jl_array_copyto) do args - destbuffer = getbuffer(ctx, args[5]) - doffs = _compile(ctx, args[6]) - srcbuffer = getbuffer(ctx, args[7]) - soffs = _compile(ctx, args[8]) - n = _compile(ctx, args[9]) - x = BinaryenArrayCopy(ctx.mod, destbuffer, doffs, srcbuffer, soffs, n) - update!(ctx, x) - setlocal!(ctx, idx, _compile(ctx, args[5])) - end + destbuffer = getbuffer(ctx, args[5]) + doffs = _compile(ctx, args[6]) + srcbuffer = getbuffer(ctx, args[7]) + soffs = _compile(ctx, args[8]) + n = _compile(ctx, args[9]) + x = BinaryenArrayCopy(ctx.mod, destbuffer, doffs, srcbuffer, soffs, n) + update!(ctx, x) + setlocal!(ctx, idx, _compile(ctx, args[5])) + end elseif matchforeigncall(node, :jl_object_id) do args - setlocal!(ctx, idx, _compile(ctx, objectid(_compile(ctx, args[5])))) - end + setlocal!(ctx, idx, _compile(ctx, objectid(_compile(ctx, args[5])))) + end elseif matchgr(node, :getfield) || matchcall(node, getfield) x = node.args[2] @@ -595,61 +590,61 @@ function compile_block(ctx::CompilerContext, cfg::Core.Compiler.CFG, phis, idx) # unsigned = eltype(T) <: Unsigned unsigned = true ## subtract one from i for zero-based indexing in WASM - i = BinaryenBinary(ctx.mod, BinaryenAddInt32(), - _compile(ctx, I32(index)), - _compile(ctx, Int32(-1))) + i = BinaryenBinary(ctx.mod, BinaryenAddInt32(), + _compile(ctx, I32(index)), + _compile(ctx, Int32(-1))) binaryenfun(ctx, idx, BinaryenArrayGet, _compile(ctx, x), Pass(i), gettype(ctx, eltype(T)), Pass(!unsigned)) elseif roottype(ctx, index) <: Integer # if length(node.args) == 3 # 2-arg version - eT = fieldtypeskept(T)[index] - # unsigned = eT <: Unsigned - unsigned = true - binaryenfun(ctx, idx, BinaryenStructGet, UInt32(index - 1), _compile(ctx, x), gettype(ctx, eT), !unsigned, passall = true) + eT = fieldtypeskept(T)[index] + # unsigned = eT <: Unsigned + unsigned = true + binaryenfun(ctx, idx, BinaryenStructGet, UInt32(index - 1), _compile(ctx, x), gettype(ctx, eT), !unsigned, passall=true) # else # 3-arg version # end else field = index nT = T <: Type ? DataType : T # handle Types index = UInt32(findfirst(x -> x == field.value, fieldskept(nT)) - 1) - eT = Base.datatype_fieldtypes(nT)[index + 1] + eT = Base.datatype_fieldtypes(nT)[index+1] # unsigned = eT <: Unsigned unsigned = true - binaryenfun(ctx, idx, BinaryenStructGet, index, _compile(ctx, x), gettype(ctx, eT), !unsigned, passall = true) + binaryenfun(ctx, idx, BinaryenStructGet, index, _compile(ctx, x), gettype(ctx, eT), !unsigned, passall=true) end - - ## 3-arg version of getfield for integer field access + + ## 3-arg version of getfield for integer field access elseif matchgr(node, :getfield) do x, index, bool - T = roottype(ctx, x) - if T <: NTuple - # unsigned = eltype(T) <: Unsigned - unsigned = true - ## subtract one from i for zero-based indexing in WASM - i = BinaryenBinary(ctx.mod, BinaryenAddInt32(), - _compile(ctx, I32(index)), - _compile(ctx, Int32(-1))) - - binaryenfun(ctx, idx, BinaryenArrayGet, _compile(ctx, x), Pass(i), gettype(ctx, eltype(T)), Pass(!unsigned)) - else - eT = Base.datatype_fieldtypes(T)[_compile(ctx, index)] - # unsigned = eT <: Unsigned - unsigned = true - binaryenfun(ctx, idx, BinaryenStructGet, UInt32(index - 1), _compile(ctx, x), gettype(ctx, eT), !unsigned, passall = true) - end + T = roottype(ctx, x) + if T <: NTuple + # unsigned = eltype(T) <: Unsigned + unsigned = true + ## subtract one from i for zero-based indexing in WASM + i = BinaryenBinary(ctx.mod, BinaryenAddInt32(), + _compile(ctx, I32(index)), + _compile(ctx, Int32(-1))) + + binaryenfun(ctx, idx, BinaryenArrayGet, _compile(ctx, x), Pass(i), gettype(ctx, eltype(T)), Pass(!unsigned)) + else + eT = Base.datatype_fieldtypes(T)[_compile(ctx, index)] + # unsigned = eT <: Unsigned + unsigned = true + binaryenfun(ctx, idx, BinaryenStructGet, UInt32(index - 1), _compile(ctx, x), gettype(ctx, eT), !unsigned, passall=true) end + end elseif matchgr(node, :setfield!) do x, field, value - if field isa QuoteNode && field.value isa Symbol - T = roottype(ctx, x) - index = UInt32(findfirst(x -> x == field.value, fieldskept(T)) - 1) - elseif field isa Integer - index = UInt32(field) - else - error("setfield! indexing with $field is not supported in $node.") - end - x = BinaryenStructSet(ctx.mod, index, _compile(ctx, x), _compile(ctx, value)) - update!(ctx, x) + if field isa QuoteNode && field.value isa Symbol + T = roottype(ctx, x) + index = UInt32(findfirst(x -> x == field.value, fieldskept(T)) - 1) + elseif field isa Integer + index = UInt32(field) + else + error("setfield! indexing with $field is not supported in $node.") end - + x = BinaryenStructSet(ctx.mod, index, _compile(ctx, x), _compile(ctx, value)) + update!(ctx, x) + end + elseif node isa Expr && (node.head == :new || (node.head == :call && node.args[1] isa GlobalRef && node.args[1].name == :tuple)) nargs = UInt32(length(node.args) - 1) @@ -663,10 +658,10 @@ function compile_block(ctx::CompilerContext, cfg::Core.Compiler.CFG, phis, idx) if jtype <: NTuple values = [_compile(ctx, v) for v in node.args[2:end]] N = Int32(length(node.args) - 1) - binaryenfun(ctx, idx, BinaryenArrayNewFixed, type, values, N; passall = true) + binaryenfun(ctx, idx, BinaryenArrayNewFixed, type, values, N; passall=true) else x = BinaryenStructNew(ctx.mod, args, nargs, type) - binaryenfun(ctx, idx, BinaryenStructNew, args, nargs, type; passall = true) + binaryenfun(ctx, idx, BinaryenStructNew, args, nargs, type; passall=true) # for (i,name) in enumerate(fieldnames(jtype)) # BinaryenModuleSetFieldName(ctx.mod, type, i - 1, string(name)) # end @@ -679,17 +674,17 @@ function compile_block(ctx::CompilerContext, cfg::Core.Compiler.CFG, phis, idx) arraytype = BinaryenTypeGetHeapType(gettype(ctx, typeof(x))) values = [_compile(ctx, v) for v in node.args[2:end]] N = _compile(ctx, Int32(length(node.args) - 1)) - binaryenfun(ctx, idx, BinaryenArrayNewFixed, arraytype, values, N; passall = true) + binaryenfun(ctx, idx, BinaryenArrayNewFixed, arraytype, values, N; passall=true) else nargs = UInt32(length(node.args) - 1) args = [_compile(ctx, x) for x in node.args[2:end]] jtype = node.args[1] type = BinaryenTypeGetHeapType(gettype(ctx, jtype)) x = BinaryenStructNew(ctx.mod, args, nargs, type) - binaryenfun(ctx, idx, BinaryenStructNew, args, nargs, type; passall = true) + binaryenfun(ctx, idx, BinaryenStructNew, args, nargs, type; passall=true) end - elseif node isa Expr && node.head == :call && + elseif node isa Expr && node.head == :call && ((node.args[1] isa GlobalRef && node.args[1].name == :llvmcall) || node.args[1] == Core.Intrinsics.llvmcall) jscode = node.args[2] @@ -720,20 +715,20 @@ function compile_block(ctx::CompilerContext, cfg::Core.Compiler.CFG, phis, idx) setlocal!(ctx, idx, x) end - #= - `invoke` is one of the toughest parts of compilation. - Cases that must be handled include: - * Variable arguments: We pass these as the last argument as a tuple. - * Callable struct / closure: We pass these as the first argument if it's not a toplevel function. - If it's top level, then the struct is stored as a global variable. - * Keyword arguments: These come in as the first argument after the function/callable struct argument. - Notes: - * The first argument is the function itself. - Use that for callable structs. We remove it if it's not callable. - * If an argument isn't used (including types or other non-data arguments), - it is not included in the argument list. - This might be weird for top-level definitions, so it's not done there (but might cause issues). - =# + #= + `invoke` is one of the toughest parts of compilation. + Cases that must be handled include: + * Variable arguments: We pass these as the last argument as a tuple. + * Callable struct / closure: We pass these as the first argument if it's not a toplevel function. + If it's top level, then the struct is stored as a global variable. + * Keyword arguments: These come in as the first argument after the function/callable struct argument. + Notes: + * The first argument is the function itself. + Use that for callable structs. We remove it if it's not callable. + * If an argument isn't used (including types or other non-data arguments), + it is not included in the argument list. + This might be weird for top-level definitions, so it's not done there (but might cause issues). + =# elseif node isa Expr && node.head == :invoke T = node.args[1].specTypes.parameters[1] if isa(DomainError, T) || @@ -755,9 +750,9 @@ function compile_block(ctx::CompilerContext, cfg::Core.Compiler.CFG, phis, idx) # mi = Core.Compiler.specialize_method(match; preexisting=true) mi = Core.Compiler.specialize_method(match) sig = mi.specTypes - newci = Base.code_typed_by_type(mi.specTypes, interp = StaticInterpreter())[1][1] + newci = Base.code_typed_by_type(mi.specTypes, interp=StaticInterpreter())[1][1] n2 = node.args[2] - newfun = n2 isa QuoteNode ? n2.value : + newfun = n2 isa QuoteNode ? n2.value : n2 isa GlobalRef ? Core.eval(n2.mod, n2.name) : n2 newctx = CompilerContext(ctx, newci) @@ -766,21 +761,21 @@ function compile_block(ctx::CompilerContext, cfg::Core.Compiler.CFG, phis, idx) newsig = newci.parent.specTypes n = length(node.args) if newci.parent.def.isva # varargs - na = length(newci.slottypes) - (callable ? 1 : 2) - jargs = [node.args[i] for i in argstart:argstart+na-1 if argused(newctx, i-1)] # up to the last arg which is a vararg + na = length(newci.slottypes) - (callable ? 1 : 2) + jargs = [node.args[i] for i in argstart:argstart+na-1 if argused(newctx, i - 1)] # up to the last arg which is a vararg args = [_compile(ctx, x) for x in jargs] nva = length(newci.slottypes[end].parameters) push!(args, _compile(ctx, tuple((node.args[i] for i in n-nva+1:n)...))) np = newsig.parameters - newsig = Tuple{np[1:end-nva]..., Tuple{np[end-nva+1:end]...}} + newsig = Tuple{np[1:end-nva]...,Tuple{np[end-nva+1:end]...}} else - jargs = [node.args[i] for i in argstart:n if argused(newctx, i-1)] + jargs = [node.args[i] for i in argstart:n if argused(newctx, i - 1)] args = [_compile(ctx, x) for x in jargs] end if haskey(ctx.names, newsig) name = ctx.names[newsig] else - name = validname(string("julia_", node.args[1].def.name, newsig.parameters[2:end]...))[1:min(end,255)] + name = validname(string("julia_", node.args[1].def.name, newsig.parameters[2:end]...))[1:min(end, 255)] ctx.sigs[name] = newsig ctx.names[newsig] = name newci.parent.specTypes = newsig @@ -804,44 +799,44 @@ function compile_block(ctx::CompilerContext, cfg::Core.Compiler.CFG, phis, idx) x = BinaryenRefTest(ctx.mod, val, wT) setlocal!(ctx, idx, x) - # DECLARE_BUILTIN(applicable); - # DECLARE_BUILTIN(_apply_iterate); - # DECLARE_BUILTIN(_apply_pure); - # DECLARE_BUILTIN(apply_type); - # DECLARE_BUILTIN(_call_in_world); - # DECLARE_BUILTIN(_call_in_world_total); - # DECLARE_BUILTIN(_call_latest); - # DECLARE_BUILTIN(replacefield); - # DECLARE_BUILTIN(const_arrayref); - # DECLARE_BUILTIN(_expr); - # DECLARE_BUILTIN(fieldtype); - # DECLARE_BUILTIN(is); - # DECLARE_BUILTIN(isa); - # DECLARE_BUILTIN(isdefined); - # DECLARE_BUILTIN(issubtype); - # DECLARE_BUILTIN(modifyfield); - # DECLARE_BUILTIN(nfields); - # DECLARE_BUILTIN(sizeof); - # DECLARE_BUILTIN(svec); - # DECLARE_BUILTIN(swapfield); - # DECLARE_BUILTIN(throw); - # DECLARE_BUILTIN(typeassert); - # DECLARE_BUILTIN(_typebody); - # DECLARE_BUILTIN(typeof); - # DECLARE_BUILTIN(_typevar); - # DECLARE_BUILTIN(donotdelete); - # DECLARE_BUILTIN(compilerbarrier); - # DECLARE_BUILTIN(getglobal); - # DECLARE_BUILTIN(setglobal); - # DECLARE_BUILTIN(finalizer); - # DECLARE_BUILTIN(_compute_sparams); - # DECLARE_BUILTIN(_svec_ref); - - ## Other ## + # DECLARE_BUILTIN(applicable); + # DECLARE_BUILTIN(_apply_iterate); + # DECLARE_BUILTIN(_apply_pure); + # DECLARE_BUILTIN(apply_type); + # DECLARE_BUILTIN(_call_in_world); + # DECLARE_BUILTIN(_call_in_world_total); + # DECLARE_BUILTIN(_call_latest); + # DECLARE_BUILTIN(replacefield); + # DECLARE_BUILTIN(const_arrayref); + # DECLARE_BUILTIN(_expr); + # DECLARE_BUILTIN(fieldtype); + # DECLARE_BUILTIN(is); + # DECLARE_BUILTIN(isa); + # DECLARE_BUILTIN(isdefined); + # DECLARE_BUILTIN(issubtype); + # DECLARE_BUILTIN(modifyfield); + # DECLARE_BUILTIN(nfields); + # DECLARE_BUILTIN(sizeof); + # DECLARE_BUILTIN(svec); + # DECLARE_BUILTIN(swapfield); + # DECLARE_BUILTIN(throw); + # DECLARE_BUILTIN(typeassert); + # DECLARE_BUILTIN(_typebody); + # DECLARE_BUILTIN(typeof); + # DECLARE_BUILTIN(_typevar); + # DECLARE_BUILTIN(donotdelete); + # DECLARE_BUILTIN(compilerbarrier); + # DECLARE_BUILTIN(getglobal); + # DECLARE_BUILTIN(setglobal); + # DECLARE_BUILTIN(finalizer); + # DECLARE_BUILTIN(_compute_sparams); + # DECLARE_BUILTIN(_svec_ref); + + ## Other ## elseif node isa GlobalRef setlocal!(ctx, idx, getglobal(ctx, node.mod, node.name)) - + elseif node isa Expr # ignore other expressions for now # println("----------------------------------------------------------------") @@ -855,10 +850,15 @@ function compile_block(ctx::CompilerContext, cfg::Core.Compiler.CFG, phis, idx) end if haskey(phis, idx) for (i, var) in phis[idx] - push!(ctx.body, BinaryenLocalSet(ctx.mod, ctx.varmap[i], _compile(ctx, var))) + push!(ctx.body, InstOperands(WC.local_set(ctx.varmap[i]), [_compile(ctx, var)])) end end - body = BinaryenBlock(ctx.mod, "body", ctx.body, length(ctx.body), BinaryenTypeAuto()) - return body + + # put cond val on the stack + isnothing(cond_val) || push!(ctx.body, cond_val) + + # body = BinaryenBlock(ctx.mod, "body", ctx.body, length(ctx.body), BinaryenTypeAuto()) + expr = WC.flatten(ctx.body) + return expr end diff --git a/src/compiler.jl b/src/compiler.jl index b4a7519..9521ad4 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -29,9 +29,11 @@ function compile(funs::Tuple...; filepath = "foo.wasm", jspath = filepath * ".js dummyci = code_typed(() -> nothing, Tuple{})[1].first ctx = CompilerContext(dummyci; experimental) debug(:offline) && _debug_ci(ctx) + # BinaryenModuleSetFeatures(ctx.mod, BinaryenFeatureReferenceTypes() | BinaryenFeatureGC() | (experimental ? BinaryenFeatureStrings() : 0)) - BinaryenModuleSetFeatures(ctx.mod, BinaryenFeatureAll()) - BinaryenModuleAutoDrop(ctx.mod) + # BinaryenModuleSetFeatures(ctx.mod, BinaryenFeatureAll()) + # BinaryenModuleAutoDrop(ctx.mod) # TODO + # Create CodeInfo's, and fill in names first for (i, funtpl) in enumerate(funs) tt = length(funtpl) > 1 ? Base.to_tuple_type(funtpl[2:end]) : Tuple{} @@ -57,21 +59,26 @@ function compile(funs::Tuple...; filepath = "foo.wasm", jspath = filepath * ".js debug(:offline) && _debug_ci(newctx, ctx) compile_method(newctx, name, exported = true) end - BinaryenModuleAutoDrop(ctx.mod) - debug(:offline) && _debug_module(ctx) - debug(:inline) && BinaryenModulePrint(ctx.mod) - validate && BinaryenModuleValidate(ctx.mod) + # BinaryenModuleAutoDrop(ctx.mod) + # debug(:offline) && _debug_module(ctx) + # debug(:inline) && BinaryenModulePrint(ctx.mod) + display(WC.Wat(ctx.mod, false)) + validate && WC.validate(ctx.mod) # BinaryenSetShrinkLevel(0) # BinaryenSetOptimizeLevel(2) - optimize && BinaryenModuleOptimize(ctx.mod) - - out = BinaryenModuleAllocateAndWrite(ctx.mod, C_NULL) - write(filepath, unsafe_wrap(Vector{UInt8}, Ptr{UInt8}(out.binary), (out.binaryBytes,))) - Libc.free(out.binary) - out = BinaryenModuleAllocateAndWriteText(ctx.mod) - write(filepath * ".wat", unsafe_string(out)) - Libc.free(out) - BinaryenModuleDispose(ctx.mod) + OPT_LEVEL = 1 # 2 # not sure about lvl 2 yet + optimize && (ctx.mod = WC.optimize(ctx.mod)) + display(WC.Wat(ctx.mod, false)) + + # out = BinaryenModuleAllocateAndWrite(ctx.mod, C_NULL) + # write(filepath, unsafe_wrap(Vector{UInt8}, Ptr{UInt8}(out.binary), (out.binaryBytes,))) + # Libc.free(out.binary) + # out = BinaryenModuleAllocateAndWriteText(ctx.mod) + # write(filepath * ".wat", unsafe_string(out)) + # Libc.free(out) + # BinaryenModuleDispose(ctx.mod) + open(io -> WC.wwrite(io, ctx.mod), filepath, "w") + jstext = "var jsexports = { js: {} };\n" imports = unique(values(ctx.imports)) jstext *= join(["jsexports['js']['$v'] = $v;" for v in imports], "\n") @@ -96,19 +103,18 @@ function compile_method(ctx::CompilerContext, funname; sig = ctx.ci.parent.specT jparams = jparams[2:end] end - bparams = BinaryenTypeCreate(jparams, length(jparams)) - rettype = gettype(ctx, ctx.ci.rettype == Nothing ? Union{} : ctx.ci.rettype) - body = compile_method_body(ctx) - debug(:inline) && println("---------------------------------------") - debug(:inline) && @show ctx.ci.parent.def.name - debug(:inline) && @show ctx.ci.parent.def - debug(:inline) && @show ctx.ci - debug(:inline) && @show ctx.ci.parent.def.name - debug(:inline) && BinaryenExpressionPrint(body) - if BinaryenGetFunction(ctx.mod, funname) == C_NULL - BinaryenAddFunction(ctx.mod, funname, bparams, rettype, ctx.locals, length(ctx.locals), body) + ircode, _ = Base.code_ircode_by_type(sig) |> only + + body = compile_method_body(ctx, ircode) + functype = WasmCompiler.FuncType(jparams, ctx.ci.rettype == Nothing ? [] : [gettype(ctx, ctx.ci.rettype)]) + if !any(f -> f.name == funname, ctx.mod.funcs) + push!( + ctx.mod.funcs, + WC.Func(funname, functype, [functype.params..., ctx.locals...], body) + ) if exported - BinaryenAddFunctionExport(ctx.mod, funname, funname) + num_exports = count(exp -> exp isa WC.FuncExport, ctx.mod.exports) + push!(ctx.mod.exports, WC.FuncExport(funname, num_exports + length(ctx.mod.funcs))) end end return nothing @@ -116,12 +122,12 @@ end import Core.Compiler: block_for_inst, compute_basic_blocks -function compile_method_body(ctx::CompilerContext) +function compile_method_body(ctx::CompilerContext, ircode) ci = ctx.ci code = ci.code ctx.localidx += nargs(ctx) - cfg = Core.Compiler.compute_basic_blocks(code) - relooper = RelooperCreate(ctx.mod) + # cfg = Core.Compiler.compute_basic_blocks(code) + cfg = ircode.cfg # Find and collect phis phis = Dict{Int, Any}() @@ -141,24 +147,14 @@ function compile_method_body(ctx::CompilerContext) ctx.localidx += 1 end end - # Create blocks - rblocks = [RelooperAddBlock(relooper, compile_block(ctx, cfg, phis, idx)) for idx in eachindex(cfg.blocks)] - # Create branches - for (idx, block) in enumerate(cfg.blocks) - terminator = code[last(block.stmts)] - if isa(terminator, Core.ReturnNode) - # return never has any successors, so no branches needed - elseif isa(terminator, Core.GotoNode) - toblock = block_for_inst(cfg, terminator.label) - RelooperAddBranch(rblocks[idx], rblocks[toblock], C_NULL, C_NULL) - elseif isa(terminator, Core.GotoIfNot) - toblock = block_for_inst(cfg, terminator.dest) - RelooperAddBranch(rblocks[idx], rblocks[idx + 1], _compile(ctx, terminator.cond), C_NULL) - RelooperAddBranch(rblocks[idx], rblocks[toblock], C_NULL, C_NULL) - elseif idx < length(cfg.blocks) - RelooperAddBranch(rblocks[idx], rblocks[idx + 1], C_NULL, C_NULL) - end - end - RelooperRenderAndDispose(relooper, rblocks[1], 0) + + # Emit code for each block + rblocks = [ + compile_block(ctx, cfg, phis, idx) + for idx in eachindex(cfg.blocks) + ] + + # Handle control flow + WasmCompiler.reloop(rblocks, ircode) end diff --git a/src/types.jl b/src/types.jl index 4efed6d..2e53c50 100644 --- a/src/types.jl +++ b/src/types.jl @@ -1,3 +1,7 @@ +using WasmCompiler: + i32, i64, f32, f64, ValType, + Inst, Module, InstOperands + export Externref """ @@ -17,38 +21,36 @@ struct Box{T} x::T end -wtypes() = Dict{Any, BinaryenType}( - Int64 => BinaryenTypeInt64(), - Int32 => BinaryenTypeInt32(), - UInt64 => BinaryenTypeInt64(), - UInt32 => BinaryenTypeInt32(), - UInt8 => BinaryenTypeInt32(), - Bool => BinaryenTypeInt32(), - Float64 => BinaryenTypeFloat64(), - Float32 => BinaryenTypeFloat32(), +wtypes() = Dict{Any, ValType}( + Int64 => WC.i64, + Int32 => WC.i32, + UInt64 => WC.i64, + UInt32 => WC.i32, + UInt8 => WC.i32, + Bool => WC.i32, + Float64 => WC.f64, + Float32 => WC.f32, # Symbol => BinaryenTypeStringref(), # String => BinaryenTypeStringref(), - Externref => BinaryenTypeExternref(), - Any => BinaryenTypeEqref(), - Union{} => BinaryenTypeNone(), - Core.TypeofBottom => BinaryenTypeNone(), + Externref => WC.ExternRef(false), + Any => WC.EqRef(false), ) const basictypes = [Int64, Int32, UInt64, UInt32, UInt8, Bool, Float64, Float32] mutable struct CompilerContext ## module-level context - mod::BinaryenModuleRef + mod::Module names::Dict{DataType, String} # function signature to name sigs::Dict{String, DataType} # name to function signature imports::Dict{String, Any} - wtypes::Dict{Any, BinaryenType} + wtypes::Dict{Any, ValType} globals::IdDict{Any, Any} objects::IdDict{Any, Any} ## function-level context ci::Core.CodeInfo - body::Vector{BinaryenExpressionRef} - locals::Vector{BinaryenType} + body::Vector{InstOperands} + locals::Vector{ValType} localidx::Int varmap::Dict{Int, Int} toplevel::Bool @@ -77,10 +79,12 @@ const wat = raw""" ) """ -CompilerContext(ci::Core.CodeInfo; experimental = false) = - CompilerContext(BinaryenModuleParse(experimental ? experimentalwat : wat), Dict{DataType, String}(), Dict{String, DataType}(), Dict{String, Any}(), wtypes(), IdDict{Any, Any}(), IdDict{Any, Any}(), - ci, BinaryenExpressionRef[], BinaryenType[], 0, Dict{Int, Int}(), true, nothing, Dict{Symbol, Any}()) -CompilerContext(ctx::CompilerContext, ci::Core.CodeInfo; toplevel = false) = +CompilerContext(ci::Core.CodeInfo; experimental = false) = begin + @assert !experimental + CompilerContext(Module(), Dict{DataType, String}(), Dict{String, DataType}(), Dict{String, Any}(), wtypes(), IdDict{Any, Any}(), IdDict{Any, Any}(), + ci, [], [], 1, Dict{Int, Int}(), true, nothing, Dict{Symbol, Any}()) +end +CompilerContext(ctx::CompilerContext, ci::Core.CodeInfo; toplevel = false) = CompilerContext(ctx.mod, ctx.names, ctx.sigs, ctx.imports, ctx.wtypes, ctx.globals, ctx.objects, - ci, BinaryenExpressionRef[], BinaryenType[], 0, Dict{Int, Int}(), toplevel, nothing, Dict{Symbol, Any}()) + ci, [], [], 1, Dict{Int, Int}(), toplevel, nothing, Dict{Symbol, Any}()) diff --git a/src/utils.jl b/src/utils.jl index 1eb67e1..8962b9d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,6 +1,6 @@ # Return the argument that function matches with the n'th slottype. # Unused arguments are skipped. -# If the 1st and 4th arguments are unused, argmap(3) == 2 and argmap(6) == 4. +# If the 1st and 4th arguments are unused, argmap(3) == 3 and argmap(6) == 5. function argmap(ctx, n) used = argsused(ctx) result = sum(used[1:n]) @@ -136,8 +136,13 @@ function gettype(ctx, type) end # exit() if type <: Type - return BinaryenTypeInt32() + return i32 end + + if type <: Function + return i32 + end + tb = TypeBuilderCreate(1) builtheaptypes = Array{BinaryenHeapType}(undef, 1) @@ -160,6 +165,7 @@ function gettype(ctx, type) # BinaryenModuleSetTypeName(ctx.mod, builtheaptypes[1], string(type)) # end # BinaryenExpressionPrint( BinaryenLocalSet(ctx.mod, 100, BinaryenLocalGet(ctx.mod, 99, newtype))) + @show type ctx.wtypes[type] = newtype return newtype end