Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 82 additions & 94 deletions transform/func-lowering.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package transform
import (
"sort"
"strconv"
"strings"

"github.com/tinygo-org/tinygo/compiler/llvmutil"
"tinygo.org/x/go-llvm"
Expand Down Expand Up @@ -55,17 +56,30 @@ func LowerFuncValues(mod llvm.Module) {
funcValueWithSignaturePtr := llvm.PointerType(mod.GetTypeByName("runtime.funcValueWithSignature"), 0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sidenote: this will eventually need to be changed. I discovered that types may be renamed when LLVM modules are merged (when there are duplicates). This is a big problem in the interface lowering pass and I'm not yet sure how to fix that.

signatures := map[string]*funcSignatureInfo{}
for global := mod.FirstGlobal(); !global.IsNil(); global = llvm.NextGlobal(global) {
if global.Type() != funcValueWithSignaturePtr {
var sig, funcVal llvm.Value
switch {
case global.Type() == funcValueWithSignaturePtr:
sig = llvm.ConstExtractValue(global.Initializer(), []uint32{1})
funcVal = global
case strings.HasPrefix(global.Name(), "reflect/types.type:func:{"):
sig = global
default:
continue
}
sig := llvm.ConstExtractValue(global.Initializer(), []uint32{1})

name := sig.Name()
var funcValueWithSignatures []llvm.Value
if funcVal.IsNil() {
funcValueWithSignatures = []llvm.Value{}
} else {
funcValueWithSignatures = []llvm.Value{funcVal}
}
if info, ok := signatures[name]; ok {
info.funcValueWithSignatures = append(info.funcValueWithSignatures, global)
info.funcValueWithSignatures = append(info.funcValueWithSignatures, funcValueWithSignatures...)
} else {
signatures[name] = &funcSignatureInfo{
sig: sig,
funcValueWithSignatures: []llvm.Value{global},
funcValueWithSignatures: funcValueWithSignatures,
}
}
}
Expand Down Expand Up @@ -123,95 +137,64 @@ func LowerFuncValues(mod llvm.Module) {
panic("expected all call uses to be runtime.getFuncPtr")
}
funcID := getFuncPtrCall.Operand(1)
switch len(functions) {
case 0:
// There are no functions used in a func value that implement
// this signature. The only possible value is a nil value.
for _, inttoptr := range getUses(getFuncPtrCall) {
if inttoptr.IsAIntToPtrInst().IsNil() {
panic("expected inttoptr")
}
nilptr := llvm.ConstPointerNull(inttoptr.Type())
inttoptr.ReplaceAllUsesWith(nilptr)
inttoptr.EraseFromParentAsInstruction()

// There are functions used in a func value that
// implement this signature.
// What we'll do is transform the following:
// rawPtr := runtime.getFuncPtr(func.ptr)
// if rawPtr == nil {
// runtime.nilPanic()
// }
// result := rawPtr(...args, func.context)
// into this:
// if false {
// runtime.nilPanic()
// }
// var result // Phi
// switch fn.id {
// case 0:
// runtime.nilPanic()
// case 1:
// result = call first implementation...
// case 2:
// result = call second implementation...
// default:
// unreachable
// }

// Remove some casts, checks, and the old call which we're going
// to replace.
for _, callIntPtr := range getUses(getFuncPtrCall) {
if !callIntPtr.IsACallInst().IsNil() && callIntPtr.CalledValue().Name() == "internal/task.start" {
// Special case for goroutine starts.
addFuncLoweringSwitch(mod, builder, funcID, callIntPtr, func(funcPtr llvm.Value, params []llvm.Value) llvm.Value {
i8ptrType := llvm.PointerType(ctx.Int8Type(), 0)
calleeValue := builder.CreatePtrToInt(funcPtr, uintptrType, "")
start := mod.NamedFunction("internal/task.start")
builder.CreateCall(start, []llvm.Value{calleeValue, callIntPtr.Operand(1), llvm.Undef(i8ptrType), llvm.ConstNull(i8ptrType)}, "")
return llvm.Value{} // void so no return value
}, functions)
callIntPtr.EraseFromParentAsInstruction()
continue
}
getFuncPtrCall.EraseFromParentAsInstruction()
case 1:
// There is exactly one function with this signature that is
// used in a func value. The func value itself can be either nil
// or this one function.
builder.SetInsertPointBefore(getFuncPtrCall)
zero := llvm.ConstInt(uintptrType, 0, false)
isnil := builder.CreateICmp(llvm.IntEQ, funcID, zero, "")
funcPtrNil := llvm.ConstPointerNull(functions[0].funcPtr.Type())
funcPtr := builder.CreateSelect(isnil, funcPtrNil, functions[0].funcPtr, "")
for _, inttoptr := range getUses(getFuncPtrCall) {
if inttoptr.IsAIntToPtrInst().IsNil() {
panic("expected inttoptr")
}
inttoptr.ReplaceAllUsesWith(funcPtr)
inttoptr.EraseFromParentAsInstruction()
if callIntPtr.IsAIntToPtrInst().IsNil() {
panic("expected inttoptr")
}
getFuncPtrCall.EraseFromParentAsInstruction()
default:
// There are multiple functions used in a func value that
// implement this signature.
// What we'll do is transform the following:
// rawPtr := runtime.getFuncPtr(func.ptr)
// if rawPtr == nil {
// runtime.nilPanic()
// }
// result := rawPtr(...args, func.context)
// into this:
// if false {
// runtime.nilPanic()
// }
// var result // Phi
// switch fn.id {
// case 0:
// runtime.nilPanic()
// case 1:
// result = call first implementation...
// case 2:
// result = call second implementation...
// default:
// unreachable
// }

// Remove some casts, checks, and the old call which we're going
// to replace.
for _, callIntPtr := range getUses(getFuncPtrCall) {
if !callIntPtr.IsACallInst().IsNil() && callIntPtr.CalledValue().Name() == "internal/task.start" {
// Special case for goroutine starts.
addFuncLoweringSwitch(mod, builder, funcID, callIntPtr, func(funcPtr llvm.Value, params []llvm.Value) llvm.Value {
i8ptrType := llvm.PointerType(ctx.Int8Type(), 0)
calleeValue := builder.CreatePtrToInt(funcPtr, uintptrType, "")
start := mod.NamedFunction("internal/task.start")
builder.CreateCall(start, []llvm.Value{calleeValue, callIntPtr.Operand(1), llvm.Undef(i8ptrType), llvm.ConstNull(i8ptrType)}, "")
return llvm.Value{} // void so no return value
for _, ptrUse := range getUses(callIntPtr) {
if !ptrUse.IsAICmpInst().IsNil() {
ptrUse.ReplaceAllUsesWith(llvm.ConstInt(ctx.Int1Type(), 0, false))
} else if !ptrUse.IsACallInst().IsNil() && ptrUse.CalledValue() == callIntPtr {
addFuncLoweringSwitch(mod, builder, funcID, ptrUse, func(funcPtr llvm.Value, params []llvm.Value) llvm.Value {
return builder.CreateCall(funcPtr, params, "")
}, functions)
callIntPtr.EraseFromParentAsInstruction()
continue
}
if callIntPtr.IsAIntToPtrInst().IsNil() {
panic("expected inttoptr")
}
for _, ptrUse := range getUses(callIntPtr) {
if !ptrUse.IsAICmpInst().IsNil() {
ptrUse.ReplaceAllUsesWith(llvm.ConstInt(ctx.Int1Type(), 0, false))
} else if !ptrUse.IsACallInst().IsNil() && ptrUse.CalledValue() == callIntPtr {
addFuncLoweringSwitch(mod, builder, funcID, ptrUse, func(funcPtr llvm.Value, params []llvm.Value) llvm.Value {
return builder.CreateCall(funcPtr, params, "")
}, functions)
} else {
panic("unexpected getFuncPtrCall")
}
ptrUse.EraseFromParentAsInstruction()
} else {
panic("unexpected getFuncPtrCall")
}
callIntPtr.EraseFromParentAsInstruction()
ptrUse.EraseFromParentAsInstruction()
}
getFuncPtrCall.EraseFromParentAsInstruction()
callIntPtr.EraseFromParentAsInstruction()
}
getFuncPtrCall.EraseFromParentAsInstruction()
}
}
}
Expand Down Expand Up @@ -270,13 +253,18 @@ func addFuncLoweringSwitch(mod llvm.Module, builder llvm.Builder, funcID, call l
phiBlocks[i] = bb
phiValues[i] = result
}
// Create the PHI node so that the call result flows into the
// next block (after the split). This is only necessary when the
// call produced a value.
if call.Type().TypeKind() != llvm.VoidTypeKind {
builder.SetInsertPointBefore(nextBlock.FirstInstruction())
phi := builder.CreatePHI(call.Type(), "")
phi.AddIncoming(phiValues, phiBlocks)
call.ReplaceAllUsesWith(phi)
if len(functions) > 0 {
// Create the PHI node so that the call result flows into the
// next block (after the split). This is only necessary when the
// call produced a value.
builder.SetInsertPointBefore(nextBlock.FirstInstruction())
phi := builder.CreatePHI(call.Type(), "")
phi.AddIncoming(phiValues, phiBlocks)
call.ReplaceAllUsesWith(phi)
} else {
// This is always a nil panic, so replace the call result with undef.
call.ReplaceAllUsesWith(llvm.Undef(call.Type()))
}
}
}
24 changes: 10 additions & 14 deletions transform/testdata/func-lowering.ll
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@ target triple = "wasm32-unknown-unknown-wasm"
%runtime.typecodeID = type { %runtime.typecodeID*, i32 }
%runtime.funcValueWithSignature = type { i32, %runtime.typecodeID* }

@"reflect/types.type:func:{basic:int8}{}" = external constant %runtime.typecodeID
@"reflect/types.type:func:{basic:uint8}{}" = external constant %runtime.typecodeID
@"reflect/types.type:func:{basic:int}{}" = external constant %runtime.typecodeID
@"funcInt8$withSignature" = constant %runtime.funcValueWithSignature { i32 ptrtoint (void (i8, i8*, i8*)* @funcInt8 to i32), %runtime.typecodeID* @"reflect/types.type:func:{basic:int8}{}" }
@"reflect/types.type:func:{}{basic:uint32}" = external constant %runtime.typecodeID
@"func1Uint8$withSignature" = constant %runtime.funcValueWithSignature { i32 ptrtoint (void (i8, i8*, i8*)* @func1Uint8 to i32), %runtime.typecodeID* @"reflect/types.type:func:{basic:uint8}{}" }
@"func2Uint8$withSignature" = constant %runtime.funcValueWithSignature { i32 ptrtoint (void (i8, i8*, i8*)* @func2Uint8 to i32), %runtime.typecodeID* @"reflect/types.type:func:{basic:uint8}{}" }
@"main$withSignature" = constant %runtime.funcValueWithSignature { i32 ptrtoint (void (i32, i8*, i8*)* @"main$1" to i32), %runtime.typecodeID* @"reflect/types.type:func:{basic:int}{}" }
Expand All @@ -23,29 +22,26 @@ declare void @"main$1"(i32, i8*, i8*)

declare void @"main$2"(i32, i8*, i8*)

declare void @funcInt8(i8, i8*, i8*)

declare void @func1Uint8(i8, i8*, i8*)

declare void @func2Uint8(i8, i8*, i8*)

; Call a function of which only one function with this signature is used as a
; function value. This means that lowering it to IR is trivial: simply check
; whether the func value is nil, and if not, call that one function directly.
define void @runFunc1(i8*, i32, i8, i8* %context, i8* %parentHandle) {
; There are no functions with this signature used in a func value.
; This means that this should unconditionally nil panic.
define i32 @runFuncNone(i8*, i32, i8* %context, i8* %parentHandle) {
entry:
%3 = call i32 @runtime.getFuncPtr(i8* %0, i32 %1, %runtime.typecodeID* @"reflect/types.type:func:{basic:int8}{}", i8* undef, i8* null)
%4 = inttoptr i32 %3 to void (i8, i8*, i8*)*
%5 = icmp eq void (i8, i8*, i8*)* %4, null
br i1 %5, label %fpcall.nil, label %fpcall.next
%2 = call i32 @runtime.getFuncPtr(i8* %0, i32 %1, %runtime.typecodeID* @"reflect/types.type:func:{}{basic:uint32}", i8* undef, i8* null)
%3 = inttoptr i32 %2 to i32 (i8*, i8*)*
%4 = icmp eq i32 (i8*, i8*)* %3, null
br i1 %4, label %fpcall.nil, label %fpcall.next

fpcall.nil:
call void @runtime.nilPanic(i8* undef, i8* null)
unreachable

fpcall.next:
call void %4(i8 %2, i8* %0, i8* undef)
ret void
%5 = call i32 %3(i8* %0, i8* undef)
ret i32 %5
}

; There are two functions with this signature used in a func value. That means
Expand Down
27 changes: 16 additions & 11 deletions transform/testdata/func-lowering.out.ll
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@ target triple = "wasm32-unknown-unknown-wasm"
%runtime.typecodeID = type { %runtime.typecodeID*, i32 }
%runtime.funcValueWithSignature = type { i32, %runtime.typecodeID* }

@"reflect/types.type:func:{basic:int8}{}" = external constant %runtime.typecodeID
@"reflect/types.type:func:{basic:uint8}{}" = external constant %runtime.typecodeID
@"reflect/types.type:func:{basic:int}{}" = external constant %runtime.typecodeID
@"funcInt8$withSignature" = constant %runtime.funcValueWithSignature { i32 ptrtoint (void (i8, i8*, i8*)* @funcInt8 to i32), %runtime.typecodeID* @"reflect/types.type:func:{basic:int8}{}" }
@"reflect/types.type:func:{}{basic:uint32}" = external constant %runtime.typecodeID
@"func1Uint8$withSignature" = constant %runtime.funcValueWithSignature { i32 ptrtoint (void (i8, i8*, i8*)* @func1Uint8 to i32), %runtime.typecodeID* @"reflect/types.type:func:{basic:uint8}{}" }
@"func2Uint8$withSignature" = constant %runtime.funcValueWithSignature { i32 ptrtoint (void (i8, i8*, i8*)* @func2Uint8 to i32), %runtime.typecodeID* @"reflect/types.type:func:{basic:uint8}{}" }
@"main$withSignature" = constant %runtime.funcValueWithSignature { i32 ptrtoint (void (i32, i8*, i8*)* @"main$1" to i32), %runtime.typecodeID* @"reflect/types.type:func:{basic:int}{}" }
Expand All @@ -23,26 +22,32 @@ declare void @"main$1"(i32, i8*, i8*)

declare void @"main$2"(i32, i8*, i8*)

declare void @funcInt8(i8, i8*, i8*)

declare void @func1Uint8(i8, i8*, i8*)

declare void @func2Uint8(i8, i8*, i8*)

define void @runFunc1(i8* %0, i32 %1, i8 %2, i8* %context, i8* %parentHandle) {
define i32 @runFuncNone(i8* %0, i32 %1, i8* %context, i8* %parentHandle) {
entry:
%3 = icmp eq i32 %1, 0
%4 = select i1 %3, void (i8, i8*, i8*)* null, void (i8, i8*, i8*)* @funcInt8
%5 = icmp eq void (i8, i8*, i8*)* %4, null
br i1 %5, label %fpcall.nil, label %fpcall.next
br i1 false, label %fpcall.nil, label %fpcall.next

fpcall.nil: ; preds = %entry
call void @runtime.nilPanic(i8* undef, i8* null)
unreachable

fpcall.next: ; preds = %entry
call void %4(i8 %2, i8* %0, i8* undef)
ret void
switch i32 %1, label %func.default [
i32 0, label %func.nil
]

func.nil: ; preds = %fpcall.next
call void @runtime.nilPanic(i8* undef, i8* null)
unreachable

func.next: ; No predecessors!
ret i32 undef

func.default: ; preds = %fpcall.next
unreachable
}

define void @runFunc2(i8* %0, i32 %1, i8 %2, i8* %context, i8* %parentHandle) {
Expand Down