Skip to content

Commit dbe96c8

Browse files
committed
[MLIR][SCF] Combine nested ifs with yields
This patch extends the existing combine nested if combination canonicalization to also handle ifs which yield values Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D121923
1 parent 5ab421f commit dbe96c8

File tree

2 files changed

+153
-8
lines changed

2 files changed

+153
-8
lines changed

mlir/lib/Dialect/SCF/SCF.cpp

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1746,27 +1746,65 @@ struct CombineNestedIfs : public OpRewritePattern<IfOp> {
17461746

17471747
LogicalResult matchAndRewrite(IfOp op,
17481748
PatternRewriter &rewriter) const override {
1749-
// Both `if` ops must not yield results and have only `then` block.
1750-
if (op->getNumResults() != 0 || op.elseBlock())
1751-
return failure();
1752-
17531749
auto nestedOps = op.thenBlock()->without_terminator();
17541750
// Nested `if` must be the only op in block.
17551751
if (!llvm::hasSingleElement(nestedOps))
17561752
return failure();
17571753

1754+
// If there is an else block, it can only yield
1755+
if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
1756+
return failure();
1757+
17581758
auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
1759-
if (!nestedIf || nestedIf->getNumResults() != 0 || nestedIf.elseBlock())
1759+
if (!nestedIf)
1760+
return failure();
1761+
1762+
if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
17601763
return failure();
17611764

1765+
SmallVector<Value> thenYield(op.thenYield().getOperands());
1766+
SmallVector<Value> elseYield;
1767+
if (op.elseBlock())
1768+
llvm::append_range(elseYield, op.elseYield().getOperands());
1769+
1770+
// If the outer scf.if yields a value produced by the inner scf.if,
1771+
// only permit combining if the value yielded when the condition
1772+
// is false in the outer scf.if is the same value yielded when the
1773+
// inner scf.if condition is false.
1774+
// Note that the array access to elseYield will not go out of bounds
1775+
// since it must have the same length as thenYield, since they both
1776+
// come from the same scf.if.
1777+
for (auto tup : llvm::enumerate(thenYield)) {
1778+
if (tup.value().getDefiningOp() == nestedIf) {
1779+
auto nestedIdx = tup.value().cast<OpResult>().getResultNumber();
1780+
if (nestedIf.elseYield().getOperand(nestedIdx) !=
1781+
elseYield[tup.index()]) {
1782+
return failure();
1783+
}
1784+
// If the correctness test passes, we will yield
1785+
// corresponding value from the inner scf.if
1786+
thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
1787+
}
1788+
}
1789+
17621790
Location loc = op.getLoc();
17631791
Value newCondition = rewriter.create<arith::AndIOp>(
17641792
loc, op.getCondition(), nestedIf.getCondition());
1765-
auto newIf = rewriter.create<IfOp>(loc, newCondition);
1793+
auto newIf = rewriter.create<IfOp>(loc, op.getResultTypes(), newCondition);
17661794
Block *newIfBlock = newIf.thenBlock();
1767-
rewriter.eraseOp(newIfBlock->getTerminator());
1795+
if (newIfBlock)
1796+
rewriter.eraseOp(newIfBlock->getTerminator());
1797+
else
1798+
newIfBlock = rewriter.createBlock(&newIf.getThenRegion());
17681799
rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock);
1769-
rewriter.eraseOp(op);
1800+
rewriter.setInsertionPointToEnd(newIf.thenBlock());
1801+
rewriter.replaceOpWithNewOp<YieldOp>(newIf.thenYield(), thenYield);
1802+
if (!elseYield.empty()) {
1803+
rewriter.createBlock(&newIf.getElseRegion());
1804+
rewriter.setInsertionPointToEnd(newIf.elseBlock());
1805+
rewriter.create<YieldOp>(loc, elseYield);
1806+
}
1807+
rewriter.replaceOp(op, newIf.getResults());
17701808
return success();
17711809
}
17721810
};

