Skip to content

Commit dceb95a

Browse files
committed
Add proof of concept BF16 support via dlpack
1 parent c8b188f commit dceb95a

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

src/pb_stub_utils.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,8 @@ triton_to_pybind_dtype(TRITONSERVER_DataType data_type)
168168
dtype_numpy = py::dtype(py::format_descriptor<uint8_t>::format());
169169
break;
170170
case TRITONSERVER_TYPE_BF16:
171+
// Currently skipping this call via `if (BF16)` check, but probably
172+
// need to handle this or set some default/invalid dtype.
171173
throw PythonBackendException("TYPE_BF16 not currently supported.");
172174
case TRITONSERVER_TYPE_INVALID:
173175
throw PythonBackendException("Dtype is invalid.");
@@ -240,6 +242,10 @@ triton_to_dlpack_type(TRITONSERVER_DataType triton_dtype)
240242
case TRITONSERVER_TYPE_BYTES:
241243
throw PythonBackendException(
242244
"TYPE_BYTES tensors cannot be converted to DLPack.");
245+
case TRITONSERVER_TYPE_BF16:
246+
dl_code = DLDataTypeCode::kDLBfloat;
247+
dt_size = 16;
248+
break;
243249

244250
default:
245251
throw PythonBackendException(
@@ -301,6 +307,15 @@ dlpack_to_triton_type(const DLDataType& data_type)
301307
}
302308
}
303309

310+
if (data_type.code == DLDataTypeCode::kDLBfloat) {
311+
if (data_type.bits != 16) {
312+
throw PythonBackendException(
313+
"Expected BF16 tensor to have 16 bits, but had: " +
314+
std::to_string(data_type.bits));
315+
}
316+
return TRITONSERVER_TYPE_BF16;
317+
}
318+
304319
return TRITONSERVER_TYPE_INVALID;
305320
}
306321
}}} // namespace triton::backend::python

src/pb_tensor.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,10 @@ PbTensor::PbTensor(
152152
#ifdef TRITON_PB_STUB
153153
if (memory_type_ == TRITONSERVER_MEMORY_CPU ||
154154
memory_type_ == TRITONSERVER_MEMORY_CPU_PINNED) {
155-
if (dtype != TRITONSERVER_TYPE_BYTES) {
155+
if (dtype == TRITONSERVER_TYPE_BF16) {
156+
// No native numpy representation for BF16. DLPack should be used instead.
157+
numpy_array_ = py::none();
158+
} else if (dtype != TRITONSERVER_TYPE_BYTES) {
156159
py::object numpy_array =
157160
py::array(triton_to_pybind_dtype(dtype_), dims_, (void*)memory_ptr_);
158161
numpy_array_ = numpy_array.attr("view")(triton_to_numpy_type(dtype_));
@@ -643,7 +646,10 @@ PbTensor::PbTensor(
643646
#ifdef TRITON_PB_STUB
644647
if (memory_type_ == TRITONSERVER_MEMORY_CPU ||
645648
memory_type_ == TRITONSERVER_MEMORY_CPU_PINNED) {
646-
if (dtype_ != TRITONSERVER_TYPE_BYTES) {
649+
if (dtype_ == TRITONSERVER_TYPE_BF16) {
650+
// No native numpy representation for BF16. DLPack should be used instead.
651+
numpy_array_ = py::none();
652+
} else if (dtype_ != TRITONSERVER_TYPE_BYTES) {
647653
py::object numpy_array =
648654
py::array(triton_to_pybind_dtype(dtype_), dims_, (void*)memory_ptr_);
649655
numpy_array_ = numpy_array.attr("view")(triton_to_numpy_type(dtype_));

0 commit comments

Comments
 (0)