Skip to content

Commit 2188a47

Browse files
committed
Fix invalid SPIRV from continue inside switch inside for loop (#9585, #10198)
For a for-loop with a switch containing a `continue`, the IR has this shape: loop(target=%body, break=%after, continue=%incr) %body: switch(x, break=%post_switch, ...) case 0: unconditionalBranch(%incr) <- 'continue' case 1: unconditionalBranch(%post_switch) <- 'break' %post_switch: unconditionalBranch(%incr) <- normal post-switch flow %incr: i++; unconditionalBranch(%body) <- back-edge %after: ... <- loop break The `continue` inside the switch branches to %incr (the loop's continueBlock). Because continueBlock (%incr) != targetBlock (%body) in a for-loop, this is a branch that exits the switch region to reach an exit block of the enclosing loop region -- a multi-level branch that must be transformed for valid SPIRV. The bug: populateExitBlocks() stored continueBlock but did not add it to region->exitBlocks, so mapExitBlockToRegion never mapped %incr to the loop region. gatherInfo() therefore never flagged the branch from case 0 to %incr as a multi-level branch, needsContinueElimination stayed false, and the raw IR branch reached the SPIRV emitter -- producing unstructured control flow. The fix: add continueBlock to region->exitBlocks whenever continueBlock != targetBlock (i.e. for for-loops). This makes %incr visible in mapExitBlockToRegion, gatherInfo() detects the multi-level branch, and eliminateContinueBlocksInFunc() is called to wrap the loop body in an inner breakable region. After that transformation the `continue` becomes a break from the inner region -- a valid SPIRV structured exit to its merge block -- and the outer loop handles the actual iteration. Also fix a stack leak: after processing a loop region, pop continueBlock from the global exitBlocks stack unconditionally. For for-loops it is already in info.exitBlocks and removed by the existing loop; for while-loops (where continueBlock == targetBlock and is not in info.exitBlocks) it was pushed to the global stack but never popped, which could affect sibling constructs. Fixes #9585, #10198.
1 parent 99e2a4f commit 2188a47

File tree

2 files changed

+95
-2
lines changed

2 files changed

+95
-2
lines changed

source/slang/slang-ir-eliminate-multilevel-break.cpp

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,33 @@ struct EliminateMultiLevelBreakContext
7171
// If this is a loop, store the continue block.
7272
// We add it to the exitBlocks stack separately in collectBreakableRegionBlocks
7373
// so that nested constructs treat it as an exit point.
74-
if (as<IRLoop>(headerInst))
74+
if (auto loop = as<IRLoop>(headerInst))
75+
{
7576
continueBlock = getContinueBlock();
77+
SLANG_ASSERT(continueBlock);
78+
79+
// For a for-loop with a switch statement inside, the IR looks like:
80+
//
81+
// loop(target=%body, break=%after, continue=%incr)
82+
// %body: ...
83+
// switch(x, break=%post_switch, ...)
84+
// case 0: unconditionalBranch(%incr) <- 'continue'
85+
// case 1: unconditionalBranch(%post_switch) <- 'break'
86+
// %post_switch: ...
87+
// unconditionalBranch(%incr) <- normal loop flow
88+
// %incr: i++; unconditionalBranch(%body) <- back-edge
89+
// %after: ... <- loop break
90+
//
91+
// The 'continue' inside the switch branches to %incr (continueBlock).
92+
// This is a multi-level branch: it exits the switch region to reach
93+
// an exit block of the enclosing loop region.
94+
//
95+
// In order to handle the multi-level branch properly, the continueBlock
96+
// needs to be added to exitBlocks.
97+
//
98+
if (continueBlock != loop->getTargetBlock())
99+
exitBlocks.add(continueBlock);
100+
}
76101
}
77102

78103
void replaceBreakBlock(IRBuilder* builder, IRBlock* block)
@@ -177,9 +202,11 @@ struct EliminateMultiLevelBreakContext
177202
}
178203
}
179204

180-
// Pop the exit blocks.
205+
// Pop the exit blocks that were pushed at the top of this function.
181206
for (auto exitBlock : info.exitBlocks)
182207
exitBlocks.remove(exitBlock);
208+
if (info.continueBlock)
209+
exitBlocks.remove(info.continueBlock);
183210
}
184211

185212
void gatherInfo(IRGlobalValueWithCode* func)
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
//TEST:SIMPLE(filecheck=SPIRV): -target spirv-asm -stage compute -entry computeMain
2+
3+
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-dx12 -output-using-type
4+
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type
5+
6+
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-dx12 -output-using-type -Xslang -DWHILE_LOOP
7+
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -Xslang -DWHILE_LOOP
8+
9+
// Disable dx11/FXC: "error X3708: 'continue' cannot be used in a switch"
10+
//DISABLE_TEST(compute):COMPARE_COMPUTE:-dx11 -output-using-type
11+
//DISABLE_TEST(compute):COMPARE_COMPUTE:-dx11 -output-using-type -Xslang -DWHILE_LOOP
12+
13+
// Test `continue` statment inside of switch inside of a loop.
14+
15+
// SPIRV: OpEntryPoint
16+
17+
//TEST_INPUT:ubuffer(data=[0 0 0], stride=4):out,name=outputBuffer
18+
RWStructuredBuffer<int> outputBuffer;
19+
20+
[[noinline]]
21+
int test(int value)
22+
{
23+
int result = 0;
24+
#if defined(WHILE_LOOP)
25+
uint i = 0;
26+
while(i < 3)
27+
{
28+
i++;
29+
#else
30+
for (int i = 0; i < 3; i++)
31+
{
32+
#endif
33+
switch (value)
34+
{
35+
case 0:
36+
result += 1;
37+
continue; // triggers the bug: continue from switch inside for-loop
38+
case 1:
39+
result += 2;
40+
break;
41+
default:
42+
result += 3;
43+
break;
44+
}
45+
result += 10;
46+
}
47+
return result;
48+
}
49+
50+
[shader("compute")]
51+
[numthreads(1, 1, 1)]
52+
void computeMain()
53+
{
54+
// CHECK: 3
55+
// test(0): case 0 hits, result += 1, continue (skips +10) -> 3 iterations -> result = 3
56+
outputBuffer[0] = test(0);
57+
58+
// CHECK: 36
59+
// test(1): case 1 hits, result += 2, break, then +10 -> 3 iterations -> result = 36
60+
outputBuffer[1] = test(1);
61+
62+
// CHECK: 39
63+
// test(2): default hits, result += 3, break, then +10 -> 3 iterations -> result = 39
64+
outputBuffer[2] = test(2);
65+
}
66+

0 commit comments

Comments
 (0)