Skip to content

Commit 3421d0b

Browse files
authored
Fix pytorch forward argument naming convention (#72)
1 parent 1f89243 commit 3421d0b

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

src/libtorch.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,7 @@ class ModelInstanceState : public BackendModelInstance {
506506
// Get the naming convention for inputs/outputs from the model configuration
507507
TRITONSERVER_Error* GetNamingConvention(
508508
NamingConvention* naming_convention,
509-
const std::set<std::string>& allowed_io);
509+
const std::vector<std::string>& allowed_io);
510510

511511
ModelState* model_state_;
512512

@@ -713,7 +713,7 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
713713
{
714714
// Collect all the expected input tensor names and validate that the model
715715
// configuration specifies only those.
716-
std::set<std::string> allowed_inputs;
716+
std::vector<std::string> allowed_inputs;
717717

718718
const torch::jit::Method& method = torch_model_->get_method("forward");
719719
const auto& schema = method.function().getSchema();
@@ -755,7 +755,7 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
755755
"Dict(str, Tensor) or input(s) of type Tensor are supported.")
756756
.c_str());
757757
}
758-
allowed_inputs.emplace(arguments.at(i).name());
758+
allowed_inputs.emplace_back(arguments.at(i).name());
759759
}
760760

761761
// If all inputs are tensors, match number of expected inputs between model
@@ -800,7 +800,7 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
800800
} else {
801801
switch (naming_convention) {
802802
case NamingConvention::FORWARD_ARGUMENT: {
803-
auto itr = allowed_inputs.find(io_name);
803+
auto itr = std::find(allowed_inputs.begin(), allowed_inputs.end(), io_name);
804804
if (itr != allowed_inputs.end()) {
805805
input_index_map_[io_name] =
806806
std::distance(allowed_inputs.begin(), itr);
@@ -1325,7 +1325,7 @@ ModelInstanceState::Execute(
13251325
TRITONSERVER_Error*
13261326
ModelInstanceState::GetNamingConvention(
13271327
NamingConvention* naming_convention,
1328-
const std::set<std::string>& allowed_ios)
1328+
const std::vector<std::string>& allowed_ios)
13291329
{
13301330
// Rules for (non-Dictionary) input tensor names:
13311331
// 1. Must be in 'allowed_inputs' (arguments in the forward function)
@@ -1358,7 +1358,7 @@ ModelInstanceState::GetNamingConvention(
13581358
// Validate name
13591359
std::string io_name;
13601360
RETURN_IF_ERROR(io.MemberAsString("name", &io_name));
1361-
auto itr = allowed_ios.find(io_name);
1361+
auto itr = std::find(allowed_ios.begin(), allowed_ios.end(), io_name);
13621362
if (itr == allowed_ios.end()) {
13631363
*naming_convention = NamingConvention::NAMED_INDEX;
13641364
break;

0 commit comments

Comments
 (0)