@@ -48,7 +48,62 @@ typedef SSIZE_T ssize_t;
4848namespace triton { namespace backend { namespace python {
4949
5050#ifdef TRITON_PB_STUB
51- py::array deserialize_bytes_tensor_cpp (const uint8_t * data, size_t data_size);
51+ py::array
52+ deserialize_bytes_tensor_cpp (const uint8_t * data, size_t data_size)
53+ {
54+ if (data_size == 0 ) {
55+ py::module numpy = py::module::import (" numpy" );
56+ return numpy.attr (" empty" )(0 , py::dtype (" object" ));
57+ }
58+
59+ // First pass: count the number of strings and calculate total size
60+ size_t offset = 0 ;
61+ size_t num_strings = 0 ;
62+ size_t total_string_size = 0 ;
63+
64+ while (offset < data_size) {
65+ if (offset + 4 > data_size) {
66+ throw PythonBackendException (
67+ " Invalid bytes tensor data: incomplete length field" );
68+ }
69+
70+ // Read 4-byte length (little-endian)
71+ uint32_t length = *reinterpret_cast <const uint32_t *>(data + offset);
72+ offset += 4 ;
73+
74+ if (offset + length > data_size) {
75+ throw PythonBackendException (
76+ " Invalid bytes tensor data: string extends beyond buffer" );
77+ }
78+
79+ num_strings++;
80+ total_string_size += length;
81+ offset += length;
82+ }
83+
84+ // Create numpy array of objects using pybind11's numpy module
85+ py::module numpy = py::module::import (" numpy" );
86+ py::array result = numpy.attr (" empty" )(num_strings, py::dtype (" object" ));
87+ auto result_ptr = static_cast <PyObject**>(result.request ().ptr );
88+
89+ // Second pass: extract strings
90+ offset = 0 ;
91+ size_t string_index = 0 ;
92+
93+ while (offset < data_size) {
94+ uint32_t length = *reinterpret_cast <const uint32_t *>(data + offset);
95+ offset += 4 ;
96+
97+ // Create Python bytes object using pybind11
98+ py::bytes bytes_obj (reinterpret_cast <const char *>(data + offset), length);
99+ Py_INCREF (bytes_obj.ptr ()); // Increment reference count
100+ result_ptr[string_index] = bytes_obj.ptr ();
101+ string_index++;
102+ offset += length;
103+ }
104+
105+ return result;
106+ }
52107
53108PbTensor::PbTensor (const std::string& name, py::array& numpy_array)
54109 : name_(name)
@@ -166,9 +221,9 @@ PbTensor::PbTensor(
166221 py::array (triton_to_pybind_dtype (dtype_), dims_, (void *)memory_ptr_);
167222 numpy_array_ = numpy_array.attr (" view" )(triton_to_numpy_type (dtype_));
168223 } else {
169- numpy_array_ = deserialize_bytes_tensor_cpp (
170- static_cast <const uint8_t *>(memory_ptr_), byte_size)
171- .attr (" reshape" )(dims );
224+ py::object numpy_array = deserialize_bytes_tensor_cpp (
225+ static_cast <const uint8_t *>(memory_ptr_), byte_size_);
226+ numpy_array_ = numpy_array .attr (" reshape" )(dims_ );
172227 }
173228 } else {
174229 numpy_array_ = py::none ();
@@ -235,62 +290,6 @@ delete_unused_dltensor(PyObject* dlp)
235290 }
236291}
237292
238- py::array
239- deserialize_bytes_tensor_cpp (const uint8_t * data, size_t data_size)
240- {
241- if (data_size == 0 ) {
242- py::module numpy = py::module::import (" numpy" );
243- return numpy.attr (" empty" )(0 , py::dtype (" object" ));
244- }
245-
246- // First pass: count the number of strings and calculate total size
247- size_t offset = 0 ;
248- size_t num_strings = 0 ;
249- size_t total_string_size = 0 ;
250-
251- while (offset < data_size) {
252- if (offset + 4 > data_size) {
253- throw PythonBackendException (
254- " Invalid bytes tensor data: incomplete length field" );
255- }
256-
257- // Read 4-byte length (little-endian)
258- uint32_t length = *reinterpret_cast <const uint32_t *>(data + offset);
259- offset += 4 ;
260-
261- if (offset + length > data_size) {
262- throw PythonBackendException (
263- " Invalid bytes tensor data: string extends beyond buffer" );
264- }
265-
266- num_strings++;
267- total_string_size += length;
268- offset += length;
269- }
270-
271- // Create numpy array of objects using pybind11's numpy module
272- py::module numpy = py::module::import (" numpy" );
273- py::array result = numpy.attr (" empty" )(num_strings, py::dtype (" object" ));
274- auto result_ptr = static_cast <PyObject**>(result.request ().ptr );
275-
276- // Second pass: extract strings
277- offset = 0 ;
278- size_t string_index = 0 ;
279-
280- while (offset < data_size) {
281- uint32_t length = *reinterpret_cast <const uint32_t *>(data + offset);
282- offset += 4 ;
283-
284- // Create Python bytes object using pybind11
285- py::bytes bytes_obj (reinterpret_cast <const char *>(data + offset), length);
286- Py_INCREF (bytes_obj.ptr ()); // Increment reference count
287- result_ptr[string_index] = bytes_obj.ptr ();
288- string_index++;
289- offset += length;
290- }
291-
292- return result;
293- }
294293
295294std::shared_ptr<PbTensor>
296295PbTensor::FromNumpy (const std::string& name, py::array& numpy_array)
@@ -726,9 +725,9 @@ PbTensor::PbTensor(
726725 py::array (triton_to_pybind_dtype (dtype_), dims_, (void *)memory_ptr_);
727726 numpy_array_ = numpy_array.attr (" view" )(triton_to_numpy_type (dtype_));
728727 } else {
729- numpy_array_ = deserialize_bytes_tensor_cpp (
730- static_cast <const uint8_t *>(memory_ptr_), byte_size_)
731- .attr (" reshape" )(dims_);
728+ py::object numpy_array = deserialize_bytes_tensor_cpp (
729+ static_cast <const uint8_t *>(memory_ptr_), byte_size_);
730+ numpy_array_ = numpy_array .attr (" reshape" )(dims_);
732731 }
733732 } else {
734733 numpy_array_ = py::none ();
0 commit comments