Skip to content

Commit bb5f753

Browse files
niaowaykevl
authored andcommitted
transform (coroutines): remove map iteration from coroutine lowering pass
The coroutine lowering pass had issues where it iterated over maps, sometimes resulting in non-deterministic output. This change removes many of the maps and ensures that the transformations are deterministic.
1 parent 3862d6e commit bb5f753

File tree

1 file changed

+101
-53
lines changed

1 file changed

+101
-53
lines changed

transform/coroutines.go

Lines changed: 101 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,22 @@ type asyncFunc struct {
115115
// callers is a set of all functions which call this async function.
116116
callers map[llvm.Value]struct{}
117117

118-
// returns is a map of terminal basic blocks to their return kinds.
119-
returns map[llvm.BasicBlock]returnKind
118+
// returns is a list of returns in the function, along with metadata.
119+
returns []asyncReturn
120120

121-
// calls is the set of all calls in the asyncFunc.
122-
// normalCalls is the set of all intermideate suspending calls in the asyncFunc.
123-
// tailCalls is the set of all tail calls in the asyncFunc.
124-
calls, normalCalls, tailCalls map[llvm.Value]struct{}
121+
// calls is a list of all calls in the asyncFunc.
122+
// normalCalls is a list of all intermideate suspending calls in the asyncFunc.
123+
// tailCalls is a list of all tail calls in the asyncFunc.
124+
calls, normalCalls, tailCalls []llvm.Value
125+
}
126+
127+
// asyncReturn is a metadata container for a return from an asynchronous function.
128+
type asyncReturn struct {
129+
// block is the basic block terminated by the return.
130+
block llvm.BasicBlock
131+
132+
// kind is the kind of the return.
133+
kind returnKind
125134
}
126135

