@@ -204,12 +204,12 @@ def resolve(
204204
205205
206206def scalar_function (
207- uri : str ,
208- function : str ,
207+ function : Union [Iterable [str ], str ],
209208 expressions : Iterable [ExtendedExpressionOrUnbound ],
210209 alias : Union [Iterable [str ], str ] = None ,
211210):
212211 """Builds a resolver for ExtendedExpression containing a ScalarFunction expression"""
212+ functions = [function ] if isinstance (function , str ) else function
213213
214214 def resolve (
215215 base_schema : stp .NamedStruct , registry : ExtensionRegistry
@@ -224,23 +224,30 @@ def resolve(
224224
225225 signature = [typ for es in expression_schemas for typ in es .types ]
226226
227- func = registry .lookup_function (uri , function , signature )
227+ for f in functions :
228+ uri , name = f .split (":" )
229+ func = registry .lookup_function (uri , name , signature )
230+ if func :
231+ break
228232
229233 if not func :
230234 raise Exception (f"Unknown function { function } for { signature } " )
231235
236+ resolved_func , return_type = func
237+
232238 func_extension_uris = [
233239 ste .SimpleExtensionURI (
234- extension_uri_anchor = registry .lookup_uri (uri ), uri = uri
240+ extension_uri_anchor = registry .lookup_uri (resolved_func .uri ),
241+ uri = resolved_func .uri ,
235242 )
236243 ]
237244
238245 func_extensions = [
239246 ste .SimpleExtensionDeclaration (
240247 extension_function = ste .SimpleExtensionDeclaration .ExtensionFunction (
241- extension_uri_reference = registry .lookup_uri (uri ),
242- function_anchor = func [ 0 ] .anchor ,
243- name = str (func [ 0 ] ),
248+ extension_uri_reference = registry .lookup_uri (resolved_func . uri ),
249+ function_anchor = resolved_func .anchor ,
250+ name = str (resolved_func ),
244251 )
245252 )
246253 ]
@@ -258,14 +265,14 @@ def resolve(
258265 stee .ExpressionReference (
259266 expression = stalg .Expression (
260267 scalar_function = stalg .Expression .ScalarFunction (
261- function_reference = func [ 0 ] .anchor ,
268+ function_reference = resolved_func .anchor ,
262269 arguments = [
263270 stalg .FunctionArgument (
264271 value = e .referred_expr [0 ].expression
265272 )
266273 for e in bound_expressions
267274 ],
268- output_type = func [ 1 ] ,
275+ output_type = return_type ,
269276 )
270277 ),
271278 output_names = _alias_or_inferred (
@@ -284,12 +291,12 @@ def resolve(
284291
285292
286293def aggregate_function (
287- uri : str ,
288- function : str ,
294+ function : Union [Iterable [str ], str ],
289295 expressions : Iterable [ExtendedExpressionOrUnbound ],
290296 alias : Union [Iterable [str ], str ] = None ,
291297):
292298 """Builds a resolver for ExtendedExpression containing a AggregateFunction measure"""
299+ functions = [function ] if isinstance (function , str ) else function
293300
294301 def resolve (
295302 base_schema : stp .NamedStruct , registry : ExtensionRegistry
@@ -304,23 +311,30 @@ def resolve(
304311
305312 signature = [typ for es in expression_schemas for typ in es .types ]
306313
307- func = registry .lookup_function (uri , function , signature )
314+ for f in functions :
315+ uri , name = f .split (":" )
316+ func = registry .lookup_function (uri , name , signature )
317+ if func :
318+ break
308319
309320 if not func :
310321 raise Exception (f"Unknown function { function } for { signature } " )
311322
323+ resolved_func , return_type = func
324+
312325 func_extension_uris = [
313326 ste .SimpleExtensionURI (
314- extension_uri_anchor = registry .lookup_uri (uri ), uri = uri
327+ extension_uri_anchor = registry .lookup_uri (resolved_func .uri ),
328+ uri = resolved_func .uri ,
315329 )
316330 ]
317331
318332 func_extensions = [
319333 ste .SimpleExtensionDeclaration (
320334 extension_function = ste .SimpleExtensionDeclaration .ExtensionFunction (
321- extension_uri_reference = registry .lookup_uri (uri ),
322- function_anchor = func [ 0 ] .anchor ,
323- name = str (func [ 0 ] ),
335+ extension_uri_reference = registry .lookup_uri (resolved_func . uri ),
336+ function_anchor = resolved_func .anchor ,
337+ name = str (resolved_func ),
324338 )
325339 )
326340 ]
@@ -342,7 +356,7 @@ def resolve(
342356 stalg .FunctionArgument (value = e .referred_expr [0 ].expression )
343357 for e in bound_expressions
344358 ],
345- output_type = func [ 1 ] ,
359+ output_type = return_type ,
346360 ),
347361 output_names = _alias_or_inferred (
348362 alias ,
@@ -361,13 +375,13 @@ def resolve(
361375
362376# TODO bounds, sorts
363377def window_function (
364- uri : str ,
365- function : str ,
378+ function : Union [Iterable [str ], str ],
366379 expressions : Iterable [ExtendedExpressionOrUnbound ],
367380 partitions : Iterable [ExtendedExpressionOrUnbound ] = [],
368381 alias : Union [Iterable [str ], str ] = None ,
369382):
370383 """Builds a resolver for ExtendedExpression containing a WindowFunction expression"""
384+ functions = [function ] if isinstance (function , str ) else function
371385
372386 def resolve (
373387 base_schema : stp .NamedStruct , registry : ExtensionRegistry
@@ -386,23 +400,30 @@ def resolve(
386400
387401 signature = [typ for es in expression_schemas for typ in es .types ]
388402
389- func = registry .lookup_function (uri , function , signature )
403+ for f in functions :
404+ uri , name = f .split (":" )
405+ func = registry .lookup_function (uri , name , signature )
406+ if func :
407+ break
390408
391409 if not func :
392410 raise Exception (f"Unknown function { function } for { signature } " )
393411
412+ resolved_func , return_type = func
413+
394414 func_extension_uris = [
395415 ste .SimpleExtensionURI (
396- extension_uri_anchor = registry .lookup_uri (uri ), uri = uri
416+ extension_uri_anchor = registry .lookup_uri (resolved_func .uri ),
417+ uri = resolved_func .uri ,
397418 )
398419 ]
399420
400421 func_extensions = [
401422 ste .SimpleExtensionDeclaration (
402423 extension_function = ste .SimpleExtensionDeclaration .ExtensionFunction (
403- extension_uri_reference = registry .lookup_uri (uri ),
404- function_anchor = func [ 0 ] .anchor ,
405- name = str (func [ 0 ] ),
424+ extension_uri_reference = registry .lookup_uri (resolved_func . uri ),
425+ function_anchor = resolved_func .anchor ,
426+ name = str (resolved_func ),
406427 )
407428 )
408429 ]
@@ -431,7 +452,7 @@ def resolve(
431452 )
432453 for e in bound_expressions
433454 ],
434- output_type = func [ 1 ] ,
455+ output_type = return_type ,
435456 partitions = [
436457 e .referred_expr [0 ].expression for e in bound_partitions
437458 ],
0 commit comments