Skip to content
Merged
54 changes: 36 additions & 18 deletions src/infer_request.cc
Original file line number Diff line number Diff line change
Expand Up @@ -910,6 +910,7 @@ Status
InferenceRequest::Normalize()
{
const inference::ModelConfig& model_config = model_raw_->Config();
const std::string& model_name = ModelName();

// Fill metadata for raw input
if (!raw_input_name_.empty()) {
Expand All @@ -922,7 +923,7 @@ InferenceRequest::Normalize()
std::to_string(original_inputs_.size()) +
") to be deduced but got " +
std::to_string(model_config.input_size()) + " inputs in '" +
ModelName() + "' model configuration");
model_name + "' model configuration");
}
auto it = original_inputs_.begin();
if (raw_input_name_ != it->first) {
Expand Down Expand Up @@ -1040,7 +1041,7 @@ InferenceRequest::Normalize()
Status::Code::INVALID_ARG,
LogRequest() + "input '" + input.Name() +
"' has no shape but model requires batch dimension for '" +
ModelName() + "'");
model_name + "'");
}

if (batch_size_ == 0) {
Expand All @@ -1049,7 +1050,7 @@ InferenceRequest::Normalize()
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "input '" + input.Name() +
"' batch size does not match other inputs for '" + ModelName() +
"' batch size does not match other inputs for '" + model_name +
"'");
}

Expand All @@ -1065,7 +1066,7 @@ InferenceRequest::Normalize()
Status::Code::INVALID_ARG,
LogRequest() + "inference request batch-size must be <= " +
std::to_string(model_config.max_batch_size()) + " for '" +
ModelName() + "'");
model_name + "'");
}

// Verify that each input shape is valid for the model, make
Expand All @@ -1074,17 +1075,17 @@ InferenceRequest::Normalize()
const inference::ModelInput* input_config;
RETURN_IF_ERROR(model_raw_->GetInput(pr.second.Name(), &input_config));

auto& input_id = pr.first;
auto& input_name = pr.first;
auto& input = pr.second;
auto shape = input.MutableShape();

if (input.DType() != input_config->data_type()) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "inference input '" + input_id + "' data-type is '" +
LogRequest() + "inference input '" + input_name + "' data-type is '" +
std::string(
triton::common::DataTypeToProtocolString(input.DType())) +
"', but model '" + ModelName() + "' expects '" +
"', but model '" + model_name + "' expects '" +
std::string(triton::common::DataTypeToProtocolString(
input_config->data_type())) +
"'");
Expand All @@ -1104,7 +1105,7 @@ InferenceRequest::Normalize()
Status::Code::INVALID_ARG,
LogRequest() +
"All input dimensions should be specified for input '" +
input_id + "' for model '" + ModelName() + "', got " +
input_name + "' for model '" + model_name + "', got " +
triton::common::DimsListToString(input.OriginalShape()));
} else if (
(config_dims[i] != triton::common::WILDCARD_DIM) &&
Expand Down Expand Up @@ -1133,8 +1134,8 @@ InferenceRequest::Normalize()
}
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "unexpected shape for input '" + input_id +
"' for model '" + ModelName() + "'. Expected " +
LogRequest() + "unexpected shape for input '" + input_name +
"' for model '" + model_name + "'. Expected " +
triton::common::DimsListToString(full_dims) + ", got " +
triton::common::DimsListToString(input.OriginalShape()) + ". " +
implicit_batch_note);
Expand Down Expand Up @@ -1196,8 +1197,8 @@ InferenceRequest::Normalize()
// (prepend 4 bytes to specify string length), so need to add all the
// first 4 bytes for each element to find expected byte size
if (data_type == inference::DataType::TYPE_STRING) {
RETURN_IF_ERROR(
ValidateBytesInputs(input_id, input, &input_memory_type));
RETURN_IF_ERROR(ValidateBytesInputs(
input_name, input, model_name, &input_memory_type));
// FIXME: Temporarily skips byte size checks for GPU tensors. See
// DLIS-6820.
skip_byte_size_check |=
Expand All @@ -1213,7 +1214,7 @@ InferenceRequest::Normalize()
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "input byte size mismatch for input '" +
input_id + "' for model '" + ModelName() + "'. Expected " +
input_name + "' for model '" + model_name + "'. Expected " +
std::to_string(expected_byte_size) + ", got " +
std::to_string(byte_size));
}
Expand Down Expand Up @@ -1287,7 +1288,8 @@ InferenceRequest::ValidateRequestInputs()

