1+ import datetime
12import functools
3+ import string
24from typing import Any
35from typing import Callable
46from typing import Dict
57from typing import List
68from typing import Optional
79from typing import Union
810
11+ from . import dtypes
912from .dtypes import DataType
1013
14+ python_type_map : Dict [Any , Callable [..., str ]] = {
15+ str : dtypes .TEXT ,
16+ int : dtypes .BIGINT ,
17+ float : dtypes .DOUBLE ,
18+ bool : dtypes .BOOL ,
19+ bytes : dtypes .BINARY ,
20+ bytearray : dtypes .BINARY ,
21+ datetime .datetime : dtypes .DATETIME ,
22+ datetime .date : dtypes .DATE ,
23+ datetime .timedelta : dtypes .TIME ,
24+ }
25+
1126
1227def listify (x : Any ) -> List [Any ]:
1328 """Make sure sure value is a list."""
@@ -23,30 +38,37 @@ def _func(
2338 * ,
2439 name : Optional [str ] = None ,
2540 args : Optional [Union [DataType , List [DataType ], Dict [str , DataType ]]] = None ,
26- returns : Optional [str ] = None ,
41+ returns : Optional [Union [ str , List [ DataType ], List [ type ]] ] = None ,
2742 data_format : Optional [str ] = None ,
2843 include_masks : bool = False ,
2944 function_type : str = 'udf' ,
45+ output_fields : Optional [List [str ]] = None ,
3046) -> Callable [..., Any ]:
3147 """Generic wrapper for UDF and TVF decorators."""
3248 if args is None :
3349 pass
3450 elif isinstance (args , (list , tuple )):
3551 args = list (args )
3652 for i , item in enumerate (args ):
37- if callable (item ):
53+ if args [i ] in python_type_map :
54+ args [i ] = python_type_map [args [i ]]()
55+ elif callable (item ):
3856 args [i ] = item ()
3957 for item in args :
4058 if not isinstance (item , str ):
4159 raise TypeError (f'unrecognized type for parameter: { item } ' )
4260 elif isinstance (args , dict ):
4361 args = dict (args )
4462 for k , v in list (args .items ()):
45- if callable (v ):
63+ if args [k ] in python_type_map :
64+ args [k ] = python_type_map [args [k ]]()
65+ elif callable (v ):
4666 args [k ] = v ()
4767 for item in args .values ():
4868 if not isinstance (item , str ):
4969 raise TypeError (f'unrecognized type for parameter: { item } ' )
70+ elif args in python_type_map :
71+ args = python_type_map [args ]()
5072 elif callable (args ):
5173 args = args ()
5274 elif isinstance (args , str ):
@@ -56,16 +78,47 @@ def _func(
5678
5779 if returns is None :
5880 pass
81+ elif isinstance (returns , (list , tuple )):
82+ returns = list (returns )
83+ for i , item in enumerate (returns ):
84+ if item in python_type_map :
85+ returns [i ] = python_type_map [item ]()
86+ elif callable (item ):
87+ returns [i ] = item ()
88+ for item in returns :
89+ if not isinstance (item , str ):
90+ raise TypeError (f'unrecognized return type: { item } ' )
91+ elif returns in python_type_map :
92+ returns = python_type_map [returns ]()
5993 elif callable (returns ):
6094 returns = returns ()
6195 elif isinstance (returns , str ):
6296 returns = returns
6397 else :
6498 raise TypeError (f'unrecognized return type: { returns } ' )
6599
66- if returns is not None and not isinstance (returns , str ):
100+ if returns is None :
101+ pass
102+ elif isinstance (returns , list ):
103+ for item in returns :
104+ if not isinstance (item , str ):
105+ raise TypeError (f'unrecognized return type: { item } ' )
106+ elif not isinstance (returns , str ):
67107 raise TypeError (f'unrecognized return type: { returns } ' )
68108
109+ if not output_fields :
110+ if isinstance (returns , list ):
111+ output_fields = []
112+ for i , _ in enumerate (returns ):
113+ output_fields .append (string .ascii_letters [i ])
114+ else :
115+ output_fields = [string .ascii_letters [0 ]]
116+
117+ if isinstance (returns , list ) and len (output_fields ) != len (returns ):
118+ raise ValueError (
119+ 'The number of output fields must match the number of return types' ,
120+ )
121+
69122 if include_masks and data_format == 'python' :
70123 raise RuntimeError (
71124 'include_masks is only valid when using '
@@ -80,6 +133,7 @@ def _func(
80133 data_format = data_format ,
81134 include_masks = include_masks ,
82135 function_type = function_type ,
136+ output_fields = output_fields ,
83137 ).items () if v is not None
84138 }
85139
@@ -107,7 +161,7 @@ def udf(
107161 * ,
108162 name : Optional [str ] = None ,
109163 args : Optional [Union [DataType , List [DataType ], Dict [str , DataType ]]] = None ,
110- returns : Optional [str ] = None ,
164+ returns : Optional [Union [ str , List [ DataType ], List [ type ]] ] = None ,
111165 data_format : Optional [str ] = None ,
112166 include_masks : bool = False ,
113167) -> Callable [..., Any ]:
@@ -170,9 +224,10 @@ def tvf(
170224 * ,
171225 name : Optional [str ] = None ,
172226 args : Optional [Union [DataType , List [DataType ], Dict [str , DataType ]]] = None ,
173- returns : Optional [str ] = None ,
227+ returns : Optional [Union [ str , List [ DataType ], List [ type ]] ] = None ,
174228 data_format : Optional [str ] = None ,
175229 include_masks : bool = False ,
230+ output_fields : Optional [List [str ]] = None ,
176231) -> Callable [..., Any ]:
177232 """
178233 Apply attributes to a TVF.
@@ -205,6 +260,9 @@ def tvf(
205260 Should boolean masks be included with each input parameter to indicate
206261 which elements are NULL? This is only used when a input parameters are
207262 configured to a vector type (numpy, pandas, polars, arrow).
263+ output_fields : List[str], optional
264+ The names of the output fields for the TVF. If not specified, the
265+ names are generated.
208266
209267 Returns
210268 -------
@@ -219,6 +277,7 @@ def tvf(
219277 data_format = data_format ,
220278 include_masks = include_masks ,
221279 function_type = 'tvf' ,
280+ output_fields = output_fields ,
222281 )
223282
224283
0 commit comments