Skip to content

Commit b112ba0

Browse files
committed
Fix tests
1 parent 56ab367 commit b112ba0

File tree

3 files changed

+54
-57
lines changed

3 files changed

+54
-57
lines changed

singlestoredb/functions/signature.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1139,19 +1139,19 @@ def get_schema(
11391139
for i, x in enumerate(typing.get_args(spec)):
11401140
params, out_data_format, _ = get_schema(
11411141
unpack_masked_type(x),
1142-
overrides=overrides if overrides else None,
1142+
overrides=[overrides[i]] if overrides else None,
11431143
# Always pass UDF mode for individual items
11441144
mode=mode,
11451145
)
11461146

11471147
# Use the name from the overrides if specified
11481148
if overrides:
1149-
if overrides[i] and not params[i].name:
1150-
params[i].name = overrides[i].name
1149+
if overrides[i] and not params[0].name:
1150+
params[0].name = overrides[i].name
11511151
elif not overrides[i].name:
1152-
params[i].name = f'{string.ascii_letters[i]}'
1152+
params[0].name = f'{string.ascii_letters[i]}'
11531153

1154-
colspec.append(params[i])
1154+
colspec.append(params[0])
11551155
out_data_formats.append(out_data_format)
11561156

11571157
# Make sure that all the data formats are the same

singlestoredb/functions/transformers.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ def json_to_dict(cls: Type[T], json_value: str) -> Dict[str, Any]:
2020
2121
Parameters
2222
----------
23-
json_value : str or dict
24-
The JSON string or dictionary representing the object.
23+
json_value : str
24+
The JSON string representing the object.
2525
2626
Returns
2727
-------
@@ -40,8 +40,8 @@ def json_to_pydantic(cls: Type[T], json_value: str) -> T:
4040
----------
4141
cls : Type[T]
4242
The Pydantic model type to instantiate.
43-
json_value : str or dict
44-
The JSON string or dictionary representing the object.
43+
json_value : str
44+
The JSON string representing the object.
4545
4646
Returns
4747
-------
@@ -60,8 +60,8 @@ def json_to_namedtuple(cls: Type[T], json_value: str) -> T:
6060
----------
6161
cls : Type[T]
6262
The namedtuple type to instantiate.
63-
json_value : str or dict
64-
The JSON string or dictionary representing the object.
63+
json_value : str
64+
The JSON string representing the object.
6565
6666
Returns
6767
-------
@@ -92,8 +92,8 @@ def json_to_typeddict(cls: Type[T], json_value: str) -> Dict[str, Any]:
9292
----------
9393
cls : Type[T]
9494
The TypedDict type to instantiate.
95-
json_value : str or dict
96-
The JSON string or dictionary representing the object.
95+
json_value : str
96+
The JSON string representing the object.
9797
9898
Returns
9999
-------
@@ -159,8 +159,8 @@ def json_to_pandas_dataframe(cls: Type[T], json_value: str) -> T:
159159
----------
160160
cls : Type[T]
161161
The DataFrame type to instantiate.
162-
json_value : str or dict
163-
The JSON string or dictionary representing the object.
162+
json_value : str
163+
The JSON string representing the object.
164164
165165
Returns
166166
-------

singlestoredb/tests/test_ext_func_data.py

Lines changed: 39 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -31,25 +31,22 @@
3131
BINARY = -254
3232

3333
col_spec = [
34-
('tiny', TINYINT),
35-
('unsigned_tiny', UNSIGNED_TINYINT),
36-
('short', SMALLINT),
37-
('unsigned_short', UNSIGNED_SMALLINT),
38-
('long', INT),
39-
('unsigned_long', UNSIGNED_INT),
40-
('float', FLOAT),
41-
('double', DOUBLE),
42-
('longlong', BIGINT),
43-
('unsigned_longlong', UNSIGNED_BIGINT),
44-
('int24', MEDIUMINT),
45-
('unsigned_int24', UNSIGNED_MEDIUMINT),
46-
('string', STRING),
47-
('binary', BINARY),
34+
('tiny', TINYINT, None),
35+
('unsigned_tiny', UNSIGNED_TINYINT, None),
36+
('short', SMALLINT, None),
37+
('unsigned_short', UNSIGNED_SMALLINT, None),
38+
('long', INT, None),
39+
('unsigned_long', UNSIGNED_INT, None),
40+
('float', FLOAT, None),
41+
('double', DOUBLE, None),
42+
('longlong', BIGINT, None),
43+
('unsigned_longlong', UNSIGNED_BIGINT, None),
44+
('int24', MEDIUMINT, None),
45+
('unsigned_int24', UNSIGNED_MEDIUMINT, None),
46+
('string', STRING, None),
47+
('binary', BINARY, None),
4848
]
4949

50-
col_types = [x[1] for x in col_spec]
51-
col_names = [x[0] for x in col_spec]
52-
5350
numpy_row_ids = np.array([1, 2, 3, 4])
5451
numpy_nulls = np.array([False, False, False, True])
5552

@@ -268,7 +265,7 @@ class TestRowdat1(unittest.TestCase):
268265

269266
def test_numpy_accel(self):
270267
dump_res = rowdat_1._dump_numpy_accel(
271-
col_types, numpy_row_ids, numpy_data,
268+
col_spec, numpy_row_ids, numpy_data,
272269
).tobytes()
273270
load_res = rowdat_1._load_numpy_accel(col_spec, dump_res)
274271

@@ -293,7 +290,7 @@ def test_numpy_accel(self):
293290

294291
def test_numpy(self):
295292
dump_res = rowdat_1._dump_numpy(
296-
col_types, numpy_row_ids, numpy_data,
293+
col_spec, numpy_row_ids, numpy_data,
297294
).tobytes()
298295
load_res = rowdat_1._load_numpy(col_spec, dump_res)
299296

@@ -386,7 +383,7 @@ def test_numpy_accel_limits(self, name, dtype, data, res):
386383
# Accelerated
387384
with self.assertRaises(res, msg=f'Expected {res} for {data} in {dtype}'):
388385
rowdat_1._dump_numpy_accel(
389-
[dtype], numpy_row_ids, [(arr, None)],
386+
[('x', dtype, None)], numpy_row_ids, [(arr, None)],
390387
).tobytes()
391388

392389
# Pure Python
@@ -395,23 +392,23 @@ def test_numpy_accel_limits(self, name, dtype, data, res):
395392
else:
396393
with self.assertRaises(res, msg=f'Expected {res} for {data} in {dtype}'):
397394
rowdat_1._dump_numpy(
398-
[dtype], numpy_row_ids, [(arr, None)],
395+
[('x', dtype, None)], numpy_row_ids, [(arr, None)],
399396
).tobytes()
400397

401398
else:
402399
# Accelerated
403400
dump_res = rowdat_1._dump_numpy_accel(
404-
[dtype], numpy_row_ids, [(arr, None)],
401+
[('x', dtype, None)], numpy_row_ids, [(arr, None)],
405402
).tobytes()
406-
load_res = rowdat_1._load_numpy_accel([('x', dtype)], dump_res)
403+
load_res = rowdat_1._load_numpy_accel([('x', dtype, None)], dump_res)
407404
assert load_res[1][0][0] == res, \
408405
f'Expected {res} for {data}, but got {load_res[1][0][0]} in {dtype}'
409406

410407
# Pure Python
411408
dump_res = rowdat_1._dump_numpy(
412-
[dtype], numpy_row_ids, [(arr, None)],
409+
[('x', dtype, None)], numpy_row_ids, [(arr, None)],
413410
).tobytes()
414-
load_res = rowdat_1._load_numpy([('x', dtype)], dump_res)
411+
load_res = rowdat_1._load_numpy([('x', dtype, None)], dump_res)
415412
assert load_res[1][0][0] == res, \
416413
f'Expected {res} for {data}, but got {load_res[1][0][0]} in {dtype}'
417414

@@ -787,9 +784,9 @@ def test_numpy_accel_casts(self, name, dtype, data, res):
787784

788785
# Accelerated
789786
dump_res = rowdat_1._dump_numpy_accel(
790-
[dtype], numpy_row_ids, [(data, None)],
787+
[('x', dtype, None)], numpy_row_ids, [(data, None)],
791788
).tobytes()
792-
load_res = rowdat_1._load_numpy_accel([('x', dtype)], dump_res)
789+
load_res = rowdat_1._load_numpy_accel([('x', dtype, None)], dump_res)
793790

794791
if name == 'double from float32':
795792
assert load_res[1][0][0].dtype is res.dtype
@@ -799,9 +796,9 @@ def test_numpy_accel_casts(self, name, dtype, data, res):
799796

800797
# Pure Python
801798
dump_res = rowdat_1._dump_numpy(
802-
[dtype], numpy_row_ids, [(data, None)],
799+
[('x', dtype, None)], numpy_row_ids, [(data, None)],
803800
).tobytes()
804-
load_res = rowdat_1._load_numpy([('x', dtype)], dump_res)
801+
load_res = rowdat_1._load_numpy([('x', dtype, None)], dump_res)
805802

806803
if name == 'double from float32':
807804
assert load_res[1][0][0].dtype is res.dtype
@@ -811,7 +808,7 @@ def test_numpy_accel_casts(self, name, dtype, data, res):
811808

812809
def test_python(self):
813810
dump_res = rowdat_1._dump(
814-
col_types, py_row_ids, py_col_data,
811+
col_spec, py_row_ids, py_col_data,
815812
).tobytes()
816813
load_res = rowdat_1._load(col_spec, dump_res)
817814

@@ -823,7 +820,7 @@ def test_python(self):
823820

824821
def test_python_accel(self):
825822
dump_res = rowdat_1._dump_accel(
826-
col_types, py_row_ids, py_col_data,
823+
col_spec, py_row_ids, py_col_data,
827824
).tobytes()
828825
load_res = rowdat_1._load_accel(col_spec, dump_res)
829826

@@ -835,7 +832,7 @@ def test_python_accel(self):
835832

836833
def test_polars(self):
837834
dump_res = rowdat_1._dump_polars(
838-
col_types, polars_row_ids, polars_data,
835+
col_spec, polars_row_ids, polars_data,
839836
).tobytes()
840837
load_res = rowdat_1._load_polars(col_spec, dump_res)
841838

@@ -860,7 +857,7 @@ def test_polars(self):
860857

861858
def test_polars_accel(self):
862859
dump_res = rowdat_1._dump_polars_accel(
863-
col_types, polars_row_ids, polars_data,
860+
col_spec, polars_row_ids, polars_data,
864861
).tobytes()
865862
load_res = rowdat_1._load_polars_accel(col_spec, dump_res)
866863

@@ -885,7 +882,7 @@ def test_polars_accel(self):
885882

886883
def test_pandas(self):
887884
dump_res = rowdat_1._dump_pandas(
888-
col_types, pandas_row_ids, pandas_data,
885+
col_spec, pandas_row_ids, pandas_data,
889886
).tobytes()
890887
load_res = rowdat_1._load_pandas(col_spec, dump_res)
891888

@@ -910,7 +907,7 @@ def test_pandas(self):
910907

911908
def test_pandas_accel(self):
912909
dump_res = rowdat_1._dump_pandas_accel(
913-
col_types, pandas_row_ids, pandas_data,
910+
col_spec, pandas_row_ids, pandas_data,
914911
).tobytes()
915912
load_res = rowdat_1._load_pandas_accel(col_spec, dump_res)
916913

@@ -935,7 +932,7 @@ def test_pandas_accel(self):
935932

936933
def test_pyarrow(self):
937934
dump_res = rowdat_1._dump_arrow(
938-
col_types, pyarrow_row_ids, pyarrow_data,
935+
col_spec, pyarrow_row_ids, pyarrow_data,
939936
).tobytes()
940937
load_res = rowdat_1._load_arrow(col_spec, dump_res)
941938

@@ -960,7 +957,7 @@ def test_pyarrow(self):
960957

961958
def test_pyarrow_accel(self):
962959
dump_res = rowdat_1._dump_arrow_accel(
963-
col_types, pyarrow_row_ids, pyarrow_data,
960+
col_spec, pyarrow_row_ids, pyarrow_data,
964961
).tobytes()
965962
load_res = rowdat_1._load_arrow_accel(col_spec, dump_res)
966963

@@ -988,7 +985,7 @@ class TestJSON(unittest.TestCase):
988985

989986
def test_numpy(self):
990987
dump_res = jsonx.dump_numpy(
991-
col_types, numpy_row_ids, numpy_data,
988+
col_spec, numpy_row_ids, numpy_data,
992989
)
993990
import pprint
994991
pprint.pprint(json.loads(dump_res))
@@ -1015,7 +1012,7 @@ def test_numpy(self):
10151012

10161013
def test_python(self):
10171014
dump_res = jsonx.dump(
1018-
col_types, py_row_ids, py_col_data,
1015+
col_spec, py_row_ids, py_col_data,
10191016
)
10201017
load_res = jsonx.load(col_spec, dump_res)
10211018

@@ -1027,7 +1024,7 @@ def test_python(self):
10271024

10281025
def test_polars(self):
10291026
dump_res = jsonx.dump_polars(
1030-
col_types, polars_row_ids, polars_data,
1027+
col_spec, polars_row_ids, polars_data,
10311028
)
10321029
load_res = jsonx.load_polars(col_spec, dump_res)
10331030

@@ -1052,7 +1049,7 @@ def test_polars(self):
10521049

10531050
def test_pandas(self):
10541051
dump_res = rowdat_1._dump_pandas(
1055-
col_types, pandas_row_ids, pandas_data,
1052+
col_spec, pandas_row_ids, pandas_data,
10561053
).tobytes()
10571054
load_res = rowdat_1._load_pandas(col_spec, dump_res)
10581055

@@ -1077,7 +1074,7 @@ def test_pandas(self):
10771074

10781075
def test_pyarrow(self):
10791076
dump_res = rowdat_1._dump_arrow(
1080-
col_types, pyarrow_row_ids, pyarrow_data,
1077+
col_spec, pyarrow_row_ids, pyarrow_data,
10811078
).tobytes()
10821079
load_res = rowdat_1._load_arrow(col_spec, dump_res)
10831080

0 commit comments

Comments
 (0)