Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/viam/sdk/services/mlmodel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,20 @@ const char* MLModelService::tensor_info::data_type_to_string(const data_types da
return nullptr;
}

std::ostream& operator<<(std::ostream& os, MLModelService::tensor_info::data_types dt) {
const char* str = MLModelService::tensor_info::data_type_to_string(dt);

if (str) {
os << str;
} else {
// Cast to unsigned because uint8_t is unsigned char, and 0-9 are whitespace or non printing
// characters
os << static_cast<unsigned>(static_cast<std::underlying_type_t<decltype(dt)>>(dt));
}

return os;
}

MLModelService::tensor_info::data_types MLModelService::tensor_info::tensor_views_to_data_type(
const tensor_views& view) {
class visitor : public boost::static_visitor<data_types> {
Expand Down
4 changes: 4 additions & 0 deletions src/viam/sdk/services/mlmodel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

#pragma once

#include <iosfwd>

#include <boost/mpl/joint_view.hpp>
#include <boost/mpl/list.hpp>
#include <boost/mpl/transform_view.hpp>
Expand Down Expand Up @@ -177,5 +179,7 @@ struct API::traits<MLModelService> {
static API api();
};

std::ostream& operator<<(std::ostream&, MLModelService::tensor_info::data_types);

} // namespace sdk
} // namespace viam
18 changes: 5 additions & 13 deletions src/viam/sdk/services/private/mlmodel_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,8 @@ ::grpc::Status MLModelServiceServer::Infer(
const auto tensor_type = MLModelService::tensor_info::tensor_views_to_data_type(tensor);
if (tensor_type != input.data_type) {
std::ostringstream message;
using ut = std::underlying_type<MLModelService::tensor_info::data_types>::type;
message << "Tensor input `" << input.name << "` was the wrong type; expected type "
<< static_cast<ut>(input.data_type) << " but got type "
<< static_cast<ut>(tensor_type);
<< input.data_type << " but got type " << tensor_type;
return helper.fail(::grpc::INVALID_ARGUMENT, message.str().c_str());
}
inputs.emplace(input.name, std::move(tensor));
Expand All @@ -74,11 +72,9 @@ ::grpc::Status MLModelServiceServer::Infer(
MLModelService::tensor_info::tensor_views_to_data_type(tensor);
if (tensor_type != input.data_type) {
std::ostringstream message;
using ut = std::underlying_type<MLModelService::tensor_info::data_types>::type;
message << "Tensor input `" << input.name
<< "` was the wrong type; expected type "
<< static_cast<ut>(input.data_type) << " but got type "
<< static_cast<ut>(tensor_type);
<< "` was the wrong type; expected type " << input.data_type
<< " but got type " << tensor_type;
return helper.fail(::grpc::INVALID_ARGUMENT, message.str().c_str());
}
inputs.emplace(std::move(input.name), std::move(tensor));
Expand Down Expand Up @@ -122,12 +118,8 @@ ::grpc::Status MLModelServiceServer::Metadata(
MLModelService::tensor_info::data_type_to_string(s.data_type);
if (!string_for_data_type) {
std::ostringstream message;
message
<< "Served MLModelService returned an unknown data type with value `"
<< static_cast<
std::underlying_type<MLModelService::tensor_info::data_types>::type>(
s.data_type)
<< "` in its metadata";
message << "Served MLModelService returned an unknown data type with value `"
<< s.data_type << "` in its metadata";
return helper.fail(grpc::INTERNAL, message.str().c_str());
}
new_entry.set_data_type(string_for_data_type);
Expand Down
Loading