1+ import dataclasses
12import datetime
23import functools
3- import string
4+ import inspect
45from typing import Any
56from typing import Callable
67from typing import Dict
78from typing import List
89from typing import Optional
10+ from typing import Tuple
911from typing import Union
1012
1113from . import dtypes
1214from .dtypes import DataType
15+ from .signature import simplify_dtype
16+
17+ try :
18+ import pydantic
19+ has_pydantic = True
20+ except ImportError :
21+ has_pydantic = False
1322
1423python_type_map : Dict [Any , Callable [..., str ]] = {
1524 str : dtypes .TEXT ,
@@ -33,88 +42,123 @@ def listify(x: Any) -> List[Any]:
3342 return [x ]
3443
3544
45+ def process_annotation (annotation : Any ) -> Tuple [Any , bool ]:
46+ types = simplify_dtype (annotation )
47+ if isinstance (types , list ):
48+ nullable = False
49+ if type (None ) in types :
50+ nullable = True
51+ types = [x for x in types if x is not type (None )]
52+ if len (types ) > 1 :
53+ raise ValueError (f'multiple types not supported: { annotation } ' )
54+ return types [0 ], nullable
55+ return types , True
56+
57+
58+ def process_types (params : Any ) -> Any :
59+ if params is None :
60+ return params , []
61+
62+ elif isinstance (params , (list , tuple )):
63+ params = list (params )
64+ for i , item in enumerate (params ):
65+ if params [i ] in python_type_map :
66+ params [i ] = python_type_map [params [i ]]()
67+ elif callable (item ):
68+ params [i ] = item ()
69+ for item in params :
70+ if not isinstance (item , str ):
71+ raise TypeError (f'unrecognized type for parameter: { item } ' )
72+ return params , []
73+
74+ elif isinstance (params , dict ):
75+ names = []
76+ params = dict (params )
77+ for k , v in list (params .items ()):
78+ names .append (k )
79+ if params [k ] in python_type_map :
80+ params [k ] = python_type_map [params [k ]]()
81+ elif callable (v ):
82+ params [k ] = v ()
83+ for item in params .values ():
84+ if not isinstance (item , str ):
85+ raise TypeError (f'unrecognized type for parameter: { item } ' )
86+ return params , names
87+
88+ elif dataclasses .is_dataclass (params ):
89+ names = []
90+ out = []
91+ for item in dataclasses .fields (params ):
92+ typ , nullable = process_annotation (item .type )
93+ sql_type = process_types (typ )[0 ]
94+ if not nullable :
95+ sql_type = sql_type .replace ('NULL' , 'NOT NULL' )
96+ out .append (sql_type )
97+ names .append (item .name )
98+ return out , names
99+
100+ elif has_pydantic and inspect .isclass (params ) \
101+ and issubclass (params , pydantic .BaseModel ):
102+ names = []
103+ out = []
104+ for name , item in params .model_fields .items ():
105+ typ , nullable = process_annotation (item .annotation )
106+ sql_type = process_types (typ )[0 ]
107+ if not nullable :
108+ sql_type = sql_type .replace ('NULL' , 'NOT NULL' )
109+ out .append (sql_type )
110+ names .append (name )
111+ return out , names
112+
113+ elif params in python_type_map :
114+ return python_type_map [params ](), []
115+
116+ elif callable (params ):
117+ return params (), []
118+
119+ elif isinstance (params , str ):
120+ return params , []
121+
122+ raise TypeError (f'unrecognized data type for args: { params } ' )
123+
124+
36125def _func (
37126 func : Optional [Callable [..., Any ]] = None ,
38127 * ,
39128 name : Optional [str ] = None ,
40- args : Optional [Union [DataType , List [DataType ], Dict [str , DataType ]]] = None ,
41- returns : Optional [Union [str , List [DataType ], List [type ]]] = None ,
129+ args : Optional [
130+ Union [
131+ DataType ,
132+ List [DataType ],
133+ Dict [str , DataType ],
134+ 'pydantic.BaseModel' ,
135+ type ,
136+ ]
137+ ] = None ,
138+ returns : Optional [
139+ Union [
140+ str ,
141+ List [DataType ],
142+ List [type ],
143+ 'pydantic.BaseModel' ,
144+ type ,
145+ ]
146+ ] = None ,
42147 data_format : Optional [str ] = None ,
43148 include_masks : bool = False ,
44149 function_type : str = 'udf' ,
45150 output_fields : Optional [List [str ]] = None ,
46151) -> Callable [..., Any ]:
47152 """Generic wrapper for UDF and TVF decorators."""
48- if args is None :
49- pass
50- elif isinstance (args , (list , tuple )):
51- args = list (args )
52- for i , item in enumerate (args ):
53- if args [i ] in python_type_map :
54- args [i ] = python_type_map [args [i ]]()
55- elif callable (item ):
56- args [i ] = item ()
57- for item in args :
58- if not isinstance (item , str ):
59- raise TypeError (f'unrecognized type for parameter: { item } ' )
60- elif isinstance (args , dict ):
61- args = dict (args )
62- for k , v in list (args .items ()):
63- if args [k ] in python_type_map :
64- args [k ] = python_type_map [args [k ]]()
65- elif callable (v ):
66- args [k ] = v ()
67- for item in args .values ():
68- if not isinstance (item , str ):
69- raise TypeError (f'unrecognized type for parameter: { item } ' )
70- elif args in python_type_map :
71- args = python_type_map [args ]()
72- elif callable (args ):
73- args = args ()
74- elif isinstance (args , str ):
75- args = args
76- else :
77- raise TypeError (f'unrecognized data type for args: { args } ' )
78-
79- if returns is None :
80- pass
81- elif isinstance (returns , (list , tuple )):
82- returns = list (returns )
83- for i , item in enumerate (returns ):
84- if item in python_type_map :
85- returns [i ] = python_type_map [item ]()
86- elif callable (item ):
87- returns [i ] = item ()
88- for item in returns :
89- if not isinstance (item , str ):
90- raise TypeError (f'unrecognized return type: { item } ' )
91- elif returns in python_type_map :
92- returns = python_type_map [returns ]()
93- elif callable (returns ):
94- returns = returns ()
95- elif isinstance (returns , str ):
96- returns = returns
97- else :
98- raise TypeError (f'unrecognized return type: { returns } ' )
99-
100- if returns is None :
101- pass
102- elif isinstance (returns , list ):
103- for item in returns :
104- if not isinstance (item , str ):
105- raise TypeError (f'unrecognized return type: { item } ' )
106- elif not isinstance (returns , str ):
107- raise TypeError (f'unrecognized return type: { returns } ' )
108-
109- if not output_fields :
110- if isinstance (returns , list ):
111- output_fields = []
112- for i , _ in enumerate (returns ):
113- output_fields .append (string .ascii_letters [i ])
114- else :
115- output_fields = [string .ascii_letters [0 ]]
116-
117- if isinstance (returns , list ) and len (output_fields ) != len (returns ):
153+ args , _ = process_types (args )
154+ returns , fields = process_types (returns )
155+
156+ if not output_fields and fields :
157+ output_fields = fields
158+
159+ if isinstance (returns , list ) \
160+ and isinstance (output_fields , list ) \
161+ and len (output_fields ) != len (returns ):
118162 raise ValueError (
119163 'The number of output fields must match the number of return types' ,
120164 )
@@ -133,7 +177,7 @@ def _func(
133177 data_format = data_format ,
134178 include_masks = include_masks ,
135179 function_type = function_type ,
136- output_fields = output_fields ,
180+ output_fields = output_fields or None ,
137181 ).items () if v is not None
138182 }
139183
0 commit comments