Skip to content

Commit c0b9f6e

Browse files
committed
feat: allow passing multiple functions to function builders
1 parent 42e979b commit c0b9f6e

File tree

6 files changed

+71
-48
lines changed

6 files changed

+71
-48
lines changed

src/substrait/builders/extended_expression.py

Lines changed: 46 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -204,12 +204,12 @@ def resolve(
204204

205205

206206
def scalar_function(
207-
uri: str,
208-
function: str,
207+
function: Union[Iterable[str], str],
209208
expressions: Iterable[ExtendedExpressionOrUnbound],
210209
alias: Union[Iterable[str], str] = None,
211210
):
212211
"""Builds a resolver for ExtendedExpression containing a ScalarFunction expression"""
212+
functions = [function] if isinstance(function, str) else function
213213

214214
def resolve(
215215
base_schema: stp.NamedStruct, registry: ExtensionRegistry
@@ -224,23 +224,30 @@ def resolve(
224224

225225
signature = [typ for es in expression_schemas for typ in es.types]
226226

227-
func = registry.lookup_function(uri, function, signature)
227+
for f in functions:
228+
uri, name = f.split(":")
229+
func = registry.lookup_function(uri, name, signature)
230+
if func:
231+
break
228232

229233
if not func:
230234
raise Exception(f"Unknown function {function} for {signature}")
231235

236+
resolved_func, return_type = func
237+
232238
func_extension_uris = [
233239
ste.SimpleExtensionURI(
234-
extension_uri_anchor=registry.lookup_uri(uri), uri=uri
240+
extension_uri_anchor=registry.lookup_uri(resolved_func.uri),
241+
uri=resolved_func.uri,
235242
)
236243
]
237244

238245
func_extensions = [
239246
ste.SimpleExtensionDeclaration(
240247
extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction(
241-
extension_uri_reference=registry.lookup_uri(uri),
242-
function_anchor=func[0].anchor,
243-
name=str(func[0]),
248+
extension_uri_reference=registry.lookup_uri(resolved_func.uri),
249+
function_anchor=resolved_func.anchor,
250+
name=str(resolved_func),
244251
)
245252
)
246253
]
@@ -258,14 +265,14 @@ def resolve(
258265
stee.ExpressionReference(
259266
expression=stalg.Expression(
260267
scalar_function=stalg.Expression.ScalarFunction(
261-
function_reference=func[0].anchor,
268+
function_reference=resolved_func.anchor,
262269
arguments=[
263270
stalg.FunctionArgument(
264271
value=e.referred_expr[0].expression
265272
)
266273
for e in bound_expressions
267274
],
268-
output_type=func[1],
275+
output_type=return_type,
269276
)
270277
),
271278
output_names=_alias_or_inferred(
@@ -284,12 +291,12 @@ def resolve(
284291

285292

286293
def aggregate_function(
287-
uri: str,
288-
function: str,
294+
function: Union[Iterable[str], str],
289295
expressions: Iterable[ExtendedExpressionOrUnbound],
290296
alias: Union[Iterable[str], str] = None,
291297
):
292298
"""Builds a resolver for ExtendedExpression containing a AggregateFunction measure"""
299+
functions = [function] if isinstance(function, str) else function
293300

294301
def resolve(
295302
base_schema: stp.NamedStruct, registry: ExtensionRegistry
@@ -304,23 +311,30 @@ def resolve(
304311

305312
signature = [typ for es in expression_schemas for typ in es.types]
306313

307-
func = registry.lookup_function(uri, function, signature)
314+
for f in functions:
315+
uri, name = f.split(":")
316+
func = registry.lookup_function(uri, name, signature)
317+
if func:
318+
break
308319

309320
if not func:
310321
raise Exception(f"Unknown function {function} for {signature}")
311322

323+
resolved_func, return_type = func
324+
312325
func_extension_uris = [
313326
ste.SimpleExtensionURI(
314-
extension_uri_anchor=registry.lookup_uri(uri), uri=uri
327+
extension_uri_anchor=registry.lookup_uri(resolved_func.uri),
328+
uri=resolved_func.uri,
315329
)
316330
]
317331

318332
func_extensions = [
319333
ste.SimpleExtensionDeclaration(
320334
extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction(
321-
extension_uri_reference=registry.lookup_uri(uri),
322-
function_anchor=func[0].anchor,
323-
name=str(func[0]),
335+
extension_uri_reference=registry.lookup_uri(resolved_func.uri),
336+
function_anchor=resolved_func.anchor,
337+
name=str(resolved_func),
324338
)
325339
)
326340
]
@@ -342,7 +356,7 @@ def resolve(
342356
stalg.FunctionArgument(value=e.referred_expr[0].expression)
343357
for e in bound_expressions
344358
],
345-
output_type=func[1],
359+
output_type=return_type,
346360
),
347361
output_names=_alias_or_inferred(
348362
alias,
@@ -361,13 +375,13 @@ def resolve(
361375

362376
# TODO bounds, sorts
363377
def window_function(
364-
uri: str,
365-
function: str,
378+
function: Union[Iterable[str], str],
366379
expressions: Iterable[ExtendedExpressionOrUnbound],
367380
partitions: Iterable[ExtendedExpressionOrUnbound] = [],
368381
alias: Union[Iterable[str], str] = None,
369382
):
370383
"""Builds a resolver for ExtendedExpression containing a WindowFunction expression"""
384+
functions = [function] if isinstance(function, str) else function
371385

372386
def resolve(
373387
base_schema: stp.NamedStruct, registry: ExtensionRegistry
@@ -386,23 +400,30 @@ def resolve(
386400

387401
signature = [typ for es in expression_schemas for typ in es.types]
388402

389-
func = registry.lookup_function(uri, function, signature)
403+
for f in functions:
404+
uri, name = f.split(":")
405+
func = registry.lookup_function(uri, name, signature)
406+
if func:
407+
break
390408

391409
if not func:
392410
raise Exception(f"Unknown function {function} for {signature}")
393411

412+
resolved_func, return_type = func
413+
394414
func_extension_uris = [
395415
ste.SimpleExtensionURI(
396-
extension_uri_anchor=registry.lookup_uri(uri), uri=uri
416+
extension_uri_anchor=registry.lookup_uri(resolved_func.uri),
417+
uri=resolved_func.uri,
397418
)
398419
]
399420

400421
func_extensions = [
401422
ste.SimpleExtensionDeclaration(
402423
extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction(
403-
extension_uri_reference=registry.lookup_uri(uri),
404-
function_anchor=func[0].anchor,
405-
name=str(func[0]),
424+
extension_uri_reference=registry.lookup_uri(resolved_func.uri),
425+
function_anchor=resolved_func.anchor,
426+
name=str(resolved_func),
406427
)
407428
)
408429
]
@@ -431,7 +452,7 @@ def resolve(
431452
)
432453
for e in bound_expressions
433454
],
434-
output_type=func[1],
455+
output_type=return_type,
435456
partitions=[
436457
e.referred_expr[0].expression for e in bound_partitions
437458
],

src/substrait/sql/sql_to_substrait.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,18 +28,21 @@
2828
SchemaResolver = Callable[[str], stt.NamedStruct]
2929

3030
function_mapping = {
31-
"Plus": ("functions_arithmetic.yaml", "add"),
32-
"Minus": ("functions_arithmetic.yaml", "subtract"),
33-
"Gt": ("functions_comparison.yaml", "gt"),
34-
"GtEq": ("functions_comparison.yaml", "gte"),
35-
"Lt": ("functions_comparison.yaml", "lt"),
36-
"Eq": ("functions_comparison.yaml", "equal"),
31+
"Plus": ["functions_arithmetic.yaml:add", "functions_arithmetic_decimal.yaml:add"],
32+
"Minus": [
33+
"functions_arithmetic.yaml:subtract",
34+
"functions_arithmetic_decimal.yaml:subtract",
35+
],
36+
"Gt": ["functions_comparison.yaml:gt"],
37+
"GtEq": ["functions_comparison.yaml:gte"],
38+
"Lt": ["functions_comparison.yaml:lt"],
39+
"Eq": ["functions_comparison.yaml:equal"],
3740
}
3841

39-
aggregate_function_mapping = {"SUM": ("functions_arithmetic.yaml", "sum")}
42+
aggregate_function_mapping = {"SUM": ["functions_arithmetic.yaml:sum"]}
4043

4144
window_function_mapping = {
42-
"row_number": ("functions_arithmetic.yaml", "row_number"),
45+
"row_number": ["functions_arithmetic.yaml:row_number"],
4346
}
4447

4548

@@ -105,7 +108,7 @@ def translate_expression(
105108
),
106109
]
107110
func = function_mapping[ast["op"]]
108-
return scalar_function(func[0], func[1], expressions=expressions, alias=alias)
111+
return scalar_function(func, expressions=expressions, alias=alias)
109112
elif op == "Value":
110113
return literal(
111114
int(ast["value"]["Number"][0]), stt.Type(i64=stt.Type.I64()), alias=alias
@@ -138,7 +141,7 @@ def translate_expression(
138141
random_name = "".join(
139142
random.choices(string.ascii_uppercase + string.digits, k=5)
140143
) # TODO make this deterministic
141-
aggr = aggregate_function(func[0], func[1], expressions, alias=random_name)
144+
aggr = aggregate_function(func, expressions, alias=random_name)
142145
measures.append((aggr, ast, random_name))
143146
return column(random_name, alias=alias)
144147
elif name in window_function_mapping:
@@ -156,7 +159,7 @@ def translate_expression(
156159
]
157160

158161
return window_function(
159-
func[0], func[1], expressions, partitions=partitions, alias=alias
162+
func, expressions, partitions=partitions, alias=alias
160163
)
161164

162165
else:

tests/builders/extended_expression/test_aggregate_function.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,7 @@
4141

4242
def test_aggregate_count():
4343
e = aggregate_function(
44-
"test_uri",
45-
"count",
44+
"test_uri:count",
4645
expressions=[
4746
literal(
4847
10,

tests/builders/extended_expression/test_scalar_function.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,7 @@
4545

4646
def test_sclar_add():
4747
e = scalar_function(
48-
"test_uri",
49-
"test_func",
48+
"test_uri:test_func",
5049
expressions=[
5150
literal(
5251
10,
@@ -68,7 +67,7 @@ def test_sclar_add():
6867
extensions=[
6968
ste.SimpleExtensionDeclaration(
7069
extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction(
71-
extension_uri_reference=1, function_anchor=1, name="test_func:i8"
70+
extension_uri_reference=1, function_anchor=1, name="test_func:i8"
7271
)
7372
)
7473
],
@@ -98,23 +97,24 @@ def test_sclar_add():
9897
),
9998
)
10099
),
101-
output_names=["test_func(Literal(10),Literal(20))"],
100+
output_names=["test_uri:test_func(Literal(10),Literal(20))"],
102101
)
103102
],
104103
base_schema=named_struct,
105104
)
106105

106+
print(e)
107+
print(expected)
108+
107109
assert e == expected
108110

109111

110112
def test_nested_scalar_calls():
111113
e = scalar_function(
112-
"test_uri",
113-
"is_positive",
114+
"test_uri:is_positive",
114115
expressions=[
115116
scalar_function(
116-
"test_uri",
117-
"test_func",
117+
"test_uri:test_func",
118118
expressions=[
119119
literal(
120120
10,

tests/builders/extended_expression/test_window_function.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646

4747

4848
def test_row_number():
49-
e = window_function("test_uri", "row_number", expressions=[], alias="rn")(
49+
e = window_function("test_uri:row_number", expressions=[], alias="rn")(
5050
named_struct, registry
5151
)
5252

tests/builders/plan/test_aggregate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def test_aggregate():
3838

3939
group_expr = column("id")
4040
measure_expr = aggregate_function(
41-
"test_uri", "count", expressions=[column("is_applicable")], alias=["count"]
41+
"test_uri:count", expressions=[column("is_applicable")], alias=["count"]
4242
)
4343

4444
actual = aggregate(

0 commit comments

Comments
 (0)