Skip to content

Commit b336ecc

Browse files
committed
Update core checks
1 parent 86a5573 commit b336ecc

File tree

3 files changed

+67
-33
lines changed

3 files changed

+67
-33
lines changed

src/infer_request.cc

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -906,6 +906,7 @@ Status
906906
InferenceRequest::Normalize()
907907
{
908908
const inference::ModelConfig& model_config = model_raw_->Config();
909+
const std::string& model_name = ModelName();
909910

910911
// Fill metadata for raw input
911912
if (!raw_input_name_.empty()) {
@@ -918,7 +919,7 @@ InferenceRequest::Normalize()
918919
std::to_string(original_inputs_.size()) +
919920
") to be deduced but got " +
920921
std::to_string(model_config.input_size()) + " inputs in '" +
921-
ModelName() + "' model configuration");
922+
model_name + "' model configuration");
922923
}
923924
auto it = original_inputs_.begin();
924925
if (raw_input_name_ != it->first) {
@@ -1036,7 +1037,7 @@ InferenceRequest::Normalize()
10361037
Status::Code::INVALID_ARG,
10371038
LogRequest() + "input '" + input.Name() +
10381039
"' has no shape but model requires batch dimension for '" +
1039-
ModelName() + "'");
1040+
model_name + "'");
10401041
}
10411042

10421043
if (batch_size_ == 0) {
@@ -1045,7 +1046,7 @@ InferenceRequest::Normalize()
10451046
return Status(
10461047
Status::Code::INVALID_ARG,
10471048
LogRequest() + "input '" + input.Name() +
1048-
"' batch size does not match other inputs for '" + ModelName() +
1049+
"' batch size does not match other inputs for '" + model_name +
10491050
"'");
10501051
}
10511052

@@ -1061,7 +1062,7 @@ InferenceRequest::Normalize()
10611062
Status::Code::INVALID_ARG,
10621063
LogRequest() + "inference request batch-size must be <= " +
10631064
std::to_string(model_config.max_batch_size()) + " for '" +
1064-
ModelName() + "'");
1065+
model_name + "'");
10651066
}
10661067

10671068
// Verify that each input shape is valid for the model, make
@@ -1070,17 +1071,17 @@ InferenceRequest::Normalize()
10701071
const inference::ModelInput* input_config;
10711072
RETURN_IF_ERROR(model_raw_->GetInput(pr.second.Name(), &input_config));
10721073

1073-
auto& input_id = pr.first;
1074+
auto& input_name = pr.first;
10741075
auto& input = pr.second;
10751076
auto shape = input.MutableShape();
10761077

10771078
if (input.DType() != input_config->data_type()) {
10781079
return Status(
10791080
Status::Code::INVALID_ARG,
1080-
LogRequest() + "inference input '" + input_id + "' data-type is '" +
1081+
LogRequest() + "inference input '" + input_name + "' data-type is '" +
10811082
std::string(
10821083
triton::common::DataTypeToProtocolString(input.DType())) +
1083-
"', but model '" + ModelName() + "' expects '" +
1084+
"', but model '" + model_name + "' expects '" +
10841085
std::string(triton::common::DataTypeToProtocolString(
10851086
input_config->data_type())) +
10861087
"'");
@@ -1100,7 +1101,7 @@ InferenceRequest::Normalize()
11001101
Status::Code::INVALID_ARG,
11011102
LogRequest() +
11021103
"All input dimensions should be specified for input '" +
1103-
input_id + "' for model '" + ModelName() + "', got " +
1104+
input_name + "' for model '" + model_name + "', got " +
11041105
triton::common::DimsListToString(input.OriginalShape()));
11051106
} else if (
11061107
(config_dims[i] != triton::common::WILDCARD_DIM) &&
@@ -1129,8 +1130,8 @@ InferenceRequest::Normalize()
11291130
}
11301131
return Status(
11311132
Status::Code::INVALID_ARG,
1132-
LogRequest() + "unexpected shape for input '" + input_id +
1133-
"' for model '" + ModelName() + "'. Expected " +
1133+
LogRequest() + "unexpected shape for input '" + input_name +
1134+
"' for model '" + model_name + "'. Expected " +
11341135
triton::common::DimsListToString(full_dims) + ", got " +
11351136
triton::common::DimsListToString(input.OriginalShape()) + ". " +
11361137
implicit_batch_note);
@@ -1192,8 +1193,8 @@ InferenceRequest::Normalize()
11921193
// (prepend 4 bytes to specify string length), so need to add all the
11931194
// first 4 bytes for each element to find expected byte size
11941195
if (data_type == inference::DataType::TYPE_STRING) {
1195-
RETURN_IF_ERROR(
1196-
ValidateBytesInputs(input_id, input, &input_memory_type));
1196+
RETURN_IF_ERROR(ValidateBytesInputs(
1197+
input_name, input, model_name, &input_memory_type));
11971198
// FIXME: Temporarily skips byte size checks for GPU tensors. See
11981199
// DLIS-6820.
11991200
skip_byte_size_check |=
@@ -1209,7 +1210,7 @@ InferenceRequest::Normalize()
12091210
return Status(
12101211
Status::Code::INVALID_ARG,
12111212
LogRequest() + "input byte size mismatch for input '" +
1212-
input_id + "' for model '" + ModelName() + "'. Expected " +
1213+
input_name + "' for model '" + model_name + "'. Expected " +
12131214
std::to_string(expected_byte_size) + ", got " +
12141215
std::to_string(byte_size));
12151216
}
@@ -1283,7 +1284,8 @@ InferenceRequest::ValidateRequestInputs()
12831284

