diff --git a/src/viam/sdk/services/private/mlmodel_server.cpp b/src/viam/sdk/services/private/mlmodel_server.cpp index 550c54690..c0210f3ce 100644 --- a/src/viam/sdk/services/private/mlmodel_server.cpp +++ b/src/viam/sdk/services/private/mlmodel_server.cpp @@ -38,17 +38,13 @@ ::grpc::Status MLModelServiceServer::Infer( const auto md = mlms->metadata({}); MLModelService::named_tensor_views inputs; - for (const auto& input : md.inputs) { - const auto where = request->input_tensors().tensors().find(input.name); - if (where == request->input_tensors().tensors().end()) { - // Ignore any inputs for which we don't have metadata, since - // we can't validate the type info. - // - // TODO: Should this be an error? For now we just don't decode - // those tensors. - continue; - } - auto tensor = mlmodel::make_sdk_tensor_from_api_tensor(where->second); + + // Check if there's only one input tensor and metadata only expects one, too + if (request->input_tensors().tensors().size() == 1 && md.inputs.size() == 1) { + // Special case: just one tensor, add it without name check + const MLModelService::tensor_info input = md.inputs[0]; + const auto& tensor_pair = *request->input_tensors().tensors().begin(); + auto tensor = mlmodel::make_sdk_tensor_from_api_tensor(tensor_pair.second); const auto tensor_type = MLModelService::tensor_info::tensor_views_to_data_type(tensor); if (tensor_type != input.data_type) { std::ostringstream message; @@ -58,7 +54,35 @@ ::grpc::Status MLModelServiceServer::Infer( << static_cast(tensor_type); return helper.fail(::grpc::INVALID_ARGUMENT, message.str().c_str()); } - inputs.emplace(std::move(input.name), std::move(tensor)); + inputs.emplace(tensor_pair.first, std::move(tensor)); + } else { + // Normal case: multiple tensors, do metadata checks + // If there are extra tensors in the inputs that not found in the metadata, + // they will not be passed on to the implementation. + for (const auto& input : md.inputs) { + const auto where = request->input_tensors().tensors().find(input.name); + if (where == request->input_tensors().tensors().end()) { + // if the input vector of the expected name is not found, return an error + std::ostringstream message; + message << "Expected tensor input `" << input.name + << "` was not found; if you believe you have this tensor under a " + "different name, rename it to the expected tensor name"; + return helper.fail(::grpc::INVALID_ARGUMENT, message.str().c_str()); + } + auto tensor = mlmodel::make_sdk_tensor_from_api_tensor(where->second); + const auto tensor_type = + MLModelService::tensor_info::tensor_views_to_data_type(tensor); + if (tensor_type != input.data_type) { + std::ostringstream message; + using ut = std::underlying_type::type; + message << "Tensor input `" << input.name + << "` was the wrong type; expected type " + << static_cast(input.data_type) << " but got type " + << static_cast(tensor_type); + return helper.fail(::grpc::INVALID_ARGUMENT, message.str().c_str()); + } + inputs.emplace(std::move(input.name), std::move(tensor)); + } } const auto outputs = mlms->infer(inputs, helper.getExtra());