Skip to content

Commit 68b9ae8

Browse files
authored
Add bfloat to ndarray_import conversion code (#1228)
1 parent dbe8a3c commit 68b9ae8

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

src/nb_ndarray.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -704,6 +704,9 @@ ndarray_handle *ndarray_import(PyObject *src, const ndarray_config *c,
704704
case (uint8_t) dlpack::dtype_code::Float:
705705
prefix = "float";
706706
break;
707+
case (uint8_t) dlpack::dtype_code::Bfloat:
708+
prefix = "bfloat";
709+
break;
707710
case (uint8_t) dlpack::dtype_code::Complex:
708711
prefix = "complex";
709712
break;

0 commit comments

Comments
 (0)