11import hashlib
22from typing import Any , Dict , List , Optional
33
4- import numpy as np
5- from ml_dtypes import bfloat16
4+ from redisvl .utils .utils import lazy_import
5+
6+ # Lazy import numpy
7+ np = lazy_import ("numpy" )
8+ # Lazy import ml_dtypes
9+ bfloat16 = lazy_import ("ml_dtypes.bfloat16" )
610
711from redisvl .schema .fields import VectorDataType
812
@@ -41,6 +45,14 @@ def array_to_buffer(array: List[float], dtype: str) -> bytes:
4145 raise ValueError (
4246 f"Invalid data type: { dtype } . Supported types are: { [t .lower () for t in VectorDataType ]} "
4347 )
48+
49+ # Special handling for bfloat16 which requires explicit import from ml_dtypes
50+ if dtype .lower () == "bfloat16" :
51+ # Import ml_dtypes.bfloat16 directly to ensure it's registered with numpy
52+ from ml_dtypes import bfloat16 as bf16
53+
54+ return np .array (array , dtype = bf16 ).tobytes ()
55+
4456 return np .array (array , dtype = dtype .lower ()).tobytes ()
4557
4658
@@ -52,6 +64,14 @@ def buffer_to_array(buffer: bytes, dtype: str) -> List[Any]:
5264 raise ValueError (
5365 f"Invalid data type: { dtype } . Supported types are: { [t .lower () for t in VectorDataType ]} "
5466 )
67+
68+ # Special handling for bfloat16 which requires explicit import from ml_dtypes
69+ if dtype .lower () == "bfloat16" :
70+ # Import ml_dtypes.bfloat16 directly to ensure it's registered with numpy
71+ from ml_dtypes import bfloat16 as bf16
72+
73+ return np .frombuffer (buffer , dtype = bf16 ).tolist () # type: ignore[return-value]
74+
5575 return np .frombuffer (buffer , dtype = dtype .lower ()).tolist () # type: ignore[return-value]
5676
5777
0 commit comments