@@ -1335,24 +1335,28 @@ Stmt LowererImplImperative::lowerForallPosition(Forall forall, Iterator iterator
13351335 endBound = endBounds[1 ];
13361336 }
13371337
1338- LoopKind kind = LoopKind::Serial;
1339- if (forall.getParallelUnit () == ParallelUnit::CPUVector && !ignoreVectorize) {
1340- kind = LoopKind::Vectorized;
1341- }
1342- else if (forall.getParallelUnit () != ParallelUnit::NotParallel
1343- && forall.getOutputRaceStrategy () != OutputRaceStrategy::ParallelReduction && !ignoreVectorize) {
1344- kind = LoopKind::Runtime;
1338+ Stmt loop = Block::make (strideGuard, declareCoordinate, boundsGuard, body);
1339+ if (iterator.isBranchless () && iterator.isCompact () &&
1340+ (iterator.getParent ().isRoot () || iterator.getParent ().isUnique ())) {
1341+ loop = Block::make (VarDecl::make (iterator.getPosVar (), startBound), loop);
1342+ } else {
1343+ LoopKind kind = LoopKind::Serial;
1344+ if (forall.getParallelUnit () == ParallelUnit::CPUVector && !ignoreVectorize) {
1345+ kind = LoopKind::Vectorized;
1346+ }
1347+ else if (forall.getParallelUnit () != ParallelUnit::NotParallel &&
1348+ forall.getOutputRaceStrategy () != OutputRaceStrategy::ParallelReduction &&
1349+ !ignoreVectorize) {
1350+ kind = LoopKind::Runtime;
1351+ }
1352+
1353+ loop = For::make (iterator.getPosVar (), startBound, endBound, 1 , loop, kind,
1354+ ignoreVectorize ? ParallelUnit::NotParallel : forall.getParallelUnit (),
1355+ ignoreVectorize ? 0 : forall.getUnrollFactor ());
13451356 }
13461357
13471358 // Loop with preamble and postamble
1348- return Block::blanks (
1349- boundsCompute,
1350- For::make (iterator.getPosVar (), startBound, endBound, 1 ,
1351- Block::make (strideGuard, declareCoordinate, boundsGuard, body),
1352- kind,
1353- ignoreVectorize ? ParallelUnit::NotParallel : forall.getParallelUnit (), ignoreVectorize ? 0 : forall.getUnrollFactor ()),
1354- posAppend);
1355-
1359+ return Block::blanks (boundsCompute, loop, posAppend);
13561360}
13571361
13581362Stmt LowererImplImperative::lowerForallFusedPosition (Forall forall, Iterator iterator,
0 commit comments