diff --git a/src/viam/sdk/services/mlmodel.cpp b/src/viam/sdk/services/mlmodel.cpp index 28bc9c51d..8a749b6d5 100644 --- a/src/viam/sdk/services/mlmodel.cpp +++ b/src/viam/sdk/services/mlmodel.cpp @@ -100,6 +100,20 @@ const char* MLModelService::tensor_info::data_type_to_string(const data_types da return nullptr; } +std::ostream& operator<<(std::ostream& os, MLModelService::tensor_info::data_types dt) { + const char* str = MLModelService::tensor_info::data_type_to_string(dt); + + if (str) { + os << str; + } else { + // Cast to unsigned because uint8_t is unsigned char, and 0-9 are whitespace or non printing + // characters + os << static_cast(dt); + } + + return os; +} + MLModelService::tensor_info::data_types MLModelService::tensor_info::tensor_views_to_data_type( const tensor_views& view) { class visitor : public boost::static_visitor { diff --git a/src/viam/sdk/services/mlmodel.hpp b/src/viam/sdk/services/mlmodel.hpp index fec2ad0d3..7797e86bd 100644 --- a/src/viam/sdk/services/mlmodel.hpp +++ b/src/viam/sdk/services/mlmodel.hpp @@ -14,6 +14,8 @@ #pragma once +#include + #include #include #include @@ -177,5 +179,7 @@ struct API::traits { static API api(); }; +std::ostream& operator<<(std::ostream&, MLModelService::tensor_info::data_types); + } // namespace sdk } // namespace viam diff --git a/src/viam/sdk/services/private/mlmodel_server.cpp b/src/viam/sdk/services/private/mlmodel_server.cpp index ddd831b9e..680446568 100644 --- a/src/viam/sdk/services/private/mlmodel_server.cpp +++ b/src/viam/sdk/services/private/mlmodel_server.cpp @@ -48,10 +48,8 @@ ::grpc::Status MLModelServiceServer::Infer( const auto tensor_type = MLModelService::tensor_info::tensor_views_to_data_type(tensor); if (tensor_type != input.data_type) { std::ostringstream message; - using ut = std::underlying_type::type; message << "Tensor input `" << input.name << "` was the wrong type; expected type " - << static_cast(input.data_type) << " but got type " - << static_cast(tensor_type); + << input.data_type << " but got type " << tensor_type; return helper.fail(::grpc::INVALID_ARGUMENT, message.str().c_str()); } inputs.emplace(input.name, std::move(tensor)); @@ -74,11 +72,9 @@ ::grpc::Status MLModelServiceServer::Infer( MLModelService::tensor_info::tensor_views_to_data_type(tensor); if (tensor_type != input.data_type) { std::ostringstream message; - using ut = std::underlying_type::type; message << "Tensor input `" << input.name - << "` was the wrong type; expected type " - << static_cast(input.data_type) << " but got type " - << static_cast(tensor_type); + << "` was the wrong type; expected type " << input.data_type + << " but got type " << tensor_type; return helper.fail(::grpc::INVALID_ARGUMENT, message.str().c_str()); } inputs.emplace(std::move(input.name), std::move(tensor)); @@ -122,12 +118,8 @@ ::grpc::Status MLModelServiceServer::Metadata( MLModelService::tensor_info::data_type_to_string(s.data_type); if (!string_for_data_type) { std::ostringstream message; - message - << "Served MLModelService returned an unknown data type with value `" - << static_cast< - std::underlying_type::type>( - s.data_type) - << "` in its metadata"; + message << "Served MLModelService returned an unknown data type with value `" + << s.data_type << "` in its metadata"; return helper.fail(grpc::INTERNAL, message.str().c_str()); } new_entry.set_data_type(string_for_data_type);