Skip to content

Commit ff103c4

Browse files
author
Hemant Jain
authored
Enforce ordering of I/O if naming convention is not followed (#63)
* Enforce ordering of I/O if naming convention is not followed * Enforce usage of consistent naming convention for inputs and outputs - Convention between inputs and outputs can differ * Use helper function GetNamingConvention - use switch case - use c++ style enum * Use class enum - cleanup unnecessary code blocks * Add clarifying comment about atoi usage * fix typo * Remove try catch for atoi and use checks for is digit instead
1 parent ab59a37 commit ff103c4

File tree

1 file changed

+181
-52
lines changed

1 file changed

+181
-52
lines changed

src/libtorch.cc

Lines changed: 181 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,14 @@ ModelState::ParseParameters()
424424
return nullptr;
425425
}
426426

427+
// The naming convention followed for inputs/outputs in the model configuration.
428+
// Outputs don't support FORWARD_ARGUMENT.
429+
enum class NamingConvention {
430+
NAMED_INDEX,
431+
FORWARD_ARGUMENT,
432+
STRICT_CONFIG_ORDERING
433+
};
434+
427435
//
428436
// ModelInstanceState
429437
//
@@ -476,6 +484,11 @@ class ModelInstanceState : public BackendModelInstance {
476484
std::vector<TRITONBACKEND_Response*>* responses,
477485
uint64_t* compute_end_ns);
478486

487+
// Get the naming convention for inputs/outputs from the model configuration
488+
TRITONSERVER_Error* GetNamingConvention(
489+
NamingConvention* naming_convention,
490+
const std::set<std::string>& allowed_io);
491+
479492
ModelState* model_state_;
480493

