Skip to content

Commit d3d8b0a

Browse files
authored
feat: introduce fixedbinary<L> type (#83)
introduces the `fixedbinary<L>` type, which represents a binary string of L bytes.
1 parent 6284e0f commit d3d8b0a

File tree

11 files changed

+277
-8
lines changed

11 files changed

+277
-8
lines changed

include/substrait-mlir/Dialect/Substrait/IR/SubstraitAttrs.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,17 @@ def Substrait_FixedCharAttr
120120
}];
121121
}
122122

123+
def Substrait_FixedBinaryAttr : Substrait_Attr<"FixedBinary", "fixed_binary",
124+
[TypedAttrInterface]> {
125+
let summary = "Substrait fixed-length binary type";
126+
let description = [{
127+
This type represents a substrait binary string of L bytes.
128+
}];
129+
let parameters = (ins "StringAttr":$value, "FixedBinaryType":$type);
130+
let genVerifyDecl = 1;
131+
let assemblyFormat = "`<` custom<FixedBinaryLiteral>($value, $type) `>`";
132+
}
133+
123134
def Substrait_IntervalDaySecondAttr
124135
: Substrait_StaticallyTypedAttr<"IntervalDaySecond", "interval_day_second",
125136
"IntervalDaySecondType"> {
@@ -271,6 +282,7 @@ def Substrait_ParametrizedAttributes {
271282
list<Attr> attrs = [
272283
Substrait_FixedCharAttr, // FixedChar
273284
Substrait_VarCharAttr, // VarChar
285+
Substrait_FixedBinaryAttr, // FixedBinary
274286
Substrait_DecimalAttr, // Decimal
275287
];
276288
}

include/substrait-mlir/Dialect/Substrait/IR/SubstraitTypes.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,15 @@ def Substrait_DecimalType : Substrait_Type<"Decimal", "decimal"> {
4343
let genVerifyDecl = 1;
4444
}
4545

46+
def Substrait_FixedBinaryType : Substrait_Type<"FixedBinary", "fixed_binary"> {
47+
let summary = "Substrait fixed-length binary type";
48+
let description = [{
49+
This type represents a substrait binary string of L bytes.
50+
}];
51+
let parameters = (ins "int32_t":$length);
52+
let assemblyFormat = [{ `<` $length `>` }];
53+
}
54+
4655
def Substrait_FixedCharType : Substrait_Type<"FixedChar", "fixed_char"> {
4756
let summary = "Substrait fixed-length char type";
4857
let description = [{
@@ -145,6 +154,7 @@ def Substrait_ParametrizedTypes {
145154
list<Type> types = [
146155
Substrait_FixedCharType, // FixedChar
147156
Substrait_VarCharType, // VarChar
157+
Substrait_FixedBinaryType, // FixedBinary
148158
Substrait_DecimalType, // Decimal
149159
];
150160
}

lib/Dialect/Substrait/IR/Substrait.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,19 @@ LogicalResult mlir::substrait::FixedCharAttr::verify(
8080
return success();
8181
}
8282

83+
LogicalResult mlir::substrait::FixedBinaryAttr::verify(
84+
llvm::function_ref<mlir::InFlightDiagnostic()> emitError, StringAttr value,
85+
FixedBinaryType type) {
86+
FixedBinaryType fixedBinaryType = mlir::dyn_cast<FixedBinaryType>(type);
87+
if (fixedBinaryType == nullptr)
88+
return emitError() << "expected a fixed binary type";
89+
int32_t value_length = value.size();
90+
if (value_length != fixedBinaryType.getLength())
91+
return emitError() << "value length must be " << fixedBinaryType.getLength()
92+
<< " characters.";
93+
return success();
94+
}
95+
8396
LogicalResult mlir::substrait::IntervalYearMonthAttr::verify(
8497
llvm::function_ref<mlir::InFlightDiagnostic()> emitError, int32_t year,
8598
int32_t month) {
@@ -358,6 +371,32 @@ void printDecimalNumber(AsmPrinter &printer, DecimalType type,
358371
printer << "P = " << type.getPrecision() << ", S = " << type.getScale();
359372
}
360373

374+
ParseResult parseFixedBinaryLiteral(AsmParser &parser, StringAttr &value,
375+
FixedBinaryType &type) {
376+
std::string valueStr;
377+
// Parse fixed binary value as quoted string.
378+
if (parser.parseString(&valueStr))
379+
return failure();
380+
381+
// Create `FixedBinaryType`.
382+
auto emitError = [&]() {
383+
return parser.emitError(parser.getCurrentLocation());
384+
};
385+
MLIRContext *context = parser.getContext();
386+
uint32_t length = valueStr.size();
387+
if (!(type = FixedBinaryType::getChecked(emitError, context, length)))
388+
return failure();
389+
390+
value = parser.getBuilder().getStringAttr(valueStr);
391+
392+
return success();
393+
}
394+
395+
void printFixedBinaryLiteral(AsmPrinter &printer, StringAttr value,
396+
FixedBinaryType type) {
397+
printer << value;
398+
}
399+
361400
//===----------------------------------------------------------------------===//
362401
// Substrait operations
363402
//===----------------------------------------------------------------------===//

lib/Target/SubstraitPB/Export.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,19 @@ SubstraitExporter::exportType(Location loc, mlir::Type mlirType) {
403403
return std::move(type);
404404
}
405405

406+
// Handle fixed binary.
407+
if (mlir::isa<FixedBinaryType>(mlirType)) {
408+
// TODO(ingomueller): support other nullability modes.
409+
auto fixedBinaryType = std::make_unique<proto::Type::FixedBinary>();
410+
fixedBinaryType->set_length(
411+
mlir::cast<FixedBinaryType>(mlirType).getLength());
412+
fixedBinaryType->set_nullability(
413+
Type_Nullability::Type_Nullability_NULLABILITY_REQUIRED);
414+
auto type = std::make_unique<proto::Type>();
415+
type->set_allocated_fixed_binary(fixedBinaryType.release());
416+
return std::move(type);
417+
}
418+
406419
// Handle decimal.
407420
if (auto decimalType = llvm::dyn_cast<DecimalType>(mlirType)) {
408421
auto decimalTypeProto = std::make_unique<proto::Type::Decimal>();
@@ -1020,6 +1033,10 @@ SubstraitExporter::exportOperation(LiteralOp op) {
10201033
std::make_unique<::substrait::proto::Expression_Literal_VarChar>();
10211034
varChar->set_value(mlir::cast<VarCharAttr>(value).getValue().str());
10221035
literal->set_allocated_var_char(varChar.release());
1036+
// `FixedBinaryType`.
1037+
} else if (auto fixedBinaryType = dyn_cast<FixedBinaryType>(literalType)) {
1038+
literal->set_allocated_fixed_binary(
1039+
new std::string(mlir::cast<FixedBinaryAttr>(value).getValue().str()));
10231040
} // `DecimalType`.
10241041
else if (auto decimalType = dyn_cast<DecimalType>(literalType)) {
10251042
auto decimal =

lib/Target/SubstraitPB/Import.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,8 @@ static mlir::FailureOr<mlir::Type> importType(MLIRContext *context,
221221
return FixedCharType::get(context, type.fixed_char().length());
222222
case proto::Type::kVarchar:
223223
return VarCharType::get(context, type.varchar().length());
224+
case proto::Type::kFixedBinary:
225+
return FixedBinaryType::get(context, type.fixed_binary().length());
224226
case proto::Type::kDecimal: {
225227
const proto::Type::Decimal &decimalType = type.decimal();
226228
return mlir::substrait::DecimalType::get(context, decimalType.precision(),
@@ -712,6 +714,13 @@ importLiteral(ImplicitLocOpBuilder builder,
712714
auto attr = VarCharAttr::get(context, stringAttr, varCharType);
713715
return builder.create<LiteralOp>(attr);
714716
}
717+
case Expression::Literal::LiteralTypeCase::kFixedBinary: {
718+
StringAttr stringAttr = StringAttr::get(context, message.fixed_binary());
719+
FixedBinaryType fixedBinaryType =
720+
FixedBinaryType::get(context, message.fixed_binary().size());
721+
auto attr = FixedBinaryAttr::get(context, stringAttr, fixedBinaryType);
722+
return builder.create<LiteralOp>(attr);
723+
}
715724
case Expression::Literal::LiteralTypeCase::kDecimal: {
716725
APInt var(128, 0);
717726
llvm::LoadIntFromMemory(

test/Dialect/Substrait/literal.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,30 @@
3131

3232
// -----
3333

34+
// CHECK: substrait.plan version 0 : 42 : 1 {
35+
// CHECK-NEXT: relation
36+
// CHECK: %[[V0:.*]] = named_table
37+
// CHECK-NEXT: %[[V1:.*]] = project %[[V0]] : tuple<si1> -> tuple<si1, !substrait.fixed_binary<10>> {
38+
// CHECK-NEXT: ^[[BB0:.*]](%[[ARG0:.*]]: tuple<si1>):
39+
// CHECK-NEXT: %[[V2:.*]] = literal #substrait.fixed_binary<"8181818181">
40+
// CHECK-NEXT: yield %[[V2]] : !substrait.fixed_binary<10>
41+
// CHECK-NEXT: }
42+
// CHECK-NEXT: yield %[[V1]] : tuple<si1, !substrait.fixed_binary<10>
43+
44+
substrait.plan version 0 : 42 : 1 {
45+
relation {
46+
%0 = named_table @t1 as ["a"] : tuple<si1>
47+
%1 = project %0 : tuple<si1> -> tuple<si1, !substrait.fixed_binary<10>> {
48+
^bb0(%arg : tuple<si1>):
49+
%bytes = literal #substrait.fixed_binary<"8181818181">
50+
yield %bytes : !substrait.fixed_binary<10>
51+
}
52+
yield %1 : tuple<si1, !substrait.fixed_binary<10>>
53+
}
54+
}
55+
56+
// -----
57+
3458
// CHECK: substrait.plan version 0 : 42 : 1 {
3559
// CHECK-NEXT: relation
3660
// CHECK: %[[V0:.*]] = named_table

test/Dialect/Substrait/types.mlir

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,27 +3,41 @@
33

44
// CHECK-LABEL: substrait.plan
55
// CHECK: relation
6-
// CHECK: %[[V0:.*]] = named_table @t1 as ["a"] : tuple<!substrait.var_char<6>>
7-
// CHECK-NEXT: yield %0 : tuple<!substrait.var_char<6>>
6+
// CHECK: %[[V0:.*]] = named_table @t1 as ["a"] : tuple<!substrait.decimal<12, 2>>
7+
// CHECK-NEXT: yield %0 : tuple<!substrait.decimal<12, 2>>
88

99
substrait.plan version 0 : 42 : 1 {
1010
relation {
11-
%0 = named_table @t1 as ["a"] : tuple<!substrait.var_char<6>>
12-
yield %0 : tuple<!substrait.var_char<6>>
11+
%0 = named_table @t1 as ["a"] : tuple<!substrait.decimal<12, 2>>
12+
yield %0 : tuple<!substrait.decimal<12, 2>>
1313
}
1414
}
1515

1616
// -----
1717

1818
// CHECK-LABEL: substrait.plan
1919
// CHECK: relation
20-
// CHECK: %[[V0:.*]] = named_table @t1 as ["a"] : tuple<!substrait.decimal<12, 2>>
21-
// CHECK-NEXT: yield %0 : tuple<!substrait.decimal<12, 2>>
20+
// CHECK: %[[V0:.*]] = named_table @t1 as ["a"] : tuple<!substrait.fixed_binary<4>>
21+
// CHECK-NEXT: yield %0 : tuple<!substrait.fixed_binary<4>>
2222

2323
substrait.plan version 0 : 42 : 1 {
2424
relation {
25-
%0 = named_table @t1 as ["a"] : tuple<!substrait.decimal<12, 2>>
26-
yield %0 : tuple<!substrait.decimal<12, 2>>
25+
%0 = named_table @t1 as ["a"] : tuple<!substrait.fixed_binary<4>>
26+
yield %0 : tuple<!substrait.fixed_binary<4>>
27+
}
28+
}
29+
30+
// -----
31+
32+
// CHECK-LABEL: substrait.plan
33+
// CHECK: relation
34+
// CHECK: %[[V0:.*]] = named_table @t1 as ["a"] : tuple<!substrait.var_char<6>>
35+
// CHECK-NEXT: yield %0 : tuple<!substrait.var_char<6>>
36+
37+
substrait.plan version 0 : 42 : 1 {
38+
relation {
39+
%0 = named_table @t1 as ["a"] : tuple<!substrait.var_char<6>>
40+
yield %0 : tuple<!substrait.var_char<6>>
2741
}
2842
}
2943

test/Target/SubstraitPB/Export/literal.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,33 @@ substrait.plan version 0 : 42 : 1 {
3939

4040
// -----
4141

42+
// CHECK-LABEL: relations {
43+
// CHECK-NEXT: rel {
44+
// CHECK-NEXT: project {
45+
// CHECK-NEXT: common {
46+
// CHECK-NEXT: direct {
47+
// CHECK-NEXT: }
48+
// CHECK-NEXT: }
49+
// CHECK-NEXT: input {
50+
// CHECK-NEXT: read {
51+
// CHECK: expressions {
52+
// CHECK-NEXT: literal {
53+
// CHECK-NEXT: fixed_binary: "8181818181"
54+
55+
substrait.plan version 0 : 42 : 1 {
56+
relation {
57+
%0 = named_table @t1 as ["a"] : tuple<si1>
58+
%1 = project %0 : tuple<si1> -> tuple<si1, !substrait.fixed_binary<10>> {
59+
^bb0(%arg : tuple<si1>):
60+
%fixed_binary = literal #substrait.fixed_binary<"8181818181">
61+
yield %fixed_binary : !substrait.fixed_binary<10>
62+
}
63+
yield %1 : tuple<si1, !substrait.fixed_binary<10>>
64+
}
65+
}
66+
67+
// -----
68+
4269
// CHECK-LABEL: relations {
4370
// CHECK-NEXT: rel {
4471
// CHECK-NEXT: project {

test/Target/SubstraitPB/Export/types.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,32 @@ substrait.plan version 0 : 42 : 1 {
3636

3737
// -----
3838

39+
// CHECK-LABEL: relations {
40+
// CHECK-NEXT: rel {
41+
// CHECK-NEXT: read {
42+
// CHECK: base_schema {
43+
// CHECK-NEXT: names: "a"
44+
// CHECK-NEXT: struct {
45+
// CHECK-NEXT: types {
46+
// CHECK-NEXT: fixed_binary {
47+
// CHECK-NEXT: length: 4
48+
// CHECK-NEXT: nullability: NULLABILITY_REQUIRED
49+
// CHECK-NEXT: }
50+
// CHECK-NEXT: }
51+
// CHECK-NEXT: nullability: NULLABILITY_REQUIRED
52+
// CHECK-NEXT: }
53+
// CHECK-NEXT: }
54+
// CHECK-NEXT: named_table {
55+
56+
substrait.plan version 0 : 42 : 1 {
57+
relation {
58+
%0 = named_table @t1 as ["a"] : tuple<!substrait.fixed_binary<4>>
59+
yield %0 : tuple<!substrait.fixed_binary<4>>
60+
}
61+
}
62+
63+
// -----
64+
3965
// CHECK-LABEL: relations {
4066
// CHECK-NEXT: rel {
4167
// CHECK-NEXT: read {

test/Target/SubstraitPB/Import/literal.textpb

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,60 @@ version {
6868

6969
# -----
7070

71+
# CHECK: substrait.plan version 0 : 42 : 1 {
72+
# CHECK-NEXT: relation
73+
# CHECK: %[[V0:.*]] = named_table
74+
# CHECK-NEXT: %[[V1:.*]] = project %[[V0]] : tuple<si1> -> tuple<si1, !substrait.fixed_binary<10>> {
75+
# CHECK-NEXT: ^[[BB0:.*]](%[[ARG0:.*]]: tuple<si1>):
76+
# CHECK-NEXT: %[[V2:.*]] = literal #substrait.fixed_binary<"8181818181">
77+
# CHECK-NEXT: yield %[[V2]] : !substrait.fixed_binary<10>
78+
# CHECK-NEXT: }
79+
# CHECK-NEXT: yield %[[V1]] : tuple<si1, !substrait.fixed_binary<10>
80+
81+
relations {
82+
rel {
83+
project {
84+
common {
85+
direct {
86+
}
87+
}
88+
input {
89+
read {
90+
common {
91+
direct {
92+
}
93+
}
94+
base_schema {
95+
names: "a"
96+
struct {
97+
types {
98+
bool {
99+
nullability: NULLABILITY_REQUIRED
100+
}
101+
}
102+
nullability: NULLABILITY_REQUIRED
103+
}
104+
}
105+
named_table {
106+
names: "t1"
107+
}
108+
}
109+
}
110+
expressions {
111+
literal {
112+
fixed_binary: "8181818181"
113+
}
114+
}
115+
}
116+
}
117+
}
118+
version {
119+
minor_number: 42
120+
patch_number: 1
121+
}
122+
123+
# -----
124+
71125
# CHECK: substrait.plan version 0 : 42 : 1 {
72126
# CHECK-NEXT: relation
73127
# CHECK: %[[V0:.*]] = named_table

0 commit comments

Comments
 (0)