Skip to content
Merged
49 changes: 36 additions & 13 deletions src/viam/sdk/services/private/mlmodel_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <viam/sdk/services/private/mlmodel_server.hpp>

#include <viam/sdk/common/private/service_helper.hpp>
Expand All @@ -38,17 +37,13 @@ ::grpc::Status MLModelServiceServer::Infer(

const auto md = mlms->metadata({});
MLModelService::named_tensor_views inputs;
for (const auto& input : md.inputs) {
const auto where = request->input_tensors().tensors().find(input.name);
if (where == request->input_tensors().tensors().end()) {
// Ignore any inputs for which we don't have metadata, since
// we can't validate the type info.
//
// TODO: Should this be an error? For now we just don't decode
// those tensors.
continue;
}
auto tensor = mlmodel::make_sdk_tensor_from_api_tensor(where->second);

// Check if there's only one input tensor and metadata only expects one, too
if (request->input_tensors().tensors().size() == 1 && md.inputs.size() == 1) {
// Special case: just one tensor, add it without name check
const MLModelService::tensor_info input = md.inputs[0];
const auto& tensor_pair = *request->input_tensors().tensors().begin();
auto tensor = mlmodel::make_sdk_tensor_from_api_tensor(tensor_pair.second);
const auto tensor_type = MLModelService::tensor_info::tensor_views_to_data_type(tensor);
if (tensor_type != input.data_type) {
std::ostringstream message;
Expand All @@ -58,7 +53,35 @@ ::grpc::Status MLModelServiceServer::Infer(
<< static_cast<ut>(tensor_type);
return helper.fail(::grpc::INVALID_ARGUMENT, message.str().c_str());
}
inputs.emplace(std::move(input.name), std::move(tensor));
inputs.emplace(tensor_pair.first, std::move(tensor));
} else {
// Normal case: multiple tensors, do metadata checks
// If there are extra tensors in the inputs that not found in the metadata,
// they will not be passed on to the implementation.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment to say that if the user provides more inputs than are necessary, the extra inputs will be skipped over

for (const auto& input : md.inputs) {
const auto where = request->input_tensors().tensors().find(input.name);
if (where == request->input_tensors().tensors().end()) {
// if the input vector of the expected name is not found, return an error
std::ostringstream message;
message << "Expected tensor input `" << input.name
<< "` was not found; if you believe you have this tensor under a "
"different name, rename it to the expected tensor name";
return helper.fail(::grpc::INVALID_ARGUMENT, message.str().c_str());
}
auto tensor = mlmodel::make_sdk_tensor_from_api_tensor(where->second);
const auto tensor_type =
MLModelService::tensor_info::tensor_views_to_data_type(tensor);
if (tensor_type != input.data_type) {
std::ostringstream message;
using ut = std::underlying_type<MLModelService::tensor_info::data_types>::type;
message << "Tensor input `" << input.name
<< "` was the wrong type; expected type "
<< static_cast<ut>(input.data_type) << " but got type "
<< static_cast<ut>(tensor_type);
return helper.fail(::grpc::INVALID_ARGUMENT, message.str().c_str());
}
inputs.emplace(std::move(input.name), std::move(tensor));
}
}

const auto outputs = mlms->infer(inputs, helper.getExtra());
Expand Down
Loading