|
1 | | -// Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 1 | +// Copyright 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
2 | 2 | // |
3 | 3 | // Redistribution and use in source and binary forms, with or without |
4 | 4 | // modification, are permitted provided that the following conditions |
@@ -1001,28 +1001,7 @@ InferenceRequest::Normalize() |
1001 | 1001 | } |
1002 | 1002 | // Make sure that the request is providing the number of inputs |
1003 | 1003 | // as is expected by the model. |
1004 | | - if ((original_inputs_.size() > (size_t)model_config.input_size()) || |
1005 | | - (original_inputs_.size() < model_raw_->RequiredInputCount())) { |
1006 | | - // If no input is marked as optional, then use exact match error message |
1007 | | - // for consistency / backward compatibility |
1008 | | - if ((size_t)model_config.input_size() == model_raw_->RequiredInputCount()) { |
1009 | | - return Status( |
1010 | | - Status::Code::INVALID_ARG, |
1011 | | - LogRequest() + "expected " + |
1012 | | - std::to_string(model_config.input_size()) + " inputs but got " + |
1013 | | - std::to_string(original_inputs_.size()) + " inputs for model '" + |
1014 | | - ModelName() + "'"); |
1015 | | - } else { |
1016 | | - return Status( |
1017 | | - Status::Code::INVALID_ARG, |
1018 | | - LogRequest() + "expected number of inputs between " + |
1019 | | - std::to_string(model_raw_->RequiredInputCount()) + " and " + |
1020 | | - std::to_string(model_config.input_size()) + " but got " + |
1021 | | - std::to_string(original_inputs_.size()) + " inputs for model '" + |
1022 | | - ModelName() + "'"); |
1023 | | - } |
1024 | | - } |
1025 | | - |
| 1004 | + RETURN_IF_ERROR(ValidateRequestInputs()); |
1026 | 1005 | // Determine the batch size and shape of each input. |
1027 | 1006 | if (model_config.max_batch_size() == 0) { |
1028 | 1007 | // Model does not support Triton-style batching so set as |
@@ -1195,6 +1174,67 @@ InferenceRequest::Normalize() |
1195 | 1174 | return Status::Success; |
1196 | 1175 | } |
1197 | 1176 |
|
| 1177 | +Status |
| 1178 | +InferenceRequest::ValidateRequestInputs() |
| 1179 | +{ |
| 1180 | + const inference::ModelConfig& model_config = model_raw_->Config(); |
| 1181 | + if ((original_inputs_.size() > (size_t)model_config.input_size()) || |
| 1182 | + (original_inputs_.size() < model_raw_->RequiredInputCount())) { |
| 1183 | + // If no input is marked as optional, then use exact match error message |
| 1184 | + // for consistency / backward compatibility |
| 1185 | + std::string missing_required_input_string = "["; |
| 1186 | + std::string original_input_string = "["; |
| 1187 | + |
| 1188 | + for (size_t i = 0; i < (size_t)model_config.input_size(); ++i) { |
| 1189 | + const inference::ModelInput& input = model_config.input(i); |
| 1190 | + if ((!input.optional()) && |
| 1191 | + (original_inputs_.find(input.name()) == original_inputs_.end())) { |
| 1192 | + missing_required_input_string = |
| 1193 | + missing_required_input_string + "'" + input.name() + "'" + ","; |
| 1194 | + } |
| 1195 | + } |
| 1196 | + // Removes the extra "," |
| 1197 | + missing_required_input_string.pop_back(); |
| 1198 | + missing_required_input_string = missing_required_input_string + "]"; |
| 1199 | + |
| 1200 | + for (const auto& pair : original_inputs_) { |
| 1201 | + original_input_string = |
| 1202 | + original_input_string + "'" + pair.first + "'" + ","; |
| 1203 | + } |
| 1204 | + // Removes the extra "," |
| 1205 | + original_input_string.pop_back(); |
| 1206 | + original_input_string = original_input_string + "]"; |
| 1207 | + if (original_inputs_.size() == 0) { |
| 1208 | + original_input_string = "[]"; |
| 1209 | + } |
| 1210 | + if ((size_t)model_config.input_size() == model_raw_->RequiredInputCount()) { |
| 1211 | + // This is response ONLY when there are no optional parameters in the |
| 1212 | + // model |
| 1213 | + return Status( |
| 1214 | + Status::Code::INVALID_ARG, |
| 1215 | + LogRequest() + "expected " + |
| 1216 | + std::to_string(model_config.input_size()) + " inputs but got " + |
| 1217 | + std::to_string(original_inputs_.size()) + " inputs for model '" + |
| 1218 | + ModelName() + "'. Got input(s) " + original_input_string + |
| 1219 | + ", but missing required input(s) " + |
| 1220 | + missing_required_input_string + |
| 1221 | + ". Please provide all required input(s)."); |
| 1222 | + } else { |
| 1223 | + return Status( |
| 1224 | + Status::Code::INVALID_ARG, |
| 1225 | + LogRequest() + "expected number of inputs between " + |
| 1226 | + std::to_string(model_raw_->RequiredInputCount()) + " and " + |
| 1227 | + std::to_string(model_config.input_size()) + " but got " + |
| 1228 | + std::to_string(original_inputs_.size()) + " inputs for model '" + |
| 1229 | + ModelName() + "'. Got input(s) " + original_input_string + |
| 1230 | + ", but missing required input(s) " + |
| 1231 | + missing_required_input_string + |
| 1232 | + ". Please provide all required input(s)."); |
| 1233 | + } |
| 1234 | + } |
| 1235 | + return Status::Success; |
| 1236 | +} |
| 1237 | + |
1198 | 1238 | #ifdef TRITON_ENABLE_STATS |
1199 | 1239 | void |
1200 | 1240 | InferenceRequest::ReportStatistics( |
|
0 commit comments