Skip to content

Commit 411d3c3

Browse files
feat: add aggregate op and extend functions accordingly (#70)
* feat: add `aggregate` op and extend functions accordingly This PR adds support for the `AggregateRel` from the Substrait spec in the form of the `aggregate` op. This is arguably the most complex op implemented so far. It has an optional enum argument that requires custom parsing, several optional regions that require custom parsing, an attribute that depends on the presence and contents of the regions and requires custom parsing to omit it in the common case, and return types that depend on the two regions and the attribute. What's more, the current version of the spec is such that it is almost impossibly to interpret "grouping sets" because it relies on protobuf message equality, which is something can protobuf does not offer. The current implementation, thus, implements a best effort by using op equality instead (but needs to run CSE during export to ensure op uniqueness). The PR also extends the `call` op to represent also `AggregateFunction` messages (in addition to `ScalarFunction` messages), which are used by the new `aggregate` op. Finally, the PR replaces some usages of `UnknownLoc` with the location from the `ImplicitOpLocBuilder` wherever one is available (including places there aren't otherwise affected by this PR). Signed-off-by: Ingo Müller <[email protected]>
1 parent 85be57d commit 411d3c3

File tree

13 files changed

+2031
-57
lines changed

13 files changed

+2031
-57
lines changed

include/substrait-mlir/Dialect/Substrait/IR/SubstraitAttrs.td

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ class Substrait_StaticallyTypedAttr<string name, string attrMnemonic,
3535
}];
3636
}
3737

38+
//===----------------------------------------------------------------------===//
39+
// Substrait attributes
40+
//===----------------------------------------------------------------------===//
41+
3842
def Substrait_AdvancedExtensionAttr
3943
: Substrait_Attr<"AdvancedExtension", "advanced_extension"> {
4044
let summary = "Represents the `AdvancedExtenssion` message of Substrait";
@@ -90,6 +94,10 @@ def Substrait_TimestampTzAttr
9094
let assemblyFormat = [{ `<` $value `` `us` `>` }];
9195
}
9296

