Skip to content

Commit ecd8c2d

Browse files
niaowaykevl
authored andcommitted
transform (coroutines): fix memory corruption for tail calls that reference stack allocations
This change fixes a bug in which `alloca` memory lifetimes would not extend past the suspend of an asynchronous tail call. This would typically manifest as memory corruption, and could happen with or without normal suspending calls within the function.
1 parent a116fd0 commit ecd8c2d

File tree

3 files changed

+188
-12
lines changed

3 files changed

+188
-12
lines changed

transform/coroutines.go

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -600,11 +600,11 @@ func (c *coroutineLoweringPass) lowerFuncsPass() {
600600
continue
601601
}
602602

603-
if len(fn.normalCalls) == 0 {
604-
// No suspend points. Lower without turning it into a coroutine.
603+
if len(fn.normalCalls) == 0 && fn.fn.FirstBasicBlock().FirstInstruction().IsAAllocaInst().IsNil() {
604+
// No suspend points or stack allocations. Lower without turning it into a coroutine.
605605
c.lowerFuncFast(fn)
606606
} else {
607-
// There are suspend points, so it is necessary to turn this into a coroutine.
607+
// There are suspend points or stack allocations, so it is necessary to turn this into a coroutine.
608608
c.lowerFuncCoro(fn)
609609
}
610610
}
@@ -827,6 +827,7 @@ func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) {
827827
}
828828

829829
// Lower returns.
830+
var postTail llvm.BasicBlock
830831
for _, ret := range fn.returns {
831832
// Get terminator instruction.
832833
terminator := ret.block.LastInstruction()
@@ -886,10 +887,37 @@ func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) {
886887
call.EraseFromParentAsInstruction()
887888
}
888889

889-
// Replace terminator with branch to cleanup.
890+
// Replace terminator with a branch to the exit.
891+
var exit llvm.BasicBlock
892+
if ret.kind == returnNormal || ret.kind == returnVoid || fn.fn.FirstBasicBlock().FirstInstruction().IsAAllocaInst().IsNil() {
893+
// Exit through the cleanup path.
894+
exit = cleanup
895+
} else {
896+
if postTail.IsNil() {
897+
// Create a path with a suspend that never reawakens.
898+
postTail = c.ctx.AddBasicBlock(fn.fn, "post.tail")
899+
c.builder.SetInsertPointAtEnd(postTail)
900+
// %coro.save = call token @llvm.coro.save(i8* %coro.state)
901+
save := c.builder.CreateCall(c.coroSave, []llvm.Value{coroState}, "coro.save")
902+
// %call.suspend = llvm.coro.suspend(token %coro.save, i1 false)
903+
// switch i8 %call.suspend, label %suspend [i8 0, label %wakeup
904+
// i8 1, label %cleanup]
905+
suspendValue := c.builder.CreateCall(c.coroSuspend, []llvm.Value{save, llvm.ConstInt(c.ctx.Int1Type(), 0, false)}, "call.suspend")
906+
sw := c.builder.CreateSwitch(suspendValue, suspend, 2)
907+
unreachableBlock := c.ctx.AddBasicBlock(fn.fn, "unreachable")
908+
sw.AddCase(llvm.ConstInt(c.ctx.Int8Type(), 0, false), unreachableBlock)
909+
sw.AddCase(llvm.ConstInt(c.ctx.Int8Type(), 1, false), cleanup)
910+
c.builder.SetInsertPointAtEnd(unreachableBlock)
911+
c.builder.CreateUnreachable()
912+
}
913+
914+
// Exit through a permanent suspend.
915+
exit = postTail
916+
}
917+
890918
terminator.EraseFromParentAsInstruction()
891919
c.builder.SetInsertPointAtEnd(ret.block)
892-
c.builder.CreateBr(cleanup)
920+
c.builder.CreateBr(exit)
893921
}
894922

895923
// Lower regular calls.

transform/testdata/coroutines.ll

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,43 @@ entry:
8686
}
8787

