Skip to content

Commit 3d0e7b8

Browse files
kesmit13claude
andauthored
Add float16 (half-precision) vector support (#117)
* Add float16 (half-precision) vector support This commit adds comprehensive support for float16 (F16) vectors to both MySQL and HTTP clients in the SingleStoreDB Python SDK. Changes: - Add FLOAT16 = 7 as the 7th vector type constant - Add FIELD_TYPE constants: FLOAT16_VECTOR_JSON (2007) and FLOAT16_VECTOR (3007) - Update protocol parser to recognize and handle FLOAT16 vector metadata - Add float16_vector_json_or_none() and float16_vector_or_none() converters - Register float16 converters in the converters dictionary (types 2007 and 3007) - Add FLOAT16 vector types to TEXT_TYPES set for proper type handling - Update C accelerator with full float16 support: - Add FLOAT16 constants and PyStrings struct member - Update type arrays with 'e' format (2 bytes) for struct.unpack - Add float16 to JSON and binary vector case statements - Initialize numpy dtype kwargs for float16 - Add comprehensive tests: - Create f16_vectors test table with 3 test rows - Implement test_f16_vectors() method following existing patterns - Use assert_array_almost_equal with decimal=2 for float16 precision Technical notes: - Float16 has ~3 decimal digits of precision (vs ~7 for float32) - Uses struct format 'e' for half-precision (2 bytes per element) - Supports both JSON and binary wire formats - All pre-commit hooks passed (flake8, autopep8, mypy) Co-Authored-By: Claude Sonnet 4.5 <[email protected]> * Add float16 (half-precision) support to UDF system Extends the UDF framework to handle float16/half-precision types: - Add F16 vector type constant and export - Map float16 to FLOAT SQL type in signatures - Add Float16Array/HalfArray type aliases - Implement F16 numpy and struct format conversions - Add comprehensive float16 UDF tests This complements protocol-level float16 vector support (5377083). Co-Authored-By: Claude Sonnet 4.5 <[email protected]> * Add server version checks and conditional loading for float16 tests This commit implements a generalized version-based SQL loading system for tests and adds server version checks for float16 vector tests. Changes: - Add server version check to test_f16_vectors() that skips the test if server version < 9.1 with a descriptive message - Create test_9_1.sql containing float16 vector test data - Implement generalized version-specific SQL file loading in utils.py: - get_server_version(): Extract server version as (major, minor) tuple - find_version_specific_sql_files(): Discover test_X_Y.sql files - load_version_specific_sql(): Conditionally load SQL based on version - Update load_sql() to automatically load version-specific SQL files - Remove f16_vectors table from test.sql (now in test_9_1.sql) The new system automatically discovers and loads SQL files matching the pattern test_X_Y.sql where X is major version and Y is minor version. Files are loaded only if the server version >= the file version. This makes it easy to add version-specific test data in the future (e.g., test_9_2.sql, test_10_0.sql) without modifying Python code. Co-Authored-By: Claude Sonnet 4.5 <[email protected]> --------- Co-authored-by: Claude Sonnet 4.5 <[email protected]>
1 parent 37f0e64 commit 3d0e7b8

File tree

15 files changed

+300
-6
lines changed

15 files changed

+300
-6
lines changed

accel.c

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,14 @@
9191
#define MYSQL_TYPE_INT16_VECTOR_JSON 2004
9292
#define MYSQL_TYPE_INT32_VECTOR_JSON 2005
9393
#define MYSQL_TYPE_INT64_VECTOR_JSON 2006
94+
#define MYSQL_TYPE_FLOAT16_VECTOR_JSON 2007
9495
#define MYSQL_TYPE_FLOAT32_VECTOR 3001
9596
#define MYSQL_TYPE_FLOAT64_VECTOR 3002
9697
#define MYSQL_TYPE_INT8_VECTOR 3003
9798
#define MYSQL_TYPE_INT16_VECTOR 3004
9899
#define MYSQL_TYPE_INT32_VECTOR 3005
99100
#define MYSQL_TYPE_INT64_VECTOR 3006
101+
#define MYSQL_TYPE_FLOAT16_VECTOR 3007
100102

101103
#define MYSQL_TYPE_CHAR MYSQL_TYPE_TINY
102104
#define MYSQL_TYPE_INTERVAL MYSQL_TYPE_ENUM
@@ -503,6 +505,7 @@ typedef struct {
503505
PyObject *int64;
504506
PyObject *float32;
505507
PyObject *float64;
508+
PyObject *float16;
506509
PyObject *unpack;
507510
PyObject *decode;
508511
PyObject *frombuffer;
@@ -541,7 +544,7 @@ typedef struct {
541544
PyObject *namedtuple_kwargs;
542545
PyObject *create_numpy_array_args;
543546
PyObject *create_numpy_array_kwargs;
544-
PyObject *create_numpy_array_kwargs_vector[7];
547+
PyObject *create_numpy_array_kwargs_vector[8];
545548
PyObject *struct_unpack_args;
546549
PyObject *bson_decode_args;
547550
} PyObjects;
@@ -1565,8 +1568,8 @@ static PyObject *read_row_from_packet(
15651568
PyObject *py_str = NULL;
15661569
PyObject *py_memview = NULL;
15671570
char end = '\0';
1568-
char *cast_type_codes[] = {"", "f", "d", "b", "h", "i", "q"};
1569-
int item_type_lengths[] = {0, 4, 8, 1, 2, 4, 8};
1571+
char *cast_type_codes[] = {"", "f", "d", "b", "h", "i", "q", "e"};
1572+
int item_type_lengths[] = {0, 4, 8, 1, 2, 4, 8, 2};
15701573

15711574
int sign = 1;
15721575
int year = 0;
@@ -1826,6 +1829,7 @@ static PyObject *read_row_from_packet(
18261829
case MYSQL_TYPE_INT16_VECTOR_JSON:
18271830
case MYSQL_TYPE_INT32_VECTOR_JSON:
18281831
case MYSQL_TYPE_INT64_VECTOR_JSON:
1832+
case MYSQL_TYPE_FLOAT16_VECTOR_JSON:
18291833
if (!py_state->encodings[i]) {
18301834
py_item = PyBytes_FromStringAndSize(out, out_l);
18311835
if (!py_item) goto error;
@@ -1847,7 +1851,7 @@ static PyObject *read_row_from_packet(
18471851
// Parse JSON string.
18481852
if ((py_state->type_codes[i] == MYSQL_TYPE_JSON && py_state->options.parse_json)
18491853
|| (py_state->type_codes[i] >= MYSQL_TYPE_FLOAT32_VECTOR_JSON
1850-
&& py_state->type_codes[i] <= MYSQL_TYPE_INT64_VECTOR_JSON)) {
1854+
&& py_state->type_codes[i] <= MYSQL_TYPE_FLOAT16_VECTOR_JSON)) {
18511855
py_str = py_item;
18521856
py_item = PyObject_CallFunctionObjArgs(PyFunc.json_loads, py_str, NULL);
18531857
Py_CLEAR(py_str);
@@ -1862,6 +1866,7 @@ static PyObject *read_row_from_packet(
18621866
case MYSQL_TYPE_INT16_VECTOR_JSON:
18631867
case MYSQL_TYPE_INT32_VECTOR_JSON:
18641868
case MYSQL_TYPE_INT64_VECTOR_JSON:
1869+
case MYSQL_TYPE_FLOAT16_VECTOR_JSON:
18651870
CHECKRC(PyTuple_SetItem(PyObj.create_numpy_array_args, 0, py_item));
18661871
py_item = PyObject_Call(
18671872
PyFunc.numpy_array,
@@ -1880,6 +1885,7 @@ static PyObject *read_row_from_packet(
18801885
case MYSQL_TYPE_INT16_VECTOR:
18811886
case MYSQL_TYPE_INT32_VECTOR:
18821887
case MYSQL_TYPE_INT64_VECTOR:
1888+
case MYSQL_TYPE_FLOAT16_VECTOR:
18831889
{
18841890
int type_idx = py_state->type_codes[i] % 1000;
18851891

@@ -4844,6 +4850,7 @@ PyMODINIT_FUNC PyInit__singlestoredb_accel(void) {
48444850
PyStr.int64 = PyUnicode_FromString("int64");
48454851
PyStr.float32 = PyUnicode_FromString("float32");
48464852
PyStr.float64 = PyUnicode_FromString("float64");
4853+
PyStr.float16 = PyUnicode_FromString("float16");
48474854
PyStr.unpack = PyUnicode_FromString("unpack");
48484855
PyStr.decode = PyUnicode_FromString("decode");
48494856
PyStr.frombuffer = PyUnicode_FromString("frombuffer");
@@ -4921,6 +4928,11 @@ PyMODINIT_FUNC PyInit__singlestoredb_accel(void) {
49214928
if (PyDict_SetItemString(PyObj.create_numpy_array_kwargs_vector[6], "dtype", PyStr.int64)) {
49224929
goto error;
49234930
}
4931+
PyObj.create_numpy_array_kwargs_vector[7] = PyDict_New();
4932+
if (!PyObj.create_numpy_array_kwargs_vector[7]) goto error;
4933+
if (PyDict_SetItemString(PyObj.create_numpy_array_kwargs_vector[7], "dtype", PyStr.float16)) {
4934+
goto error;
4935+
}
49244936

49254937
PyObj.struct_unpack_args = PyTuple_New(2);
49264938
if (!PyObj.struct_unpack_args) goto error;

singlestoredb/converters.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,62 @@ def float32_vector_or_none(x: Optional[bytes]) -> Optional[Any]:
597597
return struct.unpack(f'<{len(x)//4}f', x)
598598

599599

600+
def float16_vector_json_or_none(x: Optional[str]) -> Optional[Any]:
601+
"""
602+
Convert value to float16 array.
603+
604+
Parameters
605+
----------
606+
x : str or None
607+
JSON array
608+
609+
Returns
610+
-------
611+
float16 numpy array
612+
If input value is not None and numpy is installed
613+
float Python list
614+
If input value is not None and numpy is not installed
615+
None
616+
If input value is None
617+
618+
"""
619+
if x is None:
620+
return None
621+
622+
if has_numpy:
623+
return numpy.array(json_loads(x), dtype=numpy.float16)
624+
625+
return map(float, json_loads(x))
626+
627+
628+
def float16_vector_or_none(x: Optional[bytes]) -> Optional[Any]:
629+
"""
630+
Convert value to float16 array.
631+
632+
Parameters
633+
----------
634+
x : bytes or None
635+
Little-endian block of bytes.
636+
637+
Returns
638+
-------
639+
float16 numpy array
640+
If input value is not None and numpy is installed
641+
float Python list
642+
If input value is not None and numpy is not installed
643+
None
644+
If input value is None
645+
646+
"""
647+
if x is None:
648+
return None
649+
650+
if has_numpy:
651+
return numpy.frombuffer(x, dtype=numpy.float16)
652+
653+
return struct.unpack(f'<{len(x)//2}e', x)
654+
655+
600656
def float64_vector_json_or_none(x: Optional[str]) -> Optional[Any]:
601657
"""
602658
Covert value to float64 array.
@@ -941,10 +997,12 @@ def bson_or_none(x: Optional[bytes]) -> Optional[Any]:
941997
2004: int16_vector_json_or_none,
942998
2005: int32_vector_json_or_none,
943999
2006: int64_vector_json_or_none,
1000+
2007: float16_vector_json_or_none,
9441001
3001: float32_vector_or_none,
9451002
3002: float64_vector_or_none,
9461003
3003: int8_vector_or_none,
9471004
3004: int16_vector_or_none,
9481005
3005: int32_vector_or_none,
9491006
3006: int64_vector_or_none,
1007+
3007: float16_vector_or_none,
9501008
}

singlestoredb/functions/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .utils import VectorTypes
99

1010

11+
F16 = VectorTypes.F16
1112
F32 = VectorTypes.F32
1213
F64 = VectorTypes.F64
1314
I8 = VectorTypes.I8

singlestoredb/functions/signature.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ class NoDefaultType:
135135
'uint16': 'SMALLINT UNSIGNED',
136136
'uint32': 'INT UNSIGNED',
137137
'uint64': 'BIGINT UNSIGNED',
138+
'float16': 'FLOAT',
138139
'float32': 'FLOAT',
139140
'float64': 'DOUBLE',
140141
'str': 'TEXT',

singlestoredb/functions/typing/numpy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
StringArray = StrArray = npt.NDArray[np.str_]
77
BytesArray = npt.NDArray[np.bytes_]
8+
Float16Array = HalfArray = npt.NDArray[np.float16]
89
Float32Array = FloatArray = npt.NDArray[np.float32]
910
Float64Array = DoubleArray = npt.NDArray[np.float64]
1011
BoolArray = npt.NDArray[np.bool_]

singlestoredb/functions/utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,9 @@ def _vector_type_to_numpy_type(
198198
vector_type: VectorTypes,
199199
) -> str:
200200
"""Convert a vector type to a numpy type."""
201-
if vector_type == VectorTypes.F32:
201+
if vector_type == VectorTypes.F16:
202+
return 'f2'
203+
elif vector_type == VectorTypes.F32:
202204
return 'f4'
203205
elif vector_type == VectorTypes.F64:
204206
return 'f8'
@@ -219,7 +221,11 @@ def _vector_type_to_struct_format(
219221
) -> str:
220222
"""Convert a vector type to a struct format string."""
221223
n = len(vec)
222-
if vector_type == VectorTypes.F32:
224+
if vector_type == VectorTypes.F16:
225+
if isinstance(vec, (bytes, bytearray)):
226+
n = n // 2
227+
return f'<{n}e'
228+
elif vector_type == VectorTypes.F32:
223229
if isinstance(vec, (bytes, bytearray)):
224230
n = n // 4
225231
return f'<{n}f'

singlestoredb/mysql/connection.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,12 +110,14 @@
110110
FIELD_TYPE.INT16_VECTOR_JSON,
111111
FIELD_TYPE.INT32_VECTOR_JSON,
112112
FIELD_TYPE.INT64_VECTOR_JSON,
113+
FIELD_TYPE.FLOAT16_VECTOR_JSON,
113114
FIELD_TYPE.FLOAT32_VECTOR,
114115
FIELD_TYPE.FLOAT64_VECTOR,
115116
FIELD_TYPE.INT8_VECTOR,
116117
FIELD_TYPE.INT16_VECTOR,
117118
FIELD_TYPE.INT32_VECTOR,
118119
FIELD_TYPE.INT64_VECTOR,
120+
FIELD_TYPE.FLOAT16_VECTOR,
119121
}
120122

121123
UNSET = 'unset'

singlestoredb/mysql/constants/FIELD_TYPE.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,11 @@
4040
INT16_VECTOR_JSON = 2004
4141
INT32_VECTOR_JSON = 2005
4242
INT64_VECTOR_JSON = 2006
43+
FLOAT16_VECTOR_JSON = 2007
4344
FLOAT32_VECTOR = 3001
4445
FLOAT64_VECTOR = 3002
4546
INT8_VECTOR = 3003
4647
INT16_VECTOR = 3004
4748
INT32_VECTOR = 3005
4849
INT64_VECTOR = 3006
50+
FLOAT16_VECTOR = 3007

singlestoredb/mysql/constants/VECTOR_TYPE.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
INT16 = 4
55
INT32 = 5
66
INT64 = 6
7+
FLOAT16 = 7

singlestoredb/mysql/protocol.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,11 @@ def _parse_field_descriptor(self, encoding):
318318
self.type_code = FIELD_TYPE.INT64_VECTOR
319319
else:
320320
self.type_code = FIELD_TYPE.INT64_VECTOR_JSON
321+
elif vec_type == VECTOR_TYPE.FLOAT16:
322+
if self.charsetnr == 63:
323+
self.type_code = FIELD_TYPE.FLOAT16_VECTOR
324+
else:
325+
self.type_code = FIELD_TYPE.FLOAT16_VECTOR_JSON
321326
else:
322327
raise TypeError(f'unrecognized vector data type: {vec_type}')
323328
else:

0 commit comments

Comments
 (0)