mlir/test/Dialect/SCF/canonicalize.mlir

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,113 @@ func @merge_nested_if(%arg0: i1, %arg1: i1) {
491491

492492
// -----
493493

494+
// CHECK-LABEL: @merge_yielding_nested_if
495+
// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1)
496+
func @merge_yielding_nested_if(%arg0: i1, %arg1: i1) -> (i32, f32, i32, i8) {
497+
// CHECK: %[[PRE0:.*]] = "test.op"() : () -> i32
498+
// CHECK: %[[PRE1:.*]] = "test.op1"() : () -> f32
499+
// CHECK: %[[PRE2:.*]] = "test.op2"() : () -> i32
500+
// CHECK: %[[PRE3:.*]] = "test.op3"() : () -> i8
501+
// CHECK: %[[COND:.*]] = arith.andi %[[ARG0]], %[[ARG1]]
502+
// CHECK: %[[RES:.*]]:2 = scf.if %[[COND]] -> (f32, i32)
503+
// CHECK: %[[IN0:.*]] = "test.inop"() : () -> i32
504+
// CHECK: %[[IN1:.*]] = "test.inop1"() : () -> f32
505+
// CHECK: scf.yield %[[IN1]], %[[IN0]] : f32, i32
506+
// CHECK: } else {
507+
// CHECK: scf.yield %[[PRE1]], %[[PRE2]] : f32, i32
508+
// CHECK: }
509+
// CHECK: return %[[PRE0]], %[[RES]]#0, %[[RES]]#1, %[[PRE3]] : i32, f32, i32, i8
510+
%0 = "test.op"() : () -> (i32)
511+
%1 = "test.op1"() : () -> (f32)
512+
%2 = "test.op2"() : () -> (i32)
513+
%3 = "test.op3"() : () -> (i8)
514+
%r:4 = scf.if %arg0 -> (i32, f32, i32, i8) {
515+
%a:2 = scf.if %arg1 -> (i32, f32) {
516+
%i = "test.inop"() : () -> (i32)
517+
%i1 = "test.inop1"() : () -> (f32)
518+
scf.yield %i, %i1 : i32, f32
519+
} else {
520+
scf.yield %2, %1 : i32, f32
521+
}
522+
scf.yield %0, %a#1, %a#0, %3 : i32, f32, i32, i8
523+
} else {
524+
scf.yield %0, %1, %2, %3 : i32, f32, i32, i8
525+
}
526+
return %r#0, %r#1, %r#2, %r#3 : i32, f32, i32, i8
527+
}
528+
529+
// CHECK-LABEL: @merge_yielding_nested_if_nv1
530+
// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1)
531+
func @merge_yielding_nested_if_nv1(%arg0: i1, %arg1: i1) {
532+
// CHECK: %[[PRE0:.*]] = "test.op"() : () -> i32
533+
// CHECK: %[[PRE1:.*]] = "test.op1"() : () -> f32
534+
// CHECK: %[[COND:.*]] = arith.andi %[[ARG0]], %[[ARG1]]
535+
// CHECK: scf.if %[[COND]]
536+
// CHECK: %[[IN0:.*]] = "test.inop"() : () -> i32
537+
// CHECK: %[[IN1:.*]] = "test.inop1"() : () -> f32
538+
// CHECK: }
539+
%0 = "test.op"() : () -> (i32)
540+
%1 = "test.op1"() : () -> (f32)
541+
scf.if %arg0 {
542+
%a:2 = scf.if %arg1 -> (i32, f32) {
543+
%i = "test.inop"() : () -> (i32)
544+
%i1 = "test.inop1"() : () -> (f32)
545+
scf.yield %i, %i1 : i32, f32
546+
} else {
547+
scf.yield %0, %1 : i32, f32
548+
}
549+
}
550+
return
551+
}
552+
553+
// CHECK-LABEL: @merge_yielding_nested_if_nv2
554+
// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1)
555+
func @merge_yielding_nested_if_nv2(%arg0: i1, %arg1: i1) -> i32 {
556+
// CHECK: %[[PRE0:.*]] = "test.op"() : () -> i32
557+
// CHECK: %[[PRE1:.*]] = "test.op1"() : () -> i32
558+
// CHECK: %[[COND:.*]] = arith.andi %[[ARG0]], %[[ARG1]]
559+
// CHECK: scf.if %[[COND]]
560+
// CHECK: "test.run"() : () -> ()
561+
// CHECK: }
562+
// CHECK: %[[RES:.*]] = arith.select %[[COND]], %[[PRE0]], %[[PRE1]]
563+
// CHECK: return %[[RES]]
564+
%0 = "test.op"() : () -> (i32)
565+
%1 = "test.op1"() : () -> (i32)
566+
%r = scf.if %arg0 -> i32 {
567+
scf.if %arg1 {
568+
"test.run"() : () -> ()
569+
}
570+
scf.yield %0 : i32
571+
} else {
572+
scf.yield %1 : i32
573+
}
574+
return %r : i32
575+
}
576+
577+
// CHECK-LABEL: @merge_fail_yielding_nested_if
578+
// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1)
579+
func @merge_fail_yielding_nested_if(%arg0: i1, %arg1: i1) -> (i32, f32, i32, i8) {
580+
// CHECK-NOT: andi
581+
%0 = "test.op"() : () -> (i32)
582+
%1 = "test.op1"() : () -> (f32)
583+
%2 = "test.op2"() : () -> (i32)
584+
%3 = "test.op3"() : () -> (i8)
585+
%r:4 = scf.if %arg0 -> (i32, f32, i32, i8) {
586+
%a:2 = scf.if %arg1 -> (i32, f32) {
587+
%i = "test.inop"() : () -> (i32)
588+
%i1 = "test.inop1"() : () -> (f32)
589+
scf.yield %i, %i1 : i32, f32
590+
} else {
591+
scf.yield %0, %1 : i32, f32
592+
}
593+
scf.yield %0, %a#1, %a#0, %3 : i32, f32, i32, i8
594+
} else {
595+
scf.yield %0, %1, %2, %3 : i32, f32, i32, i8
596+
}
597+
return %r#0, %r#1, %r#2, %r#3 : i32, f32, i32, i8
598+
}
599+
// -----
600+
494601
// CHECK-LABEL: func @if_condition_swap
495602
// CHECK-NEXT: %{{.*}} = scf.if %arg0 -> (index) {
496603
// CHECK-NEXT: %[[i1:.+]] = "test.origFalse"() : () -> index

0 commit comments

Comments
 (0)