97+
//===----------------------------------------------------------------------===//
98+
// Helpers and constraints
99+
//===----------------------------------------------------------------------===//
100+
93101
/// Attributes of currently supported atomic types, listed in order of substrait
94102
/// specification.
95103
def Substrait_AtomicAttributes {
@@ -113,4 +121,14 @@ def Substrait_AtomicAttributes {
113121
/// Attribute of one of the currently supported atomic types.
114122
def Substrait_AtomicAttribute : AnyAttrOf<Substrait_AtomicAttributes.attrs>;
115123

124+
/// `ArrayAttr` of `ArrayAttr`s if `i64`s.
125+
def I64ArrayArrayAttr : TypedArrayAttrBase<
126+
I64ArrayAttr, "64-bit integer array array attribute"
127+
>;
128+
129+
/// `ArrayAttr` of `ArrayAttr`s if `i64`s with at least one element.
130+
def NonEmptyI64ArrayArrayAttr :
131+
ConfinedAttr<I64ArrayArrayAttr, [ArrayMinCount<1>]>;
132+
133+
116134
#endif // SUBSTRAIT_DIALECT_SUBSTRAIT_IR_SUBSTRAITATTRS

include/substrait-mlir/Dialect/Substrait/IR/SubstraitEnums.td

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,20 @@
1111

1212
include "mlir/IR/EnumAttr.td"
1313

14-
def AggregationInvocationUnspecified: I32EnumAttrCase<"unspecified", 0>;
15-
def AggregationInvocationAll: I32EnumAttrCase<"all", 1>;
16-
def AggregationInvocationDistinct: I32EnumAttrCase<"distinct", 2>;
14+
/// Represents the `AggregationInvocation` protobuf enum.
15+
//
16+
/// The enum values correspond exactly to those in the `JoinRel.JoinType` enum,
17+
/// i.e., conversion through integers is possible.
18+
def AggregationInvocation
19+
: I32EnumAttr<"AggregationInvocation", "aggregate invocation type", [
20+
// clang-format off
21+
I32EnumAttrCase<"unspecified", 0>,
22+
I32EnumAttrCase<"all", 1>,
23+
I32EnumAttrCase<"distinct", 2>,
24+
// clang-format on
25+
]> {
26+
let cppNamespace = "::mlir::substrait";
27+
}
1728

1829
/// Represents the `JoinType` protobuf enum.
1930
def JoinTypeKind : I32EnumAttr<"JoinTypeKind",

include/substrait-mlir/Dialect/Substrait/IR/SubstraitOps.td

Lines changed: 103 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@ def Substrait_PlanRelOp : Substrait_Op<"relation", [
246246
def Substrait_YieldOp : Substrait_Op<"yield", [
247247
Terminator,
248248
ParentOneOf<[
249+
"::mlir::substrait::AggregateOp",
249250
"::mlir::substrait::FilterOp",
250251
"::mlir::substrait::PlanRelOp",
251252
"::mlir::substrait::ProjectOp"
@@ -323,9 +324,13 @@ def Substrait_CallOp : Substrait_ExpressionOp<"call", [
323324
]> {
324325
let summary = "Function call expression";
325326
let description = [{
326-
Represents a `ScalarFunction` message (or, in the future, other `*Function`
327-
messages) together with all messages it contains and the `Expression`
328-
message it is contained in.
327+
Represents a `ScalarFunction` or `AggregateFunction` message (or, in the
328+
future, a `WindowFunction` message) together with all messages it contains
329+
and, where applicable, the `Expression` message it is contained in. Which of
330+
the message types this op corresponds to depends on the presence of the
331+
(otherwise optional) aggregate or window-related attributes. For aggregate
332+
functions, the invocation type is omitted from the custom assembly if it is
333+
set to `all`.
329334

330335
Currently, the specification of the function, which is in an external YAML
331336
file, is not taken into account, for example, to verify whether a matching
@@ -347,11 +352,33 @@ def Substrait_CallOp : Substrait_ExpressionOp<"call", [
347352
// TODO(ingomueller): Add support for `enum` and `type` argument types.
348353
let arguments = (ins
349354
FlatSymbolRefAttr:$callee,
350-
Variadic<Substrait_FieldType>:$args
355+
Variadic<Substrait_FieldType>:$args,
356+
OptionalAttr<AggregationInvocation>:$aggregation_invocation
351357
);
352358
let results = (outs Substrait_FieldType:$result);
353359
let assemblyFormat = [{
354-
$callee `(` $args `)` attr-dict `:` `(` type($args) `)` `->` type($result)
360+
$callee `(` $args `)`
361+
(`aggregate` `` custom<AggregationInvocation>($aggregation_invocation)^)?
362+
attr-dict `:` `(` type($args) `)` `->` type($result)
363+
}];
364+
let builders = [
365+
OpBuilder<(ins "::mlir::Type":$result,
366+
"::mlir::FlatSymbolRefAttr":$callee,
367+
"::mlir::ValueRange":$args), [{
368+
build($_builder, $_state, result, callee, args,
369+
AggregationInvocationAttr());
370+
}]>,
371+
OpBuilder<(ins "::mlir::Type":$result, "::llvm::StringRef":$callee,
372+
"::mlir::ValueRange":$args), [{
373+
build($_builder, $_state, result, callee, args,
374+
AggregationInvocationAttr());
375+
}]>
376+
];
377+
let extraClassDeclaration = [{
378+
// Helpers to distinguish function types.
379+
bool isAggregate() { return getAggregationInvocation().has_value(); }
380+
bool isScalar() { return !isAggregate() && !isWindow(); }
381+
bool isWindow() { return false; } // TODO: change once supported.
355382
}];
356383
}
357384

@@ -375,6 +402,77 @@ class Substrait_RelOp<string mnemonic, list<Trait> traits = []> :
375402
]>>
376403
]>;
377404

405+
def Substrait_AggregateOp : Substrait_RelOp<"aggregate", [
406+
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getDefaultDialect"]>,
407+
SingleBlockImplicitTerminator<"::mlir::substrait::YieldOp">,
408+
DeclareOpInterfaceMethods<InferTypeOpInterface>,
409+
]> {
410+
let summary = "Aggregate operation";
411+
let description = [{
412+
Represents an `AggregateRel ` message together with the `RelCommon` and the
413+
messages it contains. The `measures` field is represented as a region where
414+
the yielded values correspond to the `AggregateFunction`s (and thus have
415+
to be produced by a `CallOp` representing an aggregate function). Filters
416+
are currently not supported. The `groupings` field is represented as a
417+
region yielding the unique (deduplicated) grouping expressions and an array
418+
of array of references to these expressions representing the grouping sets.
419+
An empty array of grouping sets corresponds to *no* `groupings` messages;
420+
an array with an empty grouping set corresponds to an *empty* `groupings`
421+
messages. These two protobuf representations are different even though their
422+
semantic is equivalent. The op can only be exported to the protobuf format
423+
if the expressions yielded by the `groupings` region are all distinct after
424+
CSE. The assembly format omits an empty region of groupings, an empty region
425+
of measures, and the grouping sets attribute with one grouping set that
426+
consists of all values yielded from `groupings` (or the empty grouping set
427+
if that region is empty).
428+
429+
Example:
430+
431+
```mlir
432+
%0 = ...
433+
%1 = aggregate %0 : tuple<si32> -> tuple<si32, si32>
434+
groupings {
435+
^bb0(%arg : tuple<si32>):
436+
%2 = field_reference %arg[0] : tuple<si32>
437+
yield %2 : si32
438+
}
439+
grouping_sets [[0]]
440+
measures {
441+
^bb0(%arg : tuple<si32>):
442+
%2 = field_reference %arg[0] : tuple<si32>
443+
%3 = call @function(%2) aggregate : (si32) -> si32
444+
yield %3 : si32
445+
}
446+
```
447+
}];
448+
let arguments = (ins
449+
Substrait_Relation:$input,
450+
I64ArrayArrayAttr:$grouping_sets
451+
);
452+
let results = (outs Substrait_Relation:$result);
453+
let regions = (region
454+
AnyRegion:$groupings,
455+
AnyRegion:$measures
456+
);
457+
let assemblyFormat = [{
458+
$input attr-dict `:` type($input) `->` type($result)
459+
custom<AggregateRegions>($groupings, $measures, $grouping_sets)
460+
}];
461+
let hasRegionVerifier = 1;
462+
let builders = [
463+
OpBuilder<(ins
464+
"::mlir::Value":$input, "::mlir::ArrayAttr":$grouping_sets,
465+
"::mlir::Region *":$groupings, "::mlir::Region *":$measures
466+
)>,
467+
];
468+
let extraClassDefinition = [{
469+
/// Implement OpAsmOpInterface.
470+
::llvm::StringRef $cppClass::getDefaultDialect() {
471+
return SubstraitDialect::getDialectNamespace();
472+
}
473+
}];
474+
}
475+
378476
def Substrait_CrossOp : Substrait_RelOp<"cross", [
379477
DeclareOpInterfaceMethods<InferTypeOpInterface>
380478
]> {

0 commit comments

Comments
 (0)