Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions accel.c
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,14 @@
#define MYSQL_TYPE_INT16_VECTOR_JSON 2004
#define MYSQL_TYPE_INT32_VECTOR_JSON 2005
#define MYSQL_TYPE_INT64_VECTOR_JSON 2006
#define MYSQL_TYPE_FLOAT16_VECTOR_JSON 2007
#define MYSQL_TYPE_FLOAT32_VECTOR 3001
#define MYSQL_TYPE_FLOAT64_VECTOR 3002
#define MYSQL_TYPE_INT8_VECTOR 3003
#define MYSQL_TYPE_INT16_VECTOR 3004
#define MYSQL_TYPE_INT32_VECTOR 3005
#define MYSQL_TYPE_INT64_VECTOR 3006
#define MYSQL_TYPE_FLOAT16_VECTOR 3007

#define MYSQL_TYPE_CHAR MYSQL_TYPE_TINY
#define MYSQL_TYPE_INTERVAL MYSQL_TYPE_ENUM
Expand Down Expand Up @@ -503,6 +505,7 @@ typedef struct {
PyObject *int64;
PyObject *float32;
PyObject *float64;
PyObject *float16;
PyObject *unpack;
PyObject *decode;
PyObject *frombuffer;
Expand Down Expand Up @@ -541,7 +544,7 @@ typedef struct {
PyObject *namedtuple_kwargs;
PyObject *create_numpy_array_args;
PyObject *create_numpy_array_kwargs;
PyObject *create_numpy_array_kwargs_vector[7];
PyObject *create_numpy_array_kwargs_vector[8];
PyObject *struct_unpack_args;
PyObject *bson_decode_args;
} PyObjects;
Expand Down Expand Up @@ -1565,8 +1568,8 @@ static PyObject *read_row_from_packet(
PyObject *py_str = NULL;
PyObject *py_memview = NULL;
char end = '\0';
char *cast_type_codes[] = {"", "f", "d", "b", "h", "i", "q"};
int item_type_lengths[] = {0, 4, 8, 1, 2, 4, 8};
char *cast_type_codes[] = {"", "f", "d", "b", "h", "i", "q", "e"};
int item_type_lengths[] = {0, 4, 8, 1, 2, 4, 8, 2};

int sign = 1;
int year = 0;
Expand Down Expand Up @@ -1826,6 +1829,7 @@ static PyObject *read_row_from_packet(
case MYSQL_TYPE_INT16_VECTOR_JSON:
case MYSQL_TYPE_INT32_VECTOR_JSON:
case MYSQL_TYPE_INT64_VECTOR_JSON:
case MYSQL_TYPE_FLOAT16_VECTOR_JSON:
if (!py_state->encodings[i]) {
py_item = PyBytes_FromStringAndSize(out, out_l);
if (!py_item) goto error;
Expand All @@ -1847,7 +1851,7 @@ static PyObject *read_row_from_packet(
// Parse JSON string.
if ((py_state->type_codes[i] == MYSQL_TYPE_JSON && py_state->options.parse_json)
|| (py_state->type_codes[i] >= MYSQL_TYPE_FLOAT32_VECTOR_JSON
&& py_state->type_codes[i] <= MYSQL_TYPE_INT64_VECTOR_JSON)) {
&& py_state->type_codes[i] <= MYSQL_TYPE_FLOAT16_VECTOR_JSON)) {
py_str = py_item;
py_item = PyObject_CallFunctionObjArgs(PyFunc.json_loads, py_str, NULL);
Py_CLEAR(py_str);
Expand All @@ -1862,6 +1866,7 @@ static PyObject *read_row_from_packet(
case MYSQL_TYPE_INT16_VECTOR_JSON:
case MYSQL_TYPE_INT32_VECTOR_JSON:
case MYSQL_TYPE_INT64_VECTOR_JSON:
case MYSQL_TYPE_FLOAT16_VECTOR_JSON:
CHECKRC(PyTuple_SetItem(PyObj.create_numpy_array_args, 0, py_item));
py_item = PyObject_Call(
PyFunc.numpy_array,
Expand All @@ -1880,6 +1885,7 @@ static PyObject *read_row_from_packet(
case MYSQL_TYPE_INT16_VECTOR:
case MYSQL_TYPE_INT32_VECTOR:
case MYSQL_TYPE_INT64_VECTOR:
case MYSQL_TYPE_FLOAT16_VECTOR:
{
int type_idx = py_state->type_codes[i] % 1000;

Expand Down Expand Up @@ -4844,6 +4850,7 @@ PyMODINIT_FUNC PyInit__singlestoredb_accel(void) {
PyStr.int64 = PyUnicode_FromString("int64");
PyStr.float32 = PyUnicode_FromString("float32");
PyStr.float64 = PyUnicode_FromString("float64");
PyStr.float16 = PyUnicode_FromString("float16");
PyStr.unpack = PyUnicode_FromString("unpack");
PyStr.decode = PyUnicode_FromString("decode");
PyStr.frombuffer = PyUnicode_FromString("frombuffer");
Expand Down Expand Up @@ -4921,6 +4928,11 @@ PyMODINIT_FUNC PyInit__singlestoredb_accel(void) {
if (PyDict_SetItemString(PyObj.create_numpy_array_kwargs_vector[6], "dtype", PyStr.int64)) {
goto error;
}
PyObj.create_numpy_array_kwargs_vector[7] = PyDict_New();
if (!PyObj.create_numpy_array_kwargs_vector[7]) goto error;
if (PyDict_SetItemString(PyObj.create_numpy_array_kwargs_vector[7], "dtype", PyStr.float16)) {
goto error;
}

PyObj.struct_unpack_args = PyTuple_New(2);
if (!PyObj.struct_unpack_args) goto error;
Expand Down
58 changes: 58 additions & 0 deletions singlestoredb/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,62 @@ def float32_vector_or_none(x: Optional[bytes]) -> Optional[Any]:
return struct.unpack(f'<{len(x)//4}f', x)


def float16_vector_json_or_none(x: Optional[str]) -> Optional[Any]:
"""
Convert value to float16 array.

Parameters
----------
x : str or None
JSON array

Returns
-------
float16 numpy array
If input value is not None and numpy is installed
float Python list
If input value is not None and numpy is not installed
None
If input value is None

"""
if x is None:
return None

if has_numpy:
return numpy.array(json_loads(x), dtype=numpy.float16)

return map(float, json_loads(x))


def float16_vector_or_none(x: Optional[bytes]) -> Optional[Any]:
"""
Convert value to float16 array.

Parameters
----------
x : bytes or None
Little-endian block of bytes.

Returns
-------
float16 numpy array
If input value is not None and numpy is installed
float Python list
If input value is not None and numpy is not installed
None
If input value is None

"""
if x is None:
return None

if has_numpy:
return numpy.frombuffer(x, dtype=numpy.float16)

return struct.unpack(f'<{len(x)//2}e', x)


def float64_vector_json_or_none(x: Optional[str]) -> Optional[Any]:
"""
Covert value to float64 array.
Expand Down Expand Up @@ -941,10 +997,12 @@ def bson_or_none(x: Optional[bytes]) -> Optional[Any]:
2004: int16_vector_json_or_none,
2005: int32_vector_json_or_none,
2006: int64_vector_json_or_none,
2007: float16_vector_json_or_none,
3001: float32_vector_or_none,
3002: float64_vector_or_none,
3003: int8_vector_or_none,
3004: int16_vector_or_none,
3005: int32_vector_or_none,
3006: int64_vector_or_none,
3007: float16_vector_or_none,
}
1 change: 1 addition & 0 deletions singlestoredb/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .utils import VectorTypes


F16 = VectorTypes.F16
F32 = VectorTypes.F32
F64 = VectorTypes.F64
I8 = VectorTypes.I8
Expand Down
1 change: 1 addition & 0 deletions singlestoredb/functions/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ class NoDefaultType:
'uint16': 'SMALLINT UNSIGNED',
'uint32': 'INT UNSIGNED',
'uint64': 'BIGINT UNSIGNED',
'float16': 'FLOAT',
'float32': 'FLOAT',
'float64': 'DOUBLE',
'str': 'TEXT',
Expand Down
1 change: 1 addition & 0 deletions singlestoredb/functions/typing/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

StringArray = StrArray = npt.NDArray[np.str_]
BytesArray = npt.NDArray[np.bytes_]
Float16Array = HalfArray = npt.NDArray[np.float16]
Float32Array = FloatArray = npt.NDArray[np.float32]
Float64Array = DoubleArray = npt.NDArray[np.float64]
BoolArray = npt.NDArray[np.bool_]
Expand Down
10 changes: 8 additions & 2 deletions singlestoredb/functions/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,9 @@ def _vector_type_to_numpy_type(
vector_type: VectorTypes,
) -> str:
"""Convert a vector type to a numpy type."""
if vector_type == VectorTypes.F32:
if vector_type == VectorTypes.F16:
return 'f2'
elif vector_type == VectorTypes.F32:
return 'f4'
elif vector_type == VectorTypes.F64:
return 'f8'
Expand All @@ -219,7 +221,11 @@ def _vector_type_to_struct_format(
) -> str:
"""Convert a vector type to a struct format string."""
n = len(vec)
if vector_type == VectorTypes.F32:
if vector_type == VectorTypes.F16:
if isinstance(vec, (bytes, bytearray)):
n = n // 2
return f'<{n}e'
elif vector_type == VectorTypes.F32:
if isinstance(vec, (bytes, bytearray)):
n = n // 4
return f'<{n}f'
Expand Down
2 changes: 2 additions & 0 deletions singlestoredb/mysql/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,14 @@
FIELD_TYPE.INT16_VECTOR_JSON,
FIELD_TYPE.INT32_VECTOR_JSON,
FIELD_TYPE.INT64_VECTOR_JSON,
FIELD_TYPE.FLOAT16_VECTOR_JSON,
FIELD_TYPE.FLOAT32_VECTOR,
FIELD_TYPE.FLOAT64_VECTOR,
FIELD_TYPE.INT8_VECTOR,
FIELD_TYPE.INT16_VECTOR,
FIELD_TYPE.INT32_VECTOR,
FIELD_TYPE.INT64_VECTOR,
FIELD_TYPE.FLOAT16_VECTOR,
}

UNSET = 'unset'
Expand Down
2 changes: 2 additions & 0 deletions singlestoredb/mysql/constants/FIELD_TYPE.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,11 @@
INT16_VECTOR_JSON = 2004
INT32_VECTOR_JSON = 2005
INT64_VECTOR_JSON = 2006
FLOAT16_VECTOR_JSON = 2007
FLOAT32_VECTOR = 3001
FLOAT64_VECTOR = 3002
INT8_VECTOR = 3003
INT16_VECTOR = 3004
INT32_VECTOR = 3005
INT64_VECTOR = 3006
FLOAT16_VECTOR = 3007
1 change: 1 addition & 0 deletions singlestoredb/mysql/constants/VECTOR_TYPE.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
INT16 = 4
INT32 = 5
INT64 = 6
FLOAT16 = 7
5 changes: 5 additions & 0 deletions singlestoredb/mysql/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,11 @@ def _parse_field_descriptor(self, encoding):
self.type_code = FIELD_TYPE.INT64_VECTOR
else:
self.type_code = FIELD_TYPE.INT64_VECTOR_JSON
elif vec_type == VECTOR_TYPE.FLOAT16:
if self.charsetnr == 63:
self.type_code = FIELD_TYPE.FLOAT16_VECTOR
else:
self.type_code = FIELD_TYPE.FLOAT16_VECTOR_JSON
else:
raise TypeError(f'unrecognized vector data type: {vec_type}')
else:
Expand Down
3 changes: 3 additions & 0 deletions singlestoredb/tests/test.sql
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,9 @@ INSERT INTO i64_vectors VALUES(1, '[1, 2, 3]');
INSERT INTO i64_vectors VALUES(2, '[4, 5, 6]');
INSERT INTO i64_vectors VALUES(3, '[-1, -4, 8]');

-- Float16 vectors require server version 9.1 or later
-- Table creation is handled conditionally in Python code (see utils.py or test setup)


--
-- Boolean test data for UDF testing
Expand Down
11 changes: 11 additions & 0 deletions singlestoredb/tests/test_9_1.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
-- Test data for SingleStore 9.1+ features
-- This file is automatically loaded by utils.py if server version >= 9.1

-- Float16 (half-precision) vectors
CREATE TABLE IF NOT EXISTS `f16_vectors` (
id INT(11),
a VECTOR(3, F16)
);
INSERT INTO f16_vectors VALUES(1, '[0.267, 0.535, 0.802]');
INSERT INTO f16_vectors VALUES(2, '[0.371, 0.557, 0.743]');
INSERT INTO f16_vectors VALUES(3, '[-0.424, -0.566, 0.707]');
47 changes: 47 additions & 0 deletions singlestoredb/tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3097,6 +3097,53 @@ def test_i64_vectors(self):
np.array([-1, -4, 8], dtype=np.int64),
)

def test_f16_vectors(self):
if self.conn.driver in ['http', 'https']:
self.skipTest('Data API does not surface vector information')

# Check server version - float16 requires 9.1 or later
self.cur.execute('select @@memsql_version')
version_str = list(self.cur)[0][0]
# Parse version string like "9.1.2" or "9.1.2-abc123"
version_parts = version_str.split('-')[0].split('.')
major = int(version_parts[0])
minor = int(version_parts[1]) if len(version_parts) > 1 else 0
if major < 9 or (major == 9 and minor < 1):
self.skipTest(
f'Float16 vectors require server version 9.1 or later '
f'(found {version_str})',
)

self.cur.execute('show variables like "enable_extended_types_metadata"')
out = list(self.cur)
if not out or out[0][1].lower() == 'off':
self.skipTest('Database engine does not support extended types metadata')

self.cur.execute('select a from f16_vectors order by id')
out = list(self.cur)

if hasattr(out[0][0], 'dtype'):
assert out[0][0].dtype is np.dtype('float16')
assert out[1][0].dtype is np.dtype('float16')
assert out[2][0].dtype is np.dtype('float16')

# Float16 has ~3 decimal digits precision, use lower tolerance
np.testing.assert_array_almost_equal(
out[0][0],
np.array([0.267, 0.535, 0.802], dtype=np.float16),
decimal=2,
)
np.testing.assert_array_almost_equal(
out[1][0],
np.array([0.371, 0.557, 0.743], dtype=np.float16),
decimal=2,
)
np.testing.assert_array_almost_equal(
out[2][0],
np.array([-0.424, -0.566, 0.707], dtype=np.float16),
decimal=2,
)


if __name__ == '__main__':
import nose2
Expand Down
15 changes: 15 additions & 0 deletions singlestoredb/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,12 +321,27 @@ def foo(x: np.uint64) -> None: ...
def foo(x: float) -> None: ...
assert to_sql(foo) == '`foo`(`x` DOUBLE NOT NULL) RETURNS TINYINT NULL'

def foo(x: np.float16) -> None: ...
assert to_sql(foo) == '`foo`(`x` FLOAT NOT NULL) RETURNS TINYINT NULL'

def foo(x: np.float32) -> None: ...
assert to_sql(foo) == '`foo`(`x` FLOAT NOT NULL) RETURNS TINYINT NULL'

def foo(x: np.float64) -> None: ...
assert to_sql(foo) == '`foo`(`x` DOUBLE NOT NULL) RETURNS TINYINT NULL'

# Vector float16 (List)
def foo(x: List[np.float16]) -> List[np.float16]: ...
assert to_sql(foo) == '`foo`(`x` FLOAT NOT NULL) RETURNS FLOAT NOT NULL'

# Vector float32 (List)
def foo(x: List[np.float32]) -> List[np.float32]: ...
assert to_sql(foo) == '`foo`(`x` FLOAT NOT NULL) RETURNS FLOAT NOT NULL'

# Vector float64 (List)
def foo(x: List[np.float64]) -> List[np.float64]: ...
assert to_sql(foo) == '`foo`(`x` DOUBLE NOT NULL) RETURNS DOUBLE NOT NULL'

#
# Type collapsing
#
Expand Down
Loading
Loading