Skip to content

Commit 86b48ec

Browse files
committed
[RSDK-10768] Use the expected name for single input tensors
1 parent 3b862ad commit 86b48ec

File tree

2 files changed

+44
-1
lines changed

2 files changed

+44
-1
lines changed

src/viam/sdk/services/private/mlmodel_server.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ ::grpc::Status MLModelServiceServer::Infer(
5454
<< static_cast<ut>(tensor_type);
5555
return helper.fail(::grpc::INVALID_ARGUMENT, message.str().c_str());
5656
}
57-
inputs.emplace(tensor_pair.first, std::move(tensor));
57+
inputs.emplace(input.name, std::move(tensor));
5858
} else {
5959
// Normal case: multiple tensors, do metadata checks
6060
// If there are extra tensors in the inputs that not found in the metadata,

src/viam/sdk/tests/test_mlmodel.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,49 @@ BOOST_AUTO_TEST_CASE(mock_infer_grpc_roundtrip) {
306306

307307
BOOST_AUTO_TEST_SUITE_END()
308308

309+
310+
BOOST_AUTO_TEST_SUITE(test_mlmodel_bugfixes)
311+
312+
BOOST_AUTO_TEST_CASE(RSDK_10768) {
313+
auto mock = std::make_shared<MockMLModelService>();
314+
315+
mock->set_metadata({"foo",
316+
"bar",
317+
"baz",
318+
// `inputs`
319+
{{"input",
320+
"the input",
321+
MLModelService::tensor_info::data_types::k_float32,
322+
{1},
323+
{},
324+
{}},
325+
},
326+
// no `outputs`
327+
{}});
328+
329+
mock->set_infer_handler([](const MLModelService::named_tensor_views& request) {
330+
BOOST_REQUIRE(request.size() == 1);
331+
BOOST_REQUIRE(request.count("input") == 1);
332+
return std::make_shared<MLModelService::named_tensor_views>();
333+
});
334+
335+
client_to_mock_pipeline<MLModelService>(mock, [&mock](auto& client) {
336+
MLModelService::named_tensor_views request;
337+
338+
std::array<float, 1> input_data{};
339+
input_data[0] = 1.0;
340+
auto input_view =
341+
MLModelService::make_tensor_view(input_data.data(), input_data.size(), {1});
342+
343+
auto mismatched_name = mock->metadata({}).inputs[0].name + "_mismatched";
344+
request.emplace(std::move(mismatched_name), std::move(input_view));
345+
auto response = client.infer(request);
346+
BOOST_TEST(response->size() == 0);
347+
});
348+
}
349+
350+
BOOST_AUTO_TEST_SUITE_END()
351+
309352
// This test suite is to validate that we can use xtensor for all of
310353
// the tensor data shuttling we need.
311354
BOOST_AUTO_TEST_SUITE(xtensor_experiment)

0 commit comments

Comments
 (0)