Skip to content

Commit d22afbe

Browse files
committed
Add support for ArenaCfg configuration options
1 parent 2be37f7 commit d22afbe

File tree

1 file changed

+107
-1
lines changed

1 file changed

+107
-1
lines changed

src/onnxruntime.cc

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,30 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
302302
}
303303
}
304304

305+
// Set use_device_allocator_for_initializers
306+
{
307+
triton::common::TritonJson::Value params;
308+
if (ModelConfig().Find("parameters", &params)) {
309+
triton::common::TritonJson::Value json_value;
310+
const char* use_device_allocator_for_initializers_key =
311+
"session.use_device_allocator_for_initializers";
312+
if (params.Find(use_device_allocator_for_initializers_key, &json_value)) {
313+
std::string string_value;
314+
THROW_IF_BACKEND_MODEL_ERROR(
315+
json_value.MemberAsString("string_value", &string_value));
316+
317+
LOG_MESSAGE(
318+
TRITONSERVER_LOG_VERBOSE,
319+
(std::string("Configuring ") +
320+
use_device_allocator_for_initializers_key + " to " + string_value)
321+
.c_str());
322+
THROW_IF_BACKEND_MODEL_ORT_ERROR(ort_api->AddSessionConfigEntry(
323+
soptions, use_device_allocator_for_initializers_key,
324+
string_value.c_str()));
325+
}
326+
}
327+
}
328+
305329
// memory configs
306330
// enable/disable mem arena
307331
{
@@ -762,8 +786,90 @@ ModelState::LoadModel(
762786
rel_cuda_options(cuda_options, ort_api->ReleaseCUDAProviderOptions);
763787
cuda_options_map["device_id"] = std::to_string(instance_group_device_id);
764788
cuda_options_map["has_user_compute_stream"] = stream != nullptr ? "1" : "0";
789+
790+
// Memory arena config
791+
OrtArenaCfg* arena_cfg = nullptr;
792+
{
793+
triton::common::TritonJson::Value params;
794+
if (model_config_.Find("parameters", &params)) {
795+
triton::common::TritonJson::Value json_value;
796+
std::vector<const char*> keys;
797+
std::vector<size_t> values;
798+
if (params.Find("max_mem", &json_value)) {
799+
std::string string_value;
800+
THROW_IF_BACKEND_MODEL_ERROR(
801+
json_value.MemberAsString("string_value", &string_value));
802+
keys.push_back("max_mem");
803+
size_t value;
804+
RETURN_IF_ERROR(ParseUnsignedLongLongValue(string_value, &value));
805+
values.push_back(value);
806+
}
807+
if (params.Find("arena_extend_strategy", &json_value)) {
808+
std::string string_value;
809+
THROW_IF_BACKEND_MODEL_ERROR(
810+
json_value.MemberAsString("string_value", &string_value));
811+
keys.push_back("arena_extend_strategy");
812+
size_t value;
813+
RETURN_IF_ERROR(ParseUnsignedLongLongValue(string_value, &value));
814+
values.push_back(value);
815+
}
816+
if (params.Find("initial_chunk_size_bytes", &json_value)) {
817+
std::string string_value;
818+
THROW_IF_BACKEND_MODEL_ERROR(
819+
json_value.MemberAsString("string_value", &string_value));
820+
keys.push_back("initial_chunk_size_bytes");
821+
size_t value;
822+
RETURN_IF_ERROR(ParseUnsignedLongLongValue(string_value, &value));
823+
values.push_back(value);
824+
}
825+
if (params.Find("initial_growth_chunk_size_bytes", &json_value)) {
826+
std::string string_value;
827+
THROW_IF_BACKEND_MODEL_ERROR(
828+
json_value.MemberAsString("string_value", &string_value));
829+
keys.push_back("initial_growth_chunk_size_bytes");
830+
size_t value;
831+
RETURN_IF_ERROR(ParseUnsignedLongLongValue(string_value, &value));
832+
values.push_back(value);
833+
}
834+
if (params.Find("max_dead_bytes_per_chunk", &json_value)) {
835+
std::string string_value;
836+
THROW_IF_BACKEND_MODEL_ERROR(
837+
json_value.MemberAsString("string_value", &string_value));
838+
keys.push_back("max_dead_bytes_per_chunk");
839+
size_t value;
840+
RETURN_IF_ERROR(ParseUnsignedLongLongValue(string_value, &value));
841+
values.push_back(value);
842+
}
843+
if (params.Find("max_power_of_two_extend_bytes", &json_value)) {
844+
std::string string_value;
845+
THROW_IF_BACKEND_MODEL_ERROR(
846+
json_value.MemberAsString("string_value", &string_value));
847+
keys.push_back("max_power_of_two_extend_bytes");
848+
size_t value;
849+
RETURN_IF_ERROR(ParseUnsignedLongLongValue(string_value, &value));
850+
values.push_back(value);
851+
}
852+
if (!keys.empty()) {
853+
RETURN_IF_ORT_ERROR(ort_api->CreateArenaCfgV2(
854+
keys.data(), values.data(), keys.size(), &arena_cfg));
855+
856+
std::ostringstream oss;
857+
for (size_t i = 0; i < keys.size(); ++i) {
858+
oss << keys[i] << "=" << values[i] << ", ";
859+
}
860+
LOG_MESSAGE(
861+
TRITONSERVER_LOG_VERBOSE,
862+
(std::string("Updated arena config options: ") + oss.str())
863+
.c_str());
864+
}
865+
}
866+
}
867+
std::unique_ptr<OrtArenaCfg, decltype(ort_api->ReleaseArenaCfg)>
868+
rel_arena_cfg(arena_cfg, ort_api->ReleaseArenaCfg);
765869
RETURN_IF_ORT_ERROR(ort_api->UpdateCUDAProviderOptionsWithValue(
766-
rel_cuda_options.get(), "default_memory_arena_cfg", nullptr));
870+
rel_cuda_options.get(), "default_memory_arena_cfg",
871+
rel_arena_cfg.get()));
872+
767873
{
768874
// Parse CUDA EP configurations directly from the parameters field.
769875
// This is deprecated with adding support for CUDA EP in the

0 commit comments

Comments
 (0)