@@ -38,17 +38,13 @@ ::grpc::Status MLModelServiceServer::Infer(
3838
3939 const auto md = mlms->metadata ({});
4040 MLModelService::named_tensor_views inputs;
41- for (const auto & input : md.inputs ) {
42- const auto where = request->input_tensors ().tensors ().find (input.name );
43- if (where == request->input_tensors ().tensors ().end ()) {
44- // Ignore any inputs for which we don't have metadata, since
45- // we can't validate the type info.
46- //
47- // TODO: Should this be an error? For now we just don't decode
48- // those tensors.
49- continue ;
50- }
51- auto tensor = mlmodel::make_sdk_tensor_from_api_tensor (where->second );
41+
42+ // Check if there's only one input tensor and metadata only expects one, too
43+ if (request->input_tensors ().tensors ().size () == 1 && md.inputs .size () == 1 ) {
44+ // Special case: just one tensor, add it without name check
45+ const MLModelService::tensor_info input = md.inputs [0 ];
46+ const auto & tensor_pair = *request->input_tensors ().tensors ().begin ();
47+ auto tensor = mlmodel::make_sdk_tensor_from_api_tensor (tensor_pair.second );
5248 const auto tensor_type = MLModelService::tensor_info::tensor_views_to_data_type (tensor);
5349 if (tensor_type != input.data_type ) {
5450 std::ostringstream message;
@@ -58,7 +54,35 @@ ::grpc::Status MLModelServiceServer::Infer(
5854 << static_cast <ut>(tensor_type);
5955 return helper.fail (::grpc::INVALID_ARGUMENT, message.str ().c_str ());
6056 }
61- inputs.emplace (std::move (input.name ), std::move (tensor));
57+ inputs.emplace (tensor_pair.first , std::move (tensor));
58+ } else {
59+ // Normal case: multiple tensors, do metadata checks
60+ // If there are extra tensors in the inputs that not found in the metadata,
61+ // they will not be passed on to the implementation.
62+ for (const auto & input : md.inputs ) {
63+ const auto where = request->input_tensors ().tensors ().find (input.name );
64+ if (where == request->input_tensors ().tensors ().end ()) {
65+ // if the input vector of the expected name is not found, return an error
66+ std::ostringstream message;
67+ message << " Expected tensor input `" << input.name
68+ << " ` was not found; if you believe you have this tensor under a "
69+ " different name, rename it to the expected tensor name" ;
70+ return helper.fail (::grpc::INVALID_ARGUMENT, message.str ().c_str ());
71+ }
72+ auto tensor = mlmodel::make_sdk_tensor_from_api_tensor (where->second );
73+ const auto tensor_type =
74+ MLModelService::tensor_info::tensor_views_to_data_type (tensor);
75+ if (tensor_type != input.data_type ) {
76+ std::ostringstream message;
77+ using ut = std::underlying_type<MLModelService::tensor_info::data_types>::type;
78+ message << " Tensor input `" << input.name
79+ << " ` was the wrong type; expected type "
80+ << static_cast <ut>(input.data_type ) << " but got type "
81+ << static_cast <ut>(tensor_type);
82+ return helper.fail (::grpc::INVALID_ARGUMENT, message.str ().c_str ());
83+ }
84+ inputs.emplace (std::move (input.name ), std::move (tensor));
85+ }
6286 }
6387
6488 const auto outputs = mlms->infer (inputs, helper.getExtra ());
0 commit comments