Skip to content

Commit 0a55131

Browse files
feat: elide attribute type with new StaticallyTypedAttrInterface (#59)
This allows to omit the type of an attribute from the assembly if it can be inferred from the attribute value. For example, the type is redundant in `#substrait.timestamp<100us> : !substrait.timestamp`. The built-in `TypedAttrInterface`, however, forces the appearance of the type in the assembly. The new interface is almost identical but does not enforce it. The PR also makes the two timestamp, the date, and the time attributes implement that interface. Signed-off-by: Ingo Müller <[email protected]>
1 parent 79a6132 commit 0a55131

File tree

9 files changed

+95
-51
lines changed

9 files changed

+95
-51
lines changed

include/substrait-mlir/Dialect/Substrait/IR/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ add_dependencies(MLIRSubstraitDialect MLIRSubstraitAttrsIncGen)
1818

1919
# Add interfaces.
2020
set(LLVM_TARGET_DEFINITIONS SubstraitInterfaces.td)
21+
mlir_tablegen(SubstraitAttrInterfaces.h.inc -gen-attr-interface-decls)
22+
mlir_tablegen(SubstraitAttrInterfaces.cpp.inc -gen-attr-interface-defs)
2123
mlir_tablegen(SubstraitOpInterfaces.h.inc -gen-op-interface-decls)
2224
mlir_tablegen(SubstraitOpInterfaces.cpp.inc -gen-op-interface-defs)
2325
mlir_tablegen(SubstraitTypeInterfaces.h.inc -gen-type-interface-decls)

include/substrait-mlir/Dialect/Substrait/IR/Substrait.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
#include "substrait-mlir/Dialect/Substrait/IR/SubstraitOpsDialect.h.inc" // IWYU: export
2121

22+
#include "substrait-mlir/Dialect/Substrait/IR/SubstraitAttrInterfaces.h.inc" // IWYU: export
2223
#include "substrait-mlir/Dialect/Substrait/IR/SubstraitOpInterfaces.h.inc" // IWYU: export
2324
#include "substrait-mlir/Dialect/Substrait/IR/SubstraitTypeInterfaces.h.inc" // IWYU: export
2425

@@ -31,4 +32,13 @@
3132
#define GET_OP_CLASSES
3233
#include "substrait-mlir/Dialect/Substrait/IR/SubstraitOps.h.inc" // IWYU: export
3334

35+
namespace mlir::substrait {
36+
37+
/// Returns the `Type` of the attribute through the `TypedAttrInterface` or the
38+
/// `TypeInferableAttrInterface`. Returns an empty `Type` if the given attribute
39+
/// does not implement one of the two interfaces.
40+
Type getAttrType(Attribute attr);
41+
42+
} // namespace mlir::substrait
43+
3444
#endif // SUBSTRAIT_MLIR_DIALECT_SUBSTRAIT_IR_SUBSTRAIT_H

include/substrait-mlir/Dialect/Substrait/IR/SubstraitAttrs.td

Lines changed: 26 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,24 @@ include "mlir/IR/AttrTypeBase.td"
1515
include "mlir/IR/BuiltinAttributeInterfaces.td"
1616

1717
// Base class for Substrait dialect attribute types.
18-
class Substrait_Attr<string name, string typeMnemonic, list<Trait> traits = []>
18+
class Substrait_Attr<string name, string attrMnemonic, list<Trait> traits = []>
1919
: AttrDef<Substrait_Dialect, name, traits> {
20-
let mnemonic = typeMnemonic;
20+
let mnemonic = attrMnemonic;
21+
}
22+
23+
// Base class for Substrait dialect attribute types that have a statically known
24+
// value type.
25+
class Substrait_StaticallyTypedAttr<string name, string attrMnemonic,
26+
string typeName, list<Trait> traits = []>
27+
: Substrait_Attr<
28+
name, attrMnemonic,
29+
traits#[DeclareAttrInterfaceMethods<TypeInferableAttrInterface>]> {
30+
let extraClassDeclaration = [{
31+
/// Implement TypeInferableAttrInterface.
32+
::mlir::Type getType() {
33+
return ::mlir::substrait::}]#typeName#[{::get(getContext());
34+
}
35+
}];
2136
}
2237