8888
; Normal function which should not be transformed.
89-
define void @doNothing(i8*, i8*) {
89+
define void @doNothing(i8*, i8* %parentHandle) {
9090
entry:
9191
ret void
9292
}
9393

94+
; Regression test: ensure that a tail call does not destroy the frame while it is still in use.
95+
; Previously, the tail-call lowering transform would branch to the cleanup block after usePtr.
96+
; This caused the lifetime of %a to be incorrectly reduced, and allowed the coroutine lowering transform to keep %a on the stack.
97+
; After a suspend %a would be used, resulting in memory corruption.
98+
define i8 @coroutineTailRegression(i8*, i8* %parentHandle) {
99+
entry:
100+
%a = alloca i8
101+
store i8 5, i8* %a
102+
%val = call i8 @usePtr(i8* %a, i8* undef, i8* null)
103+
ret i8 %val
104+
}
105+
106+
; Regression test: ensure that stack allocations alive during a suspend end up on the heap.
107+
; This used to not be transformed to a coroutine, keeping %a on the stack.
108+
; After a suspend %a would be used, resulting in memory corruption.
109+
define i8 @allocaTailRegression(i8*, i8* %parentHandle) {
110+
entry:
111+
%a = alloca i8
112+
call void @sleep(i64 1000000, i8* undef, i8* null)
113+
store i8 5, i8* %a
114+
%val = call i8 @usePtr(i8* %a, i8* undef, i8* null)
115+
ret i8 %val
116+
}
117+
118+
; usePtr uses a pointer after a suspend.
119+
define i8 @usePtr(i8*, i8*, i8* %parentHandle) {
120+
entry:
121+
call void @sleep(i64 1000000, i8* undef, i8* null)
122+
%val = load i8, i8* %0
123+
ret i8 %val
124+
}
125+
94126
; Goroutine that sleeps and does nothing.
95127
; Should be a void tail call.
96128
define void @sleepGoroutine(i8*, i8* %parentHandle) {

transform/testdata/coroutines.out.ll

Lines changed: 122 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ entry:
4545
%task.current = bitcast i8* %parentHandle to %"internal/task.Task"*
4646
%ret.ptr = call i8* @"(*internal/task.Task).getReturnPtr"(%"internal/task.Task"* %task.current, i8* undef, i8* undef)
4747
%ret.ptr.bitcast = bitcast i8* %ret.ptr to i32*
48-
store i32 %0, i32* %ret.ptr.bitcast
48+
store i32 %0, i32* %ret.ptr.bitcast, align 4
4949
call void @sleep(i64 %1, i8* undef, i8* %parentHandle)
5050
ret i32 undef
5151
}
@@ -84,7 +84,7 @@ entry:
8484
%task.current = bitcast i8* %parentHandle to %"internal/task.Task"*
8585
%ret.ptr = call i8* @"(*internal/task.Task).getReturnPtr"(%"internal/task.Task"* %task.current, i8* undef, i8* undef)
8686
%ret.ptr.bitcast = bitcast i8* %ret.ptr to i32*
87-
store i32 %0, i32* %ret.ptr.bitcast
87+
store i32 %0, i32* %ret.ptr.bitcast, align 4
8888
%ret.alternate = call i8* @runtime.alloc(i32 4, i8* undef, i8* undef)
8989
call void @"(*internal/task.Task).setReturnPtr"(%"internal/task.Task"* %task.current, i8* %ret.alternate, i8* undef, i8* undef)
9090
%4 = call i32 @delayedValue(i32 %1, i64 %2, i8* undef, i8* %parentHandle)
@@ -93,7 +93,7 @@ entry:
9393

9494
define i1 @coroutine(i32 %0, i64 %1, i8* %2, i8* %parentHandle) {
9595
entry:
96-
%call.return = alloca i32
96+
%call.return = alloca i32, align 4
9797
%coro.id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null)
9898
%coro.size = call i32 @llvm.coro.size.i32()
9999
%coro.alloc = call i8* @runtime.alloc(i32 %coro.size, i8* undef, i8* undef)
@@ -116,10 +116,10 @@ entry:
116116
]
117117

