Skip to content

Commit 2b12abe

Browse files
authored
feat: Add BF16 tensor support via dlpack (#371)
1 parent c8b188f commit 2b12abe

File tree

3 files changed

+38
-7
lines changed

3 files changed

+38
-7
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1557,6 +1557,10 @@ input0 = pb_utils.Tensor.from_dlpack("INPUT0", pytorch_tensor)
15571557
This method only supports contiguous Tensors that are in C-order. If the tensor
15581558
is not C-order contiguous an exception will be raised.
15591559

1560+
For python models with input or output tensors of type BFloat16 (BF16), the
1561+
`as_numpy()` method is not supported, and the `from_dlpack` and `to_dlpack`
1562+
methods must be used instead.
1563+
15601564
## `pb_utils.Tensor.is_cpu() -> bool`
15611565

15621566
This function can be used to check whether a tensor is placed in CPU or not.

src/pb_stub_utils.cc

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Redistribution and use in source and binary forms, with or without
44
// modification, are permitted provided that the following conditions
@@ -168,6 +168,8 @@ triton_to_pybind_dtype(TRITONSERVER_DataType data_type)
168168
dtype_numpy = py::dtype(py::format_descriptor<uint8_t>::format());
169169
break;
170170
case TRITONSERVER_TYPE_BF16:
171+
// NOTE: Currently skipping this call via `if (BF16)` check, but may
172+
// want to better handle this or set some default/invalid dtype.
171173
throw PythonBackendException("TYPE_BF16 not currently supported.");
172174
case TRITONSERVER_TYPE_INVALID:
173175
throw PythonBackendException("Dtype is invalid.");
@@ -240,6 +242,10 @@ triton_to_dlpack_type(TRITONSERVER_DataType triton_dtype)
240242
case TRITONSERVER_TYPE_BYTES:
241243
throw PythonBackendException(
242244
"TYPE_BYTES tensors cannot be converted to DLPack.");
245+
case TRITONSERVER_TYPE_BF16:
246+
dl_code = DLDataTypeCode::kDLBfloat;
247+
dt_size = 16;
248+
break;
243249

244250
default:
245251
throw PythonBackendException(
@@ -301,6 +307,15 @@ dlpack_to_triton_type(const DLDataType& data_type)
301307
}
302308
}
303309

310+
if (data_type.code == DLDataTypeCode::kDLBfloat) {
311+
if (data_type.bits != 16) {
312+
throw PythonBackendException(
313+
"Expected BF16 tensor to have 16 bits, but had: " +
314+
std::to_string(data_type.bits));
315+
}
316+
return TRITONSERVER_TYPE_BF16;
317+
}
318+
304319
return TRITONSERVER_TYPE_INVALID;
305320
}
306321
}}} // namespace triton::backend::python

src/pb_tensor.cc

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Redistribution and use in source and binary forms, with or without
44
// modification, are permitted provided that the following conditions
@@ -152,7 +152,10 @@ PbTensor::PbTensor(
152152
#ifdef TRITON_PB_STUB
153153
if (memory_type_ == TRITONSERVER_MEMORY_CPU ||
154154
memory_type_ == TRITONSERVER_MEMORY_CPU_PINNED) {
155-
if (dtype != TRITONSERVER_TYPE_BYTES) {
155+
if (dtype == TRITONSERVER_TYPE_BF16) {
156+
// No native numpy representation for BF16. DLPack should be used instead.
157+
numpy_array_ = py::none();
158+
} else if (dtype != TRITONSERVER_TYPE_BYTES) {
156159
py::object numpy_array =
157160
py::array(triton_to_pybind_dtype(dtype_), dims_, (void*)memory_ptr_);
158161
numpy_array_ = numpy_array.attr("view")(triton_to_numpy_type(dtype_));
@@ -512,12 +515,18 @@ PbTensor::Name() const
512515
const py::array*
513516
PbTensor::AsNumpy() const
514517
{
515-
if (IsCPU()) {
516-
return &numpy_array_;
517-
} else {
518+
if (!IsCPU()) {
518519
throw PythonBackendException(
519520
"Tensor is stored in GPU and cannot be converted to NumPy.");
520521
}
522+
523+
if (dtype_ == TRITONSERVER_TYPE_BF16) {
524+
throw PythonBackendException(
525+
"Tensor dtype is BF16 and cannot be converted to NumPy. Use "
526+
"to_dlpack() and from_dlpack() instead.");
527+
}
528+
529+
return &numpy_array_;
521530
}
522531
#endif // TRITON_PB_STUB
523532

@@ -643,7 +652,10 @@ PbTensor::PbTensor(
643652
#ifdef TRITON_PB_STUB
644653
if (memory_type_ == TRITONSERVER_MEMORY_CPU ||
645654
memory_type_ == TRITONSERVER_MEMORY_CPU_PINNED) {
646-
if (dtype_ != TRITONSERVER_TYPE_BYTES) {
655+
if (dtype_ == TRITONSERVER_TYPE_BF16) {
656+
// No native numpy representation for BF16. DLPack should be used instead.
657+
numpy_array_ = py::none();
658+
} else if (dtype_ != TRITONSERVER_TYPE_BYTES) {
647659
py::object numpy_array =
648660
py::array(triton_to_pybind_dtype(dtype_), dims_, (void*)memory_ptr_);
649661
numpy_array_ = numpy_array.attr("view")(triton_to_numpy_type(dtype_));

0 commit comments

Comments
 (0)