1+ import yaml
2+ import itertools
13from substrait .gen .proto .type_pb2 import Type
24from importlib .resources import files as importlib_files
3- import itertools
45from collections import defaultdict
5- from collections .abc import Mapping
66from pathlib import Path
7- from typing import Any , Optional , Union
7+ from typing import Optional , Union
88from .derivation_expression import evaluate , _evaluate , _parse
9-
10- import yaml
119from substrait .gen .antlr .SubstraitTypeParser import SubstraitTypeParser
10+ from substrait .gen .json import simple_extensions as se
11+ from substrait .simple_extension_utils import build_simple_extensions
12+
1213
1314DEFAULT_URI_PREFIX = "https://github.com/substrait-io/substrait/blob/main/extensions"
1415
@@ -166,31 +167,35 @@ def covers(
166167
167168class FunctionEntry :
168169 def __init__ (
169- self , uri : str , name : str , impl : Mapping [ str , Any ], anchor : int
170+ self , uri : str , name : str , impl : Union [ se . Impl , se . Impl1 , se . Impl2 ], anchor : int
170171 ) -> None :
171172 self .name = name
173+ self .impl = impl
172174 self .normalized_inputs : list = []
173175 self .uri : str = uri
174176 self .anchor = anchor
175177 self .arguments = []
176- self .rtn = impl ["return" ]
177- self .nullability = impl .get ("nullability" , "MIRROR" )
178- self .variadic = impl .get ("variadic" , False )
179- if input_args := impl .get ("args" , []):
180- for val in input_args :
181- if typ := val .get ("value" ):
182- self .arguments .append (_parse (typ ))
183- self .normalized_inputs .append (normalize_substrait_type_names (typ ))
184- elif _ := val .get ("name" , None ):
185- self .arguments .append (val .get ("options" ))
178+ self .nullability = (
179+ impl .nullability if impl .nullability else se .NullabilityHandling .MIRROR
180+ )
181+
182+ if impl .args :
183+ for arg in impl .args :
184+ if isinstance (arg , se .ValueArg ):
185+ self .arguments .append (_parse (arg .value ))
186+ self .normalized_inputs .append (
187+ normalize_substrait_type_names (arg .value )
188+ )
189+ elif isinstance (arg , se .EnumerationArg ):
190+ self .arguments .append (arg .options )
186191 self .normalized_inputs .append ("req" )
187192
188193 def __repr__ (self ) -> str :
189194 return f"{ self .name } :{ '_' .join (self .normalized_inputs )} "
190195
191196 def satisfies_signature (self , signature : tuple ) -> Optional [str ]:
192- if self .variadic :
193- min_args_allowed = self .variadic .get ( " min" , 0 )
197+ if self .impl . variadic :
198+ min_args_allowed = self .impl . variadic .min or 0
194199 if len (signature ) < min_args_allowed :
195200 return None
196201 inputs = [self .arguments [0 ]] * len (signature )
@@ -209,13 +214,17 @@ def satisfies_signature(self, signature: tuple) -> Optional[str]:
209214 return None
210215 else :
211216 if not covers (
212- y , x , parameters , check_nullability = self .nullability == "DISCRETE"
217+ y ,
218+ x ,
219+ parameters ,
220+ check_nullability = self .nullability
221+ == se .NullabilityHandling .DISCRETE ,
213222 ):
214223 return None
215224
216- output_type = evaluate (self .rtn , parameters )
225+ output_type = evaluate (self .impl . return_ , parameters )
217226
218- if self .nullability == " MIRROR" :
227+ if self .nullability == se . NullabilityHandling . MIRROR :
219228 sig_contains_nullable = any (
220229 [
221230 p .__getattribute__ (p .WhichOneof ("kind" )).nullability
@@ -265,19 +274,27 @@ def register_extension_yaml(
265274 def register_extension_dict (self , definitions : dict , uri : str ) -> None :
266275 self ._uri_mapping [uri ] = next (self ._uri_id_generator )
267276
268- for named_functions in definitions .values ():
269- for function in named_functions :
270- for impl in function .get ("impls" , []):
277+ simple_extensions = build_simple_extensions (definitions )
278+
279+ functions = (
280+ (simple_extensions .scalar_functions or [])
281+ + (simple_extensions .aggregate_functions or [])
282+ + (simple_extensions .window_functions or [])
283+ )
284+
285+ if functions :
286+ for function in functions :
287+ for impl in function .impls :
271288 func = FunctionEntry (
272- uri , function [ " name" ] , impl , next (self ._id_generator )
289+ uri , function . name , impl , next (self ._id_generator )
273290 )
274291 if (
275292 func .uri in self ._function_mapping
276- and function [ " name" ] in self ._function_mapping [func .uri ]
293+ and function . name in self ._function_mapping [func .uri ]
277294 ):
278- self ._function_mapping [func .uri ][function [ " name" ] ].append (func )
295+ self ._function_mapping [func .uri ][function . name ].append (func )
279296 else :
280- self ._function_mapping [func .uri ][function [ " name" ] ] = [func ]
297+ self ._function_mapping [func .uri ][function . name ] = [func ]
281298
282299 # TODO add an optional return type check
283300 def lookup_function (
0 commit comments