Skip to content

Commit 1e44dd8

Browse files
authored
feat: extended expression builders (#71)
Adds extended expression builder functions: `column`, `literal` and `scalar_function`
1 parent 912fc09 commit 1e44dd8

File tree

7 files changed

+571
-1
lines changed

7 files changed

+571
-1
lines changed
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
import itertools
2+
import substrait.gen.proto.algebra_pb2 as stalg
3+
import substrait.gen.proto.type_pb2 as stp
4+
import substrait.gen.proto.extended_expression_pb2 as stee
5+
import substrait.gen.proto.extensions.extensions_pb2 as ste
6+
from substrait.function_registry import FunctionRegistry
7+
from substrait.utils import type_num_names, merge_extension_uris, merge_extension_declarations
8+
from substrait.type_inference import infer_extended_expression_schema
9+
from typing import Callable, Any, Union
10+
11+
UnboundExpression = Callable[[stp.NamedStruct, FunctionRegistry], stee.ExtendedExpression]
12+
13+
def literal(value: Any, type: stp.Type, alias: str = None) -> UnboundExpression:
14+
"""Builds a resolver for ExtendedExpression containing a literal expression"""
15+
def resolve(base_schema: stp.NamedStruct, registry: FunctionRegistry) -> stee.ExtendedExpression:
16+
kind = type.WhichOneof('kind')
17+
18+
if kind == "bool":
19+
literal = stalg.Expression.Literal(boolean=value, nullable=type.bool.nullability == stp.Type.NULLABILITY_NULLABLE)
20+
elif kind == "i8":
21+
literal = stalg.Expression.Literal(i8=value, nullable=type.i8.nullability == stp.Type.NULLABILITY_NULLABLE)
22+
elif kind == "i16":
23+
literal = stalg.Expression.Literal(i16=value, nullable=type.i16.nullability == stp.Type.NULLABILITY_NULLABLE)
24+
elif kind == "i32":
25+
literal = stalg.Expression.Literal(i32=value, nullable=type.i32.nullability == stp.Type.NULLABILITY_NULLABLE)
26+
elif kind == "i64":
27+
literal = stalg.Expression.Literal(i64=value, nullable=type.i64.nullability == stp.Type.NULLABILITY_NULLABLE)
28+
elif kind == "fp32":
29+
literal = stalg.Expression.Literal(fp32=value, nullable=type.fp32.nullability == stp.Type.NULLABILITY_NULLABLE)
30+
elif kind == "fp64":
31+
literal = stalg.Expression.Literal(fp64=value, nullable=type.fp64.nullability == stp.Type.NULLABILITY_NULLABLE)
32+
elif kind == "string":
33+
literal = stalg.Expression.Literal(string=value, nullable=type.string.nullability == stp.Type.NULLABILITY_NULLABLE)
34+
else:
35+
raise Exception(f"Unknown literal type - {type}")
36+
37+
return stee.ExtendedExpression(
38+
referred_expr=[
39+
stee.ExpressionReference(
40+
expression=stalg.Expression(
41+
literal=literal
42+
),
43+
output_names=[alias if alias else f'literal_{kind}'],
44+
)
45+
],
46+
base_schema=base_schema,
47+
)
48+
49+
return resolve
50+
51+
def column(field: Union[str, int]):
52+
"""Builds a resolver for ExtendedExpression containing a FieldReference expression
53+
54+
Accepts either an index or a field name of a desired field.
55+
"""
56+
def resolve(base_schema: stp.NamedStruct, registry: FunctionRegistry) -> stee.ExtendedExpression:
57+
if isinstance(field, str):
58+
column_index = list(base_schema.names).index(field)
59+
lengths = [type_num_names(t) for t in base_schema.struct.types]
60+
flat_indices = [0] + list(itertools.accumulate(lengths))[:-1]
61+
field_index = flat_indices.index(column_index)
62+
else:
63+
field_index = field
64+
65+
names_start = flat_indices[field_index]
66+
names_end = (
67+
flat_indices[field_index + 1]
68+
if len(flat_indices) > field_index + 1
69+
else None
70+
)
71+
72+
return stee.ExtendedExpression(
73+
referred_expr=[
74+
stee.ExpressionReference(
75+
expression=stalg.Expression(
76+
selection=stalg.Expression.FieldReference(
77+
root_reference=stalg.Expression.FieldReference.RootReference(),
78+
direct_reference=stalg.Expression.ReferenceSegment(
79+
struct_field=stalg.Expression.ReferenceSegment.StructField(
80+
field=field_index
81+
)
82+
),
83+
)
84+
),
85+
output_names=list(base_schema.names)[names_start:names_end],
86+
)
87+
],
88+
base_schema=base_schema,
89+
)
90+
91+
return resolve
92+
93+
def scalar_function(uri: str, function: str, *expressions: UnboundExpression, alias: str = None):
94+
"""Builds a resolver for ExtendedExpression containing a ScalarFunction expression"""
95+
def resolve(base_schema: stp.NamedStruct, registry: FunctionRegistry) -> stee.ExtendedExpression:
96+
bound_expressions: list[stee.ExtendedExpression] = [e(base_schema, registry) for e in expressions]
97+
98+
expression_schemas = [infer_extended_expression_schema(b) for b in bound_expressions]
99+
100+
signature = [typ for es in expression_schemas for typ in es.types]
101+
102+
func = registry.lookup_function(uri, function, signature)
103+
104+
if not func:
105+
raise Exception('')
106+
107+
func_extension_uris = [
108+
ste.SimpleExtensionURI(
109+
extension_uri_anchor=registry.lookup_uri(uri),
110+
uri=uri
111+
)
112+
]
113+
114+
func_extensions = [
115+
ste.SimpleExtensionDeclaration(
116+
extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction(
117+
extension_uri_reference=registry.lookup_uri(uri),
118+
function_anchor=func[0].anchor,
119+
name=function
120+
)
121+
)
122+
]
123+
124+
extension_uris = merge_extension_uris(
125+
func_extension_uris,
126+
*[b.extension_uris for b in bound_expressions]
127+
)
128+
129+
extensions = merge_extension_declarations(
130+
func_extensions,
131+
*[b.extensions for b in bound_expressions]
132+
)
133+
134+
return stee.ExtendedExpression(
135+
referred_expr=[
136+
stee.ExpressionReference(
137+
expression=stalg.Expression(
138+
scalar_function=stalg.Expression.ScalarFunction(
139+
function_reference=func[0].anchor,
140+
arguments=[
141+
stalg.FunctionArgument(
142+
value=e.referred_expr[0].expression
143+
) for e in bound_expressions
144+
],
145+
output_type=func[1]
146+
)
147+
),
148+
output_names=[alias if alias else 'scalar_function'],
149+
)
150+
],
151+
base_schema=base_schema,
152+
extension_uris=extension_uris,
153+
extensions=extensions
154+
)
155+
156+
return resolve

src/substrait/function_registry.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from substrait.gen.proto.parameterized_types_pb2 import ParameterizedType
21
from substrait.gen.proto.type_pb2 import Type
32
from importlib.resources import files as importlib_files
43
import itertools
@@ -227,6 +226,9 @@ def satisfies_signature(self, signature: tuple) -> Optional[str]:
227226

228227
class FunctionRegistry:
229228
def __init__(self, load_default_extensions=True) -> None:
229+
self._uri_mapping: dict = defaultdict(dict)
230+
self._uri_id_generator = itertools.count(1)
231+
230232
self._function_mapping: dict = defaultdict(dict)
231233
self._id_generator = itertools.count(1)
232234

@@ -252,6 +254,8 @@ def register_extension_yaml(
252254
self.register_extension_dict(extension_definitions, uri)
253255

254256
def register_extension_dict(self, definitions: dict, uri: str) -> None:
257+
self._uri_mapping[uri] = next(self._uri_id_generator)
258+
255259
for named_functions in definitions.values():
256260
for function in named_functions:
257261
for impl in function.get("impls", []):
@@ -285,3 +289,7 @@ def lookup_function(
285289
return (f, rtn)
286290

287291
return None
292+
293+
def lookup_uri(self, uri: str) -> Optional[int]:
294+
uri = self._uri_aliases.get(uri, uri)
295+
return self._uri_mapping.get(uri, None)

src/substrait/type_inference.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import substrait.gen.proto.algebra_pb2 as stalg
2+
import substrait.gen.proto.extended_expression_pb2 as stee
23
import substrait.gen.proto.type_pb2 as stt
34

45

@@ -220,6 +221,17 @@ def infer_expression_type(
220221
raise Exception(f"Unknown rex_type {rex_type}")
221222

222223

224+
def infer_extended_expression_schema(ee: stee.ExtendedExpression) -> stt.Type.Struct:
225+
exprs = [e for e in ee.referred_expr]
226+
227+
types = [infer_expression_type(e.expression, ee.base_schema.struct) for e in exprs]
228+
229+
return stt.Type.Struct(
230+
types=types,
231+
nullability=stt.Type.NULLABILITY_REQUIRED,
232+
)
233+
234+
223235
def infer_rel_schema(rel: stalg.Rel) -> stt.Type.Struct:
224236
rel_type = rel.WhichOneof("rel_type")
225237

src/substrait/utils.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import substrait.gen.proto.type_pb2 as stp
2+
import substrait.gen.proto.extensions.extensions_pb2 as ste
3+
from typing import Iterable
4+
5+
def type_num_names(typ: stp.Type):
6+
kind = typ.WhichOneof("kind")
7+
if kind == "struct":
8+
lengths = [type_num_names(t) for t in typ.struct.types]
9+
return sum(lengths) + 1
10+
elif kind == "list":
11+
return type_num_names(typ.list.type)
12+
elif kind == "map":
13+
return type_num_names(typ.map.key) + type_num_names(typ.map.value)
14+
else:
15+
return 1
16+
17+
def merge_extension_uris(*extension_uris: Iterable[ste.SimpleExtensionURI]):
18+
"""Merges multiple sets of SimpleExtensionURI objects into a single set.
19+
The order of extensions is kept intact, while duplicates are discarded.
20+
Assumes that there are no collisions (different extensions having identical anchors).
21+
"""
22+
seen_uris = set()
23+
ret = []
24+
25+
for uris in extension_uris:
26+
for uri in uris:
27+
if uri.uri not in seen_uris:
28+
seen_uris.add(uri.uri)
29+
ret.append(uri)
30+
31+
return ret
32+
33+
def merge_extension_declarations(*extension_declarations: Iterable[ste.SimpleExtensionDeclaration]):
34+
"""Merges multiple sets of SimpleExtensionDeclaration objects into a single set.
35+
The order of extension declarations is kept intact, while duplicates are discarded.
36+
Assumes that there are no collisions (different extension declarations having identical anchors).
37+
"""
38+
39+
seen_extension_functions = set()
40+
ret = []
41+
42+
for declarations in extension_declarations:
43+
for declaration in declarations:
44+
if declaration.WhichOneof('mapping_type') == 'extension_function':
45+
ident = (declaration.extension_function.extension_uri_reference, declaration.extension_function.name)
46+
if ident not in seen_extension_functions:
47+
seen_extension_functions.add(ident)
48+
ret.append(declaration)
49+
else:
50+
raise Exception('') #TODO handle extension types
51+
52+
return ret
53+
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import substrait.gen.proto.algebra_pb2 as stalg
2+
import substrait.gen.proto.type_pb2 as stt
3+
import substrait.gen.proto.extended_expression_pb2 as stee
4+
from substrait.extended_expression import column
5+
6+
7+
struct = stt.Type.Struct(
8+
types=[
9+
stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)),
10+
stt.Type(string=stt.Type.String(nullability=stt.Type.NULLABILITY_NULLABLE)),
11+
stt.Type(fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_NULLABLE)),
12+
]
13+
)
14+
15+
named_struct = stt.NamedStruct(
16+
names=["order_id", "description", "order_total"], struct=struct
17+
)
18+
19+
nested_struct = stt.Type.Struct(
20+
types=[
21+
stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)),
22+
stt.Type(
23+
struct=stt.Type.Struct(
24+
types=[
25+
stt.Type(
26+
i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)
27+
),
28+
stt.Type(
29+
fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_NULLABLE)
30+
),
31+
],
32+
nullability=stt.Type.NULLABILITY_NULLABLE,
33+
)
34+
),
35+
stt.Type(fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_NULLABLE)),
36+
]
37+
)
38+
39+
nested_named_struct = stt.NamedStruct(
40+
names=["order_id", "shop_details", "shop_id", "shop_total", "order_total"],
41+
struct=nested_struct,
42+
)
43+
44+
45+
def test_column_no_nesting():
46+
assert column("description")(named_struct, None) == stee.ExtendedExpression(
47+
referred_expr=[
48+
stee.ExpressionReference(
49+
expression=stalg.Expression(
50+
selection=stalg.Expression.FieldReference(
51+
root_reference=stalg.Expression.FieldReference.RootReference(),
52+
direct_reference=stalg.Expression.ReferenceSegment(
53+
struct_field=stalg.Expression.ReferenceSegment.StructField(
54+
field=1
55+
)
56+
),
57+
)
58+
),
59+
output_names=["description"],
60+
)
61+
],
62+
base_schema=named_struct,
63+
)
64+
65+
66+
def test_column_nesting():
67+
assert column("order_total")(nested_named_struct, None) == stee.ExtendedExpression(
68+
referred_expr=[
69+
stee.ExpressionReference(
70+
expression=stalg.Expression(
71+
selection=stalg.Expression.FieldReference(
72+
root_reference=stalg.Expression.FieldReference.RootReference(),
73+
direct_reference=stalg.Expression.ReferenceSegment(
74+
struct_field=stalg.Expression.ReferenceSegment.StructField(
75+
field=2
76+
)
77+
),
78+
)
79+
),
80+
output_names=["order_total"],
81+
)
82+
],
83+
base_schema=nested_named_struct,
84+
)
85+
86+
87+
def test_column_nested_struct():
88+
assert column("shop_details")(nested_named_struct, None) == stee.ExtendedExpression(
89+
referred_expr=[
90+
stee.ExpressionReference(
91+
expression=stalg.Expression(
92+
selection=stalg.Expression.FieldReference(
93+
root_reference=stalg.Expression.FieldReference.RootReference(),
94+
direct_reference=stalg.Expression.ReferenceSegment(
95+
struct_field=stalg.Expression.ReferenceSegment.StructField(
96+
field=1
97+
)
98+
),
99+
)
100+
),
101+
output_names=["shop_details", "shop_id", "shop_total"],
102+
)
103+
],
104+
base_schema=nested_named_struct,
105+
)

0 commit comments

Comments
 (0)