Skip to content

Commit 1fce1dd

Browse files
authored
Merge branch 'main' into feat/build-unify-setup
2 parents 3b633ef + 9bb6bb9 commit 1fce1dd

22 files changed

+376
-137
lines changed

.github/workflows/codegen-check.yml

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@ jobs:
2020
uses: devcontainers/ci@v0.3
2121
with:
2222
runCmd: |
23+
# fetch submodule tags since actions/checkout does not
24+
git submodule foreach 'git fetch --unshallow || true'
2325
# Ensure dependencies are installed
2426
uv sync --extra test --extra gen_proto
2527
# Run all code generation steps
26-
make antlr
27-
./gen_proto.sh
28-
make codegen-extensions
28+
make codegen
2929
3030
- name: Check for uncommitted changes
3131
run: |
@@ -36,9 +36,7 @@ jobs:
3636
git diff src/substrait/gen/
3737
echo ""
3838
echo "To fix this, run:"
39-
echo " make antlr"
40-
echo " ./gen_proto.sh"
41-
echo " make codegen-extensions"
39+
echo " make codegen"
4240
echo "Then commit the changes."
4341
exit 1
4442
fi

CONTRIBUTING.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@ git submodule update --init --recursive
2222

2323
# Code generation
2424

25+
You can run the full code generation using the following command or use the individual commands to selectively regenerate the generated code. This does not update the Substrait Git submodule.
26+
27+
```
28+
make codegen
29+
```
30+
2531
## Protobuf stubs
2632

