Skip to content

Commit 1d9fbda

Browse files
authored
feat: sql to substrait (substrait-io#80)
Adds basic support for turning sql strings into substrait plans. Covers most common sql operators, but function support is minimal.
1 parent bb4de46 commit 1d9fbda

File tree

5 files changed

+860
-6
lines changed

5 files changed

+860
-6
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ write_to = "src/substrait/_version.py"
1414
[project.optional-dependencies]
1515
extensions = ["antlr4-python3-runtime", "pyyaml"]
1616
gen_proto = ["protobuf == 3.20.1", "protoletariat >= 2.0.0"]
17-
test = ["pytest >= 7.0.0", "antlr4-python3-runtime", "pyyaml"]
17+
sql = ["sqloxide", "deepdiff"]
18+
test = ["pytest >= 7.0.0", "antlr4-python3-runtime", "pyyaml", "sqloxide", "deepdiff", "duckdb<=1.2.2", "datafusion"]
1819

1920
[tool.pytest.ini_options]
2021
pythonpath = "src"

src/substrait/builders/plan.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626

2727
def _merge_extensions(*objs):
2828
return {
29-
"extension_uris": merge_extension_uris(*[b.extension_uris for b in objs]),
30-
"extensions": merge_extension_declarations(*[b.extensions for b in objs]),
29+
"extension_uris": merge_extension_uris(*[b.extension_uris for b in objs if b]),
30+
"extensions": merge_extension_declarations(*[b.extensions for b in objs if b]),
3131
}
3232

3333

@@ -193,13 +193,15 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
193193
bound_plan = plan if isinstance(plan, stp.Plan) else plan(registry)
194194
ns = infer_plan_schema(bound_plan)
195195

196-
bound_offset = resolve_expression(offset, ns, registry)
196+
bound_offset = resolve_expression(offset, ns, registry) if offset else None
197197
bound_count = resolve_expression(count, ns, registry)
198198

199199
rel = stalg.Rel(
200200
fetch=stalg.FetchRel(
201201
input=bound_plan.relations[-1].root.input,
202-
offset_expr=bound_offset.referred_expr[0].expression,
202+
offset_expr=bound_offset.referred_expr[0].expression
203+
if bound_offset
204+
else None,
203205
count_expr=bound_count.referred_expr[0].expression,
204206
)
205207
)
Lines changed: 339 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,339 @@
1+
import random
2+
import string
3+
from sqloxide import parse_sql
4+
from substrait.builders.extended_expression import (
5+
UnboundExtendedExpression,
6+
column,
7+
scalar_function,
8+
literal,
9+
aggregate_function,
10+
window_function,
11+
)
12+
from substrait.builders.plan import (
13+
read_named_table,
14+
project,
15+
filter,
16+
sort,
17+
fetch,
18+
set,
19+
join,
20+
aggregate,
21+
)
22+
from substrait.gen.proto import type_pb2 as stt
23+
from substrait.gen.proto import algebra_pb2 as stalg
24+
from substrait.extension_registry import ExtensionRegistry
25+
from typing import Callable
26+
from deepdiff import DeepDiff
27+
28+
SchemaResolver = Callable[[str], stt.NamedStruct]
29+
30+
function_mapping = {
31+
"Plus": ("functions_arithmetic.yaml", "add"),
32+
"Minus": ("functions_arithmetic.yaml", "subtract"),
33+
"Gt": ("functions_comparison.yaml", "gt"),
34+
"GtEq": ("functions_comparison.yaml", "gte"),
35+
"Lt": ("functions_comparison.yaml", "lt"),
36+
"Eq": ("functions_comparison.yaml", "equal"),
37+
}
38+
39+
aggregate_function_mapping = {"SUM": ("functions_arithmetic.yaml", "sum")}
40+
41+
window_function_mapping = {
42+
"row_number": ("functions_arithmetic.yaml", "row_number"),
43+
}
44+
45+
46+
def compare_dicts(dict1, dict2):
47+
diff = DeepDiff(dict1, dict2, exclude_regex_paths=["span"])
48+
return len(diff) == 0
49+
50+
51+
def translate_expression(
52+
ast: dict,
53+
schema_resolver: SchemaResolver,
54+
registry: ExtensionRegistry,
55+
measures: list[UnboundExtendedExpression],
56+
groupings: list[dict],
57+
alias: str = None,
58+
) -> UnboundExtendedExpression:
59+
assert len(ast) == 1
60+
op = list(ast.keys())[0]
61+
62+
if groupings:
63+
# This means we are parsing a projection after a grouping
64+
# Loop through used groupings for an identical ast and return it rather than recalculate
65+
for i, f in enumerate(groupings):
66+
if compare_dicts(ast, f):
67+
return column(i, alias=alias)
68+
69+
ast = ast[op]
70+
71+
if op == "Identifier":
72+
return column(ast["value"], alias=alias)
73+
elif op == "UnnamedExpr" or op == "expr" or op == "Unnamed" or op == "Expr":
74+
return translate_expression(
75+
ast,
76+
schema_resolver=schema_resolver,
77+
registry=registry,
78+
measures=measures,
79+
groupings=groupings,
80+
)
81+
elif op == "ExprWithAlias":
82+
return translate_expression(
83+
ast["expr"],
84+
schema_resolver=schema_resolver,
85+
registry=registry,
86+
measures=measures,
87+
groupings=groupings,
88+
alias=ast["alias"]["value"],
89+
)
90+
elif op == "BinaryOp":
91+
expressions = [
92+
translate_expression(
93+
ast["left"],
94+
schema_resolver=schema_resolver,
95+
registry=registry,
96+
measures=measures,
97+
groupings=groupings,
98+
),
99+
translate_expression(
100+
ast["right"],
101+
schema_resolver=schema_resolver,
102+
registry=registry,
103+
measures=measures,
104+
groupings=groupings,
105+
),
106+
]
107+
func = function_mapping[ast["op"]]
108+
return scalar_function(func[0], func[1], expressions=expressions, alias=alias)
109+
elif op == "Value":
110+
return literal(
111+
int(ast["value"]["Number"][0]), stt.Type(i64=stt.Type.I64()), alias=alias
112+
) # TODO infer type
113+
elif op == "Function":
114+
expressions = [
115+
translate_expression(
116+
e,
117+
schema_resolver=schema_resolver,
118+
registry=registry,
119+
measures=measures,
120+
groupings=groupings,
121+
)
122+
for e in ast["args"]["List"]["args"]
123+
]
124+
name = ast["name"][0]["Identifier"]["value"]
125+
126+
if name in function_mapping:
127+
func = function_mapping[name]
128+
return scalar_function(func[0], func[1], *expressions, alias=alias)
129+
elif name in aggregate_function_mapping:
130+
# All measures need to be extracted out because substrait calculates measures in a separate rel
131+
# We generate a random name for the measure and return a column with that name for the projection to work
132+
# Start by checking if multiple measures are identical and reuse previously generated name
133+
for m in measures:
134+
if compare_dicts(ast, m[1]):
135+
return column(m[2], alias=alias)
136+
137+
func = aggregate_function_mapping[name]
138+
random_name = "".join(
139+
random.choices(string.ascii_uppercase + string.digits, k=5)
140+
) # TODO make this deterministic
141+
aggr = aggregate_function(func[0], func[1], expressions, alias=random_name)
142+
measures.append((aggr, ast, random_name))
143+
return column(random_name, alias=alias)
144+
elif name in window_function_mapping:
145+
func = window_function_mapping[name]
146+
147+
partitions = [
148+
translate_expression(
149+
e,
150+
schema_resolver=schema_resolver,
151+
registry=registry,
152+
measures=measures,
153+
groupings=groupings,
154+
)
155+
for e in ast["over"]["WindowSpec"]["partition_by"]
156+
]
157+
158+
return window_function(
159+
func[0], func[1], expressions, partitions=partitions, alias=alias
160+
)
161+
162+
else:
163+
raise Exception(f"Unknown function {name}")
164+
# elif op == "Wildcard":
165+
# return wildcard()
166+
else:
167+
raise Exception(f"Unknown op {op}")
168+
169+
170+
def translate(ast: dict, schema_resolver: SchemaResolver, registry: ExtensionRegistry):
171+
assert len(ast) == 1
172+
op = list(ast.keys())[0]
173+
ast = ast[op]
174+
175+
if op == "Query":
176+
relation = translate(
177+
ast["body"], schema_resolver=schema_resolver, registry=registry
178+
)
179+
180+
if ast["order_by"]:
181+
expressions = [
182+
translate_expression(
183+
e["expr"],
184+
schema_resolver=schema_resolver,
185+
registry=registry,
186+
measures=None,
187+
groupings=None,
188+
)
189+
for e in ast["order_by"]["kind"]["Expressions"]
190+
]
191+
relation = sort(relation, expressions)(registry)
192+
193+
if ast["limit_clause"]:
194+
limit_expression = translate_expression(
195+
ast["limit_clause"]["LimitOffset"]["limit"],
196+
schema_resolver=schema_resolver,
197+
registry=registry,
198+
measures=None,
199+
groupings=None,
200+
)
201+
202+
if ast["limit_clause"]["LimitOffset"]["offset"]:
203+
offset_expression = translate_expression(
204+
ast["limit_clause"]["LimitOffset"]["offset"]["value"],
205+
schema_resolver=schema_resolver,
206+
registry=registry,
207+
measures=None,
208+
groupings=None,
209+
)
210+
else:
211+
offset_expression = None
212+
213+
relation = fetch(relation, offset_expression, limit_expression)(registry)
214+
215+
return relation
216+
elif op == "Select":
217+
relation = translate(
218+
ast["from"][0]["relation"],
219+
schema_resolver=schema_resolver,
220+
registry=registry,
221+
)
222+
223+
if ast["from"][0]["joins"]:
224+
for _join in ast["from"][0]["joins"]:
225+
join_type_mapping = {
226+
"Inner": stalg.JoinRel.JOIN_TYPE_INNER,
227+
"Left": stalg.JoinRel.JOIN_TYPE_LEFT,
228+
"LeftOuter": stalg.JoinRel.JOIN_TYPE_LEFT,
229+
"RightOuter": stalg.JoinRel.JOIN_TYPE_RIGHT,
230+
"Right": stalg.JoinRel.JOIN_TYPE_RIGHT,
231+
}
232+
right = translate(
233+
_join["relation"],
234+
schema_resolver=schema_resolver,
235+
registry=registry,
236+
)
237+
238+
join_type = list(_join["join_operator"].keys())[0]
239+
240+
expression = translate_expression(
241+
_join["join_operator"][join_type]["On"],
242+
schema_resolver=schema_resolver,
243+
registry=registry,
244+
measures=None,
245+
groupings=None,
246+
)
247+
248+
relation = join(
249+
relation, right, expression, join_type_mapping[join_type]
250+
)(registry)
251+
252+
if "selection" in ast and ast["selection"]:
253+
where_expression = translate_expression(
254+
ast["selection"],
255+
schema_resolver=schema_resolver,
256+
registry=registry,
257+
measures=None,
258+
groupings=None,
259+
)
260+
relation = filter(relation, where_expression)(registry)
261+
262+
if ast["group_by"] and ast["group_by"]["Expressions"][0]:
263+
groupings = ast["group_by"]["Expressions"][0]
264+
grouping_expressions = [
265+
translate_expression(
266+
e,
267+
schema_resolver=schema_resolver,
268+
registry=registry,
269+
measures=None,
270+
groupings=None,
271+
)
272+
for e in groupings
273+
]
274+
else:
275+
groupings = []
276+
grouping_expressions = []
277+
278+
measures = []
279+
280+
projection = [
281+
translate_expression(
282+
p,
283+
schema_resolver=schema_resolver,
284+
registry=registry,
285+
measures=measures,
286+
groupings=groupings,
287+
)
288+
for p in ast["projection"]
289+
]
290+
291+
if ast["having"]:
292+
having_predicate = translate_expression(
293+
ast["having"],
294+
schema_resolver=schema_resolver,
295+
registry=registry,
296+
measures=measures,
297+
groupings=[],
298+
)
299+
else:
300+
having_predicate = None
301+
302+
if measures or groupings:
303+
relation = aggregate(
304+
relation, grouping_expressions, [e[0] for e in measures]
305+
)(registry)
306+
307+
if having_predicate:
308+
relation = filter(relation, having_predicate)(registry)
309+
310+
return project(relation, expressions=projection)(registry)
311+
elif op == "Table":
312+
name = ast["name"][0]["Identifier"]["value"]
313+
return read_named_table(name, schema_resolver(name))
314+
elif op == "SetOperation":
315+
# TODO more than 2 inputs to a set operation
316+
left = translate(
317+
ast["left"], schema_resolver=schema_resolver, registry=registry
318+
)
319+
right = translate(
320+
ast["right"], schema_resolver=schema_resolver, registry=registry
321+
)
322+
if ast["op"] == "Union":
323+
set_op = (
324+
stalg.SetRel.SET_OP_UNION_ALL
325+
if ast["set_quantifier"] == "All"
326+
else stalg.SetRel.SET_OP_UNION_DISTINCT
327+
)
328+
else:
329+
raise Exception("")
330+
331+
return set([left, right], set_op)(registry)
332+
else:
333+
raise Exception(f"Unknown op {op}")
334+
335+
336+
def convert(query: str, dialect: str, schema_resolver: SchemaResolver):
337+
ast = parse_sql(sql=query, dialect=dialect)[0]
338+
registry = ExtensionRegistry(load_default_extensions=True)
339+
return translate(ast, schema_resolver=schema_resolver, registry=registry)

0 commit comments

Comments
 (0)