Status
InferenceRequest::ValidateBytesInputs(
const std::string& input_id, const Input& input,
const std::string& input_name, const Input& input,
const std::string& model_name,
TRITONSERVER_MemoryType* buffer_memory_type) const
{
const auto& input_dims = input.ShapeWithBatchDim();
Expand Down Expand Up @@ -1326,13 +1328,28 @@ InferenceRequest::ValidateBytesInputs(
return Status(
Status::Code::INVALID_ARG,
LogRequest() +
"element byte size indicator exceeds the end of the buffer.");
"incomplete string length indicator for inference input '" +
input_name + "' for model '" + model_name + "', expecting " +
std::to_string(sizeof(uint32_t)) + " bytes but only " +
std::to_string(remaining_buffer_size) +
" bytes available. Please make sure the string length "
"indicator is in one buffer.");
}

// Start the next element and reset the remaining element size.
remaining_element_size = *(reinterpret_cast<const uint32_t*>(buffer));
element_checked++;

// Early stop
if (element_checked > element_count) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "unexpected number of string elements " +
std::to_string(element_checked) + " for inference input '" +
input_name + "' for model '" + model_name + "', expecting " +
std::to_string(element_count));
}

// Advance pointer and remainder by the indicator size.
buffer += kElementSizeIndicator;
remaining_buffer_size -= kElementSizeIndicator;
Expand All @@ -1358,16 +1375,17 @@ InferenceRequest::ValidateBytesInputs(
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "expected " + std::to_string(buffer_count) +
" buffers for inference input '" + input_id + "', got " +
std::to_string(buffer_next_idx));
" buffers for inference input '" + input_name + "' for model '" +
model_name + "', got " + std::to_string(buffer_next_idx));
}

// Validate the number of processed elements exactly match expectations.
if (element_checked != element_count) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "expected " + std::to_string(element_count) +
" string elements for inference input '" + input_id + "', got " +
" string elements for inference input '" + input_name +
"' for model '" + model_name + "', got " +
std::to_string(element_checked));
}

Expand Down
1 change: 1 addition & 0 deletions src/infer_request.h
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,7 @@ class InferenceRequest {

Status ValidateBytesInputs(
const std::string& input_id, const Input& input,
const std::string& model_name,
TRITONSERVER_MemoryType* buffer_memory_type) const;

// Helpers for pending request metrics
Expand Down
53 changes: 34 additions & 19 deletions src/test/input_byte_size_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -258,19 +258,20 @@ char InputByteSizeTest::input_data_string_

TEST_F(InputByteSizeTest, ValidInputByteSize)
{
const char* model_name = "savedmodel_zero_1_float32";
// Create an inference request
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestNew(
&irequest_, server_, "pt_identity", -1 /* model_version */),
&irequest_, server_, model_name, -1 /* model_version */),
"creating inference request");
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestSetReleaseCallback(
irequest_, InferRequestComplete, nullptr /* request_release_userp */),
"setting request release callback");

// Define input shape and data
std::vector<int64_t> shape{1, 8};
std::vector<float> input_data(8, 1);
std::vector<int64_t> shape{1, 16};
std::vector<float> input_data(16, 1);
const auto input0_byte_size = sizeof(input_data[0]) * input_data.size();

// Set input for the request
Expand Down Expand Up @@ -312,19 +313,20 @@ TEST_F(InputByteSizeTest, ValidInputByteSize)

TEST_F(InputByteSizeTest, InputByteSizeMismatch)
{
const char* model_name = "savedmodel_zero_1_float32";
// Create an inference request
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestNew(
&irequest_, server_, "pt_identity", -1 /* model_version */),
&irequest_, server_, model_name, -1 /* model_version */),
"creating inference request");
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestSetReleaseCallback(
irequest_, InferRequestComplete, nullptr /* request_release_userp */),
"setting request release callback");

// Define input shape and data
std::vector<int64_t> shape{1, 8};
std::vector<float> input_data(10, 1);
std::vector<int64_t> shape{1, 16};
std::vector<float> input_data(17, 1);
const auto input0_byte_size = sizeof(input_data[0]) * input_data.size();

// Set input for the request
Expand Down Expand Up @@ -353,8 +355,8 @@ TEST_F(InputByteSizeTest, InputByteSizeMismatch)
FAIL_TEST_IF_SUCCESS(
TRITONSERVER_ServerInferAsync(server_, irequest_, nullptr /* trace */),
"expect error with inference request",
"input byte size mismatch for input 'INPUT0' for model 'pt_identity'. "
"Expected 32, got 40");
"input byte size mismatch for input 'INPUT0' for model '" +
std::string{model_name} + "'. Expected 64, got 68");

