Skip to content

Commit 6f79b67

Browse files
authored
chore: switch to substrait antlr grammar (#74)
switches antlr generation to grammar definition from the main substrait repo.
1 parent b12b2ab commit 6f79b67

11 files changed

+1105
-1724
lines changed

Makefile

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
11
antlr:
2-
java -jar ${ANTLR_JAR} -o src/substrait/gen/antlr -Dlanguage=Python3 SubstraitType.g4
2+
cd third_party/substrait/grammar \
3+
&& java -jar ${ANTLR_JAR} -o ../../../src/substrait/gen/antlr -Dlanguage=Python3 SubstraitType.g4 \
4+
&& rm ../../../src/substrait/gen/antlr/*.tokens \
5+
&& rm ../../../src/substrait/gen/antlr/*.interp

SubstraitType.g4

Lines changed: 0 additions & 209 deletions
This file was deleted.

src/substrait/derivation_expression.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ def _evaluate(x, values: dict):
2727
else:
2828
raise Exception(f"Unknown binary op {x.op.text}")
2929
elif type(x) == SubstraitTypeParser.LiteralNumberContext:
30-
return int(x.number.text)
31-
elif type(x) == SubstraitTypeParser.TypeParamContext:
32-
return values[x.identifier.text]
30+
return int(x.Number().symbol.text)
31+
elif type(x) == SubstraitTypeParser.ParameterNameContext:
32+
return values[x.Identifier().symbol.text]
3333
elif type(x) == SubstraitTypeParser.NumericParameterNameContext:
3434
return values[x.Identifier().symbol.text]
3535
elif type(x) == SubstraitTypeParser.ParenExpressionContext:
@@ -43,9 +43,10 @@ def _evaluate(x, values: dict):
4343
return max(*exprs)
4444
else:
4545
raise Exception(f"Unknown function {func}")
46-
elif type(x) == SubstraitTypeParser.TypeContext:
46+
elif type(x) == SubstraitTypeParser.TypeDefContext:
4747
scalar_type = x.scalarType()
4848
parametrized_type = x.parameterizedType()
49+
any_type = x.anyType()
4950
if scalar_type:
5051
nullability = (
5152
Type.NULLABILITY_NULLABLE if x.isnull else Type.NULLABILITY_REQUIRED
@@ -81,8 +82,14 @@ def _evaluate(x, values: dict):
8182
)
8283
)
8384
raise Exception(f"Unknown parametrized type {type(parametrized_type)}")
85+
elif any_type:
86+
any_var = any_type.AnyVar()
87+
if any_var:
88+
return values[any_var.symbol.text]
89+
else:
90+
raise Exception()
8491
else:
85-
raise Exception("either scalar_type or parametrized_type is required")
92+
raise Exception(f"either scalar_type, parametrized_type or any_type is required")
8693
elif type(x) == SubstraitTypeParser.NumericExpressionContext:
8794
return _evaluate(x.expr(), values)
8895
elif type(x) == SubstraitTypeParser.TernaryContext:
@@ -101,7 +108,7 @@ def _evaluate(x, values: dict):
101108

102109
return _evaluate(x.finalType, values)
103110
elif type(x) == SubstraitTypeParser.TypeLiteralContext:
104-
return _evaluate(x.type_(), values)
111+
return _evaluate(x.typeDef(), values)
105112
elif type(x) == SubstraitTypeParser.NumericLiteralContext:
106113
return int(str(x.Number()))
107114
else:

src/substrait/extension_registry.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -96,25 +96,33 @@ def types_equal(type1: Type, type2: Type, check_nullability=False):
9696
).nullability = Type.Nullability.NULLABILITY_UNSPECIFIED
9797
return x == y
9898

99+
def handle_parameter_cover(covered: Type, parameter_name: str, parameters: dict, check_nullability: bool):
100+
if parameter_name in parameters:
101+
covering = parameters[parameter_name]
102+
return types_equal(covering, covered, check_nullability)
103+
else:
104+
parameters[parameter_name] = covered
105+
return True
99106

100107
def covers(
101108
covered: Type,
102109
covering: SubstraitTypeParser.TypeLiteralContext,
103110
parameters: dict,
104111
check_nullability=False,
105-
):
106-
if isinstance(covering, SubstraitTypeParser.TypeParamContext):
112+
):
113+
if isinstance(covering, SubstraitTypeParser.ParameterNameContext):
107114
parameter_name = str(covering.Identifier())
115+
return handle_parameter_cover(covered, parameter_name, parameters, check_nullability)
108116

109-
if parameter_name in parameters:
110-
covering = parameters[parameter_name]
117+
covering: SubstraitTypeParser.TypeDefContext = covering.typeDef()
111118

112-
return types_equal(covering, covered, check_nullability)
119+
any_type: SubstraitTypeParser.AnyTypeContext = covering.anyType()
120+
if any_type:
121+
if any_type.AnyVar():
122+
return handle_parameter_cover(covered, any_type.AnyVar().symbol.text, parameters, check_nullability)
113123
else:
114-
parameters[parameter_name] = covered
115124
return True
116125

117-
covering = covering.type_()
118126
scalar_type = covering.scalarType()
119127
if scalar_type:
120128
covering = _evaluate(covering, {})
@@ -150,10 +158,6 @@ def covers(
150158
else:
151159
raise Exception(f"Unhandled type {type(parameterized_type)}")
152160

153-
any_type = covering.anyType()
154-
if any_type:
155-
return True
156-
157161

158162
class FunctionEntry:
159163
def __init__(

0 commit comments

Comments
 (0)