2733
Run the upgrade script to upgrade the submodule and regenerate the protobuf stubs.
@@ -31,6 +37,12 @@ uv sync --extra gen_proto
3137
uv run scripts/update_proto.sh <version>
3238
```
3339

40+
Or run the proto codegen without updating the Substrait Git submodule:
41+
42+
```
43+
make codegen-proto
44+
```
45+
3446
## Antlr grammar
3547

3648
Substrait uses antlr grammar to derive output types of extension functions. Make sure java is installed and antlr, by using running `make setup-antlr`.

Makefile

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,24 @@
11
setup-antlr:
22
@bash scripts/setup_antlr.sh > /dev/null
33

4+
5+
codegen: antlr codegen-proto codegen-extensions codegen-version
6+
7+
48
antlr: setup-antlr
59
cd third_party/substrait/grammar \
610
&& java -jar ../../../lib/antlr-complete.jar -o ../../../src/substrait/gen/antlr -Dlanguage=Python3 SubstraitType.g4 \
711
&& rm ../../../src/substrait/gen/antlr/*.tokens \
812
&& rm ../../../src/substrait/gen/antlr/*.interp
913

14+
codegen-version:
15+
echo -n 'substrait_version = "' > src/substrait/gen/version.py \
16+
&& cd third_party/substrait && git describe --tags | tr -d 'v\n' >> ../../src/substrait/gen/version.py && cd ../.. \
17+
&& echo '"' >> src/substrait/gen/version.py
18+
19+
codegen-proto:
20+
./gen_proto.sh
21+
1022
codegen-extensions:
1123
uv run --with datamodel-code-generator datamodel-codegen \
1224
--input-file-type jsonschema \

src/substrait/builders/extended_expression.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
1-
from datetime import date
21
import itertools
2+
from datetime import date
3+
from typing import Any, Callable, Iterable, Union
4+
35
import substrait.gen.proto.algebra_pb2 as stalg
4-
import substrait.gen.proto.type_pb2 as stp
56
import substrait.gen.proto.extended_expression_pb2 as stee
67
import substrait.gen.proto.extensions.extensions_pb2 as ste
8+
import substrait.gen.proto.type_pb2 as stp
79
from substrait.extension_registry import ExtensionRegistry
10+
from substrait.type_inference import infer_extended_expression_schema
811
from substrait.utils import (
9-
type_num_names,
10-
merge_extension_urns,
11-
merge_extension_uris,
1212
merge_extension_declarations,
13+
merge_extension_uris,
14+
merge_extension_urns,
15+
type_num_names,
1316
)
14-
from substrait.type_inference import infer_extended_expression_schema
15-
from typing import Callable, Any, Union, Iterable
1617

1718
UnboundExtendedExpression = Callable[
1819
[stp.NamedStruct, ExtensionRegistry], stee.ExtendedExpression
@@ -21,7 +22,7 @@
2122

2223

2324
def _alias_or_inferred(
24-
alias: Union[Iterable[str], str],
25+
alias: Union[Iterable[str], str, None],
2526
op: str,
2627
args: Iterable[str],
2728
):
@@ -44,7 +45,7 @@ def resolve_expression(
4445

4546

4647
def literal(
47-
value: Any, type: stp.Type, alias: Union[Iterable[str], str] = None
48+
value: Any, type: stp.Type, alias: Union[Iterable[str], str, None] = None
4849
) -> UnboundExtendedExpression:
4950
"""Builds a resolver for ExtendedExpression containing a literal expression"""
5051

@@ -154,7 +155,7 @@ def resolve(
154155
return resolve
155156

156157

157-
def column(field: Union[str, int], alias: Union[Iterable[str], str] = None):
158+
def column(field: Union[str, int], alias: Union[Iterable[str], str, None] = None):
158159
"""Builds a resolver for ExtendedExpression containing a FieldReference expression
159160
160161
Accepts either an index or a field name of a desired field.
@@ -208,7 +209,7 @@ def scalar_function(
208209
urn: str,
209210
function: str,
210211
expressions: Iterable[ExtendedExpressionOrUnbound],
211-
alias: Union[Iterable[str], str] = None,
212+
alias: Union[Iterable[str], str, None] = None,
212213
):
213214
"""Builds a resolver for ExtendedExpression containing a ScalarFunction expression"""
214215

@@ -306,7 +307,7 @@ def aggregate_function(
306307
urn: str,
307308
function: str,
308309
expressions: Iterable[ExtendedExpressionOrUnbound],
309-
alias: Union[Iterable[str], str] = None,
310+
alias: Union[Iterable[str], str, None] = None,
310311
):
311312
"""Builds a resolver for ExtendedExpression containing a AggregateFunction measure"""
312313

@@ -402,7 +403,7 @@ def window_function(
402403
function: str,
403404
expressions: Iterable[ExtendedExpressionOrUnbound],
404405
partitions: Iterable[ExtendedExpressionOrUnbound] = [],
405-
alias: Union[Iterable[str], str] = None,
406+
alias: Union[Iterable[str], str, None] = None,
406407
):
407408
"""Builds a resolver for ExtendedExpression containing a WindowFunction expression"""
408409

@@ -512,7 +513,7 @@ def resolve(
512513
def if_then(
513514
ifs: Iterable[tuple[ExtendedExpressionOrUnbound, ExtendedExpressionOrUnbound]],
514515
_else: ExtendedExpressionOrUnbound,
515-
alias: Union[Iterable[str], str] = None,
516+
alias: Union[Iterable[str], str, None] = None,
516517
):
517518
"""Builds a resolver for ExtendedExpression containing an IfThen expression"""
518519

@@ -767,7 +768,11 @@ def resolve(
767768
return resolve
768769

769770

770-
def cast(input: ExtendedExpressionOrUnbound, type: stp.Type):
771+
def cast(
772+
input: ExtendedExpressionOrUnbound,
773+
type: stp.Type,
774+
alias: Union[Iterable[str], str, None] = None,
775+
):
771776
"""Builds a resolver for ExtendedExpression containing a cast expression"""
772777

773778
def resolve(
@@ -785,7 +790,9 @@ def resolve(
785790
failure_behavior=stalg.Expression.Cast.FAILURE_BEHAVIOR_RETURN_NULL,
786791
)
787792
),
788-
output_names=["cast"], # TODO construct name from inputs
793+
output_names=_alias_or_inferred(
794+
alias, "cast", [bound_input.referred_expr[0].output_names[0]]
795+
),
789796
)
790797
],
791798
base_schema=base_schema,

src/substrait/builders/plan.py

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,30 +5,46 @@
55
See `examples/builder_example.py` for usage.
66
"""
77

8-
from typing import Iterable, Optional, Union, Callable
8+
import re
9+
from typing import Callable, Iterable, Optional, Union
910

1011
import substrait.gen.proto.algebra_pb2 as stalg
11-
from substrait.gen.proto.extensions.extensions_pb2 import AdvancedExtension
12+
import substrait.gen.proto.extended_expression_pb2 as stee
1213
import substrait.gen.proto.plan_pb2 as stp
1314
import substrait.gen.proto.type_pb2 as stt
14-
import substrait.gen.proto.extended_expression_pb2 as stee
15-
from substrait.extension_registry import ExtensionRegistry
1615
from substrait.builders.extended_expression import (
1716
ExtendedExpressionOrUnbound,
1817
resolve_expression,
1918
)
19+
from substrait.extension_registry import ExtensionRegistry
20+
from substrait.gen.proto.extensions.extensions_pb2 import AdvancedExtension
2021
from substrait.type_inference import infer_plan_schema
2122
from substrait.utils import (
2223
merge_extension_declarations,
23-
merge_extension_urns,
2424
merge_extension_uris,
25+
merge_extension_urns,
2526
)
27+
from substrait.gen.version import substrait_version
2628

2729
UnboundPlan = Callable[[ExtensionRegistry], stp.Plan]
2830

2931
PlanOrUnbound = Union[stp.Plan, UnboundPlan]
3032

3133

34+
def _create_default_version():
35+
p = re.compile(r"(\d+)\.(\d+)\.(\d+)")
36+
m = p.match(substrait_version)
37+
global default_version
38+
default_version = stp.Version(
39+
major_number=int(m.group(1)),
40+
minor_number=int(m.group(2)),
41+
patch_number=int(m.group(3)),
42+
)
43+
44+
45+
_create_default_version()
46+
47+
3248
def _merge_extensions(*objs):
3349
"""Merge extension URIs, URNs, and declarations from multiple plan/expression objects.
3450
@@ -65,9 +81,10 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
6581
)
6682

6783
return stp.Plan(
84+
version=default_version,
6885
relations=[
6986
stp.PlanRel(root=stalg.RelRoot(input=rel, names=named_struct.names))
70-
]
87+
],
7188
)
7289

7390
return resolve
@@ -107,6 +124,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
107124
)
108125

109126
return stp.Plan(
127+
version=default_version,
110128
relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=names))],
111129
**_merge_extensions(_plan, *bound_expressions),
112130
)
@@ -137,6 +155,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
137155
names = ns.names
138156

139157
return stp.Plan(
158+
version=default_version,
140159
relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=names))],
141160
**_merge_extensions(bound_plan, bound_expression),
142161
)
@@ -183,6 +202,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
183202
)
184203

185204
return stp.Plan(
205+
version=default_version,
186206
relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=ns.names))],
187207
**_merge_extensions(bound_plan, *[e[0] for e in bound_expressions]),
188208
)
@@ -200,6 +220,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
200220
)
201221

202222
return stp.Plan(
223+
version=default_version,
203224
relations=[
204225
stp.PlanRel(
205226
root=stalg.RelRoot(
@@ -238,6 +259,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
238259
)
239260

240261
return stp.Plan(
262+
version=default_version,
241263
relations=[
242264
stp.PlanRel(
243265
root=stalg.RelRoot(
@@ -286,6 +308,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
286308
)
287309

288310
return stp.Plan(
311+
version=default_version,
289312
relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=ns.names))],
290313
**_merge_extensions(bound_left, bound_right, bound_expression),
291314
)
@@ -321,6 +344,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
321344
)
322345

323346
return stp.Plan(
347+
version=default_version,
324348
relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=ns.names))],
325349
**_merge_extensions(bound_left, bound_right),
326350
)
@@ -372,10 +396,41 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
372396
] + [e.referred_expr[0].output_names[0] for e in bound_measures]
373397

374398
return stp.Plan(
399+
version=default_version,
375400
relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=names))],
376401
**_merge_extensions(
377402
bound_input, *bound_grouping_expressions, *bound_measures
378403
),
379404
)
380405

381406
return resolve
407+
408+
409+
def write_named_table(
410+
table_names: Union[str, Iterable[str]],
411+
input: PlanOrUnbound,
412+
create_mode: Union[stalg.WriteRel.CreateMode.ValueType, None] = None,
413+
) -> UnboundPlan:
414+
def resolve(registry: ExtensionRegistry) -> stp.Plan:
415+
bound_input = input if isinstance(input, stp.Plan) else input(registry)
416+
ns = infer_plan_schema(bound_input)
417+
_table_names = [table_names] if isinstance(table_names, str) else table_names
418+
_create_mode = create_mode or stalg.WriteRel.CREATE_MODE_ERROR_IF_EXISTS
419+
420+
write_rel = stalg.Rel(
421+
write=stalg.WriteRel(
422+
input=bound_input.relations[-1].root.input,
423+
table_schema=ns,
424+
op=stalg.WriteRel.WRITE_OP_CTAS,
425+
create_mode=_create_mode,
426+
named_table=stalg.NamedObjectWrite(names=_table_names),
427+
)
428+
)
429+
return stp.Plan(
430+
relations=[
431+
stp.PlanRel(root=stalg.RelRoot(input=write_rel, names=ns.names))
432+
],
433+
**_merge_extensions(bound_input),
434+
)
435+
436+
return resolve

src/substrait/gen/__init__.pyi

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/substrait/gen/version.py

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)