Skip to content

Commit c7d1f99

Browse files
committed
Add object support for UDF parameters and returns
1 parent 8a109ad commit c7d1f99

File tree

10 files changed

+703
-303
lines changed

10 files changed

+703
-303
lines changed

accel.c

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4108,6 +4108,7 @@ static PyObject *load_rowdat_1(PyObject *self, PyObject *args, PyObject *kwargs)
41084108
PyObject *py_colspec = NULL;
41094109
PyObject *py_str = NULL;
41104110
PyObject *py_blob = NULL;
4111+
PyObject **py_transformers = NULL;
41114112
Py_ssize_t length = 0;
41124113
uint64_t row_id = 0;
41134114
uint8_t is_null = 0;
@@ -4138,13 +4139,23 @@ static PyObject *load_rowdat_1(PyObject *self, PyObject *args, PyObject *kwargs)
41384139

41394140
colspec_l = PyObject_Length(py_colspec);
41404141
ctypes = malloc(sizeof(int) * colspec_l);
4142+
py_transformers = calloc(sizeof(PyObject*), colspec_l);
41414143

41424144
for (i = 0; i < colspec_l; i++) {
41434145
PyObject *py_cspec = PySequence_GetItem(py_colspec, i);
41444146
if (!py_cspec) goto error;
41454147
PyObject *py_ctype = PySequence_GetItem(py_cspec, 1);
41464148
if (!py_ctype) { Py_DECREF(py_cspec); goto error; }
41474149
ctypes[i] = (int)PyLong_AsLong(py_ctype);
4150+
py_transformers[i] = PySequence_GetItem(py_cspec, 2);
4151+
if (!py_transformers[i]) {
4152+
Py_DECREF(py_ctype);
4153+
Py_DECREF(py_cspec);
4154+
goto error;
4155+
}
4156+
if (py_transformers[i] == Py_None) {
4157+
py_transformers[i] = NULL;
4158+
}
41484159
Py_DECREF(py_ctype);
41494160
Py_DECREF(py_cspec);
41504161
if (PyErr_Occurred()) { goto error; }
@@ -4380,6 +4391,14 @@ static PyObject *load_rowdat_1(PyObject *self, PyObject *args, PyObject *kwargs)
43804391
default:
43814392
goto error;
43824393
}
4394+
4395+
if (py_transformers[i]) {
4396+
PyObject *py_item = PyTuple_GetItem(py_row, i);
4397+
PyObject *py_transformed = PyObject_CallFunction(py_transformers[i], "O", py_item);
4398+
if (!py_transformed) goto error;
4399+
Py_DECREF(py_item);
4400+
CHECKRC(PyTuple_SetItem(py_row, i, py_transformed));
4401+
}
43834402
}
43844403

43854404
CHECKRC(PyList_Append(py_out_rows, py_row));
@@ -4389,6 +4408,12 @@ static PyObject *load_rowdat_1(PyObject *self, PyObject *args, PyObject *kwargs)
43894408

43904409
exit:
43914410
if (ctypes) free(ctypes);
4411+
if (py_transformers) {
4412+
for (i = 0; i < colspec_l; i++) {
4413+
Py_XDECREF(py_transformers[i]);
4414+
}
4415+
free(py_transformers);
4416+
}
43924417

43934418
Py_XDECREF(py_row);
43944419

@@ -4412,6 +4437,7 @@ static PyObject *dump_rowdat_1(PyObject *self, PyObject *args, PyObject *kwargs)
44124437
PyObject *py_row_ids = NULL;
44134438
PyObject *py_row_ids_iter = NULL;
44144439
PyObject *py_item = NULL;
4440+
PyObject **py_transformers = NULL;
44154441
uint64_t row_id = 0;
44164442
uint8_t is_null = 0;
44174443
int8_t i8 = 0;
@@ -4459,12 +4485,26 @@ static PyObject *dump_rowdat_1(PyObject *self, PyObject *args, PyObject *kwargs)
44594485

44604486
returns = malloc(sizeof(int) * n_cols);
44614487
if (!returns) goto error;
4488+
py_transformers = calloc(sizeof(PyObject*), n_cols);
4489+
if (!py_transformers) goto error;
44624490

44634491
for (i = 0; i < n_cols; i++) {
4464-
PyObject *py_item = PySequence_GetItem(py_returns, i);
4465-
if (!py_item) goto error;
4466-
returns[i] = (int)PyLong_AsLong(py_item);
4467-
Py_DECREF(py_item);
4492+
PyObject *py_cspec = PySequence_GetItem(py_returns, i);
4493+
if (!py_cspec) goto error;
4494+
PyObject *py_ctype = PySequence_GetItem(py_cspec, 1);
4495+
if (!py_ctype) { Py_DECREF(py_cspec); goto error; }
4496+
returns[i] = (int)PyLong_AsLong(py_ctype);
4497+
py_transformers[i] = PySequence_GetItem(py_cspec, 2);
4498+
if (!py_transformers[i]) {
4499+
Py_DECREF(py_ctype);
4500+
Py_DECREF(py_cspec);
4501+
goto error;
4502+
}
4503+
if (py_transformers[i] == Py_None) {
4504+
py_transformers[i] = NULL;
4505+
}
4506+
Py_DECREF(py_ctype);
4507+
Py_DECREF(py_cspec);
44684508
if (PyErr_Occurred()) { goto error; }
44694509
}
44704510

