Skip to content

Commit 547113f

Browse files
authored
[mlir][ODS] Add ConstantEnumCase (llvm#78992)
Specifying an enum case of an enum attr currently requires the use of either `NativeCodeCall` or a `ConstantAttr` specifying the full C++ name of the enum case. The disadvantages of both are less readable code due to including C++ expressions and very few checks of any kind, creating C++ code that does not compile instead. This PR adds `ConstantEnumCase`, a kind of `ConstantAttr` which automatically derives the correct value representation from a given enum and the string representation of an enum case. It supports both `EnumAttrInfo`s (enums wrapping `IntegerAttr`) and `EnumAttr` (proper dialect attributes). It even supports bit-enums, allowing one to list multiple enum cases and have them be combined. If an enum case is not found, an assertion is triggered with a proper error message. Besides the tests, it was also used to simplify DRR patterns in the arith dialect.
1 parent 597f56f commit 547113f

File tree

5 files changed

+89
-29
lines changed

5 files changed

+89
-29
lines changed

mlir/include/mlir/IR/EnumAttr.td

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,4 +417,56 @@ class EnumAttr<Dialect dialect, EnumAttrInfo enumInfo, string name = "",
417417
let assemblyFormat = "$value";
418418
}
419419

420+
class _symbolToValue<EnumAttrInfo enumAttrInfo, string case> {
421+
defvar cases =
422+
!filter(iter, enumAttrInfo.enumerants, !eq(iter.str, case));
423+
424+
assert !not(!empty(cases)), "failed to find enum-case '" # case # "'";
425+
426+
// `!empty` check to not cause an error if the cases are empty.
427+
// The assertion catches the issue later and emits a proper error message.
428+
string value = enumAttrInfo.cppType # "::"
429+
# !if(!empty(cases), "", !head(cases).symbol);
430+
}
431+
432+
class _bitSymbolsToValue<BitEnumAttr bitEnumAttr, string case> {
433+
defvar pos = !find(case, "|");
434+
435+
// Recursive instantiation looking up the symbol before the `|` in
436+
// enum cases.
437+
string value = !if(
438+
!eq(pos, -1), /*baseCase=*/_symbolToValue<bitEnumAttr, case>.value,
439+
/*rec=*/_symbolToValue<bitEnumAttr, !substr(case, 0, pos)>.value # "|"
440+
# _bitSymbolsToValue<bitEnumAttr, !substr(case, !add(pos, 1))>.value
441+
);
442+
}
443+
444+
class ConstantEnumCaseBase<Attr attribute,
445+
EnumAttrInfo enumAttrInfo, string case>
446+
: ConstantAttr<attribute,
447+
!if(!isa<BitEnumAttr>(enumAttrInfo),
448+
_bitSymbolsToValue<!cast<BitEnumAttr>(enumAttrInfo), case>.value,
449+
_symbolToValue<enumAttrInfo, case>.value
450+
)
451+
>;
452+
453+
/// Attribute constraint matching a constant enum case. `attribute` should be
454+
/// one of `EnumAttrInfo` or `EnumAttr` and `symbol` the string representation
455+
/// of an enum case. Multiple enum values of a bit-enum can be combined using
456+
/// `|` as a separator. Note that there mustn't be any whitespace around the
457+
/// separator.
458+
/// This attribute constraint is additionally buildable, making it possible to
459+
/// use it in result patterns.
460+
///
461+
/// Examples:
462+
/// * ConstantEnumCase<Arith_IntegerOverflowAttr, "nsw|nuw">
463+
/// * ConstantEnumCase<Arith_CmpIPredicateAttr, "slt">
464+
class ConstantEnumCase<Attr attribute, string case>
465+
: ConstantEnumCaseBase<attribute,
466+
!if(!isa<EnumAttrInfo>(attribute), !cast<EnumAttrInfo>(attribute),
467+
!cast<EnumAttr>(attribute).enum), case> {
468+
assert !or(!isa<EnumAttr>(attribute), !isa<EnumAttrInfo>(attribute)),
469+
"attribute must be one of 'EnumAttr' or 'EnumAttrInfo'";
470+
}
471+
420472
#endif // ENUMATTR_TD

mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def MulIntAttrs : NativeCodeCall<"mulIntegerAttrs($_builder, $0, $1, $2)">;
2828
// flags and always reset them to default (wraparound) which is safe but can
2929
// inhibit later optimizations. Individual patterns must be reviewed for
3030
// better handling of overflow flags.
31-
def DefOverflow : NativeCodeCall<"getDefOverflowFlags($_builder)">;
31+
defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">;
3232

3333
class cast<string type> : NativeCodeCall<"::mlir::cast<" # type # ">($0)">;
3434

@@ -45,23 +45,23 @@ def AddIAddConstant :
4545
(Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
4646
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
4747
(Arith_AddIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)),
48-
(DefOverflow))>;
48+
DefOverflow)>;
4949

