|
1 |
| -// Copyright 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 1 | +// Copyright 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
2 | 2 | //
|
3 | 3 | // Redistribution and use in source and binary forms, with or without
|
4 | 4 | // modification, are permitted provided that the following conditions
|
@@ -152,7 +152,10 @@ PbTensor::PbTensor(
|
152 | 152 | #ifdef TRITON_PB_STUB
|
153 | 153 | if (memory_type_ == TRITONSERVER_MEMORY_CPU ||
|
154 | 154 | memory_type_ == TRITONSERVER_MEMORY_CPU_PINNED) {
|
155 |
| - if (dtype != TRITONSERVER_TYPE_BYTES) { |
| 155 | + if (dtype == TRITONSERVER_TYPE_BF16) { |
| 156 | + // No native numpy representation for BF16. DLPack should be used instead. |
| 157 | + numpy_array_ = py::none(); |
| 158 | + } else if (dtype != TRITONSERVER_TYPE_BYTES) { |
156 | 159 | py::object numpy_array =
|
157 | 160 | py::array(triton_to_pybind_dtype(dtype_), dims_, (void*)memory_ptr_);
|
158 | 161 | numpy_array_ = numpy_array.attr("view")(triton_to_numpy_type(dtype_));
|
@@ -512,12 +515,18 @@ PbTensor::Name() const
|
512 | 515 | const py::array*
|
513 | 516 | PbTensor::AsNumpy() const
|
514 | 517 | {
|
515 |
| - if (IsCPU()) { |
516 |
| - return &numpy_array_; |
517 |
| - } else { |
| 518 | + if (!IsCPU()) { |
518 | 519 | throw PythonBackendException(
|
519 | 520 | "Tensor is stored in GPU and cannot be converted to NumPy.");
|
520 | 521 | }
|
| 522 | + |
| 523 | + if (dtype_ == TRITONSERVER_TYPE_BF16) { |
| 524 | + throw PythonBackendException( |
| 525 | + "Tensor dtype is BF16 and cannot be converted to NumPy. Use " |
| 526 | + "to_dlpack() and from_dlpack() instead."); |
| 527 | + } |
| 528 | + |
| 529 | + return &numpy_array_; |
521 | 530 | }
|
522 | 531 | #endif // TRITON_PB_STUB
|
523 | 532 |
|
@@ -643,7 +652,10 @@ PbTensor::PbTensor(
|
643 | 652 | #ifdef TRITON_PB_STUB
|
644 | 653 | if (memory_type_ == TRITONSERVER_MEMORY_CPU ||
|
645 | 654 | memory_type_ == TRITONSERVER_MEMORY_CPU_PINNED) {
|
646 |
| - if (dtype_ != TRITONSERVER_TYPE_BYTES) { |
| 655 | + if (dtype_ == TRITONSERVER_TYPE_BF16) { |
| 656 | + // No native numpy representation for BF16. DLPack should be used instead. |
| 657 | + numpy_array_ = py::none(); |
| 658 | + } else if (dtype_ != TRITONSERVER_TYPE_BYTES) { |
647 | 659 | py::object numpy_array =
|
648 | 660 | py::array(triton_to_pybind_dtype(dtype_), dims_, (void*)memory_ptr_);
|
649 | 661 | numpy_array_ = numpy_array.attr("view")(triton_to_numpy_type(dtype_));
|
|
0 commit comments