Skip to content

Commit 9f9eb42

Browse files
authored
Merge branch 'main' into project-select
2 parents 0b68a6d + 02a65f4 commit 9f9eb42

27 files changed

+946
-224
lines changed

.github/workflows/codegen-check.yml

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@ jobs:
2020
uses: devcontainers/ci@v0.3
2121
with:
2222
runCmd: |
23+
# fetch submodule tags since actions/checkout does not
24+
git submodule foreach 'git fetch --unshallow || true'
2325
# Ensure dependencies are installed
2426
uv sync --extra test --extra gen_proto
2527
# Run all code generation steps
26-
make antlr
27-
./gen_proto.sh
28-
make codegen-extensions
28+
make codegen
2929
3030
- name: Check for uncommitted changes
3131
run: |
@@ -36,9 +36,7 @@ jobs:
3636
git diff src/substrait/gen/
3737
echo ""
3838
echo "To fix this, run:"
39-
echo " make antlr"
40-
echo " ./gen_proto.sh"
41-
echo " make codegen-extensions"
39+
echo " make codegen"
4240
echo "Then commit the changes."
4341
exit 1
4442
fi

CONTRIBUTING.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@ git submodule update --init --recursive
2222

2323
# Code generation
2424

25+
You can run the full code generation using the following command or use the individual commands to selectively regenerate the generated code. This does not update the Substrait Git submodule.
26+
27+
```
28+
make codegen
29+
```
30+
2531
## Protobuf stubs
2632