5050
// addi(subi(x, c0), c1) -> addi(x, c1 - c0)
5151
def AddISubConstantRHS :
5252
Pat<(Arith_AddIOp:$res
5353
(Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
5454
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
5555
(Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)),
56-
(DefOverflow))>;
56+
DefOverflow)>;
5757

5858
// addi(subi(c0, x), c1) -> subi(c0 + c1, x)
5959
def AddISubConstantLHS :
6060
Pat<(Arith_AddIOp:$res
6161
(Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1),
6262
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
6363
(Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x,
64-
(DefOverflow))>;
64+
DefOverflow)>;
6565

6666
def IsScalarOrSplatNegativeOne :
6767
Constraint<And<[
@@ -73,15 +73,15 @@ def AddIMulNegativeOneRhs :
7373
Pat<(Arith_AddIOp
7474
$x,
7575
(Arith_MulIOp $y, (ConstantLikeMatcher AnyAttr:$c0), $ovf1), $ovf2),
76-
(Arith_SubIOp $x, $y, (DefOverflow)),
76+
(Arith_SubIOp $x, $y, DefOverflow),
7777
[(IsScalarOrSplatNegativeOne $c0)]>;
7878

7979
// addi(muli(x, -1), y) -> subi(y, x)
8080
def AddIMulNegativeOneLhs :
8181
Pat<(Arith_AddIOp
8282
(Arith_MulIOp $x, (ConstantLikeMatcher AnyAttr:$c0), $ovf1),
8383
$y, $ovf2),
84-
(Arith_SubIOp $y, $x, (DefOverflow)),
84+
(Arith_SubIOp $y, $x, DefOverflow),
8585
[(IsScalarOrSplatNegativeOne $c0)]>;
8686

8787
// muli(muli(x, c0), c1) -> muli(x, c0 * c1)
@@ -90,7 +90,7 @@ def MulIMulIConstant :
9090
(Arith_MulIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
9191
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
9292
(Arith_MulIOp $x, (Arith_ConstantOp (MulIntAttrs $res, $c0, $c1)),
93-
(DefOverflow))>;
93+
DefOverflow)>;
9494

9595
//===----------------------------------------------------------------------===//
9696
// AddUIExtendedOp
@@ -100,7 +100,7 @@ def MulIMulIConstant :
100100
// uses. Since the 'overflow' result is unused, any replacement value will do.
101101
def AddUIExtendedToAddI:
102102
Pattern<(Arith_AddUIExtendedOp:$res $x, $y),
103-
[(Arith_AddIOp $x, $y, (DefOverflow)), (replaceWithValue $x)],
103+
[(Arith_AddIOp $x, $y, DefOverflow), (replaceWithValue $x)],
104104
[(Constraint<CPred<"$0.getUses().empty()">> $res__1)]>;
105105

106106
//===----------------------------------------------------------------------===//
@@ -113,52 +113,52 @@ def SubIRHSAddConstant :
113113
(Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
114114
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
115115
(Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)),
116-
(DefOverflow))>;
116+
DefOverflow)>;
117117

118118
// subi(c1, addi(x, c0)) -> subi(c1 - c0, x)
119119
def SubILHSAddConstant :
120120
Pat<(Arith_SubIOp:$res
121121
(ConstantLikeMatcher APIntAttr:$c1),
122122
(Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1), $ovf2),
123123
(Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)), $x,
124-
(DefOverflow))>;
124+
DefOverflow)>;
125125

126126
// subi(subi(x, c0), c1) -> subi(x, c0 + c1)
127127
def SubIRHSSubConstantRHS :
128128
Pat<(Arith_SubIOp:$res
129129
(Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
130130
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
131131
(Arith_SubIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)),
132-
(DefOverflow))>;
132+
DefOverflow)>;
133133

134134
// subi(subi(c0, x), c1) -> subi(c0 - c1, x)
135135
def SubIRHSSubConstantLHS :
136136
Pat<(Arith_SubIOp:$res
137137
(Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1),
138138
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
139139
(Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)), $x,
140-
(DefOverflow))>;
140+
DefOverflow)>;
141141

142142
// subi(c1, subi(x, c0)) -> subi(c0 + c1, x)
143143
def SubILHSSubConstantRHS :
144144
Pat<(Arith_SubIOp:$res
145145
(ConstantLikeMatcher APIntAttr:$c1),
146146
(Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1), $ovf2),
147147
(Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x,
148-
(DefOverflow))>;
148+
DefOverflow)>;
149149

150150
// subi(c1, subi(c0, x)) -> addi(x, c1 - c0)
151151
def SubILHSSubConstantLHS :
152152
Pat<(Arith_SubIOp:$res
153153
(ConstantLikeMatcher APIntAttr:$c1),
154154
(Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1), $ovf2),
155155
(Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)),
156-
(DefOverflow))>;
156+
DefOverflow)>;
157157

