Skip to content

Commit 2fc4d46

Browse files
authored
Merge pull request swiftlang#21192 from eeckstein/fix-loopunroll
LoopUnroll: handle more variations of the comparison builtin
2 parents 1db29e7 + d42f654 commit 2fc4d46

File tree

2 files changed

+190
-23
lines changed

2 files changed

+190
-23
lines changed

lib/SILOptimizer/LoopTransforms/LoopUnroll.cpp

Lines changed: 41 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -99,27 +99,42 @@ static Optional<uint64_t> getMaxLoopTripCount(SILLoop *Loop,
9999
return None;
100100

101101
// Match an add 1 recurrence.
102-
SILPhiArgument *RecArg;
103-
IntegerLiteralInst *End;
104-
SILValue RecNext;
102+
103+
auto *Cmp = dyn_cast<BuiltinInst>(CondBr->getCondition());
104+
if (!Cmp)
105+
return None;
105106

106107
unsigned Adjust = 0;
108+
SILBasicBlock *Exit = CondBr->getTrueBB();
107109

108-
if (!match(CondBr->getCondition(),
109-
m_BuiltinInst(BuiltinValueKind::ICMP_EQ, m_SILValue(RecNext),
110-
m_IntegerLiteralInst(End))) &&
111-
!match(CondBr->getCondition(),
112-
m_BuiltinInst(BuiltinValueKind::ICMP_SGE, m_SILValue(RecNext),
113-
m_IntegerLiteralInst(End)))) {
114-
if (!match(CondBr->getCondition(),
115-
m_BuiltinInst(BuiltinValueKind::ICMP_SGT, m_SILValue(RecNext),
116-
m_IntegerLiteralInst(End))))
117-
return None;
118-
// Otherwise, we have a greater than comparison.
119-
else
110+
switch (Cmp->getBuiltinInfo().ID) {
111+
case BuiltinValueKind::ICMP_EQ:
112+
case BuiltinValueKind::ICMP_SGE:
113+
break;
114+
case BuiltinValueKind::ICMP_SGT:
120115
Adjust = 1;
116+
break;
117+
case BuiltinValueKind::ICMP_SLE:
118+
Exit = CondBr->getFalseBB();
119+
Adjust = 1;
120+
break;
121+
case BuiltinValueKind::ICMP_NE:
122+
case BuiltinValueKind::ICMP_SLT:
123+
Exit = CondBr->getFalseBB();
124+
break;
125+
default:
126+
return None;
121127
}
122128

129+
if (Loop->contains(Exit))
130+
return None;
131+
132+
auto *End = dyn_cast<IntegerLiteralInst>(Cmp->getArguments()[1]);
133+
if (!End)
134+
return None;
135+
136+
SILValue RecNext = Cmp->getArguments()[0];
137+
SILPhiArgument *RecArg;
123138
if (!match(RecNext,
124139
m_TupleExtractInst(m_ApplyInst(BuiltinValueKind::SAddOver,
125140
m_SILPhiArgument(RecArg), m_One()),
@@ -194,7 +209,7 @@ static bool canAndShouldUnrollLoop(SILLoop *Loop, uint64_t TripCount) {
194209
/// iterations header or if this is the last iteration remove the backedge to
195210
/// the header.
196211
static void redirectTerminator(SILBasicBlock *Latch, unsigned CurLoopIter,
197-
unsigned LastLoopIter, SILBasicBlock *OrigHeader,
212+
unsigned LastLoopIter, SILBasicBlock *CurrentHeader,
198213
SILBasicBlock *NextIterationsHeader) {
199214

200215
auto *CurrentTerminator = Latch->getTerminator();
@@ -245,22 +260,25 @@ static void redirectTerminator(SILBasicBlock *Latch, unsigned CurLoopIter,
245260
// On the last iteration change the conditional exit to an unconditional
246261
// one.
247262
if (CurLoopIter == LastLoopIter) {
248-
if (CondBr->getTrueBB() != OrigHeader)
249-
SILBuilder(CondBr).createBranch(CondBr->getLoc(), CondBr->getTrueBB(),
250-
CondBr->getTrueArgs());
251-
else
263+
if (CondBr->getTrueBB() == CurrentHeader) {
252264
SILBuilder(CondBr).createBranch(CondBr->getLoc(), CondBr->getFalseBB(),
253265
CondBr->getFalseArgs());
266+
} else {
267+
assert(CondBr->getFalseBB() == CurrentHeader);
268+
SILBuilder(CondBr).createBranch(CondBr->getLoc(), CondBr->getTrueBB(),
269+
CondBr->getTrueArgs());
270+
}
254271
CondBr->eraseFromParent();
255272
return;
256273
}
257274

258275
// Otherwise, branch to the next iteration's header.
259-
if (CondBr->getTrueBB() == OrigHeader) {
276+
if (CondBr->getTrueBB() == CurrentHeader) {
260277
SILBuilder(CondBr).createCondBranch(
261278
CondBr->getLoc(), CondBr->getCondition(), NextIterationsHeader,
262279
CondBr->getTrueArgs(), CondBr->getFalseBB(), CondBr->getFalseArgs());
263280
} else {
281+
assert(CondBr->getFalseBB() == CurrentHeader);
264282
SILBuilder(CondBr).createCondBranch(
265283
CondBr->getLoc(), CondBr->getCondition(), CondBr->getTrueBB(),
266284
CondBr->getTrueArgs(), NextIterationsHeader, CondBr->getFalseArgs());
@@ -401,11 +419,11 @@ static bool tryToUnrollLoop(SILLoop *Loop) {
401419
++Iteration) {
402420
auto *CurrentLatch = Latches[Iteration];
403421
auto LastIteration = End - 1;
404-
auto *OriginalHeader = Headers[0];
422+
auto *CurrentHeader = Headers[Iteration];
405423
auto *NextIterationsHeader =
406424
Iteration == LastIteration ? nullptr : Headers[Iteration + 1];
407425

408-
redirectTerminator(CurrentLatch, Iteration, LastIteration, OriginalHeader,
426+
redirectTerminator(CurrentLatch, Iteration, LastIteration, CurrentHeader,
409427
NextIterationsHeader);
410428
}
411429

test/SILOptimizer/loop_unroll.sil

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,155 @@ bb3:
143143
return %8 : $()
144144
}
145145

146+
// CHECK-LABEL: sil @loop_unroll_5
147+
// CHECK: bb5:
148+
// CHECK-NEXT: br bb4
149+
// CHECK-NEXT: }
150+
151+
sil @loop_unroll_5 : $@convention(thin) () -> () {
152+
bb0:
153+
%0 = integer_literal $Builtin.Int64, 0
154+
%1 = integer_literal $Builtin.Int64, 1
155+
%2 = integer_literal $Builtin.Int64, 2
156+
%3 = integer_literal $Builtin.Int1, 1
157+
br bb1(%0 : $Builtin.Int64)
158+
159+
bb1(%4 : $Builtin.Int64):
160+
%5 = builtin "sadd_with_overflow_Int64"(%4 : $Builtin.Int64, %1 : $Builtin.Int64, %3 : $Builtin.Int1) : $(Builtin.Int64, Builtin.Int1)
161+
%6 = tuple_extract %5 : $(Builtin.Int64, Builtin.Int1), 0
162+
%7 = builtin "cmp_slt_Int64"(%6 : $Builtin.Int64, %2 : $Builtin.Int64) : $Builtin.Int1
163+
cond_br %7, bb2, bb3
164+
165+
bb2:
166+
br bb1(%6 : $Builtin.Int64)
167+
168+
bb3:
169+
%8 = tuple()
170+
return %8 : $()
171+
}
172+
173+
// CHECK-LABEL: sil @loop_unroll_6
174+
// CHECK: bb5:
175+
// CHECK-NEXT: br bb4
176+
// CHECK-NEXT: }
177+
178+
sil @loop_unroll_6 : $@convention(thin) () -> () {
179+
bb0:
180+
%0 = integer_literal $Builtin.Int64, 0
181+
%1 = integer_literal $Builtin.Int64, 1
182+
%2 = integer_literal $Builtin.Int64, 1
183+
%3 = integer_literal $Builtin.Int1, 1
184+
br bb1(%0 : $Builtin.Int64)
185+
186+
bb1(%4 : $Builtin.Int64):
187+
%5 = builtin "sadd_with_overflow_Int64"(%4 : $Builtin.Int64, %1 : $Builtin.Int64, %3 : $Builtin.Int1) : $(Builtin.Int64, Builtin.Int1)
188+
%6 = tuple_extract %5 : $(Builtin.Int64, Builtin.Int1), 0
189+
%7 = builtin "cmp_sle_Int64"(%6 : $Builtin.Int64, %2 : $Builtin.Int64) : $Builtin.Int1
190+
cond_br %7, bb2, bb3
191+
192+
bb2:
193+
br bb1(%6 : $Builtin.Int64)
194+
195+
bb3:
196+
%8 = tuple()
197+
return %8 : $()
198+
}
199+
200+
// CHECK-LABEL: sil @loop_unroll_7
201+
// CHECK: bb5:
202+
// CHECK-NEXT: br bb4
203+
// CHECK-NEXT: }
204+
205+
sil @loop_unroll_7 : $@convention(thin) () -> () {
206+
bb0:
207+
%0 = integer_literal $Builtin.Int64, 0
208+
%1 = integer_literal $Builtin.Int64, 1
209+
%2 = integer_literal $Builtin.Int64, 2
210+
%3 = integer_literal $Builtin.Int1, 1
211+
br bb1(%0 : $Builtin.Int64)
212+
213+
bb1(%4 : $Builtin.Int64):
214+
%5 = builtin "sadd_with_overflow_Int64"(%4 : $Builtin.Int64, %1 : $Builtin.Int64, %3 : $Builtin.Int1) : $(Builtin.Int64, Builtin.Int1)
215+
%6 = tuple_extract %5 : $(Builtin.Int64, Builtin.Int1), 0
216+
%7 = builtin "cmp_ne_Int64"(%6 : $Builtin.Int64, %2 : $Builtin.Int64) : $Builtin.Int1
217+
cond_br %7, bb2, bb3
218+
219+
bb2:
220+
br bb1(%6 : $Builtin.Int64)
221+
222+
bb3:
223+
%8 = tuple()
224+
return %8 : $()
225+
}
226+
227+
// CHECK-LABEL: sil @unroll_with_exit_block_arg_1
228+
// CHECK: bb1({{.*}}):
229+
// CHECK: cond_br {{.*}}, bb3{{.*}}, bb2({{.*}})
230+
// CHECK: bb2({{.*}}):
231+
// CHECK: return
232+
// CHECK: bb3({{.*}}):
233+
// CHECK: cond_br {{.*}}, bb4{{.*}}, bb2({{.*}})
234+
// CHECK: bb4({{.*}}):
235+
// CHECK: br bb2({{.*}})
236+
// CHECK: }
237+
sil @unroll_with_exit_block_arg_1 : $@convention(thin) () -> () {
238+
bb0:
239+
%27 = integer_literal $Builtin.Int64, 1
240+
%28 = integer_literal $Builtin.Int64, 4
241+
%56 = integer_literal $Builtin.Int1, -1
242+
br bb4(%27 : $Builtin.Int64, %28 : $Builtin.Int64)
243+
244+
bb4(%58 : $Builtin.Int64, %59 : $Builtin.Int64):
245+
%60 = builtin "sadd_with_overflow_Int64"(%58 : $Builtin.Int64, %27 : $Builtin.Int64, %56 : $Builtin.Int1) : $(Builtin.Int64, Builtin.Int1)
246+
%61 = tuple_extract %60 : $(Builtin.Int64, Builtin.Int1), 0
247+
%64 = builtin "smul_with_overflow_Int64"(%59 : $Builtin.Int64, %28 : $Builtin.Int64, %56 : $Builtin.Int1) : $(Builtin.Int64, Builtin.Int1)
248+
%65 = tuple_extract %64 : $(Builtin.Int64, Builtin.Int1), 0
249+
%70 = builtin "cmp_slt_Int64"(%61 : $Builtin.Int64, %28 : $Builtin.Int64) : $Builtin.Int1
250+
cond_br %70, bb4(%61 : $Builtin.Int64, %65 : $Builtin.Int64), bb6(%61 : $Builtin.Int64)
251+
252+
bb6(%72 : $Builtin.Int64):
253+
%401 = tuple ()
254+
return %401 : $()
255+
}
256+
257+
// CHECK-LABEL: sil @unroll_with_exit_block_arg_2
258+
// CHECK: bb1({{.*}}):
259+
// CHECK: cond_br {{.*}}, bb2, bb3({{.*}})
260+
// CHECK: bb2:
261+
// CHECK: br bb4{{.*}}
262+
// CHECK: bb3({{.*}}):
263+
// CHECK: return
264+
// CHECK: bb4({{.*}}):
265+
// CHECK: cond_br {{.*}}, bb5, bb3({{.*}})
266+
// CHECK: bb5:
267+
// CHECK: br bb6{{.*}}
268+
// CHECK: bb6({{.*}}):
269+
// CHECK: br bb3({{.*}})
270+
// CHECK: }
271+
sil @unroll_with_exit_block_arg_2 : $@convention(thin) () -> () {
272+
bb0:
273+
%27 = integer_literal $Builtin.Int64, 1
274+
%28 = integer_literal $Builtin.Int64, 4
275+
%56 = integer_literal $Builtin.Int1, -1
276+
br bb4(%27 : $Builtin.Int64, %28 : $Builtin.Int64)
277+
278+
bb4(%58 : $Builtin.Int64, %59 : $Builtin.Int64):
279+
%60 = builtin "sadd_with_overflow_Int64"(%58 : $Builtin.Int64, %27 : $Builtin.Int64, %56 : $Builtin.Int1) : $(Builtin.Int64, Builtin.Int1)
280+
%61 = tuple_extract %60 : $(Builtin.Int64, Builtin.Int1), 0
281+
%64 = builtin "smul_with_overflow_Int64"(%59 : $Builtin.Int64, %28 : $Builtin.Int64, %56 : $Builtin.Int1) : $(Builtin.Int64, Builtin.Int1)
282+
%65 = tuple_extract %64 : $(Builtin.Int64, Builtin.Int1), 0
283+
%70 = builtin "cmp_slt_Int64"(%61 : $Builtin.Int64, %28 : $Builtin.Int64) : $Builtin.Int1
284+
cond_br %70, bb5, bb6(%61 : $Builtin.Int64)
285+
286+
bb5:
287+
br bb4(%61 : $Builtin.Int64, %65 : $Builtin.Int64)
288+
289+
bb6(%72 : $Builtin.Int64):
290+
%401 = tuple ()
291+
return %401 : $()
292+
}
293+
294+
146295
class B {}
147296

148297
// CHECK-LABEL: sil @unroll_with_stack_allocation

0 commit comments

Comments
 (0)