12841285
Status
12851286
InferenceRequest::ValidateBytesInputs(
1286-
const std::string& input_id, const Input& input,
1287+
const std::string& input_name, const Input& input,
1288+
const std::string& model_name,
12871289
TRITONSERVER_MemoryType* buffer_memory_type) const
12881290
{
12891291
const auto& input_dims = input.ShapeWithBatchDim();
@@ -1322,13 +1324,28 @@ InferenceRequest::ValidateBytesInputs(
13221324
return Status(
13231325
Status::Code::INVALID_ARG,
13241326
LogRequest() +
1325-
"element byte size indicator exceeds the end of the buffer.");
1327+
"incomplete string length indicator for inference input '" +
1328+
input_name + "' for model '" + model_name + "', expecting " +
1329+
std::to_string(sizeof(uint32_t)) + " bytes but only " +
1330+
std::to_string(remaining_buffer_size) +
1331+
" bytes available. Please make sure the string length "
1332+
"indicator is in one buffer.");
13261333
}
13271334

13281335
// Start the next element and reset the remaining element size.
13291336
remaining_element_size = *(reinterpret_cast<const uint32_t*>(buffer));
13301337
element_checked++;
13311338

1339+
// Early stop
1340+
if (element_checked > element_count) {
1341+
return Status(
1342+
Status::Code::INVALID_ARG,
1343+
LogRequest() + "unexpected number of string elements " +
1344+
std::to_string(element_checked) + " for inference input '" +
1345+
input_name + "' for model '" + model_name + "', expecting " +
1346+
std::to_string(element_count));
1347+
}
1348+
13321349
// Advance pointer and remainder by the indicator size.
13331350
buffer += kElementSizeIndicator;
13341351
remaining_buffer_size -= kElementSizeIndicator;
@@ -1354,16 +1371,17 @@ InferenceRequest::ValidateBytesInputs(
13541371
return Status(
13551372
Status::Code::INVALID_ARG,
13561373
LogRequest() + "expected " + std::to_string(buffer_count) +
1357-
" buffers for inference input '" + input_id + "', got " +
1358-
std::to_string(buffer_next_idx));
1374+
" buffers for inference input '" + input_name + "' for model '" +
1375+
model_name + "', got " + std::to_string(buffer_next_idx));
13591376
}
13601377

13611378
// Validate the number of processed elements exactly match expectations.
13621379
if (element_checked != element_count) {
13631380
return Status(
13641381
Status::Code::INVALID_ARG,
13651382
LogRequest() + "expected " + std::to_string(element_count) +
1366-
" string elements for inference input '" + input_id + "', got " +
1383+
" string elements for inference input '" + input_name +
1384+
"' for model '" + model_name + "', got " +
13671385
std::to_string(element_checked));
13681386
}
13691387