127136
// coroutineLoweringPass is a goroutine lowering pass which is used with the "coroutines" scheduler.
@@ -135,6 +144,8 @@ type coroutineLoweringPass struct {
135144
// The map keys are function pointers.
136145
asyncFuncs map[llvm.Value]*asyncFunc
137146

147+
asyncFuncsOrdered []*asyncFunc
148+
138149
// calls is a slice of all of the async calls in the module.
139150
calls []llvm.Value
140151

@@ -159,14 +170,15 @@ type coroutineLoweringPass struct {
159170
// A function is considered asynchronous if it calls an asynchronous function or intrinsic.
160171
func (c *coroutineLoweringPass) findAsyncFuncs() {
161172
asyncs := map[llvm.Value]*asyncFunc{}
173+
asyncsOrdered := []llvm.Value{}
162174
calls := []llvm.Value{}
163175

164176
// Use a breadth-first search to find all async functions.
165177
worklist := []llvm.Value{c.pause}
166178
for len(worklist) > 0 {
167179
// Pop a function off the worklist.
168-
fn := worklist[len(worklist)-1]
169-
worklist = worklist[:len(worklist)-1]
180+
fn := worklist[0]
181+
worklist = worklist[1:]
170182

171183
// Get task pointer argument.
172184
task := fn.LastParam()
@@ -204,6 +216,7 @@ func (c *coroutineLoweringPass) findAsyncFuncs() {
204216
// Mark the caller as async.
205217
// Use nil as a temporary value. It will be replaced later.
206218
asyncs[caller] = nil
219+
asyncsOrdered = append(asyncsOrdered, caller)
207220

208221
// Put the caller on the worklist.
209222
worklist = append(worklist, caller)
@@ -216,7 +229,19 @@ func (c *coroutineLoweringPass) findAsyncFuncs() {
216229
}
217230
}
218231

232+
// Flip the order of the async functions so that the top ones are lowered first.
233+
for i := 0; i < len(asyncsOrdered)/2; i++ {
234+
asyncsOrdered[i], asyncsOrdered[len(asyncsOrdered)-(i+1)] = asyncsOrdered[len(asyncsOrdered)-(i+1)], asyncsOrdered[i]
235+
}
236+
237+
// Map the elements of asyncsOrdered to *asyncFunc.
238+
asyncFuncsOrdered := make([]*asyncFunc, len(asyncsOrdered))
239+
for i, v := range asyncsOrdered {
240+
asyncFuncsOrdered[i] = asyncs[v]
241+
}
242+
219243
c.asyncFuncs = asyncs
244+
c.asyncFuncsOrdered = asyncFuncsOrdered
220245
c.calls = calls
221246
}
222247

@@ -386,7 +411,7 @@ func (c *coroutineLoweringPass) isAsyncCall(call llvm.Value) bool {
386411

387412
// analyzeFuncReturns analyzes and classifies the returns of a function.
388413
func (c *coroutineLoweringPass) analyzeFuncReturns(fn *asyncFunc) {
389-
returns := map[llvm.BasicBlock]returnKind{}
414+
returns := []asyncReturn{}
390415
if fn.fn == c.pause {
391416
// Skip pause.
392417
fn.returns = returns
@@ -410,28 +435,49 @@ func (c *coroutineLoweringPass) analyzeFuncReturns(fn *asyncFunc) {
410435
case !c.isAsyncCall(prev):
411436
// This is not any form of asynchronous tail call.
412437
if isVoid {
413-
returns[bb] = returnVoid
438+
returns = append(returns, asyncReturn{
439+
block: bb,
440+
kind: returnVoid,
441+
})
414442
} else {
415-
returns[bb] = returnNormal
443+
returns = append(returns, asyncReturn{
444+
block: bb,
445+
kind: returnNormal,
446+
})
416447
}
417448
case isVoid:
418449
if prev.CalledValue().Type().ElementType().ReturnType().TypeKind() == llvm.VoidTypeKind {
419450
// This is a tail call to a void-returning function from a function with a void return.
420-
returns[bb] = returnVoidTail
451+
returns = append(returns, asyncReturn{
452+
block: bb,
453+
kind: returnVoidTail,
454+
})
421455
} else {
422456
// This is a tail call to a value-returning function from a function with a void return.
423457
// The returned value will be ditched.
424-
returns[bb] = returnDitchedTail
458+
returns = append(returns, asyncReturn{
459+
block: bb,
460+
kind: returnDitchedTail,
461+
})
425462
}
426463
case last.Operand(0) == prev:
427464
// This is a regular tail call. The return of the callee is returned to the parent.
428-
returns[bb] = returnTail
465+
returns = append(returns, asyncReturn{
466+
block: bb,
467+
kind: returnTail,
468+
})
429469
case prev.CalledValue().Type().ElementType().ReturnType().TypeKind() == llvm.VoidTypeKind:
430470
// This is a tail call that returns a previous value after waiting on a void function.
431-
returns[bb] = returnDelayedValue
471+
returns = append(returns, asyncReturn{
472+
block: bb,
473+
kind: returnDelayedValue,
474+
})
432475
default:
433476
// This is a tail call that returns a value that is available before the function call.
434-
returns[bb] = returnAlternateTail
477+
returns = append(returns, asyncReturn{
478+
block: bb,
479+
kind: returnAlternateTail,
480+
})
435481
}
436482
case llvm.Unreachable:
437483
prev := llvm.PrevInstruction(last)
@@ -442,7 +488,10 @@ func (c *coroutineLoweringPass) analyzeFuncReturns(fn *asyncFunc) {
442488
}
443489

444490
// This is an asyncnhronous tail call to function that does not return.
445-
returns[bb] = returnDeadTail
491+
returns = append(returns, asyncReturn{
492+
block: bb,
493+
kind: returnDeadTail,
494+
})
446495
}
447496
}
448497

@@ -451,46 +500,45 @@ func (c *coroutineLoweringPass) analyzeFuncReturns(fn *asyncFunc) {
451500

452501
// returnAnalysisPass runs an analysis pass which classifies the returns of all async functions.
453502
func (c *coroutineLoweringPass) returnAnalysisPass() {
454-
for _, async := range c.asyncFuncs {
503+
for _, async := range c.asyncFuncsOrdered {
455504
c.analyzeFuncReturns(async)
456505
}
457506
}
458507

459508
// categorizeCalls categorizes all asynchronous calls into regular vs. async and matches them to their callers.
460509
func (c *coroutineLoweringPass) categorizeCalls() {
461510
// Sort calls into their respective callers.
462-
for _, async := range c.asyncFuncs {
463-
async.calls = map[llvm.Value]struct{}{}
464-
}
465511
for _, call := range c.calls {
466-
c.asyncFuncs[call.InstructionParent().Parent()].calls[call] = struct{}{}
512+
caller := c.asyncFuncs[call.InstructionParent().Parent()]
513+
caller.calls = append(caller.calls, call)
467514
}
468515

469516
// Seperate regular and tail calls.
470-
for _, async := range c.asyncFuncs {
471-
// Find all tail calls (of any kind).
517+
for _, async := range c.asyncFuncsOrdered {
518+
// Search returns for tail calls.
472519
tails := map[llvm.Value]struct{}{}
473-
for ret, kind := range async.returns {
474-
switch kind {
520+
for _, ret := range async.returns {
521+
switch ret.kind {
475522
case returnVoidTail, returnTail, returnDeadTail, returnAlternateTail, returnDitchedTail, returnDelayedValue:
476523
// This is a tail return. The previous instruction is a tail call.
477-
tails[llvm.PrevInstruction(ret.LastInstruction())] = struct{}{}
524+
tails[llvm.PrevInstruction(ret.block.LastInstruction())] = struct{}{}
478525
}
479526
}
480527

481-
// Find all regular calls.
482-
regulars := map[llvm.Value]struct{}{}
483-
for call := range async.calls {
528+
// Seperate tail calls and regular calls.
529+
normalCalls, tailCalls := []llvm.Value{}, []llvm.Value{}
530+
for _, call := range async.calls {
484531
if _, ok := tails[call]; ok {
485532
// This is a tail call.
486-
continue
533+
tailCalls = append(tailCalls, call)
534+
} else {
535+
// This is a regular call.
536+
normalCalls = append(normalCalls, call)
487537
}
488-
489-
regulars[call] = struct{}{}
490538
}
491539

492-
async.tailCalls = tails
493-
async.normalCalls = regulars
540+
async.normalCalls = normalCalls
541+
async.tailCalls = tailCalls
494542
}
495543
}
496544

@@ -513,8 +561,8 @@ func (c *coroutineLoweringPass) lowerFuncsPass() {
513561
}
514562

515563
func (async *asyncFunc) hasValueStoreReturn() bool {
516-
for _, kind := range async.returns {
517-
switch kind {
564+
for _, ret := range async.returns {
565+
switch ret.kind {
518566
case returnNormal, returnAlternateTail, returnDelayedValue:
519567
return true
520568
}
@@ -550,18 +598,18 @@ func (c *coroutineLoweringPass) lowerFuncFast(fn *asyncFunc) {
550598
}
551599

552600
// Lower returns.
553-
for ret, kind := range fn.returns {
601+
for _, ret := range fn.returns {
554602
// Get terminator.
555-
terminator := ret.LastInstruction()
603+
terminator := ret.block.LastInstruction()
556604

557605
// Get tail call if applicable.
558606
var call llvm.Value
559-
switch kind {
607+
switch ret.kind {
560608
case returnVoidTail, returnTail, returnDeadTail, returnAlternateTail, returnDitchedTail, returnDelayedValue:
561609
call = llvm.PrevInstruction(terminator)
562610
}
563611

564-
switch kind {
612+
switch ret.kind {
565613
case returnNormal:
566614
c.builder.SetInsertPointBefore(terminator)
567615

@@ -718,8 +766,8 @@ func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) {
718766
c.builder.CreateBr(suspend)
719767

720768
// Restore old state before tail calls.
721-
for call := range fn.tailCalls {
722-
if fn.returns[call.InstructionParent()] == returnDeadTail {
769+
for _, call := range fn.tailCalls {
770+
if !llvm.NextInstruction(call).IsAUnreachableInst().IsNil() {
723771
// Callee never returns, so the state restore is ineffectual.
724772
continue
725773
}
@@ -729,18 +777,18 @@ func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) {
729777
}
730778

731779
// Lower returns.
732-
for ret, kind := range fn.returns {
780+
for _, ret := range fn.returns {
733781
// Get terminator instruction.
734-
terminator := ret.LastInstruction()
782+
terminator := ret.block.LastInstruction()
735783

736784
// Get tail call if applicable.
737785
var call llvm.Value
738-
switch kind {
786+
switch ret.kind {
739787
case returnVoidTail, returnTail, returnDeadTail, returnAlternateTail, returnDitchedTail, returnDelayedValue:
740788
call = llvm.PrevInstruction(terminator)
741789
}
742790

743-
switch kind {
791+
switch ret.kind {
744792
case returnNormal:
745793
c.builder.SetInsertPointBefore(terminator)
746794

@@ -760,7 +808,7 @@ func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) {
760808
c.builder.SetInsertPointBefore(call)
761809

762810
// Store return value.
763-
c.builder.CreateStore(ret.LastInstruction().Operand(0), retPtr)
811+
c.builder.CreateStore(terminator.Operand(0), retPtr)
764812

765813
// Heap-allocate a return buffer for the discarded return.
766814
alternateBuf := c.heapAlloc(call.Type(), "ret.alternate")
@@ -775,7 +823,7 @@ func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) {
775823
c.builder.SetInsertPointBefore(call)
776824

777825
// Store return value.
778-
c.builder.CreateStore(ret.LastInstruction().Operand(0), retPtr)
826+
c.builder.CreateStore(terminator.Operand(0), retPtr)
779827
}
780828

781829
// Delete call if it is a pause, because it has already been lowered.
@@ -785,12 +833,12 @@ func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) {
785833

786834
// Replace terminator with branch to cleanup.
787835
terminator.EraseFromParentAsInstruction()
788-
c.builder.SetInsertPointAtEnd(ret)
836+
c.builder.SetInsertPointAtEnd(ret.block)
789837
c.builder.CreateBr(cleanup)
790838
}
791839

792840
// Lower regular calls.
793-
for call := range fn.normalCalls {
841+
for _, call := range fn.normalCalls {
794842
// Lower return value of call.
795843
c.lowerCallReturn(fn, call)
796844

@@ -882,8 +930,8 @@ func (c *coroutineLoweringPass) lowerStart(start llvm.Value) {
882930
} else {
883931
// Check for any undead returns.
884932
var undead bool
885-
for _, kind := range c.asyncFuncs[fn].returns {
886-
if kind != returnDeadTail {
933+
for _, ret := range c.asyncFuncs[fn].returns {
934+
if ret.kind != returnDeadTail {
887935
// This return results in a value being eventually stored.
888936
undead = true
889937
break

0 commit comments

Comments
 (0)