@@ -4504,6 +4544,13 @@ static PyObject *dump_rowdat_1(PyObject *self, PyObject *args, PyObject *kwargs)
45044544
memcpy(out+out_idx, &is_null, 1);
45054545
out_idx += 1;
45064546

4547+
if (py_transformers[i]) {
4548+
PyObject *py_transformed = PyObject_CallFunction(py_transformers[i], "O", py_item);
4549+
if (!py_transformed) goto error;
4550+
Py_DECREF(py_item);
4551+
py_item = py_transformed;
4552+
}
4553+
45074554
switch (returns[i]) {
45084555
case MYSQL_TYPE_BIT:
45094556
// TODO
@@ -4702,6 +4749,12 @@ static PyObject *dump_rowdat_1(PyObject *self, PyObject *args, PyObject *kwargs)
47024749

47034750
exit:
47044751
if (returns) free(returns);
4752+
if (py_transformers) {
4753+
for (i = 0; i < n_cols; i++) {
4754+
Py_XDECREF(py_transformers[i]);
4755+
}
4756+
free(py_transformers);
4757+
}
47054758

47064759
Py_XDECREF(py_item);
47074760
Py_XDECREF(py_row_iter);

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ exclude =
8686
docs/*
8787
resources/*
8888
licenses/*
89-
max-complexity = 45
89+
max-complexity = 50
9090
max-line-length = 90
9191
per-file-ignores =
9292
singlestoredb/__init__.py:F401

singlestoredb/functions/decorator.py

Lines changed: 18 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
UDFType = Callable[..., Any]
2424

2525

26-
def is_valid_type(obj: Any) -> bool:
26+
def is_valid_object_type(obj: Any) -> bool:
2727
"""Check if the object is a valid type for a schema definition."""
2828
if not inspect.isclass(obj):
2929
return False
@@ -52,48 +52,29 @@ def is_valid_callable(obj: Any) -> bool:
5252

5353
returns = utils.get_annotations(obj).get('return', None)
5454

55-
if inspect.isclass(returns) and issubclass(returns, str):
55+
if inspect.isclass(returns) and issubclass(returns, SQLString):
5656
return True
5757

58-
raise TypeError(
59-
f'callable {obj} must return a str, '
60-
f'but got {returns}',
61-
)
58+
return False
6259

6360

64-
def expand_types(args: Any) -> Optional[Union[List[str], Type[Any]]]:
61+
def expand_types(args: Any) -> List[Any]:
6562
"""Expand the types for the function arguments / return values."""
6663
if args is None:
67-
return None
68-
69-
# SQL string
70-
if isinstance(args, str):
71-
return [args]
72-
73-
# General way of accepting pydantic.BaseModel, NamedTuple, TypedDict
74-
elif is_valid_type(args):
75-
return args
76-
77-
# List of SQL strings or callables
78-
elif isinstance(args, list):
79-
new_args = []
80-
for arg in args:
81-
if isinstance(arg, str):
82-
new_args.append(arg)
83-
elif callable(arg):
84-
new_args.append(arg())
85-
else:
86-
raise TypeError(f'unrecognized type for parameter: {arg}')
87-
return new_args
88-
89-
# Callable that returns a SQL string
90-
elif is_valid_callable(args):
91-
out = args()
92-
if not isinstance(out, str):
93-
raise TypeError(f'unrecognized type for parameter: {args}')
94-
return [out]
95-
96-
raise TypeError(f'unrecognized type for parameter: {args}')
64+
return []
65+
66+
if not isinstance(args, list):
67+
args = [args]
68+
69+
new_args = []
70+
for arg in args:
71+
if isinstance(arg, str):
72+
new_args.append(arg)
73+
elif is_valid_callable(arg):
74+
new_args.append(arg())
75+
else:
76+
new_args.append(arg)
77+
return new_args
9778

9879

9980
def _func(

singlestoredb/functions/ext/asgi.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ async def to_thread(
126126
'float64': ft.DOUBLE,
127127
'str': ft.STRING,
128128
'bytes': -ft.STRING,
129+
'json': ft.STRING,
129130
}
130131

131132

@@ -586,7 +587,11 @@ def make_func(
586587
dtype = x['dtype'].replace('?', '')
587588
if dtype not in rowdat_1_type_map:
588589
raise TypeError(f'no data type mapping for {dtype}')
589-
colspec.append((x['name'], rowdat_1_type_map[dtype]))
590+
colspec.append((
591+
x['name'],
592+
rowdat_1_type_map[dtype],
593+
x.get('transformer', None),
594+
))
590595
info['colspec'] = colspec
591596

592597
# Setup return type
@@ -595,7 +600,11 @@ def make_func(
595600
dtype = x['dtype'].replace('?', '')
596601
if dtype not in rowdat_1_type_map:
597602
raise TypeError(f'no data type mapping for {dtype}')
598-
returns.append((x['name'], rowdat_1_type_map[dtype]))
603+
returns.append((
604+
x['name'],
605+
rowdat_1_type_map[dtype],
606+
x.get('transformer', None),
607+
))
599608
info['returns'] = returns
600609

601610
return do_func, info
@@ -1084,7 +1093,7 @@ async def __call__(
10841093

10851094
with timer('format_output'):
10861095
body = output_handler['dump'](
1087-
[x[1] for x in func_info['returns']], *result, # type: ignore
1096+
func_info['returns'], *result, # type: ignore
10881097
)
10891098

10901099
await send(output_handler['response'])

0 commit comments

Comments
 (0)