158158
// subi(subi(a, b), a) -> subi(0, b)
159159
def SubISubILHSRHSLHS :
160160
Pat<(Arith_SubIOp:$res (Arith_SubIOp $x, $y, $ovf1), $x, $ovf2),
161-
(Arith_SubIOp (Arith_ConstantOp (GetZeroAttr $y)), $y, (DefOverflow))>;
161+
(Arith_SubIOp (Arith_ConstantOp (GetZeroAttr $y)), $y, DefOverflow)>;
162162

163163
//===----------------------------------------------------------------------===//
164164
// MulSIExtendedOp
@@ -168,7 +168,7 @@ def SubISubILHSRHSLHS :
168168
// Since the `high` result it not used, any replacement value will do.
169169
def MulSIExtendedToMulI :
170170
Pattern<(Arith_MulSIExtendedOp:$res $x, $y),
171-
[(Arith_MulIOp $x, $y, (DefOverflow)), (replaceWithValue $x)],
171+
[(Arith_MulIOp $x, $y, DefOverflow), (replaceWithValue $x)],
172172
[(Constraint<CPred<"$0.getUses().empty()">> $res__1)]>;
173173

174174

@@ -182,9 +182,9 @@ def MulSIExtendedRHSOne :
182182
Pattern<(Arith_MulSIExtendedOp $x, (ConstantLikeMatcher AnyAttr:$c1)),
183183
[(replaceWithValue $x),
184184
(Arith_ExtSIOp(Arith_CmpIOp
185-
(NativeCodeCall<"arith::CmpIPredicate::slt">),
186-
$x,
187-
(Arith_ConstantOp (GetZeroAttr $x))))],
185+
ConstantEnumCase<Arith_CmpIPredicateAttr, "slt">,
186+
$x,
187+
(Arith_ConstantOp (GetZeroAttr $x))))],
188188
[(IsScalarOrSplatOne $c1)]>;
189189

190190
//===----------------------------------------------------------------------===//
@@ -195,7 +195,7 @@ def MulSIExtendedRHSOne :
195195
// Since the `high` result it not used, any replacement value will do.
196196
def MulUIExtendedToMulI :
197197
Pattern<(Arith_MulUIExtendedOp:$res $x, $y),
198-
[(Arith_MulIOp $x, $y, (DefOverflow)), (replaceWithValue $x)],
198+
[(Arith_MulIOp $x, $y, DefOverflow), (replaceWithValue $x)],
199199
[(Constraint<CPred<"$0.getUses().empty()">> $res__1)]>;
200200

201201
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,6 @@ static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
6161
return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
6262
}
6363

64-
static IntegerOverflowFlagsAttr getDefOverflowFlags(OpBuilder &builder) {
65-
return IntegerOverflowFlagsAttr::get(builder.getContext(),
66-
IntegerOverflowFlags::none);
67-
}
68-
6964
/// Invert an integer comparison predicate.
7065
arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
7166
switch (pred) {

mlir/test/IR/enum-attr-roundtrip.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,12 @@ func.func @test_match_op_with_enum() -> () {
2626
test.op_with_enum first tag 0 : i32
2727
return
2828
}
29+
30+
// CHECK-LABEL: @test_match_op_with_bit_enum
31+
func.func @test_match_op_with_bit_enum() -> () {
32+
// CHECK: test.op_with_bit_enum <write> tag 0 : i32
33+
test.op_with_bit_enum <write> tag 0 : i32
34+
// CHECK: test.op_with_bit_enum <read, execute> tag 1 : i32
35+
test.op_with_bit_enum <execute, write> tag 0 : i32
36+
return
37+
}

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -396,11 +396,9 @@ def OpWithEnum : TEST_Op<"op_with_enum"> {
396396
}
397397

398398
// Define a pattern that matches and creates an enum attribute.
399-
def : Pat<(OpWithEnum ConstantAttr<TestEnumAttr,
400-
"::test::TestEnum::First">:$value,
399+
def : Pat<(OpWithEnum ConstantEnumCase<TestEnumAttr, "first">:$value,
401400
ConstantAttr<I32Attr, "0">:$tag),
402-
(OpWithEnum ConstantAttr<TestEnumAttr,
403-
"::test::TestEnum::Second">,
401+
(OpWithEnum ConstantEnumCase<TestEnumAttr, "second">,
404402
ConstantAttr<I32Attr, "1">)>;
405403

406404
//===----------------------------------------------------------------------===//
@@ -430,6 +428,12 @@ def OpWithBitEnumVerticalBar : TEST_Op<"op_with_bit_enum_vbar"> {
430428
let assemblyFormat = "$value (`tag` $tag^)? attr-dict";
431429
}
432430

431+
// Define a pattern that matches and creates a bit enum attribute.
432+
def : Pat<(OpWithBitEnum ConstantEnumCase<TestBitEnumAttr, "write|execute">,
433+
ConstantAttr<I32Attr, "0">),
434+
(OpWithBitEnum ConstantEnumCase<TestBitEnumAttr, "execute|read">,
435+
ConstantAttr<I32Attr, "1">)>;
436+
433437
//===----------------------------------------------------------------------===//
434438
// Test Regions
435439
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)