Skip to content
Merged
25 changes: 3 additions & 22 deletions src/viam/sdk/services/private/mlmodel_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,29 +36,10 @@ ::grpc::Status MLModelServiceServer::Infer(
return helper.fail(::grpc::INVALID_ARGUMENT, "Called with no input tensors");
}

const auto md = mlms->metadata({});
MLModelService::named_tensor_views inputs;
for (const auto& input : md.inputs) {
const auto where = request->input_tensors().tensors().find(input.name);
if (where == request->input_tensors().tensors().end()) {
// Ignore any inputs for which we don't have metadata, since
// we can't validate the type info.
//
// TODO: Should this be an error? For now we just don't decode
// those tensors.
continue;
}
auto tensor = mlmodel::make_sdk_tensor_from_api_tensor(where->second);
const auto tensor_type = MLModelService::tensor_info::tensor_views_to_data_type(tensor);
if (tensor_type != input.data_type) {
std::ostringstream message;
using ut = std::underlying_type<MLModelService::tensor_info::data_types>::type;
message << "Tensor input `" << input.name << "` was the wrong type; expected type "
<< static_cast<ut>(input.data_type) << " but got type "
<< static_cast<ut>(tensor_type);
return helper.fail(::grpc::INVALID_ARGUMENT, message.str().c_str());
}
inputs.emplace(std::move(input.name), std::move(tensor));
for (const auto& [tensor_name, api_tensor] : request->input_tensors().tensors()) {
auto tensor = mlmodel::make_sdk_tensor_from_api_tensor(api_tensor);
inputs.emplace(std::move(tensor_name), std::move(tensor));
}

const auto outputs = mlms->infer(inputs, helper.getExtra());
Expand Down
Loading