Skip to content

Commit 8880940

Browse files
authored
Merge branch 'main' into RSDK-11703-and-11602
2 parents 101c91f + 3829c8f commit 8880940

File tree

8 files changed

+62
-17
lines changed

8 files changed

+62
-17
lines changed

.github/workflows/ai-updater.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ on:
88

99
jobs:
1010
call-ai-updater:
11-
uses: gabegottlob/viam-ai-updater/.github/workflows/ai-updater.yml@main
11+
uses: viamrobotics/viam-ai-updater/.github/workflows/ai-updater.yml@main
1212
with:
1313
target_branch: workflow/update-protos
1414
sdk: cpp

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ else()
4141
cmake_minimum_required(VERSION 3.25 FATAL_ERROR)
4242
endif()
4343

44-
set(CMAKE_PROJECT_VERSION 0.17.0)
44+
set(CMAKE_PROJECT_VERSION 0.18.0)
4545

4646
# Identify the project.
4747
project(viam-cpp-sdk

src/viam/sdk/resource/resource_api.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,12 +191,28 @@ const API& RPCSubtype::api() const {
191191
ModelFamily::ModelFamily(std::string namespace_, std::string family)
192192
: namespace_(std::move(namespace_)), family_(std::move(family)) {}
193193

194+
const std::string& ModelFamily::get_namespace() const {
195+
return namespace_;
196+
}
197+
198+
const std::string& ModelFamily::family() const {
199+
return family_;
200+
}
201+
194202
Model::Model(ModelFamily model_family, std::string model_name)
195203
: model_family_(std::move(model_family)), model_name_(std::move(model_name)) {}
196204

197205
Model::Model(std::string namespace_, std::string family, std::string model_name)
198206
: Model(ModelFamily(std::move(namespace_), std::move(family)), std::move(model_name)) {}
199207

208+
const ModelFamily& Model::model_family() const {
209+
return model_family_;
210+
}
211+
212+
const std::string& Model::model_name() const {
213+
return model_name_;
214+
}
215+
200216
Model Model::from_str(std::string model) {
201217
if (std::regex_match(model, MODEL_REGEX)) {
202218
std::vector<std::string> model_parts;

src/viam/sdk/resource/resource_api.hpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,10 @@ class RPCSubtype {
113113
class ModelFamily {
114114
public:
115115
ModelFamily(std::string namespace_, std::string family);
116+
117+
const std::string& get_namespace() const;
118+
const std::string& family() const;
119+
116120
std::string to_string() const;
117121

118122
private:
@@ -124,12 +128,15 @@ class ModelFamily {
124128
/// @brief Defines the namespace_, family, and name for a particular resource model.
125129
class Model {
126130
public:
127-
std::string to_string() const;
128-
129131
Model(std::string namespace_, std::string family, std::string model_name);
130132
Model(ModelFamily model, std::string model_name);
131133
Model();
132134

135+
const ModelFamily& model_family() const;
136+
const std::string& model_name() const;
137+
138+
std::string to_string() const;
139+
133140
/// @brief Parses a single model string into a Model, using default values for namespace and
134141
/// family if not provided.
135142
///

src/viam/sdk/services/mlmodel.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,20 @@ const char* MLModelService::tensor_info::data_type_to_string(const data_types da
100100
return nullptr;
101101
}
102102

103+
std::ostream& operator<<(std::ostream& os, MLModelService::tensor_info::data_types dt) {
104+
const char* str = MLModelService::tensor_info::data_type_to_string(dt);
105+
106+
if (str) {
107+
os << str;
108+
} else {
109+
// Cast to unsigned because uint8_t is unsigned char, and 0-9 are whitespace or non printing
110+
// characters
111+
os << static_cast<unsigned>(dt);
112+
}
113+
114+
return os;
115+
}
116+
103117
MLModelService::tensor_info::data_types MLModelService::tensor_info::tensor_views_to_data_type(
104118
const tensor_views& view) {
105119
class visitor : public boost::static_visitor<data_types> {

src/viam/sdk/services/mlmodel.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
#pragma once
1616

17+
#include <iosfwd>
18+
1719
#include <boost/mpl/joint_view.hpp>
1820
#include <boost/mpl/list.hpp>
1921
#include <boost/mpl/transform_view.hpp>
@@ -177,5 +179,7 @@ struct API::traits<MLModelService> {
177179
static API api();
178180
};
179181

182+
std::ostream& operator<<(std::ostream&, MLModelService::tensor_info::data_types);
183+
180184
} // namespace sdk
181185
} // namespace viam

src/viam/sdk/services/private/mlmodel_server.cpp

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,8 @@ ::grpc::Status MLModelServiceServer::Infer(
4848
const auto tensor_type = MLModelService::tensor_info::tensor_views_to_data_type(tensor);
4949
if (tensor_type != input.data_type) {
5050
std::ostringstream message;
51-
using ut = std::underlying_type<MLModelService::tensor_info::data_types>::type;
5251
message << "Tensor input `" << input.name << "` was the wrong type; expected type "
53-
<< static_cast<ut>(input.data_type) << " but got type "
54-
<< static_cast<ut>(tensor_type);
52+
<< input.data_type << " but got type " << tensor_type;
5553
return helper.fail(::grpc::INVALID_ARGUMENT, message.str().c_str());
5654
}
5755
inputs.emplace(input.name, std::move(tensor));
@@ -74,11 +72,9 @@ ::grpc::Status MLModelServiceServer::Infer(
7472
MLModelService::tensor_info::tensor_views_to_data_type(tensor);
7573
if (tensor_type != input.data_type) {
7674
std::ostringstream message;
77-
using ut = std::underlying_type<MLModelService::tensor_info::data_types>::type;
7875
message << "Tensor input `" << input.name
79-
<< "` was the wrong type; expected type "
80-
<< static_cast<ut>(input.data_type) << " but got type "
81-
<< static_cast<ut>(tensor_type);
76+
<< "` was the wrong type; expected type " << input.data_type
77+
<< " but got type " << tensor_type;
8278
return helper.fail(::grpc::INVALID_ARGUMENT, message.str().c_str());
8379
}
8480
inputs.emplace(std::move(input.name), std::move(tensor));
@@ -122,12 +118,8 @@ ::grpc::Status MLModelServiceServer::Metadata(
122118
MLModelService::tensor_info::data_type_to_string(s.data_type);
123119
if (!string_for_data_type) {
124120
std::ostringstream message;
125-
message
126-
<< "Served MLModelService returned an unknown data type with value `"
127-
<< static_cast<
128-
std::underlying_type<MLModelService::tensor_info::data_types>::type>(
129-
s.data_type)
130-
<< "` in its metadata";
121+
message << "Served MLModelService returned an unknown data type with value `"
122+
<< s.data_type << "` in its metadata";
131123
return helper.fail(grpc::INTERNAL, message.str().c_str());
132124
}
133125
new_entry.set_data_type(string_for_data_type);

src/viam/sdk/tests/test_resource.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,20 +77,32 @@ BOOST_AUTO_TEST_CASE(test_name) {
7777
BOOST_AUTO_TEST_CASE(test_model) {
7878
ModelFamily mf("ns", "mf");
7979
BOOST_CHECK_EQUAL(mf.to_string(), "ns:mf");
80+
BOOST_CHECK_EQUAL(mf.get_namespace(), "ns");
81+
BOOST_CHECK_EQUAL(mf.family(), "mf");
8082

8183
Model model1(mf, "model1");
8284
BOOST_CHECK_EQUAL(model1.to_string(), "ns:mf:model1");
85+
BOOST_CHECK_EQUAL(model1.model_family().to_string(), "ns:mf");
86+
BOOST_CHECK_EQUAL(model1.model_name(), "model1");
8387
Model model2("ns", "mf", "model2");
8488
BOOST_CHECK_EQUAL(model2.to_string(), "ns:mf:model2");
89+
BOOST_CHECK_EQUAL(model2.model_family().to_string(), "ns:mf");
90+
BOOST_CHECK_EQUAL(model2.model_name(), "model2");
8591

8692
Model model3 = Model::from_str("ns:mf:model3");
8793
BOOST_CHECK_EQUAL(model3.to_string(), "ns:mf:model3");
94+
BOOST_CHECK_EQUAL(model3.model_family().to_string(), "ns:mf");
95+
BOOST_CHECK_EQUAL(model3.model_name(), "model3");
8896
Model model4 = Model::from_str("model4");
8997
BOOST_CHECK_EQUAL(model4.to_string(), "rdk:builtin:model4");
98+
BOOST_CHECK_EQUAL(model4.model_family().to_string(), "rdk:builtin");
99+
BOOST_CHECK_EQUAL(model4.model_name(), "model4");
90100

91101
ModelFamily empty("", "");
92102
Model model5(empty, "model5");
93103
BOOST_CHECK_EQUAL(model5.to_string(), "model5");
104+
BOOST_CHECK_EQUAL(model5.model_family().to_string(), "");
105+
BOOST_CHECK_EQUAL(model5.model_name(), "model5");
94106

95107
BOOST_CHECK_THROW(Model::from_str("@"), Exception);
96108
}

0 commit comments

Comments
 (0)