Skip to content

Commit 18cdc4a

Browse files
authored
[TensorRT EP] Update trt option to v2 (#233)
* update * revert * format * format * fix * fix * lint * replace * fix * fix * lint * check when option is empty
1 parent 6817662 commit 18cdc4a

File tree

1 file changed

+45
-41
lines changed

1 file changed

+45
-41
lines changed

src/onnxruntime.cc

Lines changed: 45 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -433,42 +433,32 @@ ModelState::LoadModel(
433433
#ifdef TRITON_ENABLE_ONNXRUNTIME_TENSORRT
434434
if (name == kTensorRTExecutionAccelerator) {
435435
// create tensorrt options with default values
436+
OrtTensorRTProviderOptionsV2* trt_options;
437+
THROW_IF_BACKEND_MODEL_ORT_ERROR(
438+
ort_api->CreateTensorRTProviderOptions(&trt_options));
439+
std::unique_ptr<
440+
OrtTensorRTProviderOptionsV2,
441+
decltype(ort_api->ReleaseTensorRTProviderOptions)>
442+
rel_trt_options(
443+
trt_options, ort_api->ReleaseTensorRTProviderOptions);
436444
std::string int8_calibration_table_name;
437445
std::string trt_engine_cache_path;
438-
OrtTensorRTProviderOptions trt_options{
439-
instance_group_device_id,
440-
stream != nullptr ? 1 : 0,
441-
stream != nullptr ? (void*)stream : nullptr,
442-
1000, // trt_max_partition_iterations
443-
1, // trt_min_subgraph_size
444-
1 << 30, // max_workspace_size
445-
0, // trt_fp16_enable
446-
0, // trt_int8_enable
447-
nullptr, // trt_int8_calibration_table_name
448-
0, // trt_int8_use_native_calibration_table
449-
0, // trt_dla_enable
450-
0, // trt_dla_core
451-
0, // trt_dump_subgraphs
452-
0, // trt_engine_cache_enable
453-
nullptr, // trt_engine_cache_path
454-
0, // trt_engine_decryption_enable
455-
nullptr, // trt_engine_decryption_lib_path
456-
0 // trt_force_sequential_engine_build
457-
};
458446
// Validate and set parameters
459447
triton::common::TritonJson::Value params;
460448
if (ea.Find("parameters", &params)) {
461-
std::vector<std::string> param_keys;
449+
std::vector<std::string> param_keys, keys, values;
462450
RETURN_IF_ERROR(params.Members(&param_keys));
463451
for (const auto& param_key : param_keys) {
464-
std::string value_string;
452+
std::string value_string, key, value;
465453
if (param_key == "precision_mode") {
466454
RETURN_IF_ERROR(params.MemberAsString(
467455
param_key.c_str(), &value_string));
468456
if (value_string == "FP16") {
469-
trt_options.trt_fp16_enable = 1;
457+
key = "trt_fp16_enable";
458+
value = "1";
470459
} else if (value_string == "INT8") {
471-
trt_options.trt_int8_enable = 1;
460+
key = "trt_int8_enable";
461+
value = "1";
472462
} else if (value_string != "FP32") {
473463
RETURN_ERROR_IF_FALSE(
474464
false, TRITONSERVER_ERROR_INVALID_ARG,
@@ -481,33 +471,32 @@ ModelState::LoadModel(
481471
size_t max_workspace_size_bytes;
482472
RETURN_IF_ERROR(ParseUnsignedLongLongValue(
483473
value_string, &max_workspace_size_bytes));
484-
trt_options.trt_max_workspace_size =
485-
max_workspace_size_bytes;
474+
key = "trt_max_workspace_size";
475+
value = value_string;
486476
} else if (param_key == "int8_calibration_table_name") {
487-
RETURN_IF_ERROR(params.MemberAsString(
488-
param_key.c_str(), &int8_calibration_table_name));
489-
trt_options.trt_int8_calibration_table_name =
490-
int8_calibration_table_name.c_str();
477+
RETURN_IF_ERROR(
478+
params.MemberAsString(param_key.c_str(), &value));
479+
key = "trt_int8_calibration_table_name";
491480
} else if (param_key == "int8_use_native_calibration_table") {
492481
RETURN_IF_ERROR(params.MemberAsString(
493482
param_key.c_str(), &value_string));
494-
int use_native_calibration_table;
495-
RETURN_IF_ERROR(ParseIntValue(
483+
bool use_native_calibration_table;
484+
RETURN_IF_ERROR(ParseBoolValue(
496485
value_string, &use_native_calibration_table));
497-
trt_options.trt_int8_use_native_calibration_table =
498-
use_native_calibration_table;
486+
key = "trt_int8_use_native_calibration_table";
487+
value = value_string;
499488
} else if (param_key == "trt_engine_cache_enable") {
500489
RETURN_IF_ERROR(params.MemberAsString(
501490
param_key.c_str(), &value_string));
502491
bool enable_cache;
503492
RETURN_IF_ERROR(
504493
ParseBoolValue(value_string, &enable_cache));
505-
trt_options.trt_engine_cache_enable = enable_cache;
494+
key = "trt_engine_cache_enable";
495+
value = value_string;
506496
} else if (param_key == "trt_engine_cache_path") {
507-
RETURN_IF_ERROR(params.MemberAsString(
508-
param_key.c_str(), &trt_engine_cache_path));
509-
trt_options.trt_engine_cache_path =
510-
trt_engine_cache_path.c_str();
497+
RETURN_IF_ERROR(
498+
params.MemberAsString(param_key.c_str(), &value));
499+
key = "trt_engine_cache_path";
511500
} else {
512501
return TRITONSERVER_ErrorNew(
513502
TRITONSERVER_ERROR_INVALID_ARG,
@@ -517,12 +506,27 @@ ModelState::LoadModel(
517506
"Accelerator")
518507
.c_str());
519508
}
509+
if (!key.empty() && !value.empty()) {
510+
keys.push_back(key);
511+
values.push_back(value);
512+
}
513+
}
514+
std::vector<const char*> c_keys, c_values;
515+
if (!keys.empty() && !values.empty()) {
516+
for (size_t i = 0; i < keys.size(); ++i) {
517+
c_keys.push_back(keys[i].c_str());
518+
c_values.push_back(values[i].c_str());
519+
}
520+
RETURN_IF_ORT_ERROR(ort_api->UpdateTensorRTProviderOptions(
521+
rel_trt_options.get(), c_keys.data(), c_values.data(),
522+
keys.size()));
520523
}
521524
}
522525

523526
RETURN_IF_ORT_ERROR(
524-
ort_api->SessionOptionsAppendExecutionProvider_TensorRT(
525-
soptions, &trt_options));
527+
ort_api->SessionOptionsAppendExecutionProvider_TensorRT_V2(
528+
static_cast<OrtSessionOptions*>(soptions),
529+
rel_trt_options.get()));
526530
LOG_MESSAGE(
527531
TRITONSERVER_LOG_VERBOSE,
528532
(std::string("TensorRT Execution Accelerator is set for '") +

0 commit comments

Comments
 (0)