481494
// The full path to the TorchScript model file.
@@ -597,21 +610,29 @@ ModelInstanceState::ValidateBooleanSequenceControl(
597610
if (*have_control) {
598611
std::string deliminator = "__";
599612
int ip_index = 0;
600-
try {
601-
int start_pos = tensor_name.find(deliminator);
602-
if (start_pos == -1) {
603-
throw std::invalid_argument("input must follow naming convention");
604-
}
605-
ip_index = std::atoi(tensor_name.substr(start_pos + 2).c_str());
606-
input_index_map_[tensor_name] = ip_index;
607-
}
608-
catch (std::exception& ex) {
613+
int start_pos = tensor_name.find(deliminator);
614+
if (start_pos == -1) {
609615
return TRITONSERVER_ErrorNew(
610616
TRITONSERVER_ERROR_INTERNAL,
611617
("input '" + tensor_name +
612-
"' does not follow naming convention i.e. <name>__<index>.")
618+
"' does not follow <name>__<index> naming convention.")
613619
.c_str());
614620
}
621+
622+
// check if the index part of the name is not an integer
623+
std::string index_str = tensor_name.substr(start_pos + 2);
624+
for (auto itr = index_str.begin(); itr != index_str.end(); itr++) {
625+
if (std::isdigit(*itr) == 0) {
626+
return TRITONSERVER_ErrorNew(
627+
TRITONSERVER_ERROR_INTERNAL,
628+
("input '" + tensor_name +
629+
"' does not follow <name>__<index> naming convention.")
630+
.c_str());
631+
}
632+
}
633+
634+
ip_index = std::atoi(tensor_name.substr(start_pos + 2).c_str());
635+
input_index_map_[tensor_name] = ip_index;
615636
}
616637

617638
return nullptr; // success
@@ -631,21 +652,29 @@ ModelInstanceState::ValidateTypedSequenceControl(
631652
if (*have_control) {
632653
std::string deliminator = "__";
633654
int ip_index = 0;
634-
try {
635-
int start_pos = tensor_name.find(deliminator);
636-
if (start_pos == -1) {
637-
throw std::invalid_argument("input must follow naming convention");
638-
}
639-
ip_index = std::atoi(tensor_name.substr(start_pos + 2).c_str());
640-
input_index_map_[tensor_name] = ip_index;
641-
}
642-
catch (std::exception& ex) {
655+
int start_pos = tensor_name.find(deliminator);
656+
if (start_pos == -1) {
643657
return TRITONSERVER_ErrorNew(
644658
TRITONSERVER_ERROR_INTERNAL,
645659
("input '" + tensor_name +
646-
"' does not follow naming convention i.e. <name>__<index>.")
660+
"' does not follow <name>__<index> naming convention.")
647661
.c_str());
648662
}
663+
664+
// check if the index part of the name is not an integer
665+
std::string index_str = tensor_name.substr(start_pos + 2);
666+
for (auto itr = index_str.begin(); itr != index_str.end(); itr++) {
667+
if (std::isdigit(*itr) == 0) {
668+
return TRITONSERVER_ErrorNew(
669+
TRITONSERVER_ERROR_INTERNAL,
670+
("input '" + tensor_name +
671+
"' does not follow <name>__<index> naming convention.")
672+
.c_str());
673+
}
674+
}
675+
676+
ip_index = std::atoi(tensor_name.substr(start_pos + 2).c_str());
677+
input_index_map_[tensor_name] = ip_index;
649678
}
650679

651680
return nullptr; // success
@@ -727,6 +756,8 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
727756
}
728757

729758
bool supports_batching = model_state_->MaxBatchSize() > 0;
759+
NamingConvention naming_convention;
760+
RETURN_IF_ERROR(GetNamingConvention(&naming_convention, allowed_inputs));
730761

731762
for (size_t i = 0; i < ios.ArraySize(); i++) {
732763
triton::common::TritonJson::Value io;
@@ -740,27 +771,24 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
740771
// input names since they are the keys for the dictionary
741772
input_index_map_[io_name] = i;
742773
} else {
743-
// input tensor name must be in 'allowed_inputs' or must follow the naming
744-
// convention
745-
auto itr = allowed_inputs.find(io_name);
746-
if (itr != allowed_inputs.end()) {
747-
input_index_map_[io_name] = std::distance(allowed_inputs.begin(), itr);
748-
} else {
749-
try {
750-
int start_pos = io_name.find(deliminator);
751-
if (start_pos == -1) {
752-
throw std::invalid_argument("input must follow naming convention");
774+
switch (naming_convention) {
775+
case NamingConvention::FORWARD_ARGUMENT: {
776+
auto itr = allowed_inputs.find(io_name);
777+
if (itr != allowed_inputs.end()) {
778+
input_index_map_[io_name] =
779+
std::distance(allowed_inputs.begin(), itr);
753780
}
781+
break;
782+
}
783+
case NamingConvention::NAMED_INDEX: {
784+
int start_pos = io_name.find(deliminator);
754785
ip_index = std::atoi(io_name.substr(start_pos + 2).c_str());
755786
input_index_map_[io_name] = ip_index;
787+
break;
756788
}
757-
catch (std::exception& ex) {
758-
return TRITONSERVER_ErrorNew(
759-
TRITONSERVER_ERROR_INTERNAL,
760-
("input '" + io_name +
761-
"' is neither an input argument to the model nor does it "
762-
"follow the naming convention i.e. <name>__<index>.")
763-
.c_str());
789+
case NamingConvention::STRICT_CONFIG_ORDERING: {
790+
input_index_map_[io_name] = i;
791+
break;
764792
}
765793
}
766794
}
@@ -821,6 +849,8 @@ ModelInstanceState::ValidateOutputs()
821849
}
822850

823851
const bool supports_batching = model_state_->MaxBatchSize() > 0;
852+
NamingConvention naming_convention;
853+
RETURN_IF_ERROR(GetNamingConvention(&naming_convention, {}));
824854

825855
for (size_t i = 0; i < ios.ArraySize(); i++) {
826856
triton::common::TritonJson::Value io;
@@ -829,19 +859,18 @@ ModelInstanceState::ValidateOutputs()
829859
// Validate name
830860
std::string io_name;
831861
RETURN_IF_ERROR(io.MemberAsString("name", &io_name));
832-
try {
833-
int start_pos = io_name.find(deliminator);
834-
if (start_pos == -1) {
835-
throw std::invalid_argument("output must follow naming convention");
862+
switch (naming_convention) {
863+
case NamingConvention::NAMED_INDEX: {
864+
int start_pos = io_name.find(deliminator);
865+
op_index = std::atoi(io_name.substr(start_pos + 2).c_str());
866+
break;
836867
}
837-
op_index = std::atoi(io_name.substr(start_pos + 2).c_str());
838-
}
839-
catch (std::exception& ex) {
840-
return TRITONSERVER_ErrorNew(
841-
TRITONSERVER_ERROR_INTERNAL,
842-
("output '" + io_name +
843-
"' does not follow naming convention i.e. <name>__<index>.")
844-
.c_str());
868+
case NamingConvention::STRICT_CONFIG_ORDERING: {
869+
op_index = i;
870+
break;
871+
}
872+
default:
873+
break;
845874
}
846875

847876
// Validate data type
@@ -1251,9 +1280,9 @@ ModelInstanceState::Execute(
12511280
output_tensors->push_back(model_outputs_);
12521281
} else {
12531282
throw std::invalid_argument(
1254-
"output must be of type Tensor, List[str] or Tuple "
1255-
"containing one of these two types. It should not be a List / "
1256-
"Dictionary of Tensors or a Scalar");
1283+
"output must be of type Tensor, List[str] or Tuple containing one of "
1284+
"these two types. It should not be a List / Dictionary of Tensors or "
1285+
"a Scalar");
12571286
}
12581287
}
12591288
catch (std::exception& ex) {
@@ -1265,6 +1294,106 @@ ModelInstanceState::Execute(
12651294
}
12661295
}
12671296

1297+
TRITONSERVER_Error*
1298+
ModelInstanceState::GetNamingConvention(
1299+
NamingConvention* naming_convention,
1300+
const std::set<std::string>& allowed_ios)
1301+
{
1302+
// Rules for (non-Dictionary) input tensor names:
1303+
// 1. Must be in 'allowed_inputs' (arguments in the forward function)
1304+
// 2. Must follow the naming convention i.e. <name>__<index>
1305+
// 3. If neither of the above conditions are satisfied, enforce strict
1306+
// ordering of model inputs.
1307+
//
1308+
// Rules for output tensor names:
1309+
// 1. Must follow the naming convention i.e. <name>__<index>
1310+
// 2. If not, we enforce strict ordering of model outputs.
1311+
std::string deliminator = "__";
1312+
std::string io_kind = "input";
1313+
*naming_convention = NamingConvention::FORWARD_ARGUMENT;
1314+
1315+
// symbolizes output
1316+
if (allowed_ios.size() == 0) {
1317+
io_kind = "output";
1318+
*naming_convention = NamingConvention::NAMED_INDEX;
1319+
}
1320+
1321+
triton::common::TritonJson::Value ios;
1322+
RETURN_IF_ERROR(
1323+
model_state_->ModelConfig().MemberAsArray(io_kind.c_str(), &ios));
1324+
1325+
if (io_kind == "input") {
1326+
for (size_t i = 0; i < ios.ArraySize(); i++) {
1327+
triton::common::TritonJson::Value io;
1328+
RETURN_IF_ERROR(ios.IndexAsObject(i, &io));
1329+
1330+
// Validate name
1331+
std::string io_name;
1332+
RETURN_IF_ERROR(io.MemberAsString("name", &io_name));
1333+
auto itr = allowed_ios.find(io_name);
1334+
if (itr == allowed_ios.end()) {
1335+
*naming_convention = NamingConvention::NAMED_INDEX;
1336+
break;
1337+
}
1338+
}
1339+
}
1340+
1341+
// If not, check if inputs follow INDEX
1342+
if (*naming_convention == NamingConvention::NAMED_INDEX) {
1343+
for (size_t i = 0; i < ios.ArraySize(); i++) {
1344+
triton::common::TritonJson::Value io;
1345+
RETURN_IF_ERROR(ios.IndexAsObject(i, &io));
1346+
1347+
// Validate name
1348+
std::string io_name;
1349+
RETURN_IF_ERROR(io.MemberAsString("name", &io_name));
1350+
int start_pos = io_name.find(deliminator);
1351+
if (start_pos == -1) {
1352+
*naming_convention = NamingConvention::STRICT_CONFIG_ORDERING;
1353+
break;
1354+
} else {
1355+
// check if the index part of the name is not an integer
1356+
std::string index_str = io_name.substr(start_pos + 2);
1357+
bool is_int = true;
1358+
for (auto itr = index_str.begin(); itr != index_str.end(); itr++) {
1359+
if (std::isdigit(*itr) == 0) {
1360+
is_int = false;
1361+
}
1362+
}
1363+
1364+
if (!is_int) {
1365+
if (io_kind == "input") {
1366+
LOG_MESSAGE(
1367+
TRITONSERVER_LOG_WARN,
1368+
("input '" + io_name +
1369+
"' or previous input(s) are neither an input argument to the "
1370+
"model '" +
1371+
model_state_->Name() +
1372+
"' nor do they follow the <name>__<index> naming convention. "
1373+
"Falling back to enforcing strict ordering from model "
1374+
"configuration.")
1375+
.c_str());
1376+
} else {
1377+
LOG_MESSAGE(
1378+
TRITONSERVER_LOG_WARN,
1379+
("output '" + io_name +
1380+
"' or previous output(s) of the model '" +
1381+
model_state_->Name() +
1382+
"' do not follow the <name>__<index> naming convention. "
1383+
"Falling back to enforcing strict ordering from model "
1384+
"configuration.")
1385+
.c_str());
1386+
}
1387+
*naming_convention = NamingConvention::STRICT_CONFIG_ORDERING;
1388+
break;
1389+
}
1390+
}
1391+
}
1392+
}
1393+
1394+
return nullptr; // success
1395+
}
1396+
12681397
// This function will return a tensor's contents as a contiguous
12691398
// chunk in system memory. In some cases this will require copying the data.
12701399
// If that happens, 'contiguous_buffer' will be set to hold the contiguous

0 commit comments

Comments
 (0)