2733
Run the upgrade script to upgrade the submodule and regenerate the protobuf stubs.
@@ -31,6 +37,12 @@ uv sync --extra gen_proto
3137
uv run ./update_proto.sh <version>
3238
```
3339

40+
Or run the proto codegen without updating the Substrait Git submodule:
41+
42+
```
43+
make codegen-proto
44+
```
45+
3446
## Antlr grammar
3547

3648
Substrait uses antlr grammar to derive output types of extension functions. Make sure java is installed and ANTLR_JAR environment variable is set. Take a look at .devcontainer/Dockerfile for example setup.

Makefile

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,27 @@
1+
codegen: antlr codegen-proto codegen-extensions codegen-version
2+
3+
14
antlr:
25
cd third_party/substrait/grammar \
36
&& java -jar ${ANTLR_JAR} -o ../../../src/substrait/gen/antlr -Dlanguage=Python3 SubstraitType.g4 \
47
&& rm ../../../src/substrait/gen/antlr/*.tokens \
58
&& rm ../../../src/substrait/gen/antlr/*.interp
69

10+
codegen-version:
11+
echo -n 'substrait_version = "' > src/substrait/gen/version.py \
12+
&& cd third_party/substrait && git describe --tags | tr -d 'v\n' >> ../../src/substrait/gen/version.py && cd ../.. \
13+
&& echo '"' >> src/substrait/gen/version.py
14+
15+
codegen-proto:
16+
./gen_proto.sh
17+
718
codegen-extensions:
819
uv run --with datamodel-code-generator datamodel-codegen \
920
--input-file-type jsonschema \
1021
--input third_party/substrait/text/simple_extensions_schema.yaml \
1122
--output src/substrait/gen/json/simple_extensions.py \
1223
--output-model-type dataclasses.dataclass \
24+
--target-python-version 3.10 \
1325
--disable-timestamp
1426

1527
lint:

src/substrait/builders/extended_expression.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
1-
from datetime import date
21
import itertools
2+
from datetime import date
3+
from typing import Any, Callable, Iterable, Union
4+
35
import substrait.gen.proto.algebra_pb2 as stalg
4-
import substrait.gen.proto.type_pb2 as stp
56
import substrait.gen.proto.extended_expression_pb2 as stee
67
import substrait.gen.proto.extensions.extensions_pb2 as ste
8+
import substrait.gen.proto.type_pb2 as stp
79
from substrait.extension_registry import ExtensionRegistry
10+
from substrait.type_inference import infer_extended_expression_schema
811
from substrait.utils import (
9-
type_num_names,
10-
merge_extension_urns,
11-
merge_extension_uris,
1212
merge_extension_declarations,
13+
merge_extension_uris,
14+
merge_extension_urns,
15+
type_num_names,
1316
)
14-
from substrait.type_inference import infer_extended_expression_schema
15-
from typing import Callable, Any, Union, Iterable
1617

1718
UnboundExtendedExpression = Callable[
1819
[stp.NamedStruct, ExtensionRegistry], stee.ExtendedExpression
@@ -21,7 +22,7 @@
2122

2223

2324
def _alias_or_inferred(
24-
alias: Union[Iterable[str], str],
25+
alias: Union[Iterable[str], str, None],
2526
op: str,
2627
args: Iterable[str],
2728
):
@@ -44,7 +45,7 @@ def resolve_expression(
4445

4546

4647
def literal(
47-
value: Any, type: stp.Type, alias: Union[Iterable[str], str] = None
48+
value: Any, type: stp.Type, alias: Union[Iterable[str], str, None] = None
4849
) -> UnboundExtendedExpression:
4950
"""Builds a resolver for ExtendedExpression containing a literal expression"""
5051

@@ -154,7 +155,7 @@ def resolve(
154155
return resolve
155156

156157

157-
def column(field: Union[str, int], alias: Union[Iterable[str], str] = None):
158+
def column(field: Union[str, int], alias: Union[Iterable[str], str, None] = None):
158159
"""Builds a resolver for ExtendedExpression containing a FieldReference expression
159160
160161
Accepts either an index or a field name of a desired field.
@@ -208,7 +209,7 @@ def scalar_function(
208209
urn: str,
209210
function: str,
210211
expressions: Iterable[ExtendedExpressionOrUnbound],
211-
alias: Union[Iterable[str], str] = None,
212+
alias: Union[Iterable[str], str, None] = None,
212213
):
213214
"""Builds a resolver for ExtendedExpression containing a ScalarFunction expression"""
214215

@@ -306,7 +307,7 @@ def aggregate_function(
306307
urn: str,
307308
function: str,
308309
expressions: Iterable[ExtendedExpressionOrUnbound],
309-
alias: Union[Iterable[str], str] = None,
310+
alias: Union[Iterable[str], str, None] = None,
310311
):
311312
"""Builds a resolver for ExtendedExpression containing a AggregateFunction measure"""
312313

@@ -402,7 +403,7 @@ def window_function(
402403
function: str,
403404
expressions: Iterable[ExtendedExpressionOrUnbound],
404405
partitions: Iterable[ExtendedExpressionOrUnbound] = [],
405-
alias: Union[Iterable[str], str] = None,
406+
alias: Union[Iterable[str], str, None] = None,
406407
):
407408
"""Builds a resolver for ExtendedExpression containing a WindowFunction expression"""
408409

@@ -512,7 +513,7 @@ def resolve(
512513
def if_then(
513514
ifs: Iterable[tuple[ExtendedExpressionOrUnbound, ExtendedExpressionOrUnbound]],
514515
_else: ExtendedExpressionOrUnbound,
515-
alias: Union[Iterable[str], str] = None,
516+
alias: Union[Iterable[str], str, None] = None,
516517
):
517518
"""Builds a resolver for ExtendedExpression containing an IfThen expression"""
518519

@@ -767,7 +768,11 @@ def resolve(
767768
return resolve
768769

769770

770-
def cast(input: ExtendedExpressionOrUnbound, type: stp.Type):
771+
def cast(
772+
input: ExtendedExpressionOrUnbound,
773+
type: stp.Type,
774+
alias: Union[Iterable[str], str, None] = None,
775+
):
771776
"""Builds a resolver for ExtendedExpression containing a cast expression"""
772777

773778
def resolve(
@@ -785,7 +790,9 @@ def resolve(
785790
failure_behavior=stalg.Expression.Cast.FAILURE_BEHAVIOR_RETURN_NULL,
786791
)
787792
),
788-
output_names=["cast"], # TODO construct name from inputs
793+
output_names=_alias_or_inferred(
794+
alias, "cast", [bound_input.referred_expr[0].output_names[0]]
795+
),
789796
)
790797
],
791798
base_schema=base_schema,

src/substrait/builders/plan.py

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,30 +5,46 @@
55
See `examples/builder_example.py` for usage.
66
"""
77

8-
from typing import Iterable, Optional, Union, Callable
8+
import re
9+
from typing import Callable, Iterable, Optional, Union
910

1011
import substrait.gen.proto.algebra_pb2 as stalg
11-
from substrait.gen.proto.extensions.extensions_pb2 import AdvancedExtension
12+
import substrait.gen.proto.extended_expression_pb2 as stee
1213
import substrait.gen.proto.plan_pb2 as stp
1314
import substrait.gen.proto.type_pb2 as stt
14-
import substrait.gen.proto.extended_expression_pb2 as stee
15-
from substrait.extension_registry import ExtensionRegistry
1615
from substrait.builders.extended_expression import (
1716
ExtendedExpressionOrUnbound,
1817
resolve_expression,
1918
)
19+
from substrait.extension_registry import ExtensionRegistry
20+
from substrait.gen.proto.extensions.extensions_pb2 import AdvancedExtension
2021
from substrait.type_inference import infer_plan_schema
2122
from substrait.utils import (
2223
merge_extension_declarations,
23-
merge_extension_urns,
2424
merge_extension_uris,
25+
merge_extension_urns,
2526
)
27+
from substrait.gen.version import substrait_version
2628

2729
UnboundPlan = Callable[[ExtensionRegistry], stp.Plan]
2830

2931
PlanOrUnbound = Union[stp.Plan, UnboundPlan]
3032

3133

34+
def _create_default_version():
35+
p = re.compile(r"(\d+)\.(\d+)\.(\d+)")
36+
m = p.match(substrait_version)
37+
global default_version
38+
default_version = stp.Version(
39+
major_number=int(m.group(1)),
40+
minor_number=int(m.group(2)),
41+
patch_number=int(m.group(3)),
42+
)
43+
44+
45+
_create_default_version()
46+
47+
3248
def _merge_extensions(*objs):
3349
"""Merge extension URIs, URNs, and declarations from multiple plan/expression objects.
3450
@@ -65,9 +81,10 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
6581
)
6682

6783
return stp.Plan(
84+
version=default_version,
6885
relations=[
6986
stp.PlanRel(root=stalg.RelRoot(input=rel, names=named_struct.names))
70-
]
87+
],
7188
)
7289

7390
return resolve
@@ -169,6 +186,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
169186
)
170187

171188
return stp.Plan(
189+
version=default_version,
172190
relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=names))],
173191
**_merge_extensions(_plan, *bound_expressions),
174192
)
@@ -199,6 +217,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
199217
names = ns.names
200218

201219
return stp.Plan(
220+
version=default_version,
202221
relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=names))],
203222
**_merge_extensions(bound_plan, bound_expression),
204223
)
@@ -245,6 +264,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
245264
)
246265

247266
return stp.Plan(
267+
version=default_version,
248268
relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=ns.names))],
249269
**_merge_extensions(bound_plan, *[e[0] for e in bound_expressions]),
250270
)
@@ -262,6 +282,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
262282
)
263283

264284
return stp.Plan(
285+
version=default_version,
265286
relations=[
266287
stp.PlanRel(
267288
root=stalg.RelRoot(
@@ -300,6 +321,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
300321
)
301322

302323
return stp.Plan(
324+
version=default_version,
303325
relations=[
304326
stp.PlanRel(
305327
root=stalg.RelRoot(
@@ -348,6 +370,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
348370
)
349371

350372
return stp.Plan(
373+
version=default_version,
351374
relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=ns.names))],
352375
**_merge_extensions(bound_left, bound_right, bound_expression),
353376
)
@@ -383,6 +406,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
383406
)
384407

385408
return stp.Plan(
409+
version=default_version,
386410
relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=ns.names))],
387411
**_merge_extensions(bound_left, bound_right),
388412
)
@@ -434,10 +458,41 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
434458
] + [e.referred_expr[0].output_names[0] for e in bound_measures]
435459

436460
return stp.Plan(
461+
version=default_version,
437462
relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=names))],
438463
**_merge_extensions(
439464
bound_input, *bound_grouping_expressions, *bound_measures
440465
),
441466
)
442467

443468
return resolve
469+
470+
471+
def write_named_table(
472+
table_names: Union[str, Iterable[str]],
473+
input: PlanOrUnbound,
474+
create_mode: Union[stalg.WriteRel.CreateMode.ValueType, None] = None,
475+
) -> UnboundPlan:
476+
def resolve(registry: ExtensionRegistry) -> stp.Plan:
477+
bound_input = input if isinstance(input, stp.Plan) else input(registry)
478+
ns = infer_plan_schema(bound_input)
479+
_table_names = [table_names] if isinstance(table_names, str) else table_names
480+
_create_mode = create_mode or stalg.WriteRel.CREATE_MODE_ERROR_IF_EXISTS
481+
482+
write_rel = stalg.Rel(
483+
write=stalg.WriteRel(
484+
input=bound_input.relations[-1].root.input,
485+
table_schema=ns,
486+
op=stalg.WriteRel.WRITE_OP_CTAS,
487+
create_mode=_create_mode,
488+
named_table=stalg.NamedObjectWrite(names=_table_names),
489+
)
490+
)
491+
return stp.Plan(
492+
relations=[
493+
stp.PlanRel(root=stalg.RelRoot(input=write_rel, names=ns.names))
494+
],
495+
**_merge_extensions(bound_input),
496+
)
497+
498+
return resolve

src/substrait/builders/type.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Iterable
2+
23
import substrait.gen.proto.type_pb2 as stt
34

45

0 commit comments

Comments
 (0)