src/infer_request.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -749,6 +749,7 @@ class InferenceRequest {
749749

750750
Status ValidateBytesInputs(
751751
const std::string& input_id, const Input& input,
752+
const std::string& model_name,
752753
TRITONSERVER_MemoryType* buffer_memory_type) const;
753754

754755
// Helpers for pending request metrics

src/test/input_byte_size_test.cc

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -258,10 +258,11 @@ char InputByteSizeTest::input_data_string_
258258

259259
TEST_F(InputByteSizeTest, ValidInputByteSize)
260260
{
261+
const char* model_name = "savedmodel_zero_1_float32";
261262
// Create an inference request
262263
FAIL_TEST_IF_ERR(
263264
TRITONSERVER_InferenceRequestNew(
264-
&irequest_, server_, "simple", -1 /* model_version */),
265+
&irequest_, server_, model_name, -1 /* model_version */),
265266
"creating inference request");
266267
FAIL_TEST_IF_ERR(
267268
TRITONSERVER_InferenceRequestSetReleaseCallback(
@@ -312,10 +313,11 @@ TEST_F(InputByteSizeTest, ValidInputByteSize)
312313

313314
TEST_F(InputByteSizeTest, InputByteSizeMismatch)
314315
{
316+
const char* model_name = "savedmodel_zero_1_float32";
315317
// Create an inference request
316318
FAIL_TEST_IF_ERR(
317319
TRITONSERVER_InferenceRequestNew(
318-
&irequest_, server_, "simple", -1 /* model_version */),
320+
&irequest_, server_, model_name, -1 /* model_version */),
319321
"creating inference request");
320322
FAIL_TEST_IF_ERR(
321323
TRITONSERVER_InferenceRequestSetReleaseCallback(
@@ -353,8 +355,8 @@ TEST_F(InputByteSizeTest, InputByteSizeMismatch)
353355
FAIL_TEST_IF_SUCCESS(
354356
TRITONSERVER_ServerInferAsync(server_, irequest_, nullptr /* trace */),
355357
"expect error with inference request",
356-
"input byte size mismatch for input 'INPUT0' for model 'simple'. "
357-
"Expected 64, got 68");
358+
"input byte size mismatch for input 'INPUT0' for model '" +
359+
std::string{model_name} + "'. Expected 64, got 68");
358360

359361
// Need to manually delete request, otherwise server will not shut down.
360362
FAIL_TEST_IF_ERR(
@@ -364,10 +366,11 @@ TEST_F(InputByteSizeTest, InputByteSizeMismatch)
364366

365367
TEST_F(InputByteSizeTest, ValidStringInputByteSize)
366368
{
369+
const char* model_name = "savedmodel_zero_1_object";
367370
// Create an inference request
368371
FAIL_TEST_IF_ERR(
369372
TRITONSERVER_InferenceRequestNew(
370-
&irequest_, server_, "simple_identity", -1 /* model_version */),
373+
&irequest_, server_, model_name, -1 /* model_version */),
371374
"creating inference request");
372375
FAIL_TEST_IF_ERR(
373376
TRITONSERVER_InferenceRequestSetReleaseCallback(
@@ -424,10 +427,11 @@ TEST_F(InputByteSizeTest, ValidStringInputByteSize)
424427

425428
TEST_F(InputByteSizeTest, StringCountMismatch)
426429
{
430+
const char* model_name = "savedmodel_zero_1_object";
427431
// Create an inference request
428432
FAIL_TEST_IF_ERR(
429433
TRITONSERVER_InferenceRequestNew(
430-
&irequest_, server_, "simple_identity", -1 /* model_version */),
434+
&irequest_, server_, model_name, -1 /* model_version */),
431435
"creating inference request");
432436
FAIL_TEST_IF_ERR(
433437
TRITONSERVER_InferenceRequestSetReleaseCallback(
@@ -457,7 +461,8 @@ TEST_F(InputByteSizeTest, StringCountMismatch)
457461
FAIL_TEST_IF_SUCCESS(
458462
TRITONSERVER_ServerInferAsync(server_, irequest_, nullptr /* trace */),
459463
"expect error with inference request",
460-
"expected 3 string elements for inference input 'INPUT0', got 2");
464+
"expected 3 string elements for inference input 'INPUT0' for model '" +
465+
std::string{model_name} + "', got 2");
461466

462467
// Need to manually delete request, otherwise server will not shut down.
463468
FAIL_TEST_IF_ERR(
@@ -467,7 +472,8 @@ TEST_F(InputByteSizeTest, StringCountMismatch)
467472
// Create an inference request
468473
FAIL_TEST_IF_ERR(
469474
TRITONSERVER_InferenceRequestNew(
470-
&irequest_, server_, "simple_identity", -1 /* model_version */),
475+
&irequest_, server_, "savedmodel_zero_1_object",
476+
-1 /* model_version */),
471477
"creating inference request");
472478
FAIL_TEST_IF_ERR(
473479
TRITONSERVER_InferenceRequestSetReleaseCallback(
@@ -495,7 +501,9 @@ TEST_F(InputByteSizeTest, StringCountMismatch)
495501
FAIL_TEST_IF_SUCCESS(
496502
TRITONSERVER_ServerInferAsync(server_, irequest_, nullptr /* trace */),
497503
"expect error with inference request",
498-
"expected 1 string elements for inference input 'INPUT0', got 2");
504+
"unexpected number of string elements 2 for inference input 'INPUT0' for "
505+
"model '" +
506+
std::string{model_name} + "', expecting 1");
499507

500508
// Need to manually delete request, otherwise server will not shut down.
501509
FAIL_TEST_IF_ERR(
@@ -505,10 +513,11 @@ TEST_F(InputByteSizeTest, StringCountMismatch)
505513

506514
TEST_F(InputByteSizeTest, StringSizeMisalign)
507515
{
516+
const char* model_name = "savedmodel_zero_1_object";
508517
// Create an inference request
509518
FAIL_TEST_IF_ERR(
510519
TRITONSERVER_InferenceRequestNew(
511-
&irequest_, server_, "simple_identity", -1 /* model_version */),
520+
&irequest_, server_, model_name, -1 /* model_version */),
512521
"creating inference request");
513522
FAIL_TEST_IF_ERR(
514523
TRITONSERVER_InferenceRequestSetReleaseCallback(
@@ -542,9 +551,13 @@ TEST_F(InputByteSizeTest, StringSizeMisalign)
542551

543552
// Run inference
544553
FAIL_TEST_IF_SUCCESS(
545-
TRITONSERVER_ServerInferAsync(server_, irequest_, nullptr /* trace
546-
*/), "expect error with inference request",
547-
"element byte size indicator exceeds the end of the buffer");
554+
TRITONSERVER_ServerInferAsync(server_, irequest_, nullptr /* trace*/),
555+
"expect error with inference request",
556+
"incomplete string length indicator for inference input 'INPUT0' for "
557+
"model '" +
558+
std::string{model_name} +
559+
"', expecting 4 bytes but only 2 bytes available. Please make sure "
560+
"the string length indicator is in one buffer.");
548561

549562
// Need to manually delete request, otherwise server will not shut down.
550563
FAIL_TEST_IF_ERR(
@@ -573,7 +586,8 @@ TEST_F(InputByteSizeTest, StringCountMismatchGPU)
573586
// Create an inference request
574587
FAIL_TEST_IF_ERR(
575588
TRITONSERVER_InferenceRequestNew(
576-
&irequest_, server_, "simple_identity", -1 /* model_version */),
589+
&irequest_, server_, "savedmodel_zero_1_object",
590+
-1 /* model_version */),
577591
"creating inference request");
578592
FAIL_TEST_IF_ERR(
579593
TRITONSERVER_InferenceRequestSetReleaseCallback(
@@ -629,7 +643,8 @@ TEST_F(InputByteSizeTest, StringCountMismatchGPU)
629643
// Create an inference request
630644
FAIL_TEST_IF_ERR(
631645
TRITONSERVER_InferenceRequestNew(
632-
&irequest_, server_, "simple_identity", -1 /* model_version */),
646+
&irequest_, server_, "savedmodel_zero_1_object",
647+
-1 /* model_version */),
633648
"creating inference request");
634649
FAIL_TEST_IF_ERR(
635650
TRITONSERVER_InferenceRequestSetReleaseCallback(

0 commit comments

Comments
 (0)