|
1 | | -// Copyright 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 1 | +// Copyright 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
2 | 2 | // |
3 | 3 | // Redistribution and use in source and binary forms, with or without |
4 | 4 | // modification, are permitted provided that the following conditions |
@@ -152,7 +152,10 @@ PbTensor::PbTensor( |
152 | 152 | #ifdef TRITON_PB_STUB |
153 | 153 | if (memory_type_ == TRITONSERVER_MEMORY_CPU || |
154 | 154 | memory_type_ == TRITONSERVER_MEMORY_CPU_PINNED) { |
155 | | - if (dtype != TRITONSERVER_TYPE_BYTES) { |
| 155 | + if (dtype == TRITONSERVER_TYPE_BF16) { |
| 156 | + // No native numpy representation for BF16. DLPack should be used instead. |
| 157 | + numpy_array_ = py::none(); |
| 158 | + } else if (dtype != TRITONSERVER_TYPE_BYTES) { |
156 | 159 | py::object numpy_array = |
157 | 160 | py::array(triton_to_pybind_dtype(dtype_), dims_, (void*)memory_ptr_); |
158 | 161 | numpy_array_ = numpy_array.attr("view")(triton_to_numpy_type(dtype_)); |
@@ -512,12 +515,18 @@ PbTensor::Name() const |
512 | 515 | const py::array* |
513 | 516 | PbTensor::AsNumpy() const |
514 | 517 | { |
515 | | - if (IsCPU()) { |
516 | | - return &numpy_array_; |
517 | | - } else { |
| 518 | + if (!IsCPU()) { |
518 | 519 | throw PythonBackendException( |
519 | 520 | "Tensor is stored in GPU and cannot be converted to NumPy."); |
520 | 521 | } |
| 522 | + |
| 523 | + if (dtype_ == TRITONSERVER_TYPE_BF16) { |
| 524 | + throw PythonBackendException( |
| 525 | + "Tensor dtype is BF16 and cannot be converted to NumPy. Use " |
| 526 | + "to_dlpack() and from_dlpack() instead."); |
| 527 | + } |
| 528 | + |
| 529 | + return &numpy_array_; |
521 | 530 | } |
522 | 531 | #endif // TRITON_PB_STUB |
523 | 532 |
|
@@ -643,7 +652,10 @@ PbTensor::PbTensor( |
643 | 652 | #ifdef TRITON_PB_STUB |
644 | 653 | if (memory_type_ == TRITONSERVER_MEMORY_CPU || |
645 | 654 | memory_type_ == TRITONSERVER_MEMORY_CPU_PINNED) { |
646 | | - if (dtype_ != TRITONSERVER_TYPE_BYTES) { |
| 655 | + if (dtype_ == TRITONSERVER_TYPE_BF16) { |
| 656 | + // No native numpy representation for BF16. DLPack should be used instead. |
| 657 | + numpy_array_ = py::none(); |
| 658 | + } else if (dtype_ != TRITONSERVER_TYPE_BYTES) { |
647 | 659 | py::object numpy_array = |
648 | 660 | py::array(triton_to_pybind_dtype(dtype_), dims_, (void*)memory_ptr_); |
649 | 661 | numpy_array_ = numpy_array.attr("view")(triton_to_numpy_type(dtype_)); |
|
0 commit comments