2338
def Substrait_AdvancedExtensionAttr
@@ -34,64 +49,45 @@ def Substrait_AdvancedExtensionAttr
3449
let genVerifyDecl = 1;
3550
}
3651

37-
def Substrait_DateAttr : Substrait_Attr<"Date", "date",
38-
[TypedAttrInterface]> {
52+
def Substrait_DateAttr
53+
: Substrait_StaticallyTypedAttr<"Date", "date", "DateType"> {
3954
let summary = "Substrait date type";
4055
let description = [{
4156
This type represents a substrait date attribute type.
4257
}];
4358
let parameters = (ins "int32_t":$value);
4459
let assemblyFormat = [{ `<` $value `>` }];
45-
let extraClassDeclaration = [{
46-
::mlir::Type getType() const {
47-
return DateType::get(getContext());
48-
}
49-
}];
5060
}
5161

52-
def Substrait_TimeAttr : Substrait_Attr<"Time", "time",
53-
[TypedAttrInterface]> {
62+
def Substrait_TimeAttr
63+
: Substrait_StaticallyTypedAttr<"Time", "time", "TimeType"> {
5464
let summary = "Substrait time type";
5565
let description = [{
5666
This type represents a substrait time attribute type.
5767
}];
5868
let parameters = (ins "int64_t":$value);
5969
let assemblyFormat = [{ `<` $value `` `us` `>` }];
60-
let extraClassDeclaration = [{
61-
::mlir::Type getType() const {
62-
return TimeType::get(getContext());
63-
}
64-
}];
6570
}
6671

67-
def Substrait_TimestampAttr : Substrait_Attr<"Timestamp", "timestamp",
68-
[TypedAttrInterface]> {
72+
def Substrait_TimestampAttr
73+
: Substrait_StaticallyTypedAttr<"Timestamp", "timestamp", "TimestampType"> {
6974
let summary = "Substrait timezone-unaware timestamp type";
7075
let description = [{
7176
This type represents a substrait timezone-unaware timestamp attribute type.
7277
}];
7378
let parameters = (ins "int64_t":$value);
7479
let assemblyFormat = [{ `<` $value `` `us` `>` }];
75-
let extraClassDeclaration = [{
76-
::mlir::Type getType() const {
77-
return TimestampType::get(getContext());
78-
}
79-
}];
8080
}
8181

82-
def Substrait_TimestampTzAttr : Substrait_Attr<"TimestampTz", "timestamp_tz",
83-
[TypedAttrInterface]> {
82+
def Substrait_TimestampTzAttr
83+
: Substrait_StaticallyTypedAttr<"TimestampTz", "timestamp_tz",
84+
"TimestampTzType"> {
8485
let summary = "Substrait timezone-aware timestamp type";
8586
let description = [{
8687
This type represents a substrait timezone-aware timestamp attribute type.
8788
}];
8889
let parameters = (ins "int64_t":$value);
8990
let assemblyFormat = [{ `<` $value `` `us` `>` }];
90-
let extraClassDeclaration = [{
91-
::mlir::Type getType() const {
92-
return TimestampTzType::get(getContext());
93-
}
94-
}];
9591
}
9692

9793
/// Attributes of currently supported atomic types, listed in order of substrait

include/substrait-mlir/Dialect/Substrait/IR/SubstraitInterfaces.td

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,26 @@
1111

1212
include "mlir/IR/OpBase.td"
1313

14+
def TypeInferableAttrInterface : AttrInterface<"TypeInferableAttrInterface"> {
15+
let cppNamespace = "::mlir::substrait";
16+
let description = [{
17+
This interface is used for attributes that have a type that can be inferred
18+
from the instance of the attribute. It is similar to the built-in
19+
`TypedAttrInterface` in that that type is understood to represent the type
20+
of the data contained in the attribute. However, it is different in that
21+
`TypedAttrInterface` is typically used for cases where the type is a
22+
parameter of the attribute such that there can be attribute instances with
23+
the same value but different types. With this interface, the type must be
24+
inferable from the value such that two instances with the same value always
25+
have the same type. Crucially, this allows to elide the type in the assembly
26+
format of the attribute.
27+
}];
28+
let methods = [InterfaceMethod<
29+
"Get the attribute's type",
30+
"::mlir::Type", "getType"
31+
>];
32+
}
33+
1434
def Substrait_ExpressionOpInterface : OpInterface<"ExpressionOpInterface"> {
1535
let description = [{
1636
Interface for any expression in a Substrait plan. This corresponds to an

include/substrait-mlir/Dialect/Substrait/IR/SubstraitTypes.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#define SUBSTRAIT_DIALECT_SUBSTRAIT_IR_SUBSTRAITTYPES
1111

1212
include "substrait-mlir/Dialect/Substrait/IR/SubstraitDialect.td"
13+
include "substrait-mlir/Dialect/Substrait/IR/SubstraitInterfaces.td"
1314
include "mlir/IR/CommonTypeConstraints.td"
1415

1516
// Base class for Substrait dialect types.

lib/Dialect/Substrait/IR/Substrait.cpp

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,10 @@
88

99
#include "substrait-mlir/Dialect/Substrait/IR/Substrait.h"
1010

11-
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
12-
#include "mlir/IR/DialectImplementation.h"
11+
#include "mlir/IR/DialectImplementation.h" // IWYU pragma: keep
1312
#include "mlir/Support/LogicalResult.h"
1413
#include "llvm/ADT/SmallSet.h"
15-
#include "llvm/ADT/TypeSwitch.h"
14+
#include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep
1615

1716
using namespace mlir;
1817
using namespace mlir::substrait;
@@ -38,6 +37,22 @@ void SubstraitDialect::initialize() {
3837
>();
3938
}
4039

40+
//===----------------------------------------------------------------------===//
41+
// Free functions
42+
//===----------------------------------------------------------------------===//
43+
44+
namespace mlir::substrait {
45+
46+
Type getAttrType(Attribute attr) {
47+
if (auto typedAttr = mlir::dyn_cast<TypedAttr>(attr))
48+
return typedAttr.getType();
49+
if (auto typedAttr = mlir::dyn_cast<TypeInferableAttrInterface>(attr))
50+
return typedAttr.getType();
51+
return Type();
52+
}
53+
54+
} // namespace mlir::substrait
55+
4156
//===----------------------------------------------------------------------===//
4257
// Substrait attributes
4358
//===----------------------------------------------------------------------===//
@@ -62,6 +77,7 @@ LogicalResult AdvancedExtensionAttr::verify(
6277
// Substrait interfaces
6378
//===----------------------------------------------------------------------===//
6479

80+
#include "substrait-mlir/Dialect/Substrait/IR/SubstraitAttrInterfaces.cpp.inc"
6581
#include "substrait-mlir/Dialect/Substrait/IR/SubstraitOpInterfaces.cpp.inc"
6682
#include "substrait-mlir/Dialect/Substrait/IR/SubstraitTypeInterfaces.cpp.inc"
6783

@@ -297,15 +313,14 @@ LiteralOp::inferReturnTypes(MLIRContext *context, std::optional<Location> loc,
297313
OpaqueProperties properties, RegionRange regions,
298314
llvm::SmallVectorImpl<Type> &inferredReturnTypes) {
299315
auto *typedProperties = properties.as<Properties *>();
316+
Attribute valueAttr = typedProperties->getValue();
300317

301-
auto attr = llvm::dyn_cast<TypedAttr>(typedProperties->getValue());
302-
if (!attr)
318+
Type resultType = getAttrType(valueAttr);
319+
if (!resultType)
303320
return emitOptionalError(loc, "unsuited attribute for literal value: ",
304321
typedProperties->getValue());
305322

306-
Type resultType = attr.getType();
307323
inferredReturnTypes.emplace_back(resultType);
308-
309324
return success();
310325
}
311326

lib/Target/SubstraitPB/Export.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -626,8 +626,8 @@ SubstraitExporter::exportOperation(FilterOp op) {
626626
FailureOr<std::unique_ptr<Expression>>
627627
SubstraitExporter::exportOperation(LiteralOp op) {
628628
// Build `Literal` message depending on type.
629-
auto value = llvm::cast<TypedAttr>(op.getValue());
630-
mlir::Type literalType = value.getType();
629+
Attribute value = op.getValue();
630+
mlir::Type literalType = getAttrType(value);
631631
auto literal = std::make_unique<Expression::Literal>();
632632

633633
// `IntegerType`s.

test/Dialect/Substrait/literal.mlir

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
// CHECK: %[[V0:.*]] = named_table
77
// CHECK-NEXT: %[[V1:.*]] = project %[[V0]] : tuple<si1> -> tuple<si1, !substrait.time> {
88
// CHECK-NEXT: ^[[BB0:.*]](%[[ARG0:.*]]: tuple<si1>):
9-
// CHECK-NEXT: %[[V2:.*]] = literal #substrait.time<200000000us> : !substrait.time
9+
// CHECK-NEXT: %[[V2:.*]] = literal #substrait.time<200000000us>{{$}}
1010
// CHECK-NEXT: yield %[[V2]] : !substrait.time
1111
// CHECK-NEXT: }
12-
// CHECK-NEXT: yield %[[V1]] : tuple<si1, !substrait.time>
12+
// CHECK-NEXT: yield %[[V1]] : tuple<si1, !substrait.time>
1313

1414
substrait.plan version 0 : 42 : 1 {
1515
relation {
@@ -19,7 +19,7 @@ substrait.plan version 0 : 42 : 1 {
1919
%time = literal #substrait.time<200000000us> : !substrait.time
2020
yield %time : !substrait.time
2121
}
22-
yield %1 : tuple<si1, !substrait.time>
22+
yield %1 : tuple<si1, !substrait.time>
2323
}
2424
}
2525

@@ -30,7 +30,7 @@ substrait.plan version 0 : 42 : 1 {
3030
// CHECK: %[[V0:.*]] = named_table
3131
// CHECK-NEXT: %[[V1:.*]] = project %[[V0]] : tuple<si1> -> tuple<si1, !substrait.date> {
3232
// CHECK-NEXT: ^[[BB0:.*]](%[[ARG0:.*]]: tuple<si1>):
33-
// CHECK-NEXT: %[[V2:.*]] = literal #substrait.date<200000000> : !substrait.date
33+
// CHECK-NEXT: %[[V2:.*]] = literal #substrait.date<200000000>{{$}}
3434
// CHECK-NEXT: yield %[[V2]] : !substrait.date
3535
// CHECK-NEXT: }
3636
// CHECK-NEXT: yield %[[V1]] : tuple<si1, !substrait.date>
@@ -43,7 +43,7 @@ substrait.plan version 0 : 42 : 1 {
4343
%date = literal #substrait.date<200000000> : !substrait.date
4444
yield %date : !substrait.date
4545
}
46-
yield %1 : tuple<si1, !substrait.date>
46+
yield %1 : tuple<si1, !substrait.date>
4747
}
4848
}
4949

@@ -54,8 +54,8 @@ substrait.plan version 0 : 42 : 1 {
5454
// CHECK: %[[V0:.*]] = named_table
5555
// CHECK-NEXT: %[[V1:.*]] = project %[[V0]] : tuple<si1> -> tuple<si1, !substrait.timestamp, !substrait.timestamp_tz> {
5656
// CHECK-NEXT: ^[[BB0:.*]](%[[ARG0:.*]]: tuple<si1>):
57-
// CHECK-NEXT: %[[V2:.*]] = literal #substrait.timestamp<10000000000us>
58-
// CHECK-NEXT: %[[V3:.*]] = literal #substrait.timestamp_tz<10000000000us>
57+
// CHECK-NEXT: %[[V2:.*]] = literal #substrait.timestamp<10000000000us>{{$}}
58+
// CHECK-NEXT: %[[V3:.*]] = literal #substrait.timestamp_tz<10000000000us>{{$}}
5959
// CHECK-NEXT: yield %[[V2]], %[[V3]] : !substrait.timestamp, !substrait.timestamp_tz
6060
// CHECK-NEXT: }
6161
// CHECK-NEXT: yield %[[V1]] : tuple<si1, !substrait.timestamp, !substrait.timestamp_tz>
@@ -65,7 +65,7 @@ substrait.plan version 0 : 42 : 1 {
6565
%0 = named_table @t1 as ["a"] : tuple<si1>
6666
%1 = project %0 : tuple<si1> -> tuple<si1, !substrait.timestamp, !substrait.timestamp_tz> {
6767
^bb0(%arg : tuple<si1>):
68-
%timestamp = literal #substrait.timestamp<10000000000us>
68+
%timestamp = literal #substrait.timestamp<10000000000us>
6969
%timestamp_tz = literal #substrait.timestamp_tz<10000000000us>
7070
yield %timestamp, %timestamp_tz : !substrait.timestamp, !substrait.timestamp_tz
7171
}

test/Target/SubstraitPB/Import/literal.textpb

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# CHECK: %[[V0:.*]] = named_table
1616
# CHECK-NEXT: %[[V1:.*]] = project %[[V0]] : tuple<si1> -> tuple<si1, !substrait.time> {
1717
# CHECK-NEXT: ^[[BB0:.*]](%[[ARG0:.*]]: tuple<si1>):
18-
# CHECK-NEXT: %[[V2:.*]] = literal #substrait.time<200000000us> : !substrait.time
18+
# CHECK-NEXT: %[[V2:.*]] = literal #substrait.time<200000000us>
1919
# CHECK-NEXT: yield %[[V2]] : !substrait.time
2020
# CHECK-NEXT: }
2121
# CHECK-NEXT: yield %[[V1]] : tuple<si1, !substrait.time>
@@ -69,7 +69,7 @@ version {
6969
# CHECK: %[[V0:.*]] = named_table
7070
# CHECK-NEXT: %[[V1:.*]] = project %[[V0]] : tuple<si1> -> tuple<si1, !substrait.date> {
7171
# CHECK-NEXT: ^[[BB0:.*]](%[[ARG0:.*]]: tuple<si1>):
72-
# CHECK-NEXT: %[[V2:.*]] = literal #substrait.date<200000000> : !substrait.date
72+
# CHECK-NEXT: %[[V2:.*]] = literal #substrait.date<200000000>
7373
# CHECK-NEXT: yield %[[V2]] : !substrait.date
7474
# CHECK-NEXT: }
7575
# CHECK-NEXT: yield %[[V1]] : tuple<si1, !substrait.date>
@@ -123,8 +123,8 @@ version {
123123
# CHECK: %[[V0:.*]] = named_table
124124
# CHECK-NEXT: %[[V1:.*]] = project %[[V0]] : tuple<si1> -> tuple<si1, !substrait.timestamp, !substrait.timestamp_tz> {
125125
# CHECK-NEXT: ^[[BB0:.*]](%[[ARG0:.*]]: tuple<si1>):
126-
# CHECK-NEXT: %[[V2:.*]] = literal #substrait.timestamp<10000000000us>
127-
# CHECK-NEXT: %[[V3:.*]] = literal #substrait.timestamp_tz<10000000000us>
126+
# CHECK-NEXT: %[[V2:.*]] = literal #substrait.timestamp<10000000000us>
127+
# CHECK-NEXT: %[[V3:.*]] = literal #substrait.timestamp_tz<10000000000us>
128128
# CHECK-NEXT: yield %[[V2]], %[[V3]] : !substrait.timestamp, !substrait.timestamp_tz
129129
# CHECK-NEXT: }
130130
# CHECK-NEXT: yield %[[V1]] : tuple<si1, !substrait.timestamp, !substrait.timestamp_tz>

0 commit comments

Comments
 (0)