Skip to content

Commit 90c1ae5

Browse files
committed
Throw exception when calling as_numpy() on a BF16 tensor
1 parent dceb95a commit 90c1ae5

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

src/pb_tensor.cc

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -515,12 +515,18 @@ PbTensor::Name() const
515515
const py::array*
516516
PbTensor::AsNumpy() const
517517
{
518-
if (IsCPU()) {
519-
return &numpy_array_;
520-
} else {
518+
if (!IsCPU()) {
521519
throw PythonBackendException(
522520
"Tensor is stored in GPU and cannot be converted to NumPy.");
523521
}
522+
523+
if (dtype_ == TRITONSERVER_TYPE_BF16) {
524+
throw PythonBackendException(
525+
"Tensor dtype is BF16 and cannot be converted to NumPy. Use "
526+
"to_dlpack() and from_dlpack() instead.");
527+
}
528+
529+
return &numpy_array_;
524530
}
525531
#endif // TRITON_PB_STUB
526532

0 commit comments

Comments
 (0)