@@ -506,7 +506,7 @@ class ModelInstanceState : public BackendModelInstance {
506506 // Get the naming convention for inputs/outputs from the model configuration
507507 TRITONSERVER_Error* GetNamingConvention (
508508 NamingConvention* naming_convention,
509- const std::set <std::string>& allowed_io);
509+ const std::vector <std::string>& allowed_io);
510510
511511 ModelState* model_state_;
512512
@@ -713,7 +713,7 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
713713{
714714 // Collect all the expected input tensor names and validate that the model
715715 // configuration specifies only those.
716- std::set <std::string> allowed_inputs;
716+ std::vector <std::string> allowed_inputs;
717717
718718 const torch::jit::Method& method = torch_model_->get_method (" forward" );
719719 const auto & schema = method.function ().getSchema ();
@@ -755,7 +755,7 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
755755 " Dict(str, Tensor) or input(s) of type Tensor are supported." )
756756 .c_str ());
757757 }
758- allowed_inputs.emplace (arguments.at (i).name ());
758+ allowed_inputs.emplace_back (arguments.at (i).name ());
759759 }
760760
761761 // If all inputs are tensors, match number of expected inputs between model
@@ -800,7 +800,7 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
800800 } else {
801801 switch (naming_convention) {
802802 case NamingConvention::FORWARD_ARGUMENT: {
803- auto itr = allowed_inputs.find ( io_name);
803+ auto itr = std::find ( allowed_inputs.begin (), allowed_inputs. end (), io_name);
804804 if (itr != allowed_inputs.end ()) {
805805 input_index_map_[io_name] =
806806 std::distance (allowed_inputs.begin (), itr);
@@ -1325,7 +1325,7 @@ ModelInstanceState::Execute(
13251325TRITONSERVER_Error*
13261326ModelInstanceState::GetNamingConvention (
13271327 NamingConvention* naming_convention,
1328- const std::set <std::string>& allowed_ios)
1328+ const std::vector <std::string>& allowed_ios)
13291329{
13301330 // Rules for (non-Dictionary) input tensor names:
13311331 // 1. Must be in 'allowed_inputs' (arguments in the forward function)
@@ -1358,7 +1358,7 @@ ModelInstanceState::GetNamingConvention(
13581358 // Validate name
13591359 std::string io_name;
13601360 RETURN_IF_ERROR (io.MemberAsString (" name" , &io_name));
1361- auto itr = allowed_ios.find ( io_name);
1361+ auto itr = std::find ( allowed_ios.begin (), allowed_ios. end (), io_name);
13621362 if (itr == allowed_ios.end ()) {
13631363 *naming_convention = NamingConvention::NAMED_INDEX;
13641364 break ;
0 commit comments