Skip to content

Commit dd0e8c2

Browse files
committed
add type check in 1 input case too
1 parent 117a1f1 commit dd0e8c2

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

src/viam/sdk/services/private/mlmodel_server.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,19 @@ ::grpc::Status MLModelServiceServer::Infer(
4040

4141
// Check if there's only one input tensor and metadata only expects one, too
4242
if (request->input_tensors().tensors().size() == 1 && md.inputs.size() == 1) {
43-
// Special case: just one tensor, add it without metadata checks
43+
// Special case: just one tensor, add it without name check
44+
const MLModelService::tensor_info::tensor_info* input = &md.inputs[0];
4445
const auto& tensor_pair = *request->input_tensors().tensors().begin();
4546
auto tensor = mlmodel::make_sdk_tensor_from_api_tensor(tensor_pair.second);
47+
const auto tensor_type = MLModelService::tensor_info::tensor_views_to_data_type(tensor);
48+
if (tensor_type != input.data_type) {
49+
std::ostringstream message;
50+
using ut = std::underlying_type<MLModelService::tensor_info::data_types>::type;
51+
message << "Tensor input `" << input.name << "` was the wrong type; expected type "
52+
<< static_cast<ut>(input.data_type) << " but got type "
53+
<< static_cast<ut>(tensor_type);
54+
return helper.fail(::grpc::INVALID_ARGUMENT, message.str().c_str());
55+
}
4656
inputs.emplace(tensor_pair.first, std::move(tensor));
4757
} else {
4858
// Normal case: multiple tensors, do metadata checks

0 commit comments

Comments
 (0)