11import hashlib
22from typing import Any , Dict , List , Optional
33
4- import numpy as np
5- from ml_dtypes import bfloat16
4+ from redisvl .utils .utils import lazy_import
5+
6+ # Lazy import numpy
7+ np = lazy_import ("numpy" )
68
79from redisvl .schema .fields import VectorDataType
810
@@ -41,6 +43,13 @@ def array_to_buffer(array: List[float], dtype: str) -> bytes:
4143 raise ValueError (
4244 f"Invalid data type: { dtype } . Supported types are: { [t .lower () for t in VectorDataType ]} "
4345 )
46+
47+ # Special handling for bfloat16 which requires explicit import from ml_dtypes
48+ if dtype .lower () == "bfloat16" :
49+ from ml_dtypes import bfloat16
50+
51+ return np .array (array , dtype = bfloat16 ).tobytes ()
52+
4453 return np .array (array , dtype = dtype .lower ()).tobytes ()
4554
4655
@@ -52,6 +61,14 @@ def buffer_to_array(buffer: bytes, dtype: str) -> List[Any]:
5261 raise ValueError (
5362 f"Invalid data type: { dtype } . Supported types are: { [t .lower () for t in VectorDataType ]} "
5463 )
64+
65+ # Special handling for bfloat16 which requires explicit import from ml_dtypes
66+ # because otherwise the (lazily imported) numpy is unaware of the type
67+ if dtype .lower () == "bfloat16" :
68+ from ml_dtypes import bfloat16
69+
70+ return np .frombuffer (buffer , dtype = bfloat16 ).tolist () # type: ignore[return-value]
71+
5572 return np .frombuffer (buffer , dtype = dtype .lower ()).tolist () # type: ignore[return-value]
5673
5774
0 commit comments