Skip to content

Commit 3ecda56

Browse files
authored
Add reshape+batching and dynamic batching support for string I/O (#69)
* reshape+batching and dynamic batching support for string I/O * Address comment * Address comment
1 parent 663ee99 commit 3ecda56

File tree

2 files changed

+33
-53
lines changed

2 files changed

+33
-53
lines changed

README.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,5 +231,3 @@ a List of Strings as input(s) / produces a List of String as output(s). For thes
231231
Triton allows users to pass String input(s)/recieve String output(s) using the String
232232
datatype. As a limitation of using List instead of Tensor for String I/O, only for
233233
1-dimensional input(s)/output(s) are supported for I/O of String type.
234-
Batching is not allowed for PyTorch models with String I/O. For these models,
235-
the user must specify `max_batch_size: 0` in the configuration.

src/libtorch.cc

Lines changed: 33 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,9 @@ class ModelInstanceState : public BackendModelInstance {
535535

536536
// If the input to the tensor is a dictionary of tensors.
537537
bool is_dict_input_;
538+
539+
// If the model supports batching.
540+
bool supports_batching_;
538541
};
539542

540543
TRITONSERVER_Error*
@@ -607,6 +610,7 @@ ModelInstanceState::ModelInstanceState(
607610
expected_input_cnt += 1;
608611
}
609612
}
613+
supports_batching_ = model_state_->MaxBatchSize() > 0;
610614

611615
THROW_IF_BACKEND_INSTANCE_ERROR(ValidateInputs(expected_input_cnt));
612616
THROW_IF_BACKEND_INSTANCE_ERROR(ValidateOutputs());
@@ -787,7 +791,6 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
787791
"specified.");
788792
}
789793

790-
bool supports_batching = model_state_->MaxBatchSize() > 0;
791794
NamingConvention naming_convention;
792795
RETURN_IF_ERROR(GetNamingConvention(&naming_convention, allowed_inputs));
793796

@@ -837,8 +840,7 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
837840
.c_str());
838841
}
839842

840-
// Validate shape for String inputs. Only allow 1 dimension and no
841-
// batching.
843+
// Validate shape for String inputs. Only allow 1 dimension.
842844
if (io_dtype == "TYPE_STRING") {
843845
// If a reshape is provided for the input then use that when
844846
// validating the model shapes.
@@ -850,7 +852,7 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
850852
RETURN_IF_ERROR(ParseShape(io, "dims", &dims));
851853
}
852854

853-
if ((dims.size() > 1) || supports_batching) {
855+
if ((dims.size() + (supports_batching_ ? 1 : 0)) > 1) {
854856
return TRITONSERVER_ErrorNew(
855857
TRITONSERVER_ERROR_INTERNAL,
856858
("Triton only supports 1 dimensional List of String as input for "
@@ -880,7 +882,6 @@ ModelInstanceState::ValidateOutputs()
880882
"specified.");
881883
}
882884

883-
const bool supports_batching = model_state_->MaxBatchSize() > 0;
884885
NamingConvention naming_convention;
885886
RETURN_IF_ERROR(GetNamingConvention(&naming_convention, {}));
886887

@@ -917,8 +918,7 @@ ModelInstanceState::ValidateOutputs()
917918
.c_str());
918919
}
919920

920-
// Validate shape for String outputs. Only allow 1 dimension and no
921-
// batching.
921+
// Validate shape for String outputs. Only allow 1 dimension.
922922
if (io_dtype == "TYPE_STRING") {
923923
// If a reshape is provided for the output then use that when
924924
// validating the model shapes.
@@ -930,7 +930,7 @@ ModelInstanceState::ValidateOutputs()
930930
RETURN_IF_ERROR(ParseShape(io, "dims", &dims));
931931
}
932932