118118
wakeup: ; preds = %entry
119-
%4 = load i32, i32* %call.return
119+
%4 = load i32, i32* %call.return, align 4
120120
call void @llvm.lifetime.end.p0i8(i64 4, i8* %call.return.bitcast)
121121
%5 = icmp eq i32 %4, 0
122-
store i1 %5, i1* %task.retPtr.bitcast
122+
store i1 %5, i1* %task.retPtr.bitcast, align 1
123123
call void @"(*internal/task.Task).returnTo"(%"internal/task.Task"* %task.current2, i8* %task.state.parent, i8* undef, i8* undef)
124124
br label %cleanup
125125

@@ -133,11 +133,127 @@ cleanup: ; preds = %entry, %wakeup
133133
br label %suspend
134134
}
135135

136-
define void @doNothing(i8* %0, i8* %1) {
136+
define void @doNothing(i8* %0, i8* %parentHandle) {
137137
entry:
138138
ret void
139139
}
140140

141+
define i8 @coroutineTailRegression(i8* %0, i8* %parentHandle) {
142+
entry:
143+
%a = alloca i8, align 1
144+
%coro.id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null)
145+
%coro.size = call i32 @llvm.coro.size.i32()
146+
%coro.alloc = call i8* @runtime.alloc(i32 %coro.size, i8* undef, i8* undef)
147+
%coro.state = call i8* @llvm.coro.begin(token %coro.id, i8* %coro.alloc)
148+
%task.current = bitcast i8* %parentHandle to %"internal/task.Task"*
149+
%task.state.parent = call i8* @"(*internal/task.Task).setState"(%"internal/task.Task"* %task.current, i8* %coro.state, i8* undef, i8* undef)
150+
%task.retPtr = call i8* @"(*internal/task.Task).getReturnPtr"(%"internal/task.Task"* %task.current, i8* undef, i8* undef)
151+
store i8 5, i8* %a, align 1
152+
%coro.state.restore = call i8* @"(*internal/task.Task).setState"(%"internal/task.Task"* %task.current, i8* %task.state.parent, i8* undef, i8* undef)
153+
call void @"(*internal/task.Task).setReturnPtr"(%"internal/task.Task"* %task.current, i8* %task.retPtr, i8* undef, i8* undef)
154+
%val = call i8 @usePtr(i8* %a, i8* undef, i8* %parentHandle)
155+
br label %post.tail
156+
157+
suspend: ; preds = %post.tail, %cleanup
158+
%unused = call i1 @llvm.coro.end(i8* %coro.state, i1 false)
159+
ret i8 undef
160+
161+
cleanup: ; preds = %post.tail
162+
%coro.memFree = call i8* @llvm.coro.free(token %coro.id, i8* %coro.state)
163+
call void @runtime.free(i8* %coro.memFree, i8* undef, i8* undef)
164+
br label %suspend
165+
166+
post.tail: ; preds = %entry
167+
%coro.save = call token @llvm.coro.save(i8* %coro.state)
168+
%call.suspend = call i8 @llvm.coro.suspend(token %coro.save, i1 false)
169+
switch i8 %call.suspend, label %suspend [
170+
i8 0, label %unreachable
171+
i8 1, label %cleanup
172+
]
173+
174+
unreachable: ; preds = %post.tail
175+
unreachable
176+
}
177+
178+
define i8 @allocaTailRegression(i8* %0, i8* %parentHandle) {
179+
entry:
180+
%a = alloca i8, align 1
181+
%coro.id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null)
182+
%coro.size = call i32 @llvm.coro.size.i32()
183+
%coro.alloc = call i8* @runtime.alloc(i32 %coro.size, i8* undef, i8* undef)
184+
%coro.state = call i8* @llvm.coro.begin(token %coro.id, i8* %coro.alloc)
185+
%task.current = bitcast i8* %parentHandle to %"internal/task.Task"*
186+
%task.state.parent = call i8* @"(*internal/task.Task).setState"(%"internal/task.Task"* %task.current, i8* %coro.state, i8* undef, i8* undef)
187+
%task.retPtr = call i8* @"(*internal/task.Task).getReturnPtr"(%"internal/task.Task"* %task.current, i8* undef, i8* undef)
188+
call void @sleep(i64 1000000, i8* undef, i8* %parentHandle)
189+
%coro.save1 = call token @llvm.coro.save(i8* %coro.state)
190+
%call.suspend2 = call i8 @llvm.coro.suspend(token %coro.save1, i1 false)
191+
switch i8 %call.suspend2, label %suspend [
192+
i8 0, label %wakeup
193+
i8 1, label %cleanup
194+
]
195+
196+
wakeup: ; preds = %entry
197+
store i8 5, i8* %a, align 1
198+
%1 = call i8* @"(*internal/task.Task).setState"(%"internal/task.Task"* %task.current, i8* %task.state.parent, i8* undef, i8* undef)
199+
call void @"(*internal/task.Task).setReturnPtr"(%"internal/task.Task"* %task.current, i8* %task.retPtr, i8* undef, i8* undef)
200+
%2 = call i8 @usePtr(i8* %a, i8* undef, i8* %parentHandle)
201+
br label %post.tail
202+
203+
suspend: ; preds = %entry, %post.tail, %cleanup
204+
%unused = call i1 @llvm.coro.end(i8* %coro.state, i1 false)
205+
ret i8 undef
206+
207+
cleanup: ; preds = %entry, %post.tail
208+
%coro.memFree = call i8* @llvm.coro.free(token %coro.id, i8* %coro.state)
209+
call void @runtime.free(i8* %coro.memFree, i8* undef, i8* undef)
210+
br label %suspend
211+
212+
post.tail: ; preds = %wakeup
213+
%coro.save = call token @llvm.coro.save(i8* %coro.state)
214+
%call.suspend = call i8 @llvm.coro.suspend(token %coro.save, i1 false)
215+
switch i8 %call.suspend, label %suspend [
216+
i8 0, label %unreachable
217+
i8 1, label %cleanup
218+
]
219+
220+
unreachable: ; preds = %post.tail
221+
unreachable
222+
}
223+
224+
define i8 @usePtr(i8* %0, i8* %1, i8* %parentHandle) {
225+
entry:
226+
%coro.id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null)
227+
%coro.size = call i32 @llvm.coro.size.i32()
228+
%coro.alloc = call i8* @runtime.alloc(i32 %coro.size, i8* undef, i8* undef)
229+
%coro.state = call i8* @llvm.coro.begin(token %coro.id, i8* %coro.alloc)
230+
%task.current = bitcast i8* %parentHandle to %"internal/task.Task"*
231+
%task.state.parent = call i8* @"(*internal/task.Task).setState"(%"internal/task.Task"* %task.current, i8* %coro.state, i8* undef, i8* undef)
232+
%task.retPtr = call i8* @"(*internal/task.Task).getReturnPtr"(%"internal/task.Task"* %task.current, i8* undef, i8* undef)
233+
call void @sleep(i64 1000000, i8* undef, i8* %parentHandle)
234+
%coro.save = call token @llvm.coro.save(i8* %coro.state)
235+
%call.suspend = call i8 @llvm.coro.suspend(token %coro.save, i1 false)
236+
switch i8 %call.suspend, label %suspend [
237+
i8 0, label %wakeup
238+
i8 1, label %cleanup
239+
]
240+
241+
wakeup: ; preds = %entry
242+
%2 = load i8, i8* %0, align 1
243+
store i8 %2, i8* %task.retPtr, align 1
244+
call void @"(*internal/task.Task).returnTo"(%"internal/task.Task"* %task.current, i8* %task.state.parent, i8* undef, i8* undef)
245+
br label %cleanup
246+
247+
suspend: ; preds = %entry, %cleanup
248+
%unused = call i1 @llvm.coro.end(i8* %coro.state, i1 false)
249+
ret i8 undef
250+
251+
cleanup: ; preds = %entry, %wakeup
252+
%coro.memFree = call i8* @llvm.coro.free(token %coro.id, i8* %coro.state)
253+
call void @runtime.free(i8* %coro.memFree, i8* undef, i8* undef)
254+
br label %suspend
255+
}
256+
141257
define void @sleepGoroutine(i8* %0, i8* %parentHandle) {
142258
%task.current = bitcast i8* %parentHandle to %"internal/task.Task"*
143259
call void @sleep(i64 1000000, i8* undef, i8* %parentHandle)

0 commit comments

Comments
 (0)