Skip to content

Commit 328bb29

Browse files
committed
fix examples and tests
1 parent 91e8070 commit 328bb29

File tree

3 files changed

+51
-38
lines changed

3 files changed

+51
-38
lines changed

src/viam/examples/mlmodel/example_audio_classification_client.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include <boost/format.hpp>
2828
#include <boost/optional.hpp>
2929
#include <boost/program_options.hpp>
30+
#include <boost/variant/get.hpp>
3031

3132
#include <viam/sdk/robot/client.hpp>
3233
#include <viam/sdk/services/mlmodel.hpp>

src/viam/examples/modules/tflite/main.cpp

Lines changed: 48 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include <grpcpp/security/credentials.h>
2727
#include <tensorflow/lite/c/c_api.h>
2828

29+
#include <viam/sdk/common/proto_value.hpp>
2930
#include <viam/sdk/components/component.hpp>
3031
#include <viam/sdk/config/resource.hpp>
3132
#include <viam/sdk/module/service.hpp>
@@ -290,64 +291,67 @@ class MLModelServiceTFLite : public vsdk::MLModelService,
290291
// Now we can begin parsing and validating the provided `configuration`.
291292
// Pull the model path out of the configuration.
292293
const auto& attributes = state->configuration.attributes();
293-
auto model_path = attributes->find("model_path");
294-
if (model_path == attributes->end()) {
294+
auto model_path = attributes.find("model_path");
295+
if (model_path == attributes.end()) {
295296
std::ostringstream buffer;
296297
buffer << service_name
297298
<< ": Required parameter `model_path` not found in configuration";
298299
throw std::invalid_argument(buffer.str());
299300
}
300-
const auto* const model_path_string = model_path->second->get<std::string>();
301-
if (!model_path_string || model_path_string->empty()) {
301+
302+
const vsdk::ProtoValue& model_path_val = model_path->second;
303+
if (!model_path_val.is_a<std::string>() ||
304+
model_path_val.get_unchecked<std::string>().empty()) {
302305
std::ostringstream buffer;
303306
buffer << service_name
304307
<< ": Required non-empty string parameter `model_path` is either not a string "
305308
"or is an empty string";
306309
throw std::invalid_argument(buffer.str());
307310
}
311+
const std::string& model_path_string = model_path_val.get_unchecked<std::string>();
308312

309313
// Process any tensor name remappings provided in the config.
310-
auto remappings = attributes->find("tensor_name_remappings");
311-
if (remappings != attributes->end()) {
312-
const auto remappings_attributes = remappings->second->get<vsdk::ProtoStruct>();
313-
if (!remappings_attributes) {
314+
auto remappings = attributes.find("tensor_name_remappings");
315+
if (remappings != attributes.end()) {
316+
if (!remappings->second.is_a<vsdk::ProtoStruct>()) {
314317
std::ostringstream buffer;
315318
buffer << service_name
316319
<< ": Optional parameter `tensor_name_remappings` must be a dictionary";
317320
throw std::invalid_argument(buffer.str());
318321
}
322+
const auto remappings_attributes =
323+
remappings->second.get_unchecked<vsdk::ProtoStruct>();
319324

320-
const auto populate_remappings = [](const vsdk::ProtoType& source, auto& target) {
321-
const auto source_attributes = source.get<vsdk::ProtoStruct>();
322-
if (!source_attributes) {
325+
const auto populate_remappings = [](const vsdk::ProtoValue& source, auto& target) {
326+
if (!source.is_a<vsdk::ProtoStruct>()) {
323327
std::ostringstream buffer;
324328
buffer << service_name
325-
<< ": Fields `inputs` and `outputs` of `tensor_name_remappings` must be "
329+
<< ": Fields `inputs` and `outputs` of `tensor_name_remappings` "
330+
"must be "
326331
"dictionaries";
327332
throw std::invalid_argument(buffer.str());
328333
}
329-
for (const auto& kv : *source_attributes) {
334+
for (const auto& kv : source.get_unchecked<vsdk::ProtoStruct>()) {
330335
const auto& k = kv.first;
331-
const auto* const kv_string = kv.second->get<std::string>();
332-
if (!kv_string) {
336+
if (!kv.second.is_a<std::string>()) {
333337
std::ostringstream buffer;
334-
buffer
335-
<< service_name
336-
<< ": Fields `inputs` and `outputs` of `tensor_name_remappings` must "
337-
"be dictionaries with string values";
338+
buffer << service_name
339+
<< ": Fields `inputs` and `outputs` of `tensor_name_remappings` "
340+
"must "
341+
"be dictionaries with string values";
338342
throw std::invalid_argument(buffer.str());
339343
}
340-
target[kv.first] = *kv_string;
344+
target[kv.first] = kv.second.get_unchecked<std::string>();
341345
}
342346
};
343347

344-
const auto inputs_where = remappings_attributes->find("inputs");
345-
if (inputs_where != remappings_attributes->end()) {
346-
populate_remappings(*inputs_where->second, state->input_name_remappings);
348+
const auto inputs_where = remappings_attributes.find("inputs");
349+
if (inputs_where != remappings_attributes.end()) {
350+
populate_remappings(inputs_where->second, state->input_name_remappings);
347351
}
348-
const auto outputs_where = remappings_attributes->find("outputs");
349-
if (outputs_where != remappings_attributes->end()) {
350-
populate_remappings(*outputs_where->second, state->output_name_remappings);
352+
const auto outputs_where = remappings_attributes.find("outputs");
353+
if (outputs_where != remappings_attributes.end()) {
354+
populate_remappings(outputs_where->second, state->output_name_remappings);
351355
}
352356
}
353357

@@ -362,11 +366,11 @@ class MLModelServiceTFLite : public vsdk::MLModelService,
362366
// buffer which we can use with `TfLiteModelCreate`. That
363367
// still requires that the buffer be kept valid, but that's
364368
// more easily done.
365-
const std::ifstream in(*model_path_string, std::ios::in | std::ios::binary);
369+
const std::ifstream in(model_path_string, std::ios::in | std::ios::binary);
366370
if (!in) {
367371
std::ostringstream buffer;
368372
buffer << service_name << ": Failed to open file for `model_path` "
369-
<< *model_path_string;
373+
<< model_path_string;
370374
throw std::invalid_argument(buffer.str());
371375
}
372376
std::ostringstream model_path_contents_stream;
@@ -399,23 +403,29 @@ class MLModelServiceTFLite : public vsdk::MLModelService,
399403
// If present, extract and validate the number of threads to
400404
// use in the interpreter and create an interpreter options
401405
// object to carry that information.
402-
auto num_threads = attributes->find("num_threads");
403-
if (num_threads != attributes->end()) {
404-
const auto* num_threads_double = num_threads->second->get<double>();
405-
if (!num_threads_double || !std::isnormal(*num_threads_double) ||
406-
(*num_threads_double < 0) ||
407-
(*num_threads_double >= std::numeric_limits<std::int32_t>::max()) ||
408-
(std::trunc(*num_threads_double) != *num_threads_double)) {
406+
auto num_threads = attributes.find("num_threads");
407+
if (num_threads != attributes.end()) {
408+
auto throwError = [&] {
409409
std::ostringstream buffer;
410410
buffer << service_name
411-
<< ": Value for field `num_threads` is not a positive integer: "
412-
<< *num_threads_double;
411+
<< ": Value for field `num_threads` is not a positive integer";
413412
throw std::invalid_argument(buffer.str());
413+
};
414+
415+
if (!num_threads->second.is_a<double>()) {
416+
throwError();
417+
}
418+
419+
double num_threads_double = num_threads->second.get_unchecked<double>();
420+
if (!std::isnormal(num_threads_double) || (num_threads_double < 0) ||
421+
(num_threads_double >= std::numeric_limits<std::int32_t>::max()) ||
422+
(std::trunc(num_threads_double) != num_threads_double)) {
423+
throwError();
414424
}
415425

416426
state->interpreter_options.reset(TfLiteInterpreterOptionsCreate());
417427
TfLiteInterpreterOptionsSetNumThreads(state->interpreter_options.get(),
418-
static_cast<int32_t>(*num_threads_double));
428+
static_cast<int32_t>(num_threads_double));
419429
}
420430

421431
// Build the single interpreter.

src/viam/sdk/tests/test_mlmodel.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
#include <tuple>
1919
#include <unordered_map>
2020

21+
#include <boost/variant/get.hpp>
22+
2123
#include <viam/sdk/tests/mocks/mlmodel_mocks.hpp>
2224
#include <viam/sdk/tests/test_utils.hpp>
2325

0 commit comments

Comments
 (0)