Skip to content

Commit 893d80d

Browse files
committed
Add support for multiple return value types and return value field names
1 parent 4dc4ead commit 893d80d

File tree

5 files changed

+96
-27
lines changed

5 files changed

+96
-27
lines changed

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ exclude =
8181
docs/*
8282
resources/*
8383
licenses/*
84-
max-complexity = 35
84+
max-complexity = 45
8585
max-line-length = 90
8686
per-file-ignores =
8787
singlestoredb/__init__.py:F401

singlestoredb/functions/decorator.py

Lines changed: 65 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,28 @@
1+
import datetime
12
import functools
3+
import string
24
from typing import Any
35
from typing import Callable
46
from typing import Dict
57
from typing import List
68
from typing import Optional
79
from typing import Union
810

11+
from . import dtypes
912
from .dtypes import DataType
1013

14+
python_type_map: Dict[Any, Callable[..., str]] = {
15+
str: dtypes.TEXT,
16+
int: dtypes.BIGINT,
17+
float: dtypes.DOUBLE,
18+
bool: dtypes.BOOL,
19+
bytes: dtypes.BINARY,
20+
bytearray: dtypes.BINARY,
21+
datetime.datetime: dtypes.DATETIME,
22+
datetime.date: dtypes.DATE,
23+
datetime.timedelta: dtypes.TIME,
24+
}
25+
1126

1227
def listify(x: Any) -> List[Any]:
1328
"""Make sure sure value is a list."""
@@ -23,30 +38,37 @@ def _func(
2338
*,
2439
name: Optional[str] = None,
2540
args: Optional[Union[DataType, List[DataType], Dict[str, DataType]]] = None,
26-
returns: Optional[str] = None,
41+
returns: Optional[Union[str, List[DataType], List[type]]] = None,
2742
data_format: Optional[str] = None,
2843
include_masks: bool = False,
2944
function_type: str = 'udf',
45+
output_fields: Optional[List[str]] = None,
3046
) -> Callable[..., Any]:
3147
"""Generic wrapper for UDF and TVF decorators."""
3248
if args is None:
3349
pass
3450
elif isinstance(args, (list, tuple)):
3551
args = list(args)
3652
for i, item in enumerate(args):
37-
if callable(item):
53+
if args[i] in python_type_map:
54+
args[i] = python_type_map[args[i]]()
55+
elif callable(item):
3856
args[i] = item()
3957
for item in args:
4058
if not isinstance(item, str):
4159
raise TypeError(f'unrecognized type for parameter: {item}')
4260
elif isinstance(args, dict):
4361
args = dict(args)
4462
for k, v in list(args.items()):
45-
if callable(v):
63+
if args[k] in python_type_map:
64+
args[k] = python_type_map[args[k]]()
65+
elif callable(v):
4666
args[k] = v()
4767
for item in args.values():
4868
if not isinstance(item, str):
4969
raise TypeError(f'unrecognized type for parameter: {item}')
70+
elif args in python_type_map:
71+
args = python_type_map[args]()
5072
elif callable(args):
5173
args = args()
5274
elif isinstance(args, str):
@@ -56,16 +78,47 @@ def _func(
5678

5779
if returns is None:
5880
pass
81+
elif isinstance(returns, (list, tuple)):
82+
returns = list(returns)
83+
for i, item in enumerate(returns):
84+
if item in python_type_map:
85+
returns[i] = python_type_map[item]()
86+
elif callable(item):
87+
returns[i] = item()
88+
for item in returns:
89+
if not isinstance(item, str):
90+
raise TypeError(f'unrecognized return type: {item}')
91+
elif returns in python_type_map:
92+
returns = python_type_map[returns]()
5993
elif callable(returns):
6094
returns = returns()
6195
elif isinstance(returns, str):
6296
returns = returns
6397
else:
6498
raise TypeError(f'unrecognized return type: {returns}')
6599

66-
if returns is not None and not isinstance(returns, str):
100+
if returns is None:
101+
pass
102+
elif isinstance(returns, list):
103+
for item in returns:
104+
if not isinstance(item, str):
105+
raise TypeError(f'unrecognized return type: {item}')
106+
elif not isinstance(returns, str):
67107
raise TypeError(f'unrecognized return type: {returns}')
68108

109+
if not output_fields:
110+
if isinstance(returns, list):
111+
output_fields = []
112+
for i, _ in enumerate(returns):
113+
output_fields.append(string.ascii_letters[i])
114+
else:
115+
output_fields = [string.ascii_letters[0]]
116+
117+
if isinstance(returns, list) and len(output_fields) != len(returns):
118+
raise ValueError(
119+
'The number of output fields must match the number of return types',
120+
)
121+
69122
if include_masks and data_format == 'python':
70123
raise RuntimeError(
71124
'include_masks is only valid when using '
@@ -80,6 +133,7 @@ def _func(
80133
data_format=data_format,
81134
include_masks=include_masks,
82135
function_type=function_type,
136+
output_fields=output_fields,
83137
).items() if v is not None
84138
}
85139

@@ -107,7 +161,7 @@ def udf(
107161
*,
108162
name: Optional[str] = None,
109163
args: Optional[Union[DataType, List[DataType], Dict[str, DataType]]] = None,
110-
returns: Optional[str] = None,
164+
returns: Optional[Union[str, List[DataType], List[type]]] = None,
111165
data_format: Optional[str] = None,
112166
include_masks: bool = False,
113167
) -> Callable[..., Any]:
@@ -170,9 +224,10 @@ def tvf(
170224
*,
171225
name: Optional[str] = None,
172226
args: Optional[Union[DataType, List[DataType], Dict[str, DataType]]] = None,
173-
returns: Optional[str] = None,
227+
returns: Optional[Union[str, List[DataType], List[type]]] = None,
174228
data_format: Optional[str] = None,
175229
include_masks: bool = False,
230+
output_fields: Optional[List[str]] = None,
176231
) -> Callable[..., Any]:
177232
"""
178233
Apply attributes to a TVF.
@@ -205,6 +260,9 @@ def tvf(
205260
Should boolean masks be included with each input parameter to indicate
206261
which elements are NULL? This is only used when a input parameters are
207262
configured to a vector type (numpy, pandas, polars, arrow).
263+
output_fields : List[str], optional
264+
The names of the output fields for the TVF. If not specified, the
265+
names are generated.
208266
209267
Returns
210268
-------
@@ -219,6 +277,7 @@ def tvf(
219277
data_format=data_format,
220278
include_masks=include_masks,
221279
function_type='tvf',
280+
output_fields=output_fields,
222281
)
223282

224283

singlestoredb/functions/ext/asgi.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ async def do_func( # type: ignore
190190
assert isinstance(out, tuple)
191191
return row_ids, [out]
192192

193-
out_ids, out = [], []
193+
out = []
194194
res = func(*[x[0] for x in cols])
195195
rtype = str(type(res)).lower()
196196

@@ -209,20 +209,20 @@ async def do_func( # type: ignore
209209
# NOTE: There is no way to determine which row ID belongs to
210210
# each result row, so we just have to use the same
211211
# row ID for all rows in the result.
212-
if data_format == 'numpy':
213-
import numpy as np
214-
out_ids = np.array([row_ids[0]] * len(out[0][0]))
215-
elif data_format == 'polars':
212+
if data_format == 'polars':
216213
import polars as pl
217-
out_ids = pl.Series([row_ids[0]] * len(out[0][0]))
214+
array_cls = pl.Series
218215
elif data_format == 'arrow':
219216
import pyarrow as pa
220-
out_ids = pa.array([row_ids[0]] * len(out[0][0]))
217+
array_cls = pa.array
221218
elif data_format == 'pandas':
222219
import pandas as pd
223-
out_ids = pd.Series([row_ids[0]] * len(out[0][0]))
220+
array_cls = pd.Series
221+
else:
222+
import numpy as np
223+
array_cls = np.array
224224

225-
return out_ids, out
225+
return array_cls([row_ids[0]] * len(out[0][0])), out
226226

227227
else:
228228
if data_format == 'python':

singlestoredb/functions/signature.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,10 @@ def classify_dtype(dtype: Any) -> str:
312312
if is_int:
313313
return int_type_map.get(name, 'int64')
314314

315-
raise TypeError(f'unsupported type annotation: {dtype}')
315+
raise TypeError(
316+
f'unsupported type annotation: {dtype}; '
317+
'use `args`/`returns` on the @udf/@tvf decotator to specify the data type',
318+
)
316319

317320

318321
def collapse_dtypes(dtypes: Union[str, List[str]]) -> str:
@@ -449,6 +452,7 @@ def get_signature(func: Callable[..., Any], name: Optional[str] = None) -> Dict[
449452

450453
args_overrides = attrs.get('args', None)
451454
returns_overrides = attrs.get('returns', None)
455+
output_fields = attrs.get('output_fields', None)
452456

453457
spec_diff = set(arg_names).difference(set(annotations.keys()))
454458

@@ -499,6 +503,22 @@ def get_signature(func: Callable[..., Any], name: Optional[str] = None) -> Dict[
499503
if isinstance(returns_overrides, str):
500504
sql = returns_overrides
501505
out_type = sql_to_dtype(sql)
506+
elif isinstance(returns_overrides, list):
507+
sqls = []
508+
out_types = []
509+
for i, item in enumerate(returns_overrides):
510+
if not isinstance(item, str):
511+
raise TypeError(f'unrecognized type for return value: {item}')
512+
if output_fields:
513+
sqls.append(f'`{output_fields[i]}` {item}')
514+
else:
515+
sqls.append(f'{string.ascii_letters[i]} {item}')
516+
out_types.append(sql_to_dtype(item))
517+
if function_type == 'tvf':
518+
sql = 'TABLE({})'.format(', '.join(sqls))
519+
else:
520+
sql = 'RECORD({})'.format(', '.join(sqls))
521+
out_type = 'tuple[{}]'.format(','.join(out_types))
502522
elif returns_overrides is not None and not isinstance(returns_overrides, str):
503523
raise TypeError(f'unrecognized type for return value: {returns_overrides}')
504524
else:

singlestoredb/tests/test_udf.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -391,16 +391,6 @@ def foo(x: int, y: float, z: str) -> int: ...
391391
'`y` DOUBLE NOT NULL, ' \
392392
'`z` CHAR(30) NULL) RETURNS SMALLINT NOT NULL'
393393

394-
# Override parameter with incorrect type
395-
with self.assertRaises(TypeError):
396-
@udf(args=dict(x=int))
397-
def foo(x: int, y: float, z: str) -> int: ...
398-
399-
# Override return value with incorrect type
400-
with self.assertRaises(TypeError):
401-
@udf(returns=int)
402-
def foo(x: int, y: float, z: str) -> int: ...
403-
404394
# Change function name
405395
@udf(name='hello_world')
406396
def foo(x: int) -> int: ...

0 commit comments

Comments
 (0)