diff --git a/accel.c b/accel.c index a785b9754..ee9a72f3f 100644 --- a/accel.c +++ b/accel.c @@ -91,12 +91,14 @@ #define MYSQL_TYPE_INT16_VECTOR_JSON 2004 #define MYSQL_TYPE_INT32_VECTOR_JSON 2005 #define MYSQL_TYPE_INT64_VECTOR_JSON 2006 +#define MYSQL_TYPE_FLOAT16_VECTOR_JSON 2007 #define MYSQL_TYPE_FLOAT32_VECTOR 3001 #define MYSQL_TYPE_FLOAT64_VECTOR 3002 #define MYSQL_TYPE_INT8_VECTOR 3003 #define MYSQL_TYPE_INT16_VECTOR 3004 #define MYSQL_TYPE_INT32_VECTOR 3005 #define MYSQL_TYPE_INT64_VECTOR 3006 +#define MYSQL_TYPE_FLOAT16_VECTOR 3007 #define MYSQL_TYPE_CHAR MYSQL_TYPE_TINY #define MYSQL_TYPE_INTERVAL MYSQL_TYPE_ENUM @@ -503,6 +505,7 @@ typedef struct { PyObject *int64; PyObject *float32; PyObject *float64; + PyObject *float16; PyObject *unpack; PyObject *decode; PyObject *frombuffer; @@ -541,7 +544,7 @@ typedef struct { PyObject *namedtuple_kwargs; PyObject *create_numpy_array_args; PyObject *create_numpy_array_kwargs; - PyObject *create_numpy_array_kwargs_vector[7]; + PyObject *create_numpy_array_kwargs_vector[8]; PyObject *struct_unpack_args; PyObject *bson_decode_args; } PyObjects; @@ -1565,8 +1568,8 @@ static PyObject *read_row_from_packet( PyObject *py_str = NULL; PyObject *py_memview = NULL; char end = '\0'; - char *cast_type_codes[] = {"", "f", "d", "b", "h", "i", "q"}; - int item_type_lengths[] = {0, 4, 8, 1, 2, 4, 8}; + char *cast_type_codes[] = {"", "f", "d", "b", "h", "i", "q", "e"}; + int item_type_lengths[] = {0, 4, 8, 1, 2, 4, 8, 2}; int sign = 1; int year = 0; @@ -1826,6 +1829,7 @@ static PyObject *read_row_from_packet( case MYSQL_TYPE_INT16_VECTOR_JSON: case MYSQL_TYPE_INT32_VECTOR_JSON: case MYSQL_TYPE_INT64_VECTOR_JSON: + case MYSQL_TYPE_FLOAT16_VECTOR_JSON: if (!py_state->encodings[i]) { py_item = PyBytes_FromStringAndSize(out, out_l); if (!py_item) goto error; @@ -1847,7 +1851,7 @@ static PyObject *read_row_from_packet( // Parse JSON string. if ((py_state->type_codes[i] == MYSQL_TYPE_JSON && py_state->options.parse_json) || (py_state->type_codes[i] >= MYSQL_TYPE_FLOAT32_VECTOR_JSON - && py_state->type_codes[i] <= MYSQL_TYPE_INT64_VECTOR_JSON)) { + && py_state->type_codes[i] <= MYSQL_TYPE_FLOAT16_VECTOR_JSON)) { py_str = py_item; py_item = PyObject_CallFunctionObjArgs(PyFunc.json_loads, py_str, NULL); Py_CLEAR(py_str); @@ -1862,6 +1866,7 @@ static PyObject *read_row_from_packet( case MYSQL_TYPE_INT16_VECTOR_JSON: case MYSQL_TYPE_INT32_VECTOR_JSON: case MYSQL_TYPE_INT64_VECTOR_JSON: + case MYSQL_TYPE_FLOAT16_VECTOR_JSON: CHECKRC(PyTuple_SetItem(PyObj.create_numpy_array_args, 0, py_item)); py_item = PyObject_Call( PyFunc.numpy_array, @@ -1880,6 +1885,7 @@ static PyObject *read_row_from_packet( case MYSQL_TYPE_INT16_VECTOR: case MYSQL_TYPE_INT32_VECTOR: case MYSQL_TYPE_INT64_VECTOR: + case MYSQL_TYPE_FLOAT16_VECTOR: { int type_idx = py_state->type_codes[i] % 1000; @@ -4844,6 +4850,7 @@ PyMODINIT_FUNC PyInit__singlestoredb_accel(void) { PyStr.int64 = PyUnicode_FromString("int64"); PyStr.float32 = PyUnicode_FromString("float32"); PyStr.float64 = PyUnicode_FromString("float64"); + PyStr.float16 = PyUnicode_FromString("float16"); PyStr.unpack = PyUnicode_FromString("unpack"); PyStr.decode = PyUnicode_FromString("decode"); PyStr.frombuffer = PyUnicode_FromString("frombuffer"); @@ -4921,6 +4928,11 @@ PyMODINIT_FUNC PyInit__singlestoredb_accel(void) { if (PyDict_SetItemString(PyObj.create_numpy_array_kwargs_vector[6], "dtype", PyStr.int64)) { goto error; } + PyObj.create_numpy_array_kwargs_vector[7] = PyDict_New(); + if (!PyObj.create_numpy_array_kwargs_vector[7]) goto error; + if (PyDict_SetItemString(PyObj.create_numpy_array_kwargs_vector[7], "dtype", PyStr.float16)) { + goto error; + } PyObj.struct_unpack_args = PyTuple_New(2); if (!PyObj.struct_unpack_args) goto error; diff --git a/singlestoredb/converters.py b/singlestoredb/converters.py index bd4b4255b..ec9b73580 100644 --- a/singlestoredb/converters.py +++ b/singlestoredb/converters.py @@ -597,6 +597,62 @@ def float32_vector_or_none(x: Optional[bytes]) -> Optional[Any]: return struct.unpack(f'<{len(x)//4}f', x) +def float16_vector_json_or_none(x: Optional[str]) -> Optional[Any]: + """ + Convert value to float16 array. + + Parameters + ---------- + x : str or None + JSON array + + Returns + ------- + float16 numpy array + If input value is not None and numpy is installed + float Python list + If input value is not None and numpy is not installed + None + If input value is None + + """ + if x is None: + return None + + if has_numpy: + return numpy.array(json_loads(x), dtype=numpy.float16) + + return map(float, json_loads(x)) + + +def float16_vector_or_none(x: Optional[bytes]) -> Optional[Any]: + """ + Convert value to float16 array. + + Parameters + ---------- + x : bytes or None + Little-endian block of bytes. + + Returns + ------- + float16 numpy array + If input value is not None and numpy is installed + float Python list + If input value is not None and numpy is not installed + None + If input value is None + + """ + if x is None: + return None + + if has_numpy: + return numpy.frombuffer(x, dtype=numpy.float16) + + return struct.unpack(f'<{len(x)//2}e', x) + + def float64_vector_json_or_none(x: Optional[str]) -> Optional[Any]: """ Covert value to float64 array. @@ -941,10 +997,12 @@ def bson_or_none(x: Optional[bytes]) -> Optional[Any]: 2004: int16_vector_json_or_none, 2005: int32_vector_json_or_none, 2006: int64_vector_json_or_none, + 2007: float16_vector_json_or_none, 3001: float32_vector_or_none, 3002: float64_vector_or_none, 3003: int8_vector_or_none, 3004: int16_vector_or_none, 3005: int32_vector_or_none, 3006: int64_vector_or_none, + 3007: float16_vector_or_none, } diff --git a/singlestoredb/functions/__init__.py b/singlestoredb/functions/__init__.py index dea2968a1..936831752 100644 --- a/singlestoredb/functions/__init__.py +++ b/singlestoredb/functions/__init__.py @@ -8,6 +8,7 @@ from .utils import VectorTypes +F16 = VectorTypes.F16 F32 = VectorTypes.F32 F64 = VectorTypes.F64 I8 = VectorTypes.I8 diff --git a/singlestoredb/functions/signature.py b/singlestoredb/functions/signature.py index 7f058c457..cf2b5d017 100644 --- a/singlestoredb/functions/signature.py +++ b/singlestoredb/functions/signature.py @@ -135,6 +135,7 @@ class NoDefaultType: 'uint16': 'SMALLINT UNSIGNED', 'uint32': 'INT UNSIGNED', 'uint64': 'BIGINT UNSIGNED', + 'float16': 'FLOAT', 'float32': 'FLOAT', 'float64': 'DOUBLE', 'str': 'TEXT', diff --git a/singlestoredb/functions/typing/numpy.py b/singlestoredb/functions/typing/numpy.py index 789bc4efd..a8ea52a8a 100644 --- a/singlestoredb/functions/typing/numpy.py +++ b/singlestoredb/functions/typing/numpy.py @@ -5,6 +5,7 @@ StringArray = StrArray = npt.NDArray[np.str_] BytesArray = npt.NDArray[np.bytes_] +Float16Array = HalfArray = npt.NDArray[np.float16] Float32Array = FloatArray = npt.NDArray[np.float32] Float64Array = DoubleArray = npt.NDArray[np.float64] BoolArray = npt.NDArray[np.bool_] diff --git a/singlestoredb/functions/utils.py b/singlestoredb/functions/utils.py index 6e4aa6295..5b948e2c4 100644 --- a/singlestoredb/functions/utils.py +++ b/singlestoredb/functions/utils.py @@ -198,7 +198,9 @@ def _vector_type_to_numpy_type( vector_type: VectorTypes, ) -> str: """Convert a vector type to a numpy type.""" - if vector_type == VectorTypes.F32: + if vector_type == VectorTypes.F16: + return 'f2' + elif vector_type == VectorTypes.F32: return 'f4' elif vector_type == VectorTypes.F64: return 'f8' @@ -219,7 +221,11 @@ def _vector_type_to_struct_format( ) -> str: """Convert a vector type to a struct format string.""" n = len(vec) - if vector_type == VectorTypes.F32: + if vector_type == VectorTypes.F16: + if isinstance(vec, (bytes, bytearray)): + n = n // 2 + return f'<{n}e' + elif vector_type == VectorTypes.F32: if isinstance(vec, (bytes, bytearray)): n = n // 4 return f'<{n}f' diff --git a/singlestoredb/mysql/connection.py b/singlestoredb/mysql/connection.py index f96ebb0cd..4ff6f37bb 100644 --- a/singlestoredb/mysql/connection.py +++ b/singlestoredb/mysql/connection.py @@ -110,12 +110,14 @@ FIELD_TYPE.INT16_VECTOR_JSON, FIELD_TYPE.INT32_VECTOR_JSON, FIELD_TYPE.INT64_VECTOR_JSON, + FIELD_TYPE.FLOAT16_VECTOR_JSON, FIELD_TYPE.FLOAT32_VECTOR, FIELD_TYPE.FLOAT64_VECTOR, FIELD_TYPE.INT8_VECTOR, FIELD_TYPE.INT16_VECTOR, FIELD_TYPE.INT32_VECTOR, FIELD_TYPE.INT64_VECTOR, + FIELD_TYPE.FLOAT16_VECTOR, } UNSET = 'unset' diff --git a/singlestoredb/mysql/constants/FIELD_TYPE.py b/singlestoredb/mysql/constants/FIELD_TYPE.py index 95c17a4d3..f38953480 100644 --- a/singlestoredb/mysql/constants/FIELD_TYPE.py +++ b/singlestoredb/mysql/constants/FIELD_TYPE.py @@ -40,9 +40,11 @@ INT16_VECTOR_JSON = 2004 INT32_VECTOR_JSON = 2005 INT64_VECTOR_JSON = 2006 +FLOAT16_VECTOR_JSON = 2007 FLOAT32_VECTOR = 3001 FLOAT64_VECTOR = 3002 INT8_VECTOR = 3003 INT16_VECTOR = 3004 INT32_VECTOR = 3005 INT64_VECTOR = 3006 +FLOAT16_VECTOR = 3007 diff --git a/singlestoredb/mysql/constants/VECTOR_TYPE.py b/singlestoredb/mysql/constants/VECTOR_TYPE.py index 2f8b0fa47..65ab714ef 100644 --- a/singlestoredb/mysql/constants/VECTOR_TYPE.py +++ b/singlestoredb/mysql/constants/VECTOR_TYPE.py @@ -4,3 +4,4 @@ INT16 = 4 INT32 = 5 INT64 = 6 +FLOAT16 = 7 diff --git a/singlestoredb/mysql/protocol.py b/singlestoredb/mysql/protocol.py index 378e2bee4..7a7870b22 100644 --- a/singlestoredb/mysql/protocol.py +++ b/singlestoredb/mysql/protocol.py @@ -318,6 +318,11 @@ def _parse_field_descriptor(self, encoding): self.type_code = FIELD_TYPE.INT64_VECTOR else: self.type_code = FIELD_TYPE.INT64_VECTOR_JSON + elif vec_type == VECTOR_TYPE.FLOAT16: + if self.charsetnr == 63: + self.type_code = FIELD_TYPE.FLOAT16_VECTOR + else: + self.type_code = FIELD_TYPE.FLOAT16_VECTOR_JSON else: raise TypeError(f'unrecognized vector data type: {vec_type}') else: diff --git a/singlestoredb/tests/test.sql b/singlestoredb/tests/test.sql index f43259d13..25efc3f02 100644 --- a/singlestoredb/tests/test.sql +++ b/singlestoredb/tests/test.sql @@ -676,6 +676,9 @@ INSERT INTO i64_vectors VALUES(1, '[1, 2, 3]'); INSERT INTO i64_vectors VALUES(2, '[4, 5, 6]'); INSERT INTO i64_vectors VALUES(3, '[-1, -4, 8]'); +-- Float16 vectors require server version 9.1 or later +-- Table creation is handled conditionally in Python code (see utils.py or test setup) + -- -- Boolean test data for UDF testing diff --git a/singlestoredb/tests/test_9_1.sql b/singlestoredb/tests/test_9_1.sql new file mode 100644 index 000000000..7d7ed5988 --- /dev/null +++ b/singlestoredb/tests/test_9_1.sql @@ -0,0 +1,11 @@ +-- Test data for SingleStore 9.1+ features +-- This file is automatically loaded by utils.py if server version >= 9.1 + +-- Float16 (half-precision) vectors +CREATE TABLE IF NOT EXISTS `f16_vectors` ( + id INT(11), + a VECTOR(3, F16) +); +INSERT INTO f16_vectors VALUES(1, '[0.267, 0.535, 0.802]'); +INSERT INTO f16_vectors VALUES(2, '[0.371, 0.557, 0.743]'); +INSERT INTO f16_vectors VALUES(3, '[-0.424, -0.566, 0.707]'); diff --git a/singlestoredb/tests/test_connection.py b/singlestoredb/tests/test_connection.py index a02c4a9f2..7158ca7d7 100755 --- a/singlestoredb/tests/test_connection.py +++ b/singlestoredb/tests/test_connection.py @@ -3097,6 +3097,53 @@ def test_i64_vectors(self): np.array([-1, -4, 8], dtype=np.int64), ) + def test_f16_vectors(self): + if self.conn.driver in ['http', 'https']: + self.skipTest('Data API does not surface vector information') + + # Check server version - float16 requires 9.1 or later + self.cur.execute('select @@memsql_version') + version_str = list(self.cur)[0][0] + # Parse version string like "9.1.2" or "9.1.2-abc123" + version_parts = version_str.split('-')[0].split('.') + major = int(version_parts[0]) + minor = int(version_parts[1]) if len(version_parts) > 1 else 0 + if major < 9 or (major == 9 and minor < 1): + self.skipTest( + f'Float16 vectors require server version 9.1 or later ' + f'(found {version_str})', + ) + + self.cur.execute('show variables like "enable_extended_types_metadata"') + out = list(self.cur) + if not out or out[0][1].lower() == 'off': + self.skipTest('Database engine does not support extended types metadata') + + self.cur.execute('select a from f16_vectors order by id') + out = list(self.cur) + + if hasattr(out[0][0], 'dtype'): + assert out[0][0].dtype is np.dtype('float16') + assert out[1][0].dtype is np.dtype('float16') + assert out[2][0].dtype is np.dtype('float16') + + # Float16 has ~3 decimal digits precision, use lower tolerance + np.testing.assert_array_almost_equal( + out[0][0], + np.array([0.267, 0.535, 0.802], dtype=np.float16), + decimal=2, + ) + np.testing.assert_array_almost_equal( + out[1][0], + np.array([0.371, 0.557, 0.743], dtype=np.float16), + decimal=2, + ) + np.testing.assert_array_almost_equal( + out[2][0], + np.array([-0.424, -0.566, 0.707], dtype=np.float16), + decimal=2, + ) + if __name__ == '__main__': import nose2 diff --git a/singlestoredb/tests/test_udf.py b/singlestoredb/tests/test_udf.py index e806fce0d..ad1545596 100755 --- a/singlestoredb/tests/test_udf.py +++ b/singlestoredb/tests/test_udf.py @@ -321,12 +321,27 @@ def foo(x: np.uint64) -> None: ... def foo(x: float) -> None: ... assert to_sql(foo) == '`foo`(`x` DOUBLE NOT NULL) RETURNS TINYINT NULL' + def foo(x: np.float16) -> None: ... + assert to_sql(foo) == '`foo`(`x` FLOAT NOT NULL) RETURNS TINYINT NULL' + def foo(x: np.float32) -> None: ... assert to_sql(foo) == '`foo`(`x` FLOAT NOT NULL) RETURNS TINYINT NULL' def foo(x: np.float64) -> None: ... assert to_sql(foo) == '`foo`(`x` DOUBLE NOT NULL) RETURNS TINYINT NULL' + # Vector float16 (List) + def foo(x: List[np.float16]) -> List[np.float16]: ... + assert to_sql(foo) == '`foo`(`x` FLOAT NOT NULL) RETURNS FLOAT NOT NULL' + + # Vector float32 (List) + def foo(x: List[np.float32]) -> List[np.float32]: ... + assert to_sql(foo) == '`foo`(`x` FLOAT NOT NULL) RETURNS FLOAT NOT NULL' + + # Vector float64 (List) + def foo(x: List[np.float64]) -> List[np.float64]: ... + assert to_sql(foo) == '`foo`(`x` DOUBLE NOT NULL) RETURNS DOUBLE NOT NULL' + # # Type collapsing # diff --git a/singlestoredb/tests/utils.py b/singlestoredb/tests/utils.py index 54cdad695..c7cba6808 100644 --- a/singlestoredb/tests/utils.py +++ b/singlestoredb/tests/utils.py @@ -1,11 +1,15 @@ #!/usr/bin/env python # type: ignore """Utilities for testing.""" +import glob import logging import os +import re import uuid from typing import Any from typing import Dict +from typing import List +from typing import Tuple from urllib.parse import urlparse import singlestoredb as s2 @@ -23,6 +27,116 @@ def apply_template(content: str, vars: Dict[str, Any]) -> str: return content +def get_server_version(cursor: Any) -> Tuple[int, int]: + """ + Get the server version as a (major, minor) tuple. + + Parameters + ---------- + cursor : Cursor + Database cursor to execute queries + + Returns + ------- + (int, int) + Tuple of (major_version, minor_version) + """ + cursor.execute('SELECT @@memsql_version') + version_str = cursor.fetchone()[0] + # Parse version string like "9.1.2" or "9.1.2-abc123" + version_parts = version_str.split('-')[0].split('.') + major = int(version_parts[0]) + minor = int(version_parts[1]) if len(version_parts) > 1 else 0 + logger.info(f'Detected server version: {major}.{minor} (full: {version_str})') + return (major, minor) + + +def find_version_specific_sql_files(base_dir: str) -> List[Tuple[int, int, str]]: + """ + Find all version-specific SQL files in the given directory. + + Looks for files matching the pattern test_X_Y.sql where X is major + version and Y is minor version. + + Parameters + ---------- + base_dir : str + Directory to search for SQL files + + Returns + ------- + List[Tuple[int, int, str]] + List of (major, minor, filepath) tuples sorted by version + """ + pattern = os.path.join(base_dir, 'test_*_*.sql') + files = [] + + for filepath in glob.glob(pattern): + filename = os.path.basename(filepath) + # Match pattern: test_X_Y.sql + match = re.match(r'test_(\d+)_(\d+)\.sql$', filename) + if match: + major = int(match.group(1)) + minor = int(match.group(2)) + files.append((major, minor, filepath)) + logger.debug( + f'Found version-specific SQL file: {filename} ' + f'(v{major}.{minor})', + ) + + # Sort by version (major, minor) + files.sort() + return files + + +def load_version_specific_sql( + cursor: Any, + base_dir: str, + server_version: Tuple[int, int], + template_vars: Dict[str, Any], +) -> None: + """ + Load version-specific SQL files based on server version. + + Parameters + ---------- + cursor : Cursor + Database cursor to execute queries + base_dir : str + Directory containing SQL files + server_version : Tuple[int, int] + Server version as (major, minor) + template_vars : Dict[str, Any] + Template variables to apply to SQL content + """ + sql_files = find_version_specific_sql_files(base_dir) + server_major, server_minor = server_version + + for file_major, file_minor, filepath in sql_files: + # Load if server version >= file version + if ( + server_major > file_major or + (server_major == file_major and server_minor >= file_minor) + ): + logger.info( + f'Loading version-specific SQL: {os.path.basename(filepath)} ' + f'(requires {file_major}.{file_minor}, ' + f'server is {server_major}.{server_minor})', + ) + with open(filepath, 'r') as sql_file: + for cmd in sql_file.read().split(';\n'): + cmd = apply_template(cmd.strip(), template_vars) + if cmd: + cmd += ';' + cursor.execute(cmd) + else: + logger.info( + f'Skipping version-specific SQL: {os.path.basename(filepath)} ' + f'(requires {file_major}.{file_minor}, ' + f'server is {server_major}.{server_minor})', + ) + + def load_sql(sql_file: str) -> str: """ Load a file containing SQL code. @@ -111,6 +225,21 @@ def load_sql(sql_file: str) -> str: cur.execute('SET GLOBAL HTTP_API=ON;') cur.execute('RESTART PROXY;') + # Load version-specific SQL files (e.g., test_9_1.sql for 9.1+) + try: + server_version = get_server_version(cur) + sql_dir = os.path.dirname(sql_file) + load_version_specific_sql( + cur, + sql_dir, + server_version, + template_vars, + ) + except Exception as e: + logger.warning( + f'Failed to load version-specific SQL files: {e}', + ) + return dbname, dbexisted