1010from typing import Callable , Any , Union , Iterable
1111
1212UnboundExtendedExpression = Callable [[stp .NamedStruct , ExtensionRegistry ], stee .ExtendedExpression ]
13+ ExtendedExpressionOrUnbound = Union [stee .ExtendedExpression , UnboundExtendedExpression ]
1314
1415def _alias_or_inferred (
1516 alias : Union [Iterable [str ], str ],
@@ -21,6 +22,13 @@ def _alias_or_inferred(
2122 else :
2223 return [f'{ op } ({ "," .join (args )} )' ]
2324
25+ def resolve_expression (
26+ expression : ExtendedExpressionOrUnbound ,
27+ base_schema : stp .NamedStruct ,
28+ registry : ExtensionRegistry
29+ ) -> stee .ExtendedExpression :
30+ return expression if isinstance (expression , stee .ExtendedExpression ) else expression (base_schema , registry )
31+
2432def literal (value : Any , type : stp .Type , alias : Union [Iterable [str ], str ] = None ) -> UnboundExtendedExpression :
2533 """Builds a resolver for ExtendedExpression containing a literal expression"""
2634 def resolve (base_schema : stp .NamedStruct , registry : ExtensionRegistry ) -> stee .ExtendedExpression :
@@ -139,14 +147,14 @@ def resolve(
139147 return resolve
140148
141149def scalar_function (
142- uri : str , function : str , * expressions : UnboundExtendedExpression , alias : Union [Iterable [str ], str ] = None
150+ uri : str , function : str , expressions : Iterable [ ExtendedExpressionOrUnbound ] , alias : Union [Iterable [str ], str ] = None
143151):
144152 """Builds a resolver for ExtendedExpression containing a ScalarFunction expression"""
145153 def resolve (
146154 base_schema : stp .NamedStruct , registry : ExtensionRegistry
147155 ) -> stee .ExtendedExpression :
148- bound_expressions : Iterable [ stee . ExtendedExpression ] = [
149- e ( base_schema , registry ) for e in expressions
156+ bound_expressions = [
157+ resolve_expression ( e , base_schema , registry ) for e in expressions
150158 ]
151159
152160 expression_schemas = [
@@ -210,14 +218,14 @@ def resolve(
210218 return resolve
211219
212220def aggregate_function (
213- uri : str , function : str , * expressions : UnboundExtendedExpression , alias : Union [Iterable [str ], str ] = None
221+ uri : str , function : str , expressions : Iterable [ ExtendedExpressionOrUnbound ] , alias : Union [Iterable [str ], str ] = None
214222):
215223 """Builds a resolver for ExtendedExpression containing a AggregateFunction measure"""
216224 def resolve (
217225 base_schema : stp .NamedStruct , registry : ExtensionRegistry
218226 ) -> stee .ExtendedExpression :
219227 bound_expressions : Iterable [stee .ExtendedExpression ] = [
220- e ( base_schema , registry ) for e in expressions
228+ resolve_expression ( e , base_schema , registry ) for e in expressions
221229 ]
222230
223231 expression_schemas = [
@@ -281,19 +289,19 @@ def resolve(
281289def window_function (
282290 uri : str ,
283291 function : str ,
284- * expressions : UnboundExtendedExpression ,
285- partitions : Iterable [UnboundExtendedExpression ] = [],
292+ expressions : Iterable [ ExtendedExpressionOrUnbound ] ,
293+ partitions : Iterable [ExtendedExpressionOrUnbound ] = [],
286294 alias : Union [Iterable [str ], str ] = None
287295):
288296 """Builds a resolver for ExtendedExpression containing a WindowFunction expression"""
289297 def resolve (
290298 base_schema : stp .NamedStruct , registry : ExtensionRegistry
291299 ) -> stee .ExtendedExpression :
292300 bound_expressions : Iterable [stee .ExtendedExpression ] = [
293- e ( base_schema , registry ) for e in expressions
301+ resolve_expression ( e , base_schema , registry ) for e in expressions
294302 ]
295303
296- bound_partitions = [e ( base_schema , registry ) for e in partitions ]
304+ bound_partitions = [resolve_expression ( e , base_schema , registry ) for e in partitions ]
297305
298306 expression_schemas = [
299307 infer_extended_expression_schema (b ) for b in bound_expressions
@@ -363,17 +371,17 @@ def resolve(
363371 return resolve
364372
365373
366- def if_then (ifs : Iterable [tuple [UnboundExtendedExpression , UnboundExtendedExpression ]], _else : UnboundExtendedExpression , alias : Union [Iterable [str ], str ] = None ):
374+ def if_then (ifs : Iterable [tuple [ExtendedExpressionOrUnbound , ExtendedExpressionOrUnbound ]], _else : ExtendedExpressionOrUnbound , alias : Union [Iterable [str ], str ] = None ):
367375 """Builds a resolver for ExtendedExpression containing an IfThen expression"""
368376 def resolve (
369377 base_schema : stp .NamedStruct , registry : ExtensionRegistry
370378 ) -> stee .ExtendedExpression :
371379 bound_ifs = [
372- (if_clause [0 ]( base_schema , registry ), if_clause [1 ]( base_schema , registry ))
380+ (resolve_expression ( if_clause [0 ], base_schema , registry ), resolve_expression ( if_clause [1 ], base_schema , registry ))
373381 for if_clause in ifs
374382 ]
375383
376- bound_else = _else ( base_schema , registry )
384+ bound_else = resolve_expression ( _else , base_schema , registry )
377385
378386 extension_uris = merge_extension_uris (
379387 * [b [0 ].extension_uris for b in bound_ifs ],
@@ -413,3 +421,169 @@ def resolve(
413421 )
414422
415423 return resolve
424+
425+ def switch (match : ExtendedExpressionOrUnbound ,
426+ ifs : Iterable [tuple [ExtendedExpressionOrUnbound , ExtendedExpressionOrUnbound ]],
427+ _else : ExtendedExpressionOrUnbound ):
428+ """Builds a resolver for ExtendedExpression containing a switch expression"""
429+ def resolve (
430+ base_schema : stp .NamedStruct , registry : ExtensionRegistry
431+ ) -> stee .ExtendedExpression :
432+ bound_match = resolve_expression (match , base_schema , registry )
433+ bound_ifs = [
434+ (
435+ resolve_expression (a , base_schema , registry ),
436+ resolve_expression (b , base_schema , registry )
437+ ) for a , b in ifs ]
438+ bound_else = resolve_expression (_else , base_schema , registry )
439+
440+ extension_uris = merge_extension_uris (
441+ bound_match .extension_uris ,
442+ * [b .extension_uris for _ , b in bound_ifs ],
443+ bound_else .extension_uris
444+ )
445+
446+ extensions = merge_extension_declarations (
447+ bound_match .extensions ,
448+ * [b .extensions for _ , b in bound_ifs ],
449+ bound_else .extensions
450+ )
451+
452+ return stee .ExtendedExpression (
453+ referred_expr = [
454+ stee .ExpressionReference (
455+ expression = stalg .Expression (
456+ switch_expression = stalg .Expression .SwitchExpression (
457+ match = bound_match .referred_expr [0 ].expression ,
458+ ifs = [
459+ stalg .Expression .SwitchExpression .IfValue (** {
460+ 'if' : i .referred_expr [0 ].expression .literal ,
461+ 'then' : t .referred_expr [0 ].expression
462+ })
463+ for i , t in bound_ifs
464+ ],
465+ ** {
466+ 'else' : bound_else .referred_expr [0 ].expression
467+ }
468+ )
469+ ),
470+ output_names = ['switch' ] #TODO construct name from inputs
471+ )
472+ ],
473+ base_schema = base_schema ,
474+ extension_uris = extension_uris ,
475+ extensions = extensions ,
476+ )
477+
478+ return resolve
479+
480+ def singular_or_list (value : ExtendedExpressionOrUnbound , options : Iterable [ExtendedExpressionOrUnbound ]):
481+ """Builds a resolver for ExtendedExpression containing a SingularOrList expression"""
482+ def resolve (
483+ base_schema : stp .NamedStruct , registry : ExtensionRegistry
484+ ) -> stee .ExtendedExpression :
485+ bound_value = resolve_expression (value , base_schema , registry )
486+ bound_options = [resolve_expression (o , base_schema , registry ) for o in options ]
487+
488+ extension_uris = merge_extension_uris (
489+ bound_value .extension_uris ,
490+ * [b .extension_uris for b in bound_options ]
491+ )
492+
493+ extensions = merge_extension_declarations (
494+ bound_value .extensions ,
495+ * [b .extensions for b in bound_options ]
496+ )
497+
498+ return stee .ExtendedExpression (
499+ referred_expr = [
500+ stee .ExpressionReference (
501+ expression = stalg .Expression (
502+ singular_or_list = stalg .Expression .SingularOrList (
503+ value = bound_value .referred_expr [0 ].expression ,
504+ options = [
505+ o .referred_expr [0 ].expression
506+ for o in bound_options
507+ ]
508+ )
509+ ),
510+ output_names = ['singular_or_list' ] #TODO construct name from inputs
511+ )
512+ ],
513+ base_schema = base_schema ,
514+ extension_uris = extension_uris ,
515+ extensions = extensions ,
516+ )
517+
518+ return resolve
519+
520+ def multi_or_list (value : Iterable [ExtendedExpressionOrUnbound ], options : Iterable [Iterable [ExtendedExpressionOrUnbound ]]):
521+ """Builds a resolver for ExtendedExpression containing a MultiOrList expression"""
522+ def resolve (
523+ base_schema : stp .NamedStruct , registry : ExtensionRegistry
524+ ) -> stee .ExtendedExpression :
525+ bound_value = [resolve_expression (e , base_schema , registry ) for e in value ]
526+ bound_options = [
527+ [resolve_expression (e , base_schema , registry ) for e in o ] for o in options
528+ ]
529+
530+ extension_uris = merge_extension_uris (
531+ * [b .extension_uris for b in bound_value ],
532+ * [e .extension_uris for b in bound_options for e in b ],
533+ )
534+
535+ extensions = merge_extension_uris (
536+ * [b .extensions for b in bound_value ],
537+ * [e .extensions for b in bound_options for e in b ],
538+ )
539+
540+ return stee .ExtendedExpression (
541+ referred_expr = [
542+ stee .ExpressionReference (
543+ expression = stalg .Expression (
544+ multi_or_list = stalg .Expression .MultiOrList (
545+ value = [e .referred_expr [0 ].expression for e in bound_value ],
546+ options = [
547+ stalg .Expression .MultiOrList .Record (
548+ fields = [e .referred_expr [0 ].expression for e in option ]
549+ )
550+ for option in bound_options
551+ ]
552+ )
553+ ),
554+ output_names = ['multi_or_list' ] #TODO construct name from inputs
555+ )
556+ ],
557+ base_schema = base_schema ,
558+ extension_uris = extension_uris ,
559+ extensions = extensions ,
560+ )
561+
562+ return resolve
563+
564+ def cast (input : ExtendedExpressionOrUnbound , type : stp .Type ):
565+ """Builds a resolver for ExtendedExpression containing a cast expression"""
566+ def resolve (
567+ base_schema : stp .NamedStruct , registry : ExtensionRegistry
568+ ) -> stee .ExtendedExpression :
569+ bound_input = resolve_expression (input , base_schema , registry )
570+
571+ return stee .ExtendedExpression (
572+ referred_expr = [
573+ stee .ExpressionReference (
574+ expression = stalg .Expression (
575+ cast = stalg .Expression .Cast (
576+ input = bound_input .referred_expr [0 ].expression ,
577+ type = type ,
578+ failure_behavior = stalg .Expression .Cast .FAILURE_BEHAVIOR_RETURN_NULL
579+ )
580+ ),
581+ output_names = ['cast' ] #TODO construct name from inputs
582+ )
583+ ],
584+ base_schema = base_schema ,
585+ extension_uris = bound_input .extension_uris ,
586+ extensions = bound_input .extensions ,
587+ )
588+
589+ return resolve
0 commit comments