Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions compiler/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,20 @@ type builder struct {
deferFuncs map[*ir.Function]int
deferInvokeFuncs map[string]int
deferClosureFuncs map[*ir.Function]int
deferExprFuncs map[interface{}]deferExpr
selectRecvBuf map[*ssa.Select]llvm.Value
deferBuiltinFuncs map[interface{}]deferBuiltin
}

type deferExpr struct {
signature *types.Signature
callback int
funcValueType llvm.Type
}

type deferBuiltin struct {
funcName string
callback int
}

type phiNode struct {
Expand Down
161 changes: 159 additions & 2 deletions compiler/defer.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@ package compiler
// frames.

import (
"fmt"
"github.com/tinygo-org/tinygo/compiler/llvmutil"
"github.com/tinygo-org/tinygo/ir"
"go/types"
"golang.org/x/tools/go/ssa"
"tinygo.org/x/go-llvm"
)
Expand All @@ -28,6 +30,8 @@ func (b *builder) deferInitFunc() {
b.deferFuncs = make(map[*ir.Function]int)
b.deferInvokeFuncs = make(map[string]int)
b.deferClosureFuncs = make(map[*ir.Function]int)
b.deferExprFuncs = make(map[interface{}]deferExpr)
b.deferBuiltinFuncs = make(map[interface{}]deferBuiltin)

// Create defer list pointer.
deferType := llvm.PointerType(b.getLLVMRuntimeType("_defer"), 0)
Expand Down Expand Up @@ -151,9 +155,87 @@ func (b *builder) createDefer(instr *ssa.Defer) {
values = append(values, context)
valueTypes = append(valueTypes, context.Type())

} else if builtin, ok := instr.Call.Value.(*ssa.Builtin); ok {
var funcName string
switch builtin.Name() {
case "close":
funcName = "chanClose"
default:
b.addError(instr.Pos(), fmt.Sprint("TODO: Implement defer for ", builtin.Name()))
return
}

if _, ok := b.deferBuiltinFuncs[instr.Call.Value]; !ok {
b.deferBuiltinFuncs[instr.Call.Value] = deferBuiltin {
funcName,
len(b.allDeferFuncs),
}
b.allDeferFuncs = append(b.allDeferFuncs, instr.Call.Value)
}
callback := llvm.ConstInt(b.uintptrType, uint64(b.deferBuiltinFuncs[instr.Call.Value].callback), false)

// Collect all values to be put in the struct (starting with
// runtime._defer fields).
values = []llvm.Value{callback, next}
for _, param := range instr.Call.Args {
llvmParam := b.getValue(param)
values = append(values, llvmParam)
valueTypes = append(valueTypes, llvmParam.Type())
}

} else {
b.addError(instr.Pos(), "todo: defer on uncommon function call type")
return
var funcValue llvm.Value
var sig *types.Signature

switch expr := instr.Call.Value.(type) {
case *ssa.Extract:
value := b.getValue(expr.Tuple)
funcValue = b.CreateExtractValue(value, expr.Index, "")
sig = expr.Tuple.(*ssa.Call).Call.Value.(*ssa.Function).Signature.Results().At(expr.Index).Type().Underlying().(*types.Signature)
case *ssa.Call:
funcValue = b.getValue(expr)
sig = expr.Call.Value.Type().Underlying().(*types.Signature).Results().At(0).Type().Underlying().(*types.Signature)
case *ssa.UnOp:
funcValue = b.getValue(expr)
switch ty := expr.X.Type().(type) {
case *types.Pointer:
sig = ty.Elem().Underlying().(*types.Signature)
default:
sig = ty.Underlying().(*types.Signature).Results().At(0).Type().Underlying().(*types.Signature)
}
}

if funcValue.IsNil() == false && sig != nil {
//funcSig, context := b.decodeFuncValue(funcValue, sig)
if _, ok := b.deferExprFuncs[instr.Call.Value]; !ok {
b.deferExprFuncs[instr.Call.Value] = deferExpr{
funcValueType: funcValue.Type(),
signature: sig,
callback: len(b.allDeferFuncs),
}
b.allDeferFuncs = append(b.allDeferFuncs, instr.Call.Value)
}

callback := llvm.ConstInt(b.uintptrType, uint64(b.deferExprFuncs[instr.Call.Value].callback), false)

// Collect all values to be put in the struct (starting with
// runtime._defer fields, followed by all parameters including the
// context pointer).
values = []llvm.Value{callback, next}
for _, param := range instr.Call.Args {
llvmParam := b.getValue(param)
values = append(values, llvmParam)
valueTypes = append(valueTypes, llvmParam.Type())
}

//Pass funcValue through defer frame
values = append(values, funcValue)
valueTypes = append(valueTypes, funcValue.Type())

} else {
b.addError(instr.Pos(), "todo: defer on uncommon function call type")
return
}
}

// Make a struct out of the collected values to put in the defer frame.
Expand Down Expand Up @@ -339,7 +421,82 @@ func (b *builder) createRunDefers() {

// Call deferred function.
b.createCall(fn.LLVMFn, forwardParams, "")
case *ssa.Extract, *ssa.Call, *ssa.UnOp:
expr := b.deferExprFuncs[callback]

// Get the real defer struct type and cast to it.
valueTypes := []llvm.Type{b.uintptrType, llvm.PointerType(b.getLLVMRuntimeType("_defer"), 0)}

//Get signature from call results
params := expr.signature.Params()
for i := 0; i < params.Len(); i++ {
valueTypes = append(valueTypes, b.getLLVMType(params.At(i).Type()))
}

valueTypes = append(valueTypes, expr.funcValueType)
deferFrameType := b.ctx.StructType(valueTypes, false)
deferFramePtr := b.CreateBitCast(deferData, llvm.PointerType(deferFrameType, 0), "deferFrame")

// Extract the params from the struct.
var forwardParams []llvm.Value
zero := llvm.ConstInt(b.ctx.Int32Type(), 0, false)
funcPtrIndex := len(valueTypes)-1
for i := 2; i < funcPtrIndex; i++ {
gep := b.CreateInBoundsGEP(deferFramePtr, []llvm.Value{zero, llvm.ConstInt(b.ctx.Int32Type(), uint64(i), false)}, "")
forwardParam := b.CreateLoad(gep, "param")
forwardParams = append(forwardParams, forwardParam)
}

//Last one is funcValue
gep := b.CreateInBoundsGEP(deferFramePtr, []llvm.Value{zero, llvm.ConstInt(b.ctx.Int32Type(), uint64(funcPtrIndex), false)}, "")
fun := b.CreateLoad(gep, "param.func")

//Get funcValueWithSignature and context
funcPtr, context := b.decodeFuncValue(fun, expr.signature)

//Pass context
forwardParams = append(forwardParams, context)

// Parent coroutine handle.
forwardParams = append(forwardParams, llvm.Undef(b.i8ptrType))

// Call deferred function.
b.createCall(funcPtr, forwardParams, "")
case *ssa.Builtin:
db := b.deferBuiltinFuncs[callback]
fullName := "runtime." + db.funcName
fn := b.mod.NamedFunction(fullName)

//Get parameter types
valueTypes := []llvm.Type{b.uintptrType, llvm.PointerType(b.getLLVMRuntimeType("_defer"), 0)}

//Get signature from call results
params := callback.Type().Underlying().(*types.Signature).Params()
for i := 0; i < params.Len(); i++ {
valueTypes = append(valueTypes, b.getLLVMType(params.At(i).Type()))
}

deferFrameType := b.ctx.StructType(valueTypes, false)
deferFramePtr := b.CreateBitCast(deferData, llvm.PointerType(deferFrameType, 0), "deferFrame")

// Extract the params from the struct.
forwardParams := []llvm.Value{}
zero := llvm.ConstInt(b.ctx.Int32Type(), 0, false)
for i := 0; i < params.Len(); i++ {
gep := b.CreateInBoundsGEP(deferFramePtr, []llvm.Value{zero, llvm.ConstInt(b.ctx.Int32Type(), uint64(i+2), false)}, "gep")
forwardParam := b.CreateLoad(gep, "param")
forwardParams = append(forwardParams, forwardParam)
}

// Add the context parameter. We know it is ignored by the receiving
// function, but we have to pass one anyway.
forwardParams = append(forwardParams, llvm.Undef(b.i8ptrType))

// Parent coroutine handle.
forwardParams = append(forwardParams, llvm.Undef(b.i8ptrType))

// Call real function.
b.createCall(fn, forwardParams, "")
default:
panic("unknown deferred function type")
}
Expand Down
49 changes: 49 additions & 0 deletions testdata/calls.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ func main() {
// defers in loop
testDeferLoop()

//defer func variable call
testDeferFuncVar()

//More complicated func variable call
testMultiFuncVar()

// Take a bound method and use it as a function pointer.
// This function pointer needs a context pointer.
testBound(thing.String)
Expand All @@ -64,6 +70,9 @@ func main() {

// regression testing
regression1033()

//Test deferred builtins
testDeferBuiltin()
}

func runFunc(f func(int), arg int) {
Expand Down Expand Up @@ -91,6 +100,8 @@ func testDefer() {
defer t.Print("bar")

println("deferring...")
d := dumb{}
defer d.Value(0)
}

func testDeferLoop() {
Expand All @@ -99,6 +110,30 @@ func testDeferLoop() {
}
}

func testDeferFuncVar() {
dummy, f := deferFunc()
dummy++
defer f(1)
}

func testMultiFuncVar() {
f := multiFuncDefer()
defer f(1)
}

func testDeferBuiltin() {
i := make(chan int)
defer close(i)
}

type dumb struct {

}

func (*dumb) Value(key interface{}) interface{} {
return nil
}

func deferred(msg string, i int) {
println(msg, i)
}
Expand All @@ -108,6 +143,20 @@ func exportedDefer() {
println("...exported defer")
}

func deferFunc() (int, func(int)) {
return 0, func(i int){println("...extracted defer func ", i)}
}

func multiFuncDefer() func(int) {
i := 0

if i > 0 {
return func(i int){println("Should not have gotten here. i = ", i)}
}

return func(i int){println("Called the correct function. i = ", i)}
}

func testBound(f func() string) {
println("bound method:", f())
}
Expand Down