55See `examples/builder_example.py` for usage.
66"""
77
8- from typing import Iterable , Optional , Union , Callable
8+ import re
9+ from typing import Callable , Iterable , Optional , Union
910
1011import substrait .gen .proto .algebra_pb2 as stalg
11- from substrait .gen .proto .extensions . extensions_pb2 import AdvancedExtension
12+ import substrait .gen .proto .extended_expression_pb2 as stee
1213import substrait .gen .proto .plan_pb2 as stp
1314import substrait .gen .proto .type_pb2 as stt
14- import substrait .gen .proto .extended_expression_pb2 as stee
15- from substrait .extension_registry import ExtensionRegistry
1615from substrait .builders .extended_expression import (
1716 ExtendedExpressionOrUnbound ,
1817 resolve_expression ,
1918)
19+ from substrait .extension_registry import ExtensionRegistry
20+ from substrait .gen .proto .extensions .extensions_pb2 import AdvancedExtension
2021from substrait .type_inference import infer_plan_schema
2122from substrait .utils import (
2223 merge_extension_declarations ,
23- merge_extension_urns ,
2424 merge_extension_uris ,
25+ merge_extension_urns ,
2526)
27+ from substrait .gen .version import substrait_version
2628
2729UnboundPlan = Callable [[ExtensionRegistry ], stp .Plan ]
2830
2931PlanOrUnbound = Union [stp .Plan , UnboundPlan ]
3032
3133
34+ def _create_default_version ():
35+ p = re .compile (r"(\d+)\.(\d+)\.(\d+)" )
36+ m = p .match (substrait_version )
37+ global default_version
38+ default_version = stp .Version (
39+ major_number = int (m .group (1 )),
40+ minor_number = int (m .group (2 )),
41+ patch_number = int (m .group (3 )),
42+ )
43+
44+
45+ _create_default_version ()
46+
47+
3248def _merge_extensions (* objs ):
3349 """Merge extension URIs, URNs, and declarations from multiple plan/expression objects.
3450
@@ -65,9 +81,10 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
6581 )
6682
6783 return stp .Plan (
84+ version = default_version ,
6885 relations = [
6986 stp .PlanRel (root = stalg .RelRoot (input = rel , names = named_struct .names ))
70- ]
87+ ],
7188 )
7289
7390 return resolve
@@ -169,6 +186,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
169186 )
170187
171188 return stp .Plan (
189+ version = default_version ,
172190 relations = [stp .PlanRel (root = stalg .RelRoot (input = rel , names = names ))],
173191 ** _merge_extensions (_plan , * bound_expressions ),
174192 )
@@ -199,6 +217,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
199217 names = ns .names
200218
201219 return stp .Plan (
220+ version = default_version ,
202221 relations = [stp .PlanRel (root = stalg .RelRoot (input = rel , names = names ))],
203222 ** _merge_extensions (bound_plan , bound_expression ),
204223 )
@@ -245,6 +264,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
245264 )
246265
247266 return stp .Plan (
267+ version = default_version ,
248268 relations = [stp .PlanRel (root = stalg .RelRoot (input = rel , names = ns .names ))],
249269 ** _merge_extensions (bound_plan , * [e [0 ] for e in bound_expressions ]),
250270 )
@@ -262,6 +282,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
262282 )
263283
264284 return stp .Plan (
285+ version = default_version ,
265286 relations = [
266287 stp .PlanRel (
267288 root = stalg .RelRoot (
@@ -300,6 +321,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
300321 )
301322
302323 return stp .Plan (
324+ version = default_version ,
303325 relations = [
304326 stp .PlanRel (
305327 root = stalg .RelRoot (
@@ -348,6 +370,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
348370 )
349371
350372 return stp .Plan (
373+ version = default_version ,
351374 relations = [stp .PlanRel (root = stalg .RelRoot (input = rel , names = ns .names ))],
352375 ** _merge_extensions (bound_left , bound_right , bound_expression ),
353376 )
@@ -383,6 +406,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
383406 )
384407
385408 return stp .Plan (
409+ version = default_version ,
386410 relations = [stp .PlanRel (root = stalg .RelRoot (input = rel , names = ns .names ))],
387411 ** _merge_extensions (bound_left , bound_right ),
388412 )
@@ -434,10 +458,41 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
434458 ] + [e .referred_expr [0 ].output_names [0 ] for e in bound_measures ]
435459
436460 return stp .Plan (
461+ version = default_version ,
437462 relations = [stp .PlanRel (root = stalg .RelRoot (input = rel , names = names ))],
438463 ** _merge_extensions (
439464 bound_input , * bound_grouping_expressions , * bound_measures
440465 ),
441466 )
442467
443468 return resolve
469+
470+
471+ def write_named_table (
472+ table_names : Union [str , Iterable [str ]],
473+ input : PlanOrUnbound ,
474+ create_mode : Union [stalg .WriteRel .CreateMode .ValueType , None ] = None ,
475+ ) -> UnboundPlan :
476+ def resolve (registry : ExtensionRegistry ) -> stp .Plan :
477+ bound_input = input if isinstance (input , stp .Plan ) else input (registry )
478+ ns = infer_plan_schema (bound_input )
479+ _table_names = [table_names ] if isinstance (table_names , str ) else table_names
480+ _create_mode = create_mode or stalg .WriteRel .CREATE_MODE_ERROR_IF_EXISTS
481+
482+ write_rel = stalg .Rel (
483+ write = stalg .WriteRel (
484+ input = bound_input .relations [- 1 ].root .input ,
485+ table_schema = ns ,
486+ op = stalg .WriteRel .WRITE_OP_CTAS ,
487+ create_mode = _create_mode ,
488+ named_table = stalg .NamedObjectWrite (names = _table_names ),
489+ )
490+ )
491+ return stp .Plan (
492+ relations = [
493+ stp .PlanRel (root = stalg .RelRoot (input = write_rel , names = ns .names ))
494+ ],
495+ ** _merge_extensions (bound_input ),
496+ )
497+
498+ return resolve
0 commit comments