@@ -433,42 +433,32 @@ ModelState::LoadModel(
433433#ifdef TRITON_ENABLE_ONNXRUNTIME_TENSORRT
434434 if (name == kTensorRTExecutionAccelerator ) {
435435 // create tensorrt options with default values
436+ OrtTensorRTProviderOptionsV2* trt_options;
437+ THROW_IF_BACKEND_MODEL_ORT_ERROR (
438+ ort_api->CreateTensorRTProviderOptions (&trt_options));
439+ std::unique_ptr<
440+ OrtTensorRTProviderOptionsV2,
441+ decltype (ort_api->ReleaseTensorRTProviderOptions )>
442+ rel_trt_options (
443+ trt_options, ort_api->ReleaseTensorRTProviderOptions );
436444 std::string int8_calibration_table_name;
437445 std::string trt_engine_cache_path;
438- OrtTensorRTProviderOptions trt_options{
439- instance_group_device_id,
440- stream != nullptr ? 1 : 0 ,
441- stream != nullptr ? (void *)stream : nullptr ,
442- 1000 , // trt_max_partition_iterations
443- 1 , // trt_min_subgraph_size
444- 1 << 30 , // max_workspace_size
445- 0 , // trt_fp16_enable
446- 0 , // trt_int8_enable
447- nullptr , // trt_int8_calibration_table_name
448- 0 , // trt_int8_use_native_calibration_table
449- 0 , // trt_dla_enable
450- 0 , // trt_dla_core
451- 0 , // trt_dump_subgraphs
452- 0 , // trt_engine_cache_enable
453- nullptr , // trt_engine_cache_path
454- 0 , // trt_engine_decryption_enable
455- nullptr , // trt_engine_decryption_lib_path
456- 0 // trt_force_sequential_engine_build
457- };
458446 // Validate and set parameters
459447 triton::common::TritonJson::Value params;
460448 if (ea.Find (" parameters" , ¶ms)) {
461- std::vector<std::string> param_keys;
449+ std::vector<std::string> param_keys, keys, values ;
462450 RETURN_IF_ERROR (params.Members (¶m_keys));
463451 for (const auto & param_key : param_keys) {
464- std::string value_string;
452+ std::string value_string, key, value ;
465453 if (param_key == " precision_mode" ) {
466454 RETURN_IF_ERROR (params.MemberAsString (
467455 param_key.c_str (), &value_string));
468456 if (value_string == " FP16" ) {
469- trt_options.trt_fp16_enable = 1 ;
457+ key = " trt_fp16_enable" ;
458+ value = " 1" ;
470459 } else if (value_string == " INT8" ) {
471- trt_options.trt_int8_enable = 1 ;
460+ key = " trt_int8_enable" ;
461+ value = " 1" ;
472462 } else if (value_string != " FP32" ) {
473463 RETURN_ERROR_IF_FALSE (
474464 false , TRITONSERVER_ERROR_INVALID_ARG,
@@ -481,33 +471,32 @@ ModelState::LoadModel(
481471 size_t max_workspace_size_bytes;
482472 RETURN_IF_ERROR (ParseUnsignedLongLongValue (
483473 value_string, &max_workspace_size_bytes));
484- trt_options. trt_max_workspace_size =
485- max_workspace_size_bytes ;
474+ key = " trt_max_workspace_size " ;
475+ value = value_string ;
486476 } else if (param_key == " int8_calibration_table_name" ) {
487- RETURN_IF_ERROR (params.MemberAsString (
488- param_key.c_str (), &int8_calibration_table_name));
489- trt_options.trt_int8_calibration_table_name =
490- int8_calibration_table_name.c_str ();
477+ RETURN_IF_ERROR (
478+ params.MemberAsString (param_key.c_str (), &value));
479+ key = " trt_int8_calibration_table_name" ;
491480 } else if (param_key == " int8_use_native_calibration_table" ) {
492481 RETURN_IF_ERROR (params.MemberAsString (
493482 param_key.c_str (), &value_string));
494- int use_native_calibration_table;
495- RETURN_IF_ERROR (ParseIntValue (
483+ bool use_native_calibration_table;
484+ RETURN_IF_ERROR (ParseBoolValue (
496485 value_string, &use_native_calibration_table));
497- trt_options. trt_int8_use_native_calibration_table =
498- use_native_calibration_table ;
486+ key = " trt_int8_use_native_calibration_table " ;
487+ value = value_string ;
499488 } else if (param_key == " trt_engine_cache_enable" ) {
500489 RETURN_IF_ERROR (params.MemberAsString (
501490 param_key.c_str (), &value_string));
502491 bool enable_cache;
503492 RETURN_IF_ERROR (
504493 ParseBoolValue (value_string, &enable_cache));
505- trt_options.trt_engine_cache_enable = enable_cache;
494+ key = " trt_engine_cache_enable" ;
495+ value = value_string;
506496 } else if (param_key == " trt_engine_cache_path" ) {
507- RETURN_IF_ERROR (params.MemberAsString (
508- param_key.c_str (), &trt_engine_cache_path));
509- trt_options.trt_engine_cache_path =
510- trt_engine_cache_path.c_str ();
497+ RETURN_IF_ERROR (
498+ params.MemberAsString (param_key.c_str (), &value));
499+ key = " trt_engine_cache_path" ;
511500 } else {
512501 return TRITONSERVER_ErrorNew (
513502 TRITONSERVER_ERROR_INVALID_ARG,
@@ -517,12 +506,27 @@ ModelState::LoadModel(
517506 " Accelerator" )
518507 .c_str ());
519508 }
509+ if (!key.empty () && !value.empty ()) {
510+ keys.push_back (key);
511+ values.push_back (value);
512+ }
513+ }
514+ std::vector<const char *> c_keys, c_values;
515+ if (!keys.empty () && !values.empty ()) {
516+ for (size_t i = 0 ; i < keys.size (); ++i) {
517+ c_keys.push_back (keys[i].c_str ());
518+ c_values.push_back (values[i].c_str ());
519+ }
520+ RETURN_IF_ORT_ERROR (ort_api->UpdateTensorRTProviderOptions (
521+ rel_trt_options.get (), c_keys.data (), c_values.data (),
522+ keys.size ()));
520523 }
521524 }
522525
523526 RETURN_IF_ORT_ERROR (
524- ort_api->SessionOptionsAppendExecutionProvider_TensorRT (
525- soptions, &trt_options));
527+ ort_api->SessionOptionsAppendExecutionProvider_TensorRT_V2 (
528+ static_cast <OrtSessionOptions*>(soptions),
529+ rel_trt_options.get ()));
526530 LOG_MESSAGE (
527531 TRITONSERVER_LOG_VERBOSE,
528532 (std::string (" TensorRT Execution Accelerator is set for '" ) +
0 commit comments