2626#include < grpcpp/security/credentials.h>
2727#include < tensorflow/lite/c/c_api.h>
2828
29+ #include < viam/sdk/common/proto_value.hpp>
2930#include < viam/sdk/components/component.hpp>
3031#include < viam/sdk/config/resource.hpp>
3132#include < viam/sdk/module/service.hpp>
@@ -290,64 +291,67 @@ class MLModelServiceTFLite : public vsdk::MLModelService,
290291 // Now we can begin parsing and validating the provided `configuration`.
291292 // Pull the model path out of the configuration.
292293 const auto & attributes = state->configuration .attributes ();
293- auto model_path = attributes-> find (" model_path" );
294- if (model_path == attributes-> end ()) {
294+ auto model_path = attributes. find (" model_path" );
295+ if (model_path == attributes. end ()) {
295296 std::ostringstream buffer;
296297 buffer << service_name
297298 << " : Required parameter `model_path` not found in configuration" ;
298299 throw std::invalid_argument (buffer.str ());
299300 }
300- const auto * const model_path_string = model_path->second ->get <std::string>();
301- if (!model_path_string || model_path_string->empty ()) {
301+
302+ const vsdk::ProtoValue& model_path_val = model_path->second ;
303+ if (!model_path_val.is_a <std::string>() ||
304+ model_path_val.get_unchecked <std::string>().empty ()) {
302305 std::ostringstream buffer;
303306 buffer << service_name
304307 << " : Required non-empty string parameter `model_path` is either not a string "
305308 " or is an empty string" ;
306309 throw std::invalid_argument (buffer.str ());
307310 }
311+ const std::string& model_path_string = model_path_val.get_unchecked <std::string>();
308312
309313 // Process any tensor name remappings provided in the config.
310- auto remappings = attributes->find (" tensor_name_remappings" );
311- if (remappings != attributes->end ()) {
312- const auto remappings_attributes = remappings->second ->get <vsdk::ProtoStruct>();
313- if (!remappings_attributes) {
314+ auto remappings = attributes.find (" tensor_name_remappings" );
315+ if (remappings != attributes.end ()) {
316+ if (!remappings->second .is_a <vsdk::ProtoStruct>()) {
314317 std::ostringstream buffer;
315318 buffer << service_name
316319 << " : Optional parameter `tensor_name_remappings` must be a dictionary" ;
317320 throw std::invalid_argument (buffer.str ());
318321 }
322+ const auto remappings_attributes =
323+ remappings->second .get_unchecked <vsdk::ProtoStruct>();
319324
320- const auto populate_remappings = [](const vsdk::ProtoType& source, auto & target) {
321- const auto source_attributes = source.get <vsdk::ProtoStruct>();
322- if (!source_attributes) {
325+ const auto populate_remappings = [](const vsdk::ProtoValue& source, auto & target) {
326+ if (!source.is_a <vsdk::ProtoStruct>()) {
323327 std::ostringstream buffer;
324328 buffer << service_name
325- << " : Fields `inputs` and `outputs` of `tensor_name_remappings` must be "
329+ << " : Fields `inputs` and `outputs` of `tensor_name_remappings` "
330+ " must be "
326331 " dictionaries" ;
327332 throw std::invalid_argument (buffer.str ());
328333 }
329- for (const auto & kv : *source_attributes ) {
334+ for (const auto & kv : source. get_unchecked <vsdk::ProtoStruct>() ) {
330335 const auto & k = kv.first ;
331- const auto * const kv_string = kv.second ->get <std::string>();
332- if (!kv_string) {
336+ if (!kv.second .is_a <std::string>()) {
333337 std::ostringstream buffer;
334- buffer
335- << service_name
336- << " : Fields `inputs` and `outputs` of `tensor_name_remappings` must "
337- " be dictionaries with string values" ;
338+ buffer << service_name
339+ << " : Fields `inputs` and `outputs` of `tensor_name_remappings` "
340+ " must "
341+ " be dictionaries with string values" ;
338342 throw std::invalid_argument (buffer.str ());
339343 }
340- target[kv.first ] = *kv_string ;
344+ target[kv.first ] = kv. second . get_unchecked <std::string>() ;
341345 }
342346 };
343347
344- const auto inputs_where = remappings_attributes-> find (" inputs" );
345- if (inputs_where != remappings_attributes-> end ()) {
346- populate_remappings (* inputs_where->second , state->input_name_remappings );
348+ const auto inputs_where = remappings_attributes. find (" inputs" );
349+ if (inputs_where != remappings_attributes. end ()) {
350+ populate_remappings (inputs_where->second , state->input_name_remappings );
347351 }
348- const auto outputs_where = remappings_attributes-> find (" outputs" );
349- if (outputs_where != remappings_attributes-> end ()) {
350- populate_remappings (* outputs_where->second , state->output_name_remappings );
352+ const auto outputs_where = remappings_attributes. find (" outputs" );
353+ if (outputs_where != remappings_attributes. end ()) {
354+ populate_remappings (outputs_where->second , state->output_name_remappings );
351355 }
352356 }
353357
@@ -362,11 +366,11 @@ class MLModelServiceTFLite : public vsdk::MLModelService,
362366 // buffer which we can use with `TfLiteModelCreate`. That
363367 // still requires that the buffer be kept valid, but that's
364368 // more easily done.
365- const std::ifstream in (* model_path_string, std::ios::in | std::ios::binary);
369+ const std::ifstream in (model_path_string, std::ios::in | std::ios::binary);
366370 if (!in) {
367371 std::ostringstream buffer;
368372 buffer << service_name << " : Failed to open file for `model_path` "
369- << * model_path_string;
373+ << model_path_string;
370374 throw std::invalid_argument (buffer.str ());
371375 }
372376 std::ostringstream model_path_contents_stream;
@@ -399,23 +403,29 @@ class MLModelServiceTFLite : public vsdk::MLModelService,
399403 // If present, extract and validate the number of threads to
400404 // use in the interpreter and create an interpreter options
401405 // object to carry that information.
402- auto num_threads = attributes->find (" num_threads" );
403- if (num_threads != attributes->end ()) {
404- const auto * num_threads_double = num_threads->second ->get <double >();
405- if (!num_threads_double || !std::isnormal (*num_threads_double) ||
406- (*num_threads_double < 0 ) ||
407- (*num_threads_double >= std::numeric_limits<std::int32_t >::max ()) ||
408- (std::trunc (*num_threads_double) != *num_threads_double)) {
406+ auto num_threads = attributes.find (" num_threads" );
407+ if (num_threads != attributes.end ()) {
408+ auto throwError = [&] {
409409 std::ostringstream buffer;
410410 buffer << service_name
411- << " : Value for field `num_threads` is not a positive integer: "
412- << *num_threads_double;
411+ << " : Value for field `num_threads` is not a positive integer" ;
413412 throw std::invalid_argument (buffer.str ());
413+ };
414+
415+ if (!num_threads->second .is_a <double >()) {
416+ throwError ();
417+ }
418+
419+ double num_threads_double = num_threads->second .get_unchecked <double >();
420+ if (!std::isnormal (num_threads_double) || (num_threads_double < 0 ) ||
421+ (num_threads_double >= std::numeric_limits<std::int32_t >::max ()) ||
422+ (std::trunc (num_threads_double) != num_threads_double)) {
423+ throwError ();
414424 }
415425
416426 state->interpreter_options .reset (TfLiteInterpreterOptionsCreate ());
417427 TfLiteInterpreterOptionsSetNumThreads (state->interpreter_options .get (),
418- static_cast <int32_t >(* num_threads_double));
428+ static_cast <int32_t >(num_threads_double));
419429 }
420430
421431 // Build the single interpreter.
0 commit comments