933-
if ((dims.size() > 1) || supports_batching) {
933+
if ((dims.size() + (supports_batching_ ? 1 : 0)) > 1) {
934934
return TRITONSERVER_ErrorNew(
935935
TRITONSERVER_ERROR_INTERNAL,
936936
("Triton only supports 1 dimensional List of String as output for "
@@ -1015,7 +1015,7 @@ ModelInstanceState::ProcessRequests(
10151015
for (size_t i = 0; i < request_count; i++) {
10161016
if (max_batch_size > 0) {
10171017
// Retrieve the batch size from one of the inputs, if the model
1018-
// supports batching, the first dimension size is batch size
1018+
// supports batching, the first dimension size is batch size.
10191019
TRITONBACKEND_Input* input;
10201020
TRITONSERVER_Error* err =
10211021
TRITONBACKEND_RequestInputByIndex(requests[i], 0 /* index */, &input);
@@ -1294,7 +1294,7 @@ ModelInstanceState::Execute(
12941294
if (list_output.elementType()->kind() != c10::TypeKind::StringType) {
12951295
throw std::invalid_argument(
12961296
"output at index " + std::to_string(op_index) +
1297-
" must be of type Tensor or List[str], recieved List[" +
1297+
" must be of type Tensor or List[str], received List[" +
12981298
list_output.elementType()->str() + "]");
12991299
}
13001300
output_tensors->push_back(m_op);
@@ -1310,7 +1310,7 @@ ModelInstanceState::Execute(
13101310
auto list_output = model_outputs_.toList();
13111311
if (list_output.elementType()->kind() != c10::TypeKind::StringType) {
13121312
throw std::invalid_argument(
1313-
"output must be of type Tensor or List[str], recieved List[" +
1313+
"output must be of type Tensor or List[str], received List[" +
13141314
list_output.elementType()->str() + "]");
13151315
}
13161316
output_tensors->push_back(model_outputs_);
@@ -1505,8 +1505,7 @@ GetContiguousInputContent(
15051505
}
15061506

15071507
void
1508-
FillStringTensor(
1509-
torch::List<std::string>* input_list, const size_t idx, const size_t cnt)
1508+
FillStringTensor(torch::List<std::string>* input_list, const size_t cnt)
15101509
{
15111510
for (size_t c = 0; c < cnt; ++c) {
15121511
input_list->push_back("");
@@ -1517,9 +1516,8 @@ bool
15171516
SetStringInputTensor(
15181517
torch::List<std::string>* input_list, TRITONBACKEND_Input* input,
15191518
const char* name, const uint32_t buffer_count,
1520-
const size_t request_element_cnt, const size_t tensor_offset,
1521-
TRITONBACKEND_Response** response, cudaStream_t stream,
1522-
const char* host_policy_name)
1519+
const size_t request_element_cnt, TRITONBACKEND_Response** response,
1520+
cudaStream_t stream, const char* host_policy_name)
15231521
{
15241522
bool cuda_copy = false;
15251523
size_t element_idx = 0;
@@ -1537,9 +1535,7 @@ SetStringInputTensor(
15371535
stream, &cuda_copy);
15381536
if (err != nullptr) {
15391537
RESPOND_AND_SET_NULL_IF_ERROR(response, err);
1540-
FillStringTensor(
1541-
input_list, tensor_offset + element_idx,
1542-
request_element_cnt - element_idx);
1538+
FillStringTensor(input_list, request_element_cnt - element_idx);
15431539
return cuda_copy;
15441540
}
15451541

@@ -1564,9 +1560,6 @@ SetStringInputTensor(
15641560
std::to_string(element_idx + 1) + " for inference input '" +
15651561
name + "', expecting " + std::to_string(request_element_cnt))
15661562
.c_str()));
1567-
FillStringTensor(
1568-
input_list, tensor_offset + element_idx,
1569-
request_element_cnt - element_idx);
15701563
return cuda_copy;
15711564
}
15721565

@@ -1585,9 +1578,7 @@ SetStringInputTensor(
15851578
std::to_string(len) + " but only " +
15861579
std::to_string(content_byte_size) + " bytes available")
15871580
.c_str()));
1588-
FillStringTensor(
1589-
input_list, tensor_offset + element_idx,
1590-
request_element_cnt - element_idx);
1581+
FillStringTensor(input_list, request_element_cnt - element_idx);
15911582
return cuda_copy;
15921583
}
15931584

@@ -1608,9 +1599,9 @@ SetStringInputTensor(
16081599
" strings for inference input '" + name + "', got " +
16091600
std::to_string(element_idx))
16101601
.c_str()));
1611-
FillStringTensor(
1612-
input_list, tensor_offset + element_idx,
1613-
request_element_cnt - element_idx);
1602+
if (element_idx < request_element_cnt) {
1603+
FillStringTensor(input_list, request_element_cnt - element_idx);
1604+
}
16141605
}
16151606

16161607
return cuda_copy;
@@ -1620,7 +1611,7 @@ bool
16201611
SetStringOutputBuffer(
16211612
torch::List<torch::jit::IValue>* tensor, TRITONBACKEND_Response** response,
16221613
TRITONBACKEND_Output* response_output, const size_t tensor_element_count,
1623-
const size_t tensor_offset, cudaStream_t stream, std::string* serialized)
1614+
cudaStream_t stream, std::string* serialized)
16241615
{
16251616
bool cuda_copy = false;
16261617

@@ -1677,8 +1668,6 @@ ModelInstanceState::SetInputTensors(
16771668
std::vector<torch::jit::IValue>* input_tensors,
16781669
std::vector<BackendMemory*>* input_memories, bool* cuda_copy)
16791670
{
1680-
const int max_batch_size = model_state_->MaxBatchSize();
1681-
16821671
// InferenceMode should be used to guard all tensors operations
16831672
torch::InferenceMode infer_guard(model_state_->EnabledInferenceMode());
16841673

@@ -1705,7 +1694,7 @@ ModelInstanceState::SetInputTensors(
17051694
// The shape for the entire input patch, [total_batch_size, ...]
17061695
std::vector<int64_t> batchn_shape(
17071696
input_shape, input_shape + input_dims_count);
1708-
if (max_batch_size != 0) {
1697+
if (supports_batching_) {
17091698
batchn_shape[0] = total_batch_size;
17101699
}
17111700

@@ -1735,20 +1724,10 @@ ModelInstanceState::SetInputTensors(
17351724

17361725

17371726
if (input_datatype == TRITONSERVER_TYPE_BYTES) {
1738-
if (batchn_shape.size() != 1) {
1739-
return TRITONSERVER_ErrorNew(
1740-
TRITONSERVER_ERROR_INTERNAL, ("Triton only supports 1 dimensional "
1741-
"List of string as input for '" +
1742-
std::string(input_name) + "'")
1743-
.c_str());
1744-
}
1745-
17461727
// Create the PyTorch list to hold the strings.
17471728
torch::List<std::string> input_list;
17481729
input_list.reserve(batchn_shape[0]);
17491730

1750-
size_t tensor_offset = 0;
1751-
17521731
for (size_t idx = 0; idx < request_count; idx++) {
17531732
TRITONBACKEND_Input* input;
17541733
RESPOND_AND_SET_NULL_IF_ERROR(
@@ -1767,9 +1746,7 @@ ModelInstanceState::SetInputTensors(
17671746

17681747
*cuda_copy |= SetStringInputTensor(
17691748
&input_list, input, input_name, buffer_count, batch_element_cnt,
1770-
tensor_offset, &((*responses)[idx]), CudaStream(),
1771-
HostPolicyName().c_str());
1772-
tensor_offset += batch_element_cnt;
1749+
&((*responses)[idx]), CudaStream(), HostPolicyName().c_str());
17731750
}
17741751

17751752
(*input_tensors)[input_index_map_[input_name]] = input_list;
@@ -1864,18 +1841,25 @@ ModelInstanceState::ReadOutputTensors(
18641841

18651842
} else if (output_tensors[op_index].isList()) {
18661843
// Custom handling for string/bytes tensor...
1867-
18681844
torch::List<torch::jit::IValue> output_list =
18691845
output_tensors[op_index].toList();
18701846

18711847
// Get output shape
18721848
std::vector<int64_t> batchn_shape{(int64_t)output_list.size()};
18731849

1874-
size_t tensor_offset = 0;
1875-
18761850
for (size_t idx = 0; idx < responses->size(); idx++) {
1851+
auto& request = requests[idx];
18771852
auto& response = (*responses)[idx];
18781853

1854+
if (supports_batching_ != 0) {
1855+
TRITONBACKEND_Input* input;
1856+
TRITONBACKEND_RequestInputByIndex(request, 0 /* index*/, &input);
1857+
const int64_t* shape;
1858+
TRITONBACKEND_InputProperties(
1859+
input, nullptr, nullptr, &shape, nullptr, nullptr, nullptr);
1860+
batchn_shape[0] = shape[0];
1861+
}
1862+
18791863
const size_t tensor_element_cnt = GetElementCount(batchn_shape);
18801864

18811865
// Only need an response tensor for requested outputs.
@@ -1889,10 +1873,8 @@ ModelInstanceState::ReadOutputTensors(
18891873
string_buffer.emplace_back(new std::string());
18901874
cuda_copy |= SetStringOutputBuffer(
18911875
&output_list, &response, response_output, tensor_element_cnt,
1892-
tensor_offset, CudaStream(), string_buffer.back().get());
1876+
CudaStream(), string_buffer.back().get());
18931877
}
1894-
1895-
tensor_offset += tensor_element_cnt;
18961878
}
18971879
} else {
18981880
return TRITONSERVER_ErrorNew(

0 commit comments

Comments
 (0)