@@ -424,6 +424,14 @@ ModelState::ParseParameters()
424424 return nullptr ;
425425}
426426
427+ // The naming convention followed for inputs/outputs in the model configuration.
428+ // Outputs don't support FORWARD_ARGUMENT.
429+ enum class NamingConvention {
430+ NAMED_INDEX,
431+ FORWARD_ARGUMENT,
432+ STRICT_CONFIG_ORDERING
433+ };
434+
427435//
428436// ModelInstanceState
429437//
@@ -476,6 +484,11 @@ class ModelInstanceState : public BackendModelInstance {
476484 std::vector<TRITONBACKEND_Response*>* responses,
477485 uint64_t * compute_end_ns);
478486
487+ // Get the naming convention for inputs/outputs from the model configuration
488+ TRITONSERVER_Error* GetNamingConvention (
489+ NamingConvention* naming_convention,
490+ const std::set<std::string>& allowed_io);
491+
479492 ModelState* model_state_;
480493
481494 // The full path to the TorchScript model file.
@@ -597,21 +610,29 @@ ModelInstanceState::ValidateBooleanSequenceControl(
597610 if (*have_control) {
598611 std::string deliminator = " __" ;
599612 int ip_index = 0 ;
600- try {
601- int start_pos = tensor_name.find (deliminator);
602- if (start_pos == -1 ) {
603- throw std::invalid_argument (" input must follow naming convention" );
604- }
605- ip_index = std::atoi (tensor_name.substr (start_pos + 2 ).c_str ());
606- input_index_map_[tensor_name] = ip_index;
607- }
608- catch (std::exception& ex) {
613+ int start_pos = tensor_name.find (deliminator);
614+ if (start_pos == -1 ) {
609615 return TRITONSERVER_ErrorNew (
610616 TRITONSERVER_ERROR_INTERNAL,
611617 (" input '" + tensor_name +
612- " ' does not follow naming convention i.e. <name>__<index>." )
618+ " ' does not follow <name>__<index> naming convention ." )
613619 .c_str ());
614620 }
621+
622+ // check if the index part of the name is not an integer
623+ std::string index_str = tensor_name.substr (start_pos + 2 );
624+ for (auto itr = index_str.begin (); itr != index_str.end (); itr++) {
625+ if (std::isdigit (*itr) == 0 ) {
626+ return TRITONSERVER_ErrorNew (
627+ TRITONSERVER_ERROR_INTERNAL,
628+ (" input '" + tensor_name +
629+ " ' does not follow <name>__<index> naming convention." )
630+ .c_str ());
631+ }
632+ }
633+
634+ ip_index = std::atoi (tensor_name.substr (start_pos + 2 ).c_str ());
635+ input_index_map_[tensor_name] = ip_index;
615636 }
616637
617638 return nullptr ; // success
@@ -631,21 +652,29 @@ ModelInstanceState::ValidateTypedSequenceControl(
631652 if (*have_control) {
632653 std::string deliminator = " __" ;
633654 int ip_index = 0 ;
634- try {
635- int start_pos = tensor_name.find (deliminator);
636- if (start_pos == -1 ) {
637- throw std::invalid_argument (" input must follow naming convention" );
638- }
639- ip_index = std::atoi (tensor_name.substr (start_pos + 2 ).c_str ());
640- input_index_map_[tensor_name] = ip_index;
641- }
642- catch (std::exception& ex) {
655+ int start_pos = tensor_name.find (deliminator);
656+ if (start_pos == -1 ) {
643657 return TRITONSERVER_ErrorNew (
644658 TRITONSERVER_ERROR_INTERNAL,
645659 (" input '" + tensor_name +
646- " ' does not follow naming convention i.e. <name>__<index>." )
660+ " ' does not follow <name>__<index> naming convention ." )
647661 .c_str ());
648662 }
663+
664+ // check if the index part of the name is not an integer
665+ std::string index_str = tensor_name.substr (start_pos + 2 );
666+ for (auto itr = index_str.begin (); itr != index_str.end (); itr++) {
667+ if (std::isdigit (*itr) == 0 ) {
668+ return TRITONSERVER_ErrorNew (
669+ TRITONSERVER_ERROR_INTERNAL,
670+ (" input '" + tensor_name +
671+ " ' does not follow <name>__<index> naming convention." )
672+ .c_str ());
673+ }
674+ }
675+
676+ ip_index = std::atoi (tensor_name.substr (start_pos + 2 ).c_str ());
677+ input_index_map_[tensor_name] = ip_index;
649678 }
650679
651680 return nullptr ; // success
@@ -727,6 +756,8 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
727756 }
728757
729758 bool supports_batching = model_state_->MaxBatchSize () > 0 ;
759+ NamingConvention naming_convention;
760+ RETURN_IF_ERROR (GetNamingConvention (&naming_convention, allowed_inputs));
730761
731762 for (size_t i = 0 ; i < ios.ArraySize (); i++) {
732763 triton::common::TritonJson::Value io;
@@ -740,27 +771,24 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
740771 // input names since they are the keys for the dictionary
741772 input_index_map_[io_name] = i;
742773 } else {
743- // input tensor name must be in 'allowed_inputs' or must follow the naming
744- // convention
745- auto itr = allowed_inputs.find (io_name);
746- if (itr != allowed_inputs.end ()) {
747- input_index_map_[io_name] = std::distance (allowed_inputs.begin (), itr);
748- } else {
749- try {
750- int start_pos = io_name.find (deliminator);
751- if (start_pos == -1 ) {
752- throw std::invalid_argument (" input must follow naming convention" );
774+ switch (naming_convention) {
775+ case NamingConvention::FORWARD_ARGUMENT: {
776+ auto itr = allowed_inputs.find (io_name);
777+ if (itr != allowed_inputs.end ()) {
778+ input_index_map_[io_name] =
779+ std::distance (allowed_inputs.begin (), itr);
753780 }
781+ break ;
782+ }
783+ case NamingConvention::NAMED_INDEX: {
784+ int start_pos = io_name.find (deliminator);
754785 ip_index = std::atoi (io_name.substr (start_pos + 2 ).c_str ());
755786 input_index_map_[io_name] = ip_index;
787+ break ;
756788 }
757- catch (std::exception& ex) {
758- return TRITONSERVER_ErrorNew (
759- TRITONSERVER_ERROR_INTERNAL,
760- (" input '" + io_name +
761- " ' is neither an input argument to the model nor does it "
762- " follow the naming convention i.e. <name>__<index>." )
763- .c_str ());
789+ case NamingConvention::STRICT_CONFIG_ORDERING: {
790+ input_index_map_[io_name] = i;
791+ break ;
764792 }
765793 }
766794 }
@@ -821,6 +849,8 @@ ModelInstanceState::ValidateOutputs()
821849 }
822850
823851 const bool supports_batching = model_state_->MaxBatchSize () > 0 ;
852+ NamingConvention naming_convention;
853+ RETURN_IF_ERROR (GetNamingConvention (&naming_convention, {}));
824854
825855 for (size_t i = 0 ; i < ios.ArraySize (); i++) {
826856 triton::common::TritonJson::Value io;
@@ -829,19 +859,18 @@ ModelInstanceState::ValidateOutputs()
829859 // Validate name
830860 std::string io_name;
831861 RETURN_IF_ERROR (io.MemberAsString (" name" , &io_name));
832- try {
833- int start_pos = io_name.find (deliminator);
834- if (start_pos == -1 ) {
835- throw std::invalid_argument (" output must follow naming convention" );
862+ switch (naming_convention) {
863+ case NamingConvention::NAMED_INDEX: {
864+ int start_pos = io_name.find (deliminator);
865+ op_index = std::atoi (io_name.substr (start_pos + 2 ).c_str ());
866+ break ;
836867 }
837- op_index = std::atoi (io_name.substr (start_pos + 2 ).c_str ());
838- }
839- catch (std::exception& ex) {
840- return TRITONSERVER_ErrorNew (
841- TRITONSERVER_ERROR_INTERNAL,
842- (" output '" + io_name +
843- " ' does not follow naming convention i.e. <name>__<index>." )
844- .c_str ());
868+ case NamingConvention::STRICT_CONFIG_ORDERING: {
869+ op_index = i;
870+ break ;
871+ }
872+ default :
873+ break ;
845874 }
846875
847876 // Validate data type
@@ -1251,9 +1280,9 @@ ModelInstanceState::Execute(
12511280 output_tensors->push_back (model_outputs_);
12521281 } else {
12531282 throw std::invalid_argument (
1254- " output must be of type Tensor, List[str] or Tuple "
1255- " containing one of these two types. It should not be a List / "
1256- " Dictionary of Tensors or a Scalar" );
1283+ " output must be of type Tensor, List[str] or Tuple containing one of "
1284+ " these two types. It should not be a List / Dictionary of Tensors or "
1285+ " a Scalar" );
12571286 }
12581287 }
12591288 catch (std::exception& ex) {
@@ -1265,6 +1294,106 @@ ModelInstanceState::Execute(
12651294 }
12661295}
12671296
1297+ TRITONSERVER_Error*
1298+ ModelInstanceState::GetNamingConvention (
1299+ NamingConvention* naming_convention,
1300+ const std::set<std::string>& allowed_ios)
1301+ {
1302+ // Rules for (non-Dictionary) input tensor names:
1303+ // 1. Must be in 'allowed_inputs' (arguments in the forward function)
1304+ // 2. Must follow the naming convention i.e. <name>__<index>
1305+ // 3. If neither of the above conditions are satisfied, enforce strict
1306+ // ordering of model inputs.
1307+ //
1308+ // Rules for output tensor names:
1309+ // 1. Must follow the naming convention i.e. <name>__<index>
1310+ // 2. If not, we enforce strict ordering of model outputs.
1311+ std::string deliminator = " __" ;
1312+ std::string io_kind = " input" ;
1313+ *naming_convention = NamingConvention::FORWARD_ARGUMENT;
1314+
1315+ // symbolizes output
1316+ if (allowed_ios.size () == 0 ) {
1317+ io_kind = " output" ;
1318+ *naming_convention = NamingConvention::NAMED_INDEX;
1319+ }
1320+
1321+ triton::common::TritonJson::Value ios;
1322+ RETURN_IF_ERROR (
1323+ model_state_->ModelConfig ().MemberAsArray (io_kind.c_str (), &ios));
1324+
1325+ if (io_kind == " input" ) {
1326+ for (size_t i = 0 ; i < ios.ArraySize (); i++) {
1327+ triton::common::TritonJson::Value io;
1328+ RETURN_IF_ERROR (ios.IndexAsObject (i, &io));
1329+
1330+ // Validate name
1331+ std::string io_name;
1332+ RETURN_IF_ERROR (io.MemberAsString (" name" , &io_name));
1333+ auto itr = allowed_ios.find (io_name);
1334+ if (itr == allowed_ios.end ()) {
1335+ *naming_convention = NamingConvention::NAMED_INDEX;
1336+ break ;
1337+ }
1338+ }
1339+ }
1340+
1341+ // If not, check if inputs follow INDEX
1342+ if (*naming_convention == NamingConvention::NAMED_INDEX) {
1343+ for (size_t i = 0 ; i < ios.ArraySize (); i++) {
1344+ triton::common::TritonJson::Value io;
1345+ RETURN_IF_ERROR (ios.IndexAsObject (i, &io));
1346+
1347+ // Validate name
1348+ std::string io_name;
1349+ RETURN_IF_ERROR (io.MemberAsString (" name" , &io_name));
1350+ int start_pos = io_name.find (deliminator);
1351+ if (start_pos == -1 ) {
1352+ *naming_convention = NamingConvention::STRICT_CONFIG_ORDERING;
1353+ break ;
1354+ } else {
1355+ // check if the index part of the name is not an integer
1356+ std::string index_str = io_name.substr (start_pos + 2 );
1357+ bool is_int = true ;
1358+ for (auto itr = index_str.begin (); itr != index_str.end (); itr++) {
1359+ if (std::isdigit (*itr) == 0 ) {
1360+ is_int = false ;
1361+ }
1362+ }
1363+
1364+ if (!is_int) {
1365+ if (io_kind == " input" ) {
1366+ LOG_MESSAGE (
1367+ TRITONSERVER_LOG_WARN,
1368+ (" input '" + io_name +
1369+ " ' or previous input(s) are neither an input argument to the "
1370+ " model '" +
1371+ model_state_->Name () +
1372+ " ' nor do they follow the <name>__<index> naming convention. "
1373+ " Falling back to enforcing strict ordering from model "
1374+ " configuration." )
1375+ .c_str ());
1376+ } else {
1377+ LOG_MESSAGE (
1378+ TRITONSERVER_LOG_WARN,
1379+ (" output '" + io_name +
1380+ " ' or previous output(s) of the model '" +
1381+ model_state_->Name () +
1382+ " ' do not follow the <name>__<index> naming convention. "
1383+ " Falling back to enforcing strict ordering from model "
1384+ " configuration." )
1385+ .c_str ());
1386+ }
1387+ *naming_convention = NamingConvention::STRICT_CONFIG_ORDERING;
1388+ break ;
1389+ }
1390+ }
1391+ }
1392+ }
1393+
1394+ return nullptr ; // success
1395+ }
1396+
12681397// This function will return a tensor's contents as a contiguous
12691398// chunk in system memory. In some cases this will require copying the data.
12701399// If that happens, 'contiguous_buffer' will be set to hold the contiguous
0 commit comments