// Need to manually delete request, otherwise server will not shut down.
FAIL_TEST_IF_ERR(
Expand All @@ -364,10 +366,11 @@ TEST_F(InputByteSizeTest, InputByteSizeMismatch)

TEST_F(InputByteSizeTest, ValidStringInputByteSize)
{
const char* model_name = "savedmodel_zero_1_object";
// Create an inference request
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestNew(
&irequest_, server_, "simple_identity", -1 /* model_version */),
&irequest_, server_, model_name, -1 /* model_version */),
"creating inference request");
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestSetReleaseCallback(
Expand Down Expand Up @@ -424,10 +427,11 @@ TEST_F(InputByteSizeTest, ValidStringInputByteSize)

TEST_F(InputByteSizeTest, StringCountMismatch)
{
const char* model_name = "savedmodel_zero_1_object";
// Create an inference request
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestNew(
&irequest_, server_, "simple_identity", -1 /* model_version */),
&irequest_, server_, model_name, -1 /* model_version */),
"creating inference request");
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestSetReleaseCallback(
Expand Down Expand Up @@ -457,7 +461,8 @@ TEST_F(InputByteSizeTest, StringCountMismatch)
FAIL_TEST_IF_SUCCESS(
TRITONSERVER_ServerInferAsync(server_, irequest_, nullptr /* trace */),
"expect error with inference request",
"expected 3 string elements for inference input 'INPUT0', got 2");
"expected 3 string elements for inference input 'INPUT0' for model '" +
std::string{model_name} + "', got 2");

// Need to manually delete request, otherwise server will not shut down.
FAIL_TEST_IF_ERR(
Expand All @@ -467,7 +472,8 @@ TEST_F(InputByteSizeTest, StringCountMismatch)
// Create an inference request
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestNew(
&irequest_, server_, "simple_identity", -1 /* model_version */),
&irequest_, server_, "savedmodel_zero_1_object",
-1 /* model_version */),
"creating inference request");
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestSetReleaseCallback(
Expand Down Expand Up @@ -495,7 +501,9 @@ TEST_F(InputByteSizeTest, StringCountMismatch)
FAIL_TEST_IF_SUCCESS(
TRITONSERVER_ServerInferAsync(server_, irequest_, nullptr /* trace */),
"expect error with inference request",
"expected 1 string elements for inference input 'INPUT0', got 2");
"unexpected number of string elements 2 for inference input 'INPUT0' for "
"model '" +
std::string{model_name} + "', expecting 1");

// Need to manually delete request, otherwise server will not shut down.
FAIL_TEST_IF_ERR(
Expand All @@ -505,10 +513,11 @@ TEST_F(InputByteSizeTest, StringCountMismatch)

TEST_F(InputByteSizeTest, StringSizeMisalign)
{
const char* model_name = "savedmodel_zero_1_object";
// Create an inference request
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestNew(
&irequest_, server_, "simple_identity", -1 /* model_version */),
&irequest_, server_, model_name, -1 /* model_version */),
"creating inference request");
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestSetReleaseCallback(
Expand Down Expand Up @@ -542,9 +551,13 @@ TEST_F(InputByteSizeTest, StringSizeMisalign)

// Run inference
FAIL_TEST_IF_SUCCESS(
TRITONSERVER_ServerInferAsync(server_, irequest_, nullptr /* trace
*/), "expect error with inference request",
"element byte size indicator exceeds the end of the buffer");
TRITONSERVER_ServerInferAsync(server_, irequest_, nullptr /* trace*/),
"expect error with inference request",
"incomplete string length indicator for inference input 'INPUT0' for "
"model '" +
std::string{model_name} +
"', expecting 4 bytes but only 2 bytes available. Please make sure "
"the string length indicator is in one buffer.");

// Need to manually delete request, otherwise server will not shut down.
FAIL_TEST_IF_ERR(
Expand Down Expand Up @@ -573,7 +586,8 @@ TEST_F(InputByteSizeTest, StringCountMismatchGPU)
// Create an inference request
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestNew(
&irequest_, server_, "simple_identity", -1 /* model_version */),
&irequest_, server_, "savedmodel_zero_1_object",
-1 /* model_version */),
"creating inference request");
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestSetReleaseCallback(
Expand Down Expand Up @@ -629,7 +643,8 @@ TEST_F(InputByteSizeTest, StringCountMismatchGPU)
// Create an inference request
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestNew(
&irequest_, server_, "simple_identity", -1 /* model_version */),
&irequest_, server_, "savedmodel_zero_1_object",
-1 /* model_version */),
"creating inference request");
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestSetReleaseCallback(
Expand Down