@@ -535,6 +535,9 @@ class ModelInstanceState : public BackendModelInstance {
535535
536536 // If the input to the tensor is a dictionary of tensors.
537537 bool is_dict_input_;
538+
539+ // If the model supports batching.
540+ bool supports_batching_;
538541};
539542
540543TRITONSERVER_Error*
@@ -607,6 +610,7 @@ ModelInstanceState::ModelInstanceState(
607610 expected_input_cnt += 1 ;
608611 }
609612 }
613+ supports_batching_ = model_state_->MaxBatchSize () > 0 ;
610614
611615 THROW_IF_BACKEND_INSTANCE_ERROR (ValidateInputs (expected_input_cnt));
612616 THROW_IF_BACKEND_INSTANCE_ERROR (ValidateOutputs ());
@@ -787,7 +791,6 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
787791 " specified." );
788792 }
789793
790- bool supports_batching = model_state_->MaxBatchSize () > 0 ;
791794 NamingConvention naming_convention;
792795 RETURN_IF_ERROR (GetNamingConvention (&naming_convention, allowed_inputs));
793796
@@ -837,8 +840,7 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
837840 .c_str ());
838841 }
839842
840- // Validate shape for String inputs. Only allow 1 dimension and no
841- // batching.
843+ // Validate shape for String inputs. Only allow 1 dimension.
842844 if (io_dtype == " TYPE_STRING" ) {
843845 // If a reshape is provided for the input then use that when
844846 // validating the model shapes.
@@ -850,7 +852,7 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
850852 RETURN_IF_ERROR (ParseShape (io, " dims" , &dims));
851853 }
852854
853- if ((dims.size () > 1 ) || supports_batching ) {
855+ if ((dims.size () + (supports_batching_ ? 1 : 0 )) > 1 ) {
854856 return TRITONSERVER_ErrorNew (
855857 TRITONSERVER_ERROR_INTERNAL,
856858 (" Triton only supports 1 dimensional List of String as input for "
@@ -880,7 +882,6 @@ ModelInstanceState::ValidateOutputs()
880882 " specified." );
881883 }
882884
883- const bool supports_batching = model_state_->MaxBatchSize () > 0 ;
884885 NamingConvention naming_convention;
885886 RETURN_IF_ERROR (GetNamingConvention (&naming_convention, {}));
886887
@@ -917,8 +918,7 @@ ModelInstanceState::ValidateOutputs()
917918 .c_str ());
918919 }
919920
920- // Validate shape for String outputs. Only allow 1 dimension and no
921- // batching.
921+ // Validate shape for String outputs. Only allow 1 dimension.
922922 if (io_dtype == " TYPE_STRING" ) {
923923 // If a reshape is provided for the output then use that when
924924 // validating the model shapes.
@@ -930,7 +930,7 @@ ModelInstanceState::ValidateOutputs()
930930 RETURN_IF_ERROR (ParseShape (io, " dims" , &dims));
931931 }
932932
933- if ((dims.size () > 1 ) || supports_batching ) {
933+ if ((dims.size () + (supports_batching_ ? 1 : 0 )) > 1 ) {
934934 return TRITONSERVER_ErrorNew (
935935 TRITONSERVER_ERROR_INTERNAL,
936936 (" Triton only supports 1 dimensional List of String as output for "
@@ -1015,7 +1015,7 @@ ModelInstanceState::ProcessRequests(
10151015 for (size_t i = 0 ; i < request_count; i++) {
10161016 if (max_batch_size > 0 ) {
10171017 // Retrieve the batch size from one of the inputs, if the model
1018- // supports batching, the first dimension size is batch size
1018+ // supports batching, the first dimension size is batch size.
10191019 TRITONBACKEND_Input* input;
10201020 TRITONSERVER_Error* err =
10211021 TRITONBACKEND_RequestInputByIndex (requests[i], 0 /* index */ , &input);
@@ -1294,7 +1294,7 @@ ModelInstanceState::Execute(
12941294 if (list_output.elementType ()->kind () != c10::TypeKind::StringType) {
12951295 throw std::invalid_argument (
12961296 " output at index " + std::to_string (op_index) +
1297- " must be of type Tensor or List[str], recieved List[" +
1297+ " must be of type Tensor or List[str], received List[" +
12981298 list_output.elementType ()->str () + " ]" );
12991299 }
13001300 output_tensors->push_back (m_op);
@@ -1310,7 +1310,7 @@ ModelInstanceState::Execute(
13101310 auto list_output = model_outputs_.toList ();
13111311 if (list_output.elementType ()->kind () != c10::TypeKind::StringType) {
13121312 throw std::invalid_argument (
1313- " output must be of type Tensor or List[str], recieved List[" +
1313+ " output must be of type Tensor or List[str], received List[" +
13141314 list_output.elementType ()->str () + " ]" );
13151315 }
13161316 output_tensors->push_back (model_outputs_);
@@ -1505,8 +1505,7 @@ GetContiguousInputContent(
15051505}
15061506
15071507void
1508- FillStringTensor (
1509- torch::List<std::string>* input_list, const size_t idx, const size_t cnt)
1508+ FillStringTensor (torch::List<std::string>* input_list, const size_t cnt)
15101509{
15111510 for (size_t c = 0 ; c < cnt; ++c) {
15121511 input_list->push_back (" " );
@@ -1517,9 +1516,8 @@ bool
15171516SetStringInputTensor (
15181517 torch::List<std::string>* input_list, TRITONBACKEND_Input* input,
15191518 const char * name, const uint32_t buffer_count,
1520- const size_t request_element_cnt, const size_t tensor_offset,
1521- TRITONBACKEND_Response** response, cudaStream_t stream,
1522- const char * host_policy_name)
1519+ const size_t request_element_cnt, TRITONBACKEND_Response** response,
1520+ cudaStream_t stream, const char * host_policy_name)
15231521{
15241522 bool cuda_copy = false ;
15251523 size_t element_idx = 0 ;
@@ -1537,9 +1535,7 @@ SetStringInputTensor(
15371535 stream, &cuda_copy);
15381536 if (err != nullptr ) {
15391537 RESPOND_AND_SET_NULL_IF_ERROR (response, err);
1540- FillStringTensor (
1541- input_list, tensor_offset + element_idx,
1542- request_element_cnt - element_idx);
1538+ FillStringTensor (input_list, request_element_cnt - element_idx);
15431539 return cuda_copy;
15441540 }
15451541
@@ -1564,9 +1560,6 @@ SetStringInputTensor(
15641560 std::to_string (element_idx + 1 ) + " for inference input '" +
15651561 name + " ', expecting " + std::to_string (request_element_cnt))
15661562 .c_str ()));
1567- FillStringTensor (
1568- input_list, tensor_offset + element_idx,
1569- request_element_cnt - element_idx);
15701563 return cuda_copy;
15711564 }
15721565
@@ -1585,9 +1578,7 @@ SetStringInputTensor(
15851578 std::to_string (len) + " but only " +
15861579 std::to_string (content_byte_size) + " bytes available" )
15871580 .c_str ()));
1588- FillStringTensor (
1589- input_list, tensor_offset + element_idx,
1590- request_element_cnt - element_idx);
1581+ FillStringTensor (input_list, request_element_cnt - element_idx);
15911582 return cuda_copy;
15921583 }
15931584
@@ -1608,9 +1599,9 @@ SetStringInputTensor(
16081599 " strings for inference input '" + name + " ', got " +
16091600 std::to_string (element_idx))
16101601 .c_str ()));
1611- FillStringTensor (
1612- input_list, tensor_offset + element_idx,
1613- request_element_cnt - element_idx);
1602+ if (element_idx < request_element_cnt) {
1603+ FillStringTensor ( input_list, request_element_cnt - element_idx);
1604+ }
16141605 }
16151606
16161607 return cuda_copy;
@@ -1620,7 +1611,7 @@ bool
16201611SetStringOutputBuffer (
16211612 torch::List<torch::jit::IValue>* tensor, TRITONBACKEND_Response** response,
16221613 TRITONBACKEND_Output* response_output, const size_t tensor_element_count,
1623- const size_t tensor_offset, cudaStream_t stream, std::string* serialized)
1614+ cudaStream_t stream, std::string* serialized)
16241615{
16251616 bool cuda_copy = false ;
16261617
@@ -1677,8 +1668,6 @@ ModelInstanceState::SetInputTensors(
16771668 std::vector<torch::jit::IValue>* input_tensors,
16781669 std::vector<BackendMemory*>* input_memories, bool * cuda_copy)
16791670{
1680- const int max_batch_size = model_state_->MaxBatchSize ();
1681-
16821671 // InferenceMode should be used to guard all tensors operations
16831672 torch::InferenceMode infer_guard (model_state_->EnabledInferenceMode ());
16841673
@@ -1705,7 +1694,7 @@ ModelInstanceState::SetInputTensors(
17051694 // The shape for the entire input patch, [total_batch_size, ...]
17061695 std::vector<int64_t > batchn_shape (
17071696 input_shape, input_shape + input_dims_count);
1708- if (max_batch_size != 0 ) {
1697+ if (supports_batching_ ) {
17091698 batchn_shape[0 ] = total_batch_size;
17101699 }
17111700
@@ -1735,20 +1724,10 @@ ModelInstanceState::SetInputTensors(
17351724
17361725
17371726 if (input_datatype == TRITONSERVER_TYPE_BYTES) {
1738- if (batchn_shape.size () != 1 ) {
1739- return TRITONSERVER_ErrorNew (
1740- TRITONSERVER_ERROR_INTERNAL, (" Triton only supports 1 dimensional "
1741- " List of string as input for '" +
1742- std::string (input_name) + " '" )
1743- .c_str ());
1744- }
1745-
17461727 // Create the PyTorch list to hold the strings.
17471728 torch::List<std::string> input_list;
17481729 input_list.reserve (batchn_shape[0 ]);
17491730
1750- size_t tensor_offset = 0 ;
1751-
17521731 for (size_t idx = 0 ; idx < request_count; idx++) {
17531732 TRITONBACKEND_Input* input;
17541733 RESPOND_AND_SET_NULL_IF_ERROR (
@@ -1767,9 +1746,7 @@ ModelInstanceState::SetInputTensors(
17671746
17681747 *cuda_copy |= SetStringInputTensor (
17691748 &input_list, input, input_name, buffer_count, batch_element_cnt,
1770- tensor_offset, &((*responses)[idx]), CudaStream (),
1771- HostPolicyName ().c_str ());
1772- tensor_offset += batch_element_cnt;
1749+ &((*responses)[idx]), CudaStream (), HostPolicyName ().c_str ());
17731750 }
17741751
17751752 (*input_tensors)[input_index_map_[input_name]] = input_list;
@@ -1864,18 +1841,25 @@ ModelInstanceState::ReadOutputTensors(
18641841
18651842 } else if (output_tensors[op_index].isList ()) {
18661843 // Custom handling for string/bytes tensor...
1867-
18681844 torch::List<torch::jit::IValue> output_list =
18691845 output_tensors[op_index].toList ();
18701846
18711847 // Get output shape
18721848 std::vector<int64_t > batchn_shape{(int64_t )output_list.size ()};
18731849
1874- size_t tensor_offset = 0 ;
1875-
18761850 for (size_t idx = 0 ; idx < responses->size (); idx++) {
1851+ auto & request = requests[idx];
18771852 auto & response = (*responses)[idx];
18781853
1854+ if (supports_batching_ != 0 ) {
1855+ TRITONBACKEND_Input* input;
1856+ TRITONBACKEND_RequestInputByIndex (request, 0 /* index*/ , &input);
1857+ const int64_t * shape;
1858+ TRITONBACKEND_InputProperties (
1859+ input, nullptr , nullptr , &shape, nullptr , nullptr , nullptr );
1860+ batchn_shape[0 ] = shape[0 ];
1861+ }
1862+
18791863 const size_t tensor_element_cnt = GetElementCount (batchn_shape);
18801864
18811865 // Only need an response tensor for requested outputs.
@@ -1889,10 +1873,8 @@ ModelInstanceState::ReadOutputTensors(
18891873 string_buffer.emplace_back (new std::string ());
18901874 cuda_copy |= SetStringOutputBuffer (
18911875 &output_list, &response, response_output, tensor_element_cnt,
1892- tensor_offset, CudaStream (), string_buffer.back ().get ());
1876+ CudaStream (), string_buffer.back ().get ());
18931877 }
1894-
1895- tensor_offset += tensor_element_cnt;
18961878 }
18971879 } else {
18981880 return TRITONSERVER_ErrorNew (
0 commit comments