Skip to content

Commit 1bb9f4e

Browse files
committed
[MLIR] Create folders for extsi/extui
Create folders/canonicalizers for extsi/extui. Specifically, extui(extui(x)) -> extui(x) extsi(extsi(x)) -> extsi(x) extsi(extui(x)) -> extui(x) Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D116515
1 parent 89af17c commit 1bb9f4e

File tree

4 files changed

+53
-0
lines changed

4 files changed

+53
-0
lines changed

mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -818,6 +818,7 @@ def Arith_ExtSIOp : Arith_IToICastOp<"extsi"> {
818818
}];
819819

820820
let hasFolder = 1;
821+
let hasCanonicalizer = 1;
821822
let verifier = [{ return verifyExtOp<IntegerType>(*this); }];
822823
}
823824

mlir/lib/Dialect/Arithmetic/IR/ArithmeticCanonicalization.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,4 +128,12 @@ def IndexCastOfExtSI :
128128
def BitcastOfBitcast :
129129
Pat<(Arith_BitcastOp (Arith_BitcastOp $x)), (replaceWithValue $x)>;
130130

131+
//===----------------------------------------------------------------------===//
132+
// ExtSIOp
133+
//===----------------------------------------------------------------------===//
134+
135+
// extsi(extui(x iN : iM) : iL) -> extui(x : iL)
136+
def ExtSIOfExtUI :
137+
Pat<(Arith_ExtSIOp (Arith_ExtUIOp $x)), (Arith_ExtUIOp $x)>;
138+
131139
#endif // ARITHMETIC_PATTERNS

mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -788,6 +788,11 @@ OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) {
788788
return IntegerAttr::get(
789789
getType(), lhs.getValue().zext(getType().getIntOrFloatBitWidth()));
790790

791+
if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
792+
getInMutable().assign(lhs.getIn());
793+
return getResult();
794+
}
795+
791796
return {};
792797
}
793798

@@ -804,13 +809,23 @@ OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) {
804809
return IntegerAttr::get(
805810
getType(), lhs.getValue().sext(getType().getIntOrFloatBitWidth()));
806811

812+
if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
813+
getInMutable().assign(lhs.getIn());
814+
return getResult();
815+
}
816+
807817
return {};
808818
}
809819

810820
bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
811821
return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
812822
}
813823

824+
void arith::ExtSIOp::getCanonicalizationPatterns(
825+
OwningRewritePatternList &patterns, MLIRContext *context) {
826+
patterns.insert<ExtSIOfExtUI>(context);
827+
}
828+
814829
//===----------------------------------------------------------------------===//
815830
// ExtFOp
816831
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Arithmetic/canonicalize.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,35 @@ func @cmpOfExtUI(%arg0: i1) -> i1 {
7070

7171
// -----
7272

73+
// CHECK-LABEL: @extSIOfExtUI
74+
// CHECK: %[[res:.+]] = arith.extui %arg0 : i1 to i64
75+
// CHECK: return %[[res]]
76+
func @extSIOfExtUI(%arg0: i1) -> i64 {
77+
%ext1 = arith.extui %arg0 : i1 to i8
78+
%ext2 = arith.extsi %ext1 : i8 to i64
79+
return %ext2 : i64
80+
}
81+
82+
// CHECK-LABEL: @extUIOfExtUI
83+
// CHECK: %[[res:.+]] = arith.extui %arg0 : i1 to i64
84+
// CHECK: return %[[res]]
85+
func @extUIOfExtUI(%arg0: i1) -> i64 {
86+
%ext1 = arith.extui %arg0 : i1 to i8
87+
%ext2 = arith.extui %ext1 : i8 to i64
88+
return %ext2 : i64
89+
}
90+
91+
// CHECK-LABEL: @extSIOfExtSI
92+
// CHECK: %[[res:.+]] = arith.extsi %arg0 : i1 to i64
93+
// CHECK: return %[[res]]
94+
func @extSIOfExtSI(%arg0: i1) -> i64 {
95+
%ext1 = arith.extsi %arg0 : i1 to i8
96+
%ext2 = arith.extsi %ext1 : i8 to i64
97+
return %ext2 : i64
98+
}
99+
100+
// -----
101+
73102
// CHECK-LABEL: @indexCastOfSignExtend
74103
// CHECK: %[[res:.+]] = arith.index_cast %arg0 : i8 to index
75104
// CHECK: return %[[res]]

0 commit comments

Comments
 (0)