Skip to content

Commit 85be57d

Browse files
feat: factor out advanced_extension logic and add project op (#65)
This PR factors out the handling of the `shared_extension` field from the `plan` op and adds that logic to the `project` op. This mainly consisted of moving the import and export logic from the functions related to the `plan` op to dedicated functions. The PR also introduces the new `ExtensibleOpInterface` that enforces an attribute called `advanced_extension` on the op that implement it and allows to deal with all such ops transparently. Since that interface depends on an attribute, the include order of the generated code of interfaces and attributes also had to be adapted. Unfortunately, the field names in the Substrait spec also vary (singular or plural), so the PR also introduces some template magic to be able to deal with protobuf message types with both spellings. With this PR, message types with an `advanced_extension` field should be able to support it by (1) adding the `ExtensibleOpInterface` to their traits, (2) adding an `advanced_extension` parameter, and (3) adding that parameter to their assembly format (although that's technically optional; otherwise, the attribute is set through the `attributes` dictionary).
1 parent e1d3751 commit 85be57d

File tree

9 files changed

+308
-54
lines changed

9 files changed

+308
-54
lines changed

include/substrait-mlir/Dialect/Substrait/IR/Substrait.h

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,39 @@
1515
#include "mlir/IR/SymbolTable.h" // IWYU: keep
1616
#include "mlir/Interfaces/InferTypeOpInterface.h" // IWYU: keep
1717

18-
#include "substrait-mlir/Dialect/Substrait/IR/SubstraitEnums.h.inc" // IWYU: export
18+
//===----------------------------------------------------------------------===//
19+
// Substrait dialect
20+
//===----------------------------------------------------------------------===//
1921

2022
#include "substrait-mlir/Dialect/Substrait/IR/SubstraitOpsDialect.h.inc" // IWYU: export
2123

22-
#include "substrait-mlir/Dialect/Substrait/IR/SubstraitAttrInterfaces.h.inc" // IWYU: export
23-
#include "substrait-mlir/Dialect/Substrait/IR/SubstraitOpInterfaces.h.inc" // IWYU: export
24-
#include "substrait-mlir/Dialect/Substrait/IR/SubstraitTypeInterfaces.h.inc" // IWYU: export
24+
//===----------------------------------------------------------------------===//
25+
// Substrait enums
26+
//===----------------------------------------------------------------------===//
27+
28+
#include "substrait-mlir/Dialect/Substrait/IR/SubstraitEnums.h.inc" // IWYU: export
29+
30+
//===----------------------------------------------------------------------===//
31+
// Substrait types
32+
//===----------------------------------------------------------------------===//
2533

34+
#include "substrait-mlir/Dialect/Substrait/IR/SubstraitTypeInterfaces.h.inc" // IWYU: export
2635
#define GET_TYPEDEF_CLASSES
2736
#include "substrait-mlir/Dialect/Substrait/IR/SubstraitOpsTypes.h.inc" // IWYU: export
2837

38+
//===----------------------------------------------------------------------===//
39+
// Substrait attributes
40+
//===----------------------------------------------------------------------===//
41+
42+
#include "substrait-mlir/Dialect/Substrait/IR/SubstraitAttrInterfaces.h.inc" // IWYU: export
2943
#define GET_ATTRDEF_CLASSES
3044
#include "substrait-mlir/Dialect/Substrait/IR/SubstraitOpsAttrs.h.inc" // IWYU: export
3145

46+
//===----------------------------------------------------------------------===//
47+
// Substrait ops
48+
//===----------------------------------------------------------------------===//
49+
50+
#include "substrait-mlir/Dialect/Substrait/IR/SubstraitOpInterfaces.h.inc" // IWYU: export
3251
#define GET_OP_CLASSES
3352
#include "substrait-mlir/Dialect/Substrait/IR/SubstraitOps.h.inc" // IWYU: export
3453

include/substrait-mlir/Dialect/Substrait/IR/SubstraitInterfaces.td

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,26 @@ def Substrait_ExpressionOpInterface : OpInterface<"ExpressionOpInterface"> {
4141
let cppNamespace = "::mlir::substrait";
4242
}
4343

44+
def Substrait_ExtensibleOpInterface : OpInterface<"ExtensibleOpInterface"> {
45+
let description = [{
46+
Interface for ops with the `advanced_extension` attribute. Several relations
47+
and other message types of the Substrait specification have a field with the
48+
same name (or the variant `advanced_extensions`, which has the same meaning)
49+
and the interface enables handling all of them transparently.
50+
}];
51+
let cppNamespace = "::mlir::substrait";
52+
let methods = [
53+
InterfaceMethod<
54+
"Get the `advanced_extension` attribute",
55+
"std::optional<::mlir::substrait::AdvancedExtensionAttr>",
56+
"getAdvancedExtension">,
57+
InterfaceMethod<
58+
"Get the `advanced_extension` attribute",
59+
"void", "setAdvancedExtensionAttr",
60+
(ins "::mlir::substrait::AdvancedExtensionAttr":$attr)>,
61+
];
62+
}
63+
4464
def Substrait_RelOpInterface : OpInterface<"RelOpInterface"> {
4565
let description = [{
4666
Interface for any relational operation in a Substrait plan. This corresponds

include/substrait-mlir/Dialect/Substrait/IR/SubstraitOps.td

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ def PlanBodyOp : AnyOf<[
151151

152152
def Substrait_PlanOp : Substrait_Op<"plan", [
153153
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getDefaultDialect"]>,
154+
DeclareOpInterfaceMethods<Substrait_ExtensibleOpInterface>,
154155
NoTerminator, NoRegionArguments, SingleBlock, SymbolTable
155156
]> {
156157
let summary = "Represents a Substrait plan";
@@ -180,9 +181,13 @@ def Substrait_PlanOp : Substrait_Op<"plan", [
180181
let builders = [
181182
OpBuilder<(ins "uint32_t":$major, "uint32_t":$minor, "uint32_t":$patch), [{
182183
build($_builder, $_state, major, minor, patch,
183-
/*git_hash=*/StringAttr(), /*producer*/StringAttr(),
184-
/*advanced_extension=*/AdvancedExtensionAttr(),
185-
/*expected_type_urls=*/ArrayAttr());
184+
/*git_hash=*/StringAttr(), /*producer*/StringAttr());
185+
}]>,
186+
OpBuilder<
187+
(ins "uint32_t":$major, "uint32_t":$minor, "uint32_t":$patch,
188+
"::llvm::StringRef":$git_hash, "::llvm::StringRef":$producer), [{
189+
build($_builder, $_state, major, minor, patch, git_hash, producer,
190+
/*advanced_extension=*/AdvancedExtensionAttr());
186191
}]>,
187192
OpBuilder<
188193
(ins "uint32_t":$major, "uint32_t":$minor, "uint32_t":$patch,
@@ -537,6 +542,7 @@ def Substrait_NamedTableOp : Substrait_RelOp<"named_table", [
537542

538543
def Substrait_ProjectOp : Substrait_RelOp<"project", [
539544
SingleBlockImplicitTerminator<"::mlir::substrait::YieldOp">,
545+
DeclareOpInterfaceMethods<Substrait_ExtensibleOpInterface>,
540546
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getDefaultDialect"]>
541547
]> {
542548
let summary = "Project operation";
@@ -561,14 +567,18 @@ def Substrait_ProjectOp : Substrait_RelOp<"project", [
561567
}
562568
```
563569
}];
564-
let arguments = (ins Substrait_Relation:$input);
570+
let arguments = (ins
571+
Substrait_Relation:$input,
572+
OptionalAttr<Substrait_AdvancedExtensionAttr>:$advanced_extension
573+
);
565574
let regions = (region AnyRegion:$expressions);
566575
let results = (outs Substrait_Relation:$result);
567576
// TODO(ingomueller): We could elide/shorten the block argument from the
568577
// assembly by writing custom printers/parsers similar to
569578
// `scf.for` etc.
570579
let assemblyFormat = [{
571-
$input attr-dict `:` type($input) `->` type($result) $expressions
580+
$input (`advanced_extension` `` $advanced_extension^)?
581+
attr-dict `:` type($input) `->` type($result) $expressions
572582
}];
573583
let hasRegionVerifier = 1;
574584
let hasFolder = 1;

lib/Target/SubstraitPB/Export.cpp

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
using namespace mlir;
2626
using namespace mlir::substrait;
27+
using namespace mlir::substrait::protobuf_utils;
2728
using namespace ::substrait;
2829
using namespace ::substrait::proto;
2930

@@ -60,6 +61,8 @@ class SubstraitExporter {
6061
DECLARE_EXPORT_FUNC(RelOpInterface, Rel)
6162
DECLARE_EXPORT_FUNC(SetOp, Rel)
6263

64+
template <typename MessageType>
65+
void exportAdvancedExtension(ExtensibleOpInterface op, MessageType &message);
6366
std::unique_ptr<pb::Any> exportAny(StringAttr attr);
6467
FailureOr<std::unique_ptr<pb::Message>> exportOperation(Operation *op);
6568
FailureOr<std::unique_ptr<proto::Type>> exportType(Location loc,
@@ -91,6 +94,36 @@ class SubstraitExporter {
9194
std::unique_ptr<SymbolTable> symbolTable; // Symbol table cache.
9295
};
9396

97+
template <typename MessageType>
98+
void SubstraitExporter::exportAdvancedExtension(ExtensibleOpInterface op,
99+
MessageType &message) {
100+
if (!op.getAdvancedExtension())
101+
return;
102+
103+
// Build the base `AdvancedExtension` message.
104+
AdvancedExtensionAttr extensionAttr = op.getAdvancedExtension().value();
105+
auto extension = std::make_unique<extensions::AdvancedExtension>();
106+
107+
StringAttr optimizationAttr = extensionAttr.getOptimization();
108+
StringAttr enhancementAttr = extensionAttr.getEnhancement();
109+
110+
// Set `optimization` field if present.
111+
if (optimizationAttr) {
112+
std::unique_ptr<pb::Any> optimization = exportAny(optimizationAttr);
113+
extension->set_allocated_optimization(optimization.release());
114+
}
115+
116+
// Set `enhancement` field if present.
117+
if (enhancementAttr) {
118+
std::unique_ptr<pb::Any> enhancement = exportAny(enhancementAttr);
119+
extension->set_allocated_enhancement(enhancement.release());
120+
}
121+
122+
// Set the `advanced_extension` field in the provided message.
123+
using Trait = advanced_extension_trait<MessageType>;
124+
Trait::set_allocated_advanced_extension(message, extension.release());
125+
}
126+
94127
std::unique_ptr<pb::Any> SubstraitExporter::exportAny(StringAttr attr) {
95128
auto any = std::make_unique<pb::Any>();
96129
auto anyType = mlir::cast<AnyType>(attr.getType());
@@ -874,26 +907,8 @@ FailureOr<std::unique_ptr<Plan>> SubstraitExporter::exportOperation(PlanOp op) {
874907
version->set_git_hash(op.getGitHash().str());
875908
plan->set_allocated_version(version.release());
876909

877-
// Build `AdvancedExtension` message.
878-
if (op.getAdvancedExtension()) {
879-
AdvancedExtensionAttr extensionAttr = op.getAdvancedExtension().value();
880-
auto extension = std::make_unique<extensions::AdvancedExtension>();
881-
882-
StringAttr optimizationAttr = extensionAttr.getOptimization();
883-
StringAttr enhancementAttr = extensionAttr.getEnhancement();
884-
885-
if (optimizationAttr) {
886-
std::unique_ptr<pb::Any> optimization = exportAny(optimizationAttr);
887-
extension->set_allocated_optimization(optimization.release());
888-
}
889-
890-
if (enhancementAttr) {
891-
std::unique_ptr<pb::Any> enhancement = exportAny(enhancementAttr);
892-
extension->set_allocated_enhancement(enhancement.release());
893-
}
894-
895-
plan->set_allocated_advanced_extensions(extension.release());
896-
}
910+
// Attach the `AdvancedExtension` message if the attribute exists.
911+
exportAdvancedExtension(op, *plan);
897912

898913
// Add `expected_type_urls` to plan if present.
899914
if (op.getExpectedTypeUrls()) {
@@ -1031,6 +1046,9 @@ SubstraitExporter::exportOperation(ProjectOp op) {
10311046
*projectRel->add_expressions() = *expression.value();
10321047
}
10331048

1049+
// Attach the `AdvancedExtension` message if the attribute exists.
1050+
exportAdvancedExtension(op, *projectRel);
1051+
10341052
// Build `Rel` message.
10351053
auto rel = std::make_unique<Rel>();
10361054
rel->set_allocated_project(projectRel.release());

lib/Target/SubstraitPB/Import.cpp

Lines changed: 48 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
using namespace mlir;
2727
using namespace mlir::substrait;
28+
using namespace mlir::substrait::protobuf_utils;
2829
using namespace ::substrait;
2930
using namespace ::substrait::proto;
3031

@@ -64,6 +65,46 @@ DECLARE_IMPORT_FUNC(ReadRel, Rel, RelOpInterface)
6465
DECLARE_IMPORT_FUNC(Rel, Rel, RelOpInterface)
6566
DECLARE_IMPORT_FUNC(ScalarFunction, Expression::ScalarFunction, CallOp)
6667

68+
/// If present, imports the `advanced_extension` or `advanced_extensions` field
69+
/// from the given message and sets the obtained attribute on the given op.
70+
template <typename MessageType>
71+
void importAdvancedExtension(ImplicitLocOpBuilder builder,
72+
ExtensibleOpInterface op,
73+
const MessageType &message);
74+
75+
template <typename MessageType>
76+
void importAdvancedExtension(ImplicitLocOpBuilder builder,
77+
ExtensibleOpInterface op,
78+
const MessageType &message) {
79+
using Trait = advanced_extension_trait<MessageType>;
80+
if (!Trait::has_advanced_extension(message))
81+
return;
82+
83+
// Get the `advanced_extension(s)` field.
84+
const extensions::AdvancedExtension &advancedExtension =
85+
Trait::advanced_extension(message);
86+
87+
// Import `optimization` field if present.
88+
StringAttr optimizationAttr;
89+
if (advancedExtension.has_optimization()) {
90+
const pb::Any &optimization = advancedExtension.optimization();
91+
optimizationAttr = importAny(builder, optimization).value();
92+
}
93+
94+
// Import `enhancement` field if present.
95+
StringAttr enhancementAttr;
96+
if (advancedExtension.has_enhancement()) {
97+
const pb::Any &enhancement = advancedExtension.enhancement();
98+
enhancementAttr = importAny(builder, enhancement).value();
99+
}
100+
101+
// Build attribute and set it on the op.
102+
MLIRContext *context = builder.getContext();
103+
auto advancedExtensionAttr =
104+
AdvancedExtensionAttr::get(context, optimizationAttr, enhancementAttr);
105+
op.setAdvancedExtensionAttr(advancedExtensionAttr);
106+
}
107+
67108
FailureOr<StringAttr> importAny(ImplicitLocOpBuilder builder,
68109
const pb::Any &message) {
69110
MLIRContext *context = builder.getContext();
@@ -494,32 +535,10 @@ static FailureOr<PlanOp> importPlan(ImplicitLocOpBuilder builder,
494535
// Import version.
495536
const Version &version = message.version();
496537

497-
// Import advanced extension.
498-
AdvancedExtensionAttr advancedExtensionAttr;
499-
if (message.has_advanced_extensions()) {
500-
const extensions::AdvancedExtension &advancedExtension =
501-
message.advanced_extensions();
502-
503-
StringAttr optimizationAttr;
504-
if (advancedExtension.has_optimization()) {
505-
const pb::Any &optimization = advancedExtension.optimization();
506-
optimizationAttr = importAny(builder, optimization).value();
507-
}
508-
509-
StringAttr enhancementAttr;
510-
if (advancedExtension.has_enhancement()) {
511-
const pb::Any &enhancement = advancedExtension.enhancement();
512-
enhancementAttr = importAny(builder, enhancement).value();
513-
}
514-
515-
advancedExtensionAttr =
516-
AdvancedExtensionAttr::get(context, optimizationAttr, enhancementAttr);
517-
}
518-
519538
// Build `PlanOp`.
520539
auto planOp = builder.create<PlanOp>(
521540
version.major_number(), version.minor_number(), version.patch_number(),
522-
version.git_hash(), version.producer(), advancedExtensionAttr);
541+
version.git_hash(), version.producer());
523542
planOp.getBody().push_back(new Block());
524543

525544
// Import `expected_type_urls` if present.
@@ -531,6 +550,9 @@ static FailureOr<PlanOp> importPlan(ImplicitLocOpBuilder builder,
531550
planOp.setExpectedTypeUrlsAttr(ArrayAttr::get(context, expected_type_urls));
532551
}
533552

553+
// Import advanced extension if it is present.
554+
importAdvancedExtension(builder, planOp, message);
555+
534556
OpBuilder::InsertionGuard insertGuard(builder);
535557
builder.setInsertionPointToEnd(&planOp.getBody().front());
536558

@@ -691,6 +713,9 @@ static mlir::FailureOr<ProjectOp> importProjectRel(ImplicitLocOpBuilder builder,
691713
builder.create<ProjectOp>(resultType, inputOp.value()->getResult(0));
692714
projectOp.getExpressions().push_back(conditionBlock.release());
693715

716+
// Import advanced extension if it is present.
717+
importAdvancedExtension(builder, projectOp, projectRel);
718+
694719
return projectOp;
695720
}
696721

lib/Target/SubstraitPB/ProtobufUtils.h

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
#ifndef LIB_TARGET_SUBSTRAITPB_PROTOBUFUTILS_H
1010
#define LIB_TARGET_SUBSTRAITPB_PROTOBUFUTILS_H
1111

12+
#include <type_traits>
13+
1214
#include "mlir/IR/Location.h"
1315

1416
namespace substrait::proto {
@@ -28,6 +30,59 @@ getCommon(const ::substrait::proto::Rel &rel, Location loc);
2830
FailureOr<::substrait::proto::RelCommon *>
2931
getMutableCommon(::substrait::proto::Rel *rel, Location loc);
3032

33+
/// SFINAE-based template that checks if the given (message) type has an field
34+
/// called `advanced_extension`: the `value` member is `true` iff it has. This
35+
/// is useful to deal with the two different names, `advanced_extension` and
36+
/// `advanced_extensions`, that are used for the same thing across different
37+
/// message types in the Substrait spec.
38+
template <typename T>
39+
class has_advanced_extensions {
40+
template <typename C>
41+
static std::true_type test(decltype(&C::advanced_extensions));
42+
template <typename C>
43+
static std::false_type test(...);
44+
45+
public:
46+
static constexpr bool value = decltype(test<T>(0))::value;
47+
};
48+
49+
/// Trait class for accessing the `advanced_extension` field. The default
50+
/// instances is automatically used for message types that call this field
51+
/// `advanced_extension`; the specialization below is automatically used for
52+
/// message types that call it `advanced_extensions`.
53+
template <typename T, typename = void>
54+
struct advanced_extension_trait {
55+
static auto has_advanced_extension(const T &message) {
56+
return message.has_advanced_extension();
57+
}
58+
static auto advanced_extension(const T &message) {
59+
return message.advanced_extension();
60+
}
61+
template <typename S>
62+
static auto set_allocated_advanced_extension(T &message,
63+
S &&advanced_extensions) {
64+
message.set_allocated_advanced_extension(
65+
std::forward<S>(advanced_extensions));
66+
}
67+
};
68+
69+
template <typename T>
70+
struct advanced_extension_trait<
71+
T, std::enable_if_t<has_advanced_extensions<T>::value>> {
72+
static auto has_advanced_extension(const T &message) {
73+
return message.has_advanced_extensions();
74+
}
75+
static auto advanced_extension(const T &message) {
76+
return message.advanced_extensions();
77+
}
78+
template <typename S>
79+
static auto set_allocated_advanced_extension(T &message,
80+
S &&advanced_extensions) {
81+
message.set_allocated_advanced_extensions(
82+
std::forward<S>(advanced_extensions));
83+
}
84+
};
85+
3186
} // namespace mlir::substrait::protobuf_utils
3287

3388
#endif // LIB_TARGET_SUBSTRAITPB_PROTOBUFUTILS_H

0 commit comments

Comments
 (0)