@@ -1741,11 +1741,32 @@ ModelInstanceState::SetInputTensors(
17411741
17421742 input_names->emplace_back (input_name);
17431743
1744- // The shape for the entire input patch, [total_batch_size, ...]
1745- std::vector<int64_t > batchn_shape (
1746- input_shape, input_shape + input_dims_count);
1747- if (supports_batching_) {
1748- batchn_shape[0 ] = total_batch_size;
1744+ // The shape for the entire input patch,
1745+ // [total_batch_size, ...] for non-ragged input and
1746+ // [total_element_count] for ragged input (non-nested tensor)
1747+ std::vector<int64_t > batchn_shape;
1748+ if (StateForModel ()->IsInputRagged (input_name)) {
1749+ batchn_shape = std::vector<int64_t >{0 };
1750+ for (size_t idx = 0 ; idx < request_count; idx++) {
1751+ TRITONBACKEND_Input* input;
1752+ RESPOND_AND_SET_NULL_IF_ERROR (
1753+ &((*responses)[idx]),
1754+ TRITONBACKEND_RequestInput (requests[idx], input_name, &input));
1755+ const int64_t * input_shape;
1756+ uint32_t input_dims_count;
1757+ RESPOND_AND_SET_NULL_IF_ERROR (
1758+ &((*responses)[idx]), TRITONBACKEND_InputProperties (
1759+ input, nullptr , nullptr , &input_shape,
1760+ &input_dims_count, nullptr , nullptr ));
1761+
1762+ batchn_shape[0 ] += GetElementCount (input_shape, input_dims_count);
1763+ }
1764+ }
1765+ else {
1766+ batchn_shape = std::vector<int64_t >(input_shape, input_shape + input_dims_count);
1767+ if (supports_batching_) {
1768+ batchn_shape[0 ] = total_batch_size;
1769+ }
17491770 }
17501771
17511772 // The input must be in contiguous CPU/GPU memory.
@@ -1866,28 +1887,36 @@ ModelInstanceState::ReadOutputTensors(
18661887
18671888 // Output tensors may not reside on the same device as model
18681889 torch::Device tensor_device = output_flat.device ();
1890+ const auto memory_type = (tensor_device.type () == torch::kCPU ) ? TRITONSERVER_MEMORY_CPU
1891+ : TRITONSERVER_MEMORY_GPU;
1892+ const auto memory_id = (tensor_device.type () == torch::kCPU ) ? 0 : tensor_device.index ();
1893+
1894+ // Batch output doesn't support string data type yet, as it is not trivial
1895+ // to parse string output
1896+ const BatchOutput* batch_output = StateForModel ()->FindBatchOutput (name);
1897+ if (batch_output == nullptr ) {
1898+ // Get output shape
1899+ std::vector<int64_t > batchn_shape;
1900+ auto shape = output_tensors[op_index].toTensor ().sizes ();
1901+ for (auto itr = shape.begin (); itr != shape.end (); itr++) {
1902+ batchn_shape.push_back (*itr);
1903+ }
18691904
1870- // Get output shape
1871- std::vector<int64_t > batchn_shape;
1872- auto shape = output_tensors[op_index].toTensor ().sizes ();
1873- for (auto itr = shape.begin (); itr != shape.end (); itr++) {
1874- batchn_shape.push_back (*itr);
1875- }
1905+ if (batchn_shape.size () == 0 ) {
1906+ return TRITONSERVER_ErrorNew (
1907+ TRITONSERVER_ERROR_INVALID_ARG,
1908+ (std::string (" output '" ) + name +
1909+ " ' is a scalar which is not supported." )
1910+ .c_str ());
1911+ }
18761912
1877- if (batchn_shape. size () == 0 ) {
1878- return TRITONSERVER_ErrorNew (
1879- TRITONSERVER_ERROR_INVALID_ARG,
1880- ( std::string ( " output ' " ) + name +
1881- " ' is a scalar which is not supported. " )
1882- . c_str () );
1913+ responder. ProcessTensor (
1914+ name, output_dtype, batchn_shape, output_buffer,
1915+ memory_type, memory_id);
1916+ } else {
1917+ responder. ProcessBatchOutput (
1918+ name, *batch_output, output_buffer, memory_type, memory_id );
18831919 }
1884-
1885- responder.ProcessTensor (
1886- name, output_dtype, batchn_shape, output_buffer,
1887- (tensor_device.type () == torch::kCPU ) ? TRITONSERVER_MEMORY_CPU
1888- : TRITONSERVER_MEMORY_GPU,
1889- (tensor_device.type () == torch::kCPU ) ? 0 : tensor_device.index ());
1890-
18911920 } else if (output_tensors[op_index].isList ()) {
18921921 // Custom handling for string/bytes tensor...
18931922 torch::List<torch::jit::IValue> output_list =
0 commit comments