Skip to content

Commit ad72349

Browse files
committed
makes single tensor exception, and adds error
1 parent 3c586c1 commit ad72349

File tree

11 files changed

+144
-181
lines changed

11 files changed

+144
-181
lines changed

src/viam/sdk/services/private/discovery_client.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
#include <viam/sdk/services/private/discovery_client.hpp>
2-
31
#include <viam/api/service/discovery/v1/discovery.grpc.pb.h>
42
#include <viam/api/service/discovery/v1/discovery.pb.h>
53

@@ -8,6 +6,7 @@
86
#include <viam/sdk/common/proto_value.hpp>
97
#include <viam/sdk/common/utils.hpp>
108
#include <viam/sdk/services/discovery.hpp>
9+
#include <viam/sdk/services/private/discovery_client.hpp>
1110

1211
namespace viam {
1312
namespace sdk {

src/viam/sdk/services/private/discovery_server.cpp

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
#include <viam/sdk/services/private/discovery_client.hpp>
2-
31
#include <viam/sdk/common/private/repeated_ptr_convert.hpp>
42
#include <viam/sdk/common/private/service_helper.hpp>
53
#include <viam/sdk/common/proto_value.hpp>
64
#include <viam/sdk/common/utils.hpp>
5+
#include <viam/sdk/services/private/discovery_client.hpp>
76
#include <viam/sdk/services/private/discovery_server.hpp>
87

98
namespace viam {
@@ -13,11 +12,10 @@ namespace impl {
1312
using namespace service::discovery::v1;
1413

1514
::grpc::Status DiscoveryServer::DiscoverResources(
16-
::grpc::ServerContext*,
17-
const ::viam::service::discovery::v1::DiscoverResourcesRequest* request,
15+
::grpc::ServerContext*, const ::viam::service::discovery::v1::DiscoverResourcesRequest* request,
1816
::viam::service::discovery::v1::DiscoverResourcesResponse* response) noexcept {
19-
return make_service_helper<Discovery>(
20-
"DiscoveryServer::DiscoverResources", this, request)([&](auto& helper, auto& discovery) {
17+
return make_service_helper<Discovery>("DiscoveryServer::DiscoverResources", this,
18+
request)([&](auto& helper, auto& discovery) {
2119
const std::vector<ResourceConfig> resources =
2220
discovery->discover_resources(helper.getExtra());
2321
for (const auto& resource : resources) {
@@ -27,11 +25,10 @@ ::grpc::Status DiscoveryServer::DiscoverResources(
2725
}
2826

2927
::grpc::Status DiscoveryServer::DoCommand(
30-
::grpc::ServerContext*,
31-
const ::viam::common::v1::DoCommandRequest* request,
28+
::grpc::ServerContext*, const ::viam::common::v1::DoCommandRequest* request,
3229
::viam::common::v1::DoCommandResponse* response) noexcept {
33-
return make_service_helper<Discovery>(
34-
"DiscoveryServer::DoCommand", this, request)([&](auto&, auto& discovery) {
30+
return make_service_helper<Discovery>("DiscoveryServer::DoCommand", this,
31+
request)([&](auto&, auto& discovery) {
3532
const ProtoStruct result = discovery->do_command(from_proto(request->command()));
3633
*response->mutable_result() = to_proto(result);
3734
});

src/viam/sdk/services/private/generic_client.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
1-
#include <viam/sdk/services/private/generic_client.hpp>
2-
3-
#include <utility>
4-
51
#include <viam/api/common/v1/common.pb.h>
62
#include <viam/api/service/generic/v1/generic.grpc.pb.h>
73

4+
#include <utility>
85
#include <viam/sdk/common/client_helper.hpp>
96
#include <viam/sdk/common/proto_value.hpp>
107
#include <viam/sdk/common/utils.hpp>
118
#include <viam/sdk/config/resource.hpp>
129
#include <viam/sdk/robot/client.hpp>
10+
#include <viam/sdk/services/private/generic_client.hpp>
1311

1412
namespace viam {
1513
namespace sdk {

src/viam/sdk/services/private/generic_server.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
#include <viam/sdk/services/private/generic_server.hpp>
2-
31
#include <viam/sdk/common/private/service_helper.hpp>
42
#include <viam/sdk/rpc/server.hpp>
53
#include <viam/sdk/services/generic.hpp>
4+
#include <viam/sdk/services/private/generic_server.hpp>
65

76
namespace viam {
87
namespace sdk {
@@ -12,11 +11,10 @@ GenericServiceServer::GenericServiceServer(std::shared_ptr<ResourceManager> mana
1211
: ResourceServer(std::move(manager)) {}
1312

1413
::grpc::Status GenericServiceServer::DoCommand(
15-
::grpc::ServerContext*,
16-
const ::viam::common::v1::DoCommandRequest* request,
14+
::grpc::ServerContext*, const ::viam::common::v1::DoCommandRequest* request,
1715
::viam::common::v1::DoCommandResponse* response) noexcept {
18-
return make_service_helper<GenericService>(
19-
"GenericServiceServer::DoCommand", this, request)([&](auto&, auto& generic) {
16+
return make_service_helper<GenericService>("GenericServiceServer::DoCommand", this,
17+
request)([&](auto&, auto& generic) {
2018
const ProtoStruct result = generic->do_command(from_proto(request->command()));
2119
*response->mutable_result() = to_proto(result);
2220
});

src/viam/sdk/services/private/mlmodel.cpp

Lines changed: 15 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,14 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
#include <viam/sdk/services/private/mlmodel.hpp>
16-
15+
#include <boost/variant/get.hpp>
1716
#include <memory>
1817
#include <stack>
1918
#include <stdexcept>
2019
#include <type_traits>
2120
#include <utility>
22-
23-
#include <boost/variant/get.hpp>
24-
2521
#include <viam/sdk/common/exception.hpp>
22+
#include <viam/sdk/services/private/mlmodel.hpp>
2623

2724
namespace viam {
2825
namespace sdk {
@@ -68,8 +65,7 @@ class copy_sdk_tensor_to_api_tensor_visitor : public boost::static_visitor<void>
6865
static_cast<int>((t.size() + 1) * sizeof(std::int16_t) / sizeof(std::uint32_t));
6966
target_->mutable_int16_tensor()->mutable_data()->Clear();
7067
target_->mutable_int16_tensor()->mutable_data()->Resize(num32s, 0);
71-
std::memcpy(target_->mutable_int16_tensor()->mutable_data()->mutable_data(),
72-
t.begin(),
68+
std::memcpy(target_->mutable_int16_tensor()->mutable_data()->mutable_data(), t.begin(),
7369
t.size() * sizeof(std::int16_t));
7470
}
7571

@@ -79,8 +75,7 @@ class copy_sdk_tensor_to_api_tensor_visitor : public boost::static_visitor<void>
7975
static_cast<int>((t.size() + 1) * sizeof(std::uint16_t) / sizeof(std::uint32_t));
8076
target_->mutable_uint16_tensor()->mutable_data()->Clear();
8177
target_->mutable_uint16_tensor()->mutable_data()->Resize(num32s, 0);
82-
std::memcpy(target_->mutable_uint16_tensor()->mutable_data()->mutable_data(),
83-
t.begin(),
78+
std::memcpy(target_->mutable_uint16_tensor()->mutable_data()->mutable_data(), t.begin(),
8479
t.size() * sizeof(std::uint16_t));
8580
}
8681

@@ -118,8 +113,7 @@ class copy_sdk_tensor_to_api_tensor_visitor : public boost::static_visitor<void>
118113
};
119114

120115
template <typename T>
121-
MLModelService::tensor_views make_sdk_tensor_from_api_tensor_t(const T* data,
122-
std::size_t size,
116+
MLModelService::tensor_views make_sdk_tensor_from_api_tensor_t(const T* data, std::size_t size,
123117
std::vector<std::size_t>&& shape,
124118
tensor_storage* ts) {
125119
if (!data || (size == 0) || shape.empty()) {
@@ -184,63 +178,49 @@ MLModelService::tensor_views make_sdk_tensor_from_api_tensor(
184178
if (api_tensor.has_int8_tensor()) {
185179
return make_sdk_tensor_from_api_tensor_t(
186180
reinterpret_cast<const std::int8_t*>(api_tensor.int8_tensor().data().data()),
187-
api_tensor.int8_tensor().data().size(),
188-
std::move(shape),
189-
storage);
181+
api_tensor.int8_tensor().data().size(), std::move(shape), storage);
190182
} else if (api_tensor.has_uint8_tensor()) {
191183
return make_sdk_tensor_from_api_tensor_t(
192184
reinterpret_cast<const std::uint8_t*>(api_tensor.uint8_tensor().data().data()),
193-
api_tensor.uint8_tensor().data().size(),
194-
std::move(shape),
195-
storage);
185+
api_tensor.uint8_tensor().data().size(), std::move(shape), storage);
196186
} else if (api_tensor.has_int16_tensor()) {
197187
// TODO: be deswizzle
198188
return make_sdk_tensor_from_api_tensor_t(
199189
reinterpret_cast<const std::int16_t*>(api_tensor.int16_tensor().data().data()),
200-
std::size_t{2} * api_tensor.int16_tensor().data().size(),
201-
std::move(shape),
202-
storage);
190+
std::size_t{2} * api_tensor.int16_tensor().data().size(), std::move(shape), storage);
203191
} else if (api_tensor.has_uint16_tensor()) {
204192
// TODO: be deswizzle
205193
return make_sdk_tensor_from_api_tensor_t(
206194
reinterpret_cast<const std::uint16_t*>(api_tensor.uint16_tensor().data().data()),
207-
std::size_t{2} * api_tensor.uint16_tensor().data().size(),
208-
std::move(shape),
209-
storage);
195+
std::size_t{2} * api_tensor.uint16_tensor().data().size(), std::move(shape), storage);
210196
} else if (api_tensor.has_int32_tensor()) {
211197
return make_sdk_tensor_from_api_tensor_t(api_tensor.int32_tensor().data().data(),
212198
api_tensor.int32_tensor().data().size(),
213-
std::move(shape),
214-
storage);
199+
std::move(shape), storage);
215200
} else if (api_tensor.has_uint32_tensor()) {
216201
return make_sdk_tensor_from_api_tensor_t(api_tensor.uint32_tensor().data().data(),
217202
api_tensor.uint32_tensor().data().size(),
218-
std::move(shape),
219-
storage);
203+
std::move(shape), storage);
220204

221205
} else if (api_tensor.has_int64_tensor()) {
222206
return make_sdk_tensor_from_api_tensor_t(api_tensor.int64_tensor().data().data(),
223207
api_tensor.int64_tensor().data().size(),
224-
std::move(shape),
225-
storage);
208+
std::move(shape), storage);
226209

227210
} else if (api_tensor.has_uint64_tensor()) {
228211
return make_sdk_tensor_from_api_tensor_t(api_tensor.uint64_tensor().data().data(),
229212
api_tensor.uint64_tensor().data().size(),
230-
std::move(shape),
231-
storage);
213+
std::move(shape), storage);
232214

233215
} else if (api_tensor.has_float_tensor()) {
234216
return make_sdk_tensor_from_api_tensor_t(api_tensor.float_tensor().data().data(),
235217
api_tensor.float_tensor().data().size(),
236-
std::move(shape),
237-
storage);
218+
std::move(shape), storage);
238219

239220
} else if (api_tensor.has_double_tensor()) {
240221
return make_sdk_tensor_from_api_tensor_t(api_tensor.double_tensor().data().data(),
241222
api_tensor.double_tensor().data().size(),
242-
std::move(shape),
243-
storage);
223+
std::move(shape), storage);
244224
}
245225
throw Exception(ErrorCondition::k_not_supported, "Unsupported tensor data type");
246226
}

src/viam/sdk/services/private/mlmodel_client.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,12 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
#include <viam/sdk/services/private/mlmodel_client.hpp>
16-
1715
#include <grpcpp/channel.h>
1816

1917
#include <viam/sdk/common/client_helper.hpp>
20-
#include <viam/sdk/services/private/mlmodel.hpp>
21-
2218
#include <viam/sdk/common/exception.hpp>
19+
#include <viam/sdk/services/private/mlmodel.hpp>
20+
#include <viam/sdk/services/private/mlmodel_client.hpp>
2321

2422
// As of proto version 27 (full version number 5.27) Arena::CreateMessage is deprecated in favor of
2523
// Arena::Create. We use this macro to accomodate earlier supported versions of proto where

src/viam/sdk/services/private/mlmodel_server.cpp

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,11 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
#include <viam/sdk/services/private/mlmodel_server.hpp>
16-
1715
#include <viam/sdk/common/private/service_helper.hpp>
1816
#include <viam/sdk/rpc/server.hpp>
1917
#include <viam/sdk/services/mlmodel.hpp>
2018
#include <viam/sdk/services/private/mlmodel.hpp>
19+
#include <viam/sdk/services/private/mlmodel_server.hpp>
2120

2221
namespace viam {
2322
namespace sdk {
@@ -27,19 +26,51 @@ MLModelServiceServer::MLModelServiceServer(std::shared_ptr<ResourceManager> mana
2726
: ResourceServer(std::move(manager)) {}
2827

2928
::grpc::Status MLModelServiceServer::Infer(
30-
::grpc::ServerContext*,
31-
const ::viam::service::mlmodel::v1::InferRequest* request,
29+
::grpc::ServerContext*, const ::viam::service::mlmodel::v1::InferRequest* request,
3230
::viam::service::mlmodel::v1::InferResponse* response) noexcept {
33-
return make_service_helper<MLModelService>(
34-
"MLModelServiceServer::Infer", this, request)([&](auto& helper, auto& mlms) {
31+
return make_service_helper<MLModelService>("MLModelServiceServer::Infer", this,
32+
request)([&](auto& helper, auto& mlms) {
3533
if (!request->has_input_tensors()) {
3634
return helper.fail(::grpc::INVALID_ARGUMENT, "Called with no input tensors");
3735
}
3836

37+
const auto md = mlms->metadata({});
3938
MLModelService::named_tensor_views inputs;
40-
for (const auto& [tensor_name, api_tensor] : request->input_tensors().tensors()) {
41-
auto tensor = mlmodel::make_sdk_tensor_from_api_tensor(api_tensor);
42-
inputs.emplace(std::move(tensor_name), std::move(tensor));
39+
40+
// Check if there's only one input tensor and metadata only expects one, too
41+
if (request->input_tensors().tensors().size() == 1 && md.inputs.size() == 1) {
42+
// Special case: just one tensor, add it without metadata checks
43+
const auto& tensor_pair = *request->input_tensors().tensors().begin();
44+
auto tensor = mlmodel::make_sdk_tensor_from_api_tensor(tensor_pair.second);
45+
inputs.emplace(tensor_pair.first, std::move(tensor));
46+
} else {
47+
// Normal case: multiple tensors, do metadata checks
48+
for (const auto& input : md.inputs) {
49+
const auto where = request->input_tensors().tensors().find(input.name);
50+
if (where == request->input_tensors().tensors().end()) {
51+
// Ignore any inputs for which we don't have metadata, since
52+
// we can't validate the type info.
53+
// if the input vector of the expected name is not found, return an error
54+
std::ostringstream message;
55+
message << "Expected tensor input `" << input.name
56+
<< "` was not found; if you believe you have this tensor under a "
57+
"different name, rename it to the expected tensor name";
58+
return helper.fail(::grpc::INVALID_ARGUMENT, message.str().c_str());
59+
}
60+
auto tensor = mlmodel::make_sdk_tensor_from_api_tensor(where->second);
61+
const auto tensor_type =
62+
MLModelService::tensor_info::tensor_views_to_data_type(tensor);
63+
if (tensor_type != input.data_type) {
64+
std::ostringstream message;
65+
using ut = std::underlying_type<MLModelService::tensor_info::data_types>::type;
66+
message << "Tensor input `" << input.name
67+
<< "` was the wrong type; expected type "
68+
<< static_cast<ut>(input.data_type) << " but got type "
69+
<< static_cast<ut>(tensor_type);
70+
return helper.fail(::grpc::INVALID_ARGUMENT, message.str().c_str());
71+
}
72+
inputs.emplace(std::move(input.name), std::move(tensor));
73+
}
4374
}
4475

4576
const auto outputs = mlms->infer(inputs, helper.getExtra());
@@ -55,11 +86,10 @@ ::grpc::Status MLModelServiceServer::Infer(
5586
}
5687

5788
::grpc::Status MLModelServiceServer::Metadata(
58-
::grpc::ServerContext*,
59-
const ::viam::service::mlmodel::v1::MetadataRequest* request,
89+
::grpc::ServerContext*, const ::viam::service::mlmodel::v1::MetadataRequest* request,
6090
::viam::service::mlmodel::v1::MetadataResponse* response) noexcept {
61-
return make_service_helper<MLModelService>(
62-
"MLModelServiceServer::Metadata", this, request)([&](auto& helper, auto& mlms) {
91+
return make_service_helper<MLModelService>("MLModelServiceServer::Metadata", this,
92+
request)([&](auto& helper, auto& mlms) {
6393
auto md = mlms->metadata(helper.getExtra());
6494

6595
auto& metadata_pb = *response->mutable_metadata();

0 commit comments

Comments
 (0)