@@ -115,13 +115,22 @@ type asyncFunc struct {
115
115
// callers is a set of all functions which call this async function.
116
116
callers map [llvm.Value ]struct {}
117
117
118
- // returns is a map of terminal basic blocks to their return kinds .
119
- returns map [llvm. BasicBlock ] returnKind
118
+ // returns is a list of returns in the function, along with metadata .
119
+ returns [] asyncReturn
120
120
121
- // calls is the set of all calls in the asyncFunc.
122
- // normalCalls is the set of all intermideate suspending calls in the asyncFunc.
123
- // tailCalls is the set of all tail calls in the asyncFunc.
124
- calls , normalCalls , tailCalls map [llvm.Value ]struct {}
121
+ // calls is a list of all calls in the asyncFunc.
122
+ // normalCalls is a list of all intermideate suspending calls in the asyncFunc.
123
+ // tailCalls is a list of all tail calls in the asyncFunc.
124
+ calls , normalCalls , tailCalls []llvm.Value
125
+ }
126
+
127
+ // asyncReturn is a metadata container for a return from an asynchronous function.
128
+ type asyncReturn struct {
129
+ // block is the basic block terminated by the return.
130
+ block llvm.BasicBlock
131
+
132
+ // kind is the kind of the return.
133
+ kind returnKind
125
134
}
126
135
127
136
// coroutineLoweringPass is a goroutine lowering pass which is used with the "coroutines" scheduler.
@@ -135,6 +144,8 @@ type coroutineLoweringPass struct {
135
144
// The map keys are function pointers.
136
145
asyncFuncs map [llvm.Value ]* asyncFunc
137
146
147
+ asyncFuncsOrdered []* asyncFunc
148
+
138
149
// calls is a slice of all of the async calls in the module.
139
150
calls []llvm.Value
140
151
@@ -159,14 +170,15 @@ type coroutineLoweringPass struct {
159
170
// A function is considered asynchronous if it calls an asynchronous function or intrinsic.
160
171
func (c * coroutineLoweringPass ) findAsyncFuncs () {
161
172
asyncs := map [llvm.Value ]* asyncFunc {}
173
+ asyncsOrdered := []llvm.Value {}
162
174
calls := []llvm.Value {}
163
175
164
176
// Use a breadth-first search to find all async functions.
165
177
worklist := []llvm.Value {c .pause }
166
178
for len (worklist ) > 0 {
167
179
// Pop a function off the worklist.
168
- fn := worklist [len ( worklist ) - 1 ]
169
- worklist = worklist [: len ( worklist ) - 1 ]
180
+ fn := worklist [0 ]
181
+ worklist = worklist [1 : ]
170
182
171
183
// Get task pointer argument.
172
184
task := fn .LastParam ()
@@ -204,6 +216,7 @@ func (c *coroutineLoweringPass) findAsyncFuncs() {
204
216
// Mark the caller as async.
205
217
// Use nil as a temporary value. It will be replaced later.
206
218
asyncs [caller ] = nil
219
+ asyncsOrdered = append (asyncsOrdered , caller )
207
220
208
221
// Put the caller on the worklist.
209
222
worklist = append (worklist , caller )
@@ -216,7 +229,19 @@ func (c *coroutineLoweringPass) findAsyncFuncs() {
216
229
}
217
230
}
218
231
232
+ // Flip the order of the async functions so that the top ones are lowered first.
233
+ for i := 0 ; i < len (asyncsOrdered )/ 2 ; i ++ {
234
+ asyncsOrdered [i ], asyncsOrdered [len (asyncsOrdered )- (i + 1 )] = asyncsOrdered [len (asyncsOrdered )- (i + 1 )], asyncsOrdered [i ]
235
+ }
236
+
237
+ // Map the elements of asyncsOrdered to *asyncFunc.
238
+ asyncFuncsOrdered := make ([]* asyncFunc , len (asyncsOrdered ))
239
+ for i , v := range asyncsOrdered {
240
+ asyncFuncsOrdered [i ] = asyncs [v ]
241
+ }
242
+
219
243
c .asyncFuncs = asyncs
244
+ c .asyncFuncsOrdered = asyncFuncsOrdered
220
245
c .calls = calls
221
246
}
222
247
@@ -386,7 +411,7 @@ func (c *coroutineLoweringPass) isAsyncCall(call llvm.Value) bool {
386
411
387
412
// analyzeFuncReturns analyzes and classifies the returns of a function.
388
413
func (c * coroutineLoweringPass ) analyzeFuncReturns (fn * asyncFunc ) {
389
- returns := map [llvm. BasicBlock ] returnKind {}
414
+ returns := [] asyncReturn {}
390
415
if fn .fn == c .pause {
391
416
// Skip pause.
392
417
fn .returns = returns
@@ -410,28 +435,49 @@ func (c *coroutineLoweringPass) analyzeFuncReturns(fn *asyncFunc) {
410
435
case ! c .isAsyncCall (prev ):
411
436
// This is not any form of asynchronous tail call.
412
437
if isVoid {
413
- returns [bb ] = returnVoid
438
+ returns = append (returns , asyncReturn {
439
+ block : bb ,
440
+ kind : returnVoid ,
441
+ })
414
442
} else {
415
- returns [bb ] = returnNormal
443
+ returns = append (returns , asyncReturn {
444
+ block : bb ,
445
+ kind : returnNormal ,
446
+ })
416
447
}
417
448
case isVoid :
418
449
if prev .CalledValue ().Type ().ElementType ().ReturnType ().TypeKind () == llvm .VoidTypeKind {
419
450
// This is a tail call to a void-returning function from a function with a void return.
420
- returns [bb ] = returnVoidTail
451
+ returns = append (returns , asyncReturn {
452
+ block : bb ,
453
+ kind : returnVoidTail ,
454
+ })
421
455
} else {
422
456
// This is a tail call to a value-returning function from a function with a void return.
423
457
// The returned value will be ditched.
424
- returns [bb ] = returnDitchedTail
458
+ returns = append (returns , asyncReturn {
459
+ block : bb ,
460
+ kind : returnDitchedTail ,
461
+ })
425
462
}
426
463
case last .Operand (0 ) == prev :
427
464
// This is a regular tail call. The return of the callee is returned to the parent.
428
- returns [bb ] = returnTail
465
+ returns = append (returns , asyncReturn {
466
+ block : bb ,
467
+ kind : returnTail ,
468
+ })
429
469
case prev .CalledValue ().Type ().ElementType ().ReturnType ().TypeKind () == llvm .VoidTypeKind :
430
470
// This is a tail call that returns a previous value after waiting on a void function.
431
- returns [bb ] = returnDelayedValue
471
+ returns = append (returns , asyncReturn {
472
+ block : bb ,
473
+ kind : returnDelayedValue ,
474
+ })
432
475
default :
433
476
// This is a tail call that returns a value that is available before the function call.
434
- returns [bb ] = returnAlternateTail
477
+ returns = append (returns , asyncReturn {
478
+ block : bb ,
479
+ kind : returnAlternateTail ,
480
+ })
435
481
}
436
482
case llvm .Unreachable :
437
483
prev := llvm .PrevInstruction (last )
@@ -442,7 +488,10 @@ func (c *coroutineLoweringPass) analyzeFuncReturns(fn *asyncFunc) {
442
488
}
443
489
444
490
// This is an asyncnhronous tail call to function that does not return.
445
- returns [bb ] = returnDeadTail
491
+ returns = append (returns , asyncReturn {
492
+ block : bb ,
493
+ kind : returnDeadTail ,
494
+ })
446
495
}
447
496
}
448
497
@@ -451,46 +500,45 @@ func (c *coroutineLoweringPass) analyzeFuncReturns(fn *asyncFunc) {
451
500
452
501
// returnAnalysisPass runs an analysis pass which classifies the returns of all async functions.
453
502
func (c * coroutineLoweringPass ) returnAnalysisPass () {
454
- for _ , async := range c .asyncFuncs {
503
+ for _ , async := range c .asyncFuncsOrdered {
455
504
c .analyzeFuncReturns (async )
456
505
}
457
506
}
458
507
459
508
// categorizeCalls categorizes all asynchronous calls into regular vs. async and matches them to their callers.
460
509
func (c * coroutineLoweringPass ) categorizeCalls () {
461
510
// Sort calls into their respective callers.
462
- for _ , async := range c .asyncFuncs {
463
- async .calls = map [llvm.Value ]struct {}{}
464
- }
465
511
for _ , call := range c .calls {
466
- c .asyncFuncs [call .InstructionParent ().Parent ()].calls [call ] = struct {}{}
512
+ caller := c .asyncFuncs [call .InstructionParent ().Parent ()]
513
+ caller .calls = append (caller .calls , call )
467
514
}
468
515
469
516
// Seperate regular and tail calls.
470
- for _ , async := range c .asyncFuncs {
471
- // Find all tail calls (of any kind) .
517
+ for _ , async := range c .asyncFuncsOrdered {
518
+ // Search returns for tail calls.
472
519
tails := map [llvm.Value ]struct {}{}
473
- for ret , kind := range async .returns {
474
- switch kind {
520
+ for _ , ret := range async .returns {
521
+ switch ret . kind {
475
522
case returnVoidTail , returnTail , returnDeadTail , returnAlternateTail , returnDitchedTail , returnDelayedValue :
476
523
// This is a tail return. The previous instruction is a tail call.
477
- tails [llvm .PrevInstruction (ret .LastInstruction ())] = struct {}{}
524
+ tails [llvm .PrevInstruction (ret .block . LastInstruction ())] = struct {}{}
478
525
}
479
526
}
480
527
481
- // Find all regular calls.
482
- regulars := map [ llvm.Value ] struct {} {}
483
- for call := range async .calls {
528
+ // Seperate tail calls and regular calls.
529
+ normalCalls , tailCalls := [] llvm.Value {}, []llvm. Value {}
530
+ for _ , call := range async .calls {
484
531
if _ , ok := tails [call ]; ok {
485
532
// This is a tail call.
486
- continue
533
+ tailCalls = append (tailCalls , call )
534
+ } else {
535
+ // This is a regular call.
536
+ normalCalls = append (normalCalls , call )
487
537
}
488
-
489
- regulars [call ] = struct {}{}
490
538
}
491
539
492
- async .tailCalls = tails
493
- async .normalCalls = regulars
540
+ async .normalCalls = normalCalls
541
+ async .tailCalls = tailCalls
494
542
}
495
543
}
496
544
@@ -513,8 +561,8 @@ func (c *coroutineLoweringPass) lowerFuncsPass() {
513
561
}
514
562
515
563
func (async * asyncFunc ) hasValueStoreReturn () bool {
516
- for _ , kind := range async .returns {
517
- switch kind {
564
+ for _ , ret := range async .returns {
565
+ switch ret . kind {
518
566
case returnNormal , returnAlternateTail , returnDelayedValue :
519
567
return true
520
568
}
@@ -550,18 +598,18 @@ func (c *coroutineLoweringPass) lowerFuncFast(fn *asyncFunc) {
550
598
}
551
599
552
600
// Lower returns.
553
- for ret , kind := range fn .returns {
601
+ for _ , ret := range fn .returns {
554
602
// Get terminator.
555
- terminator := ret .LastInstruction ()
603
+ terminator := ret .block . LastInstruction ()
556
604
557
605
// Get tail call if applicable.
558
606
var call llvm.Value
559
- switch kind {
607
+ switch ret . kind {
560
608
case returnVoidTail , returnTail , returnDeadTail , returnAlternateTail , returnDitchedTail , returnDelayedValue :
561
609
call = llvm .PrevInstruction (terminator )
562
610
}
563
611
564
- switch kind {
612
+ switch ret . kind {
565
613
case returnNormal :
566
614
c .builder .SetInsertPointBefore (terminator )
567
615
@@ -718,8 +766,8 @@ func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) {
718
766
c .builder .CreateBr (suspend )
719
767
720
768
// Restore old state before tail calls.
721
- for call := range fn .tailCalls {
722
- if fn . returns [ call . InstructionParent ()] == returnDeadTail {
769
+ for _ , call := range fn .tailCalls {
770
+ if ! llvm . NextInstruction ( call ). IsAUnreachableInst (). IsNil () {
723
771
// Callee never returns, so the state restore is ineffectual.
724
772
continue
725
773
}
@@ -729,18 +777,18 @@ func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) {
729
777
}
730
778
731
779
// Lower returns.
732
- for ret , kind := range fn .returns {
780
+ for _ , ret := range fn .returns {
733
781
// Get terminator instruction.
734
- terminator := ret .LastInstruction ()
782
+ terminator := ret .block . LastInstruction ()
735
783
736
784
// Get tail call if applicable.
737
785
var call llvm.Value
738
- switch kind {
786
+ switch ret . kind {
739
787
case returnVoidTail , returnTail , returnDeadTail , returnAlternateTail , returnDitchedTail , returnDelayedValue :
740
788
call = llvm .PrevInstruction (terminator )
741
789
}
742
790
743
- switch kind {
791
+ switch ret . kind {
744
792
case returnNormal :
745
793
c .builder .SetInsertPointBefore (terminator )
746
794
@@ -760,7 +808,7 @@ func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) {
760
808
c .builder .SetInsertPointBefore (call )
761
809
762
810
// Store return value.
763
- c .builder .CreateStore (ret . LastInstruction () .Operand (0 ), retPtr )
811
+ c .builder .CreateStore (terminator .Operand (0 ), retPtr )
764
812
765
813
// Heap-allocate a return buffer for the discarded return.
766
814
alternateBuf := c .heapAlloc (call .Type (), "ret.alternate" )
@@ -775,7 +823,7 @@ func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) {
775
823
c .builder .SetInsertPointBefore (call )
776
824
777
825
// Store return value.
778
- c .builder .CreateStore (ret . LastInstruction () .Operand (0 ), retPtr )
826
+ c .builder .CreateStore (terminator .Operand (0 ), retPtr )
779
827
}
780
828
781
829
// Delete call if it is a pause, because it has already been lowered.
@@ -785,12 +833,12 @@ func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) {
785
833
786
834
// Replace terminator with branch to cleanup.
787
835
terminator .EraseFromParentAsInstruction ()
788
- c .builder .SetInsertPointAtEnd (ret )
836
+ c .builder .SetInsertPointAtEnd (ret . block )
789
837
c .builder .CreateBr (cleanup )
790
838
}
791
839
792
840
// Lower regular calls.
793
- for call := range fn .normalCalls {
841
+ for _ , call := range fn .normalCalls {
794
842
// Lower return value of call.
795
843
c .lowerCallReturn (fn , call )
796
844
@@ -882,8 +930,8 @@ func (c *coroutineLoweringPass) lowerStart(start llvm.Value) {
882
930
} else {
883
931
// Check for any undead returns.
884
932
var undead bool
885
- for _ , kind := range c .asyncFuncs [fn ].returns {
886
- if kind != returnDeadTail {
933
+ for _ , ret := range c .asyncFuncs [fn ].returns {
934
+ if ret . kind != returnDeadTail {
887
935
// This return results in a value being eventually stored.
888
936
undead = true
889
937
break
0 commit comments