@@ -302,6 +302,30 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
302302 }
303303 }
304304
305+ // Set use_device_allocator_for_initializers
306+ {
307+ triton::common::TritonJson::Value params;
308+ if (ModelConfig ().Find (" parameters" , ¶ms)) {
309+ triton::common::TritonJson::Value json_value;
310+ const char * use_device_allocator_for_initializers_key =
311+ " session.use_device_allocator_for_initializers" ;
312+ if (params.Find (use_device_allocator_for_initializers_key, &json_value)) {
313+ std::string string_value;
314+ THROW_IF_BACKEND_MODEL_ERROR (
315+ json_value.MemberAsString (" string_value" , &string_value));
316+
317+ LOG_MESSAGE (
318+ TRITONSERVER_LOG_VERBOSE,
319+ (std::string (" Configuring " ) +
320+ use_device_allocator_for_initializers_key + " to " + string_value)
321+ .c_str ());
322+ THROW_IF_BACKEND_MODEL_ORT_ERROR (ort_api->AddSessionConfigEntry (
323+ soptions, use_device_allocator_for_initializers_key,
324+ string_value.c_str ()));
325+ }
326+ }
327+ }
328+
305329 // memory configs
306330 // enable/disable mem arena
307331 {
@@ -762,8 +786,90 @@ ModelState::LoadModel(
762786 rel_cuda_options (cuda_options, ort_api->ReleaseCUDAProviderOptions );
763787 cuda_options_map[" device_id" ] = std::to_string (instance_group_device_id);
764788 cuda_options_map[" has_user_compute_stream" ] = stream != nullptr ? " 1" : " 0" ;
789+
790+ // Memory arena config
791+ OrtArenaCfg* arena_cfg = nullptr ;
792+ {
793+ triton::common::TritonJson::Value params;
794+ if (model_config_.Find (" parameters" , ¶ms)) {
795+ triton::common::TritonJson::Value json_value;
796+ std::vector<const char *> keys;
797+ std::vector<size_t > values;
798+ if (params.Find (" max_mem" , &json_value)) {
799+ std::string string_value;
800+ THROW_IF_BACKEND_MODEL_ERROR (
801+ json_value.MemberAsString (" string_value" , &string_value));
802+ keys.push_back (" max_mem" );
803+ size_t value;
804+ RETURN_IF_ERROR (ParseUnsignedLongLongValue (string_value, &value));
805+ values.push_back (value);
806+ }
807+ if (params.Find (" arena_extend_strategy" , &json_value)) {
808+ std::string string_value;
809+ THROW_IF_BACKEND_MODEL_ERROR (
810+ json_value.MemberAsString (" string_value" , &string_value));
811+ keys.push_back (" arena_extend_strategy" );
812+ size_t value;
813+ RETURN_IF_ERROR (ParseUnsignedLongLongValue (string_value, &value));
814+ values.push_back (value);
815+ }
816+ if (params.Find (" initial_chunk_size_bytes" , &json_value)) {
817+ std::string string_value;
818+ THROW_IF_BACKEND_MODEL_ERROR (
819+ json_value.MemberAsString (" string_value" , &string_value));
820+ keys.push_back (" initial_chunk_size_bytes" );
821+ size_t value;
822+ RETURN_IF_ERROR (ParseUnsignedLongLongValue (string_value, &value));
823+ values.push_back (value);
824+ }
825+ if (params.Find (" initial_growth_chunk_size_bytes" , &json_value)) {
826+ std::string string_value;
827+ THROW_IF_BACKEND_MODEL_ERROR (
828+ json_value.MemberAsString (" string_value" , &string_value));
829+ keys.push_back (" initial_growth_chunk_size_bytes" );
830+ size_t value;
831+ RETURN_IF_ERROR (ParseUnsignedLongLongValue (string_value, &value));
832+ values.push_back (value);
833+ }
834+ if (params.Find (" max_dead_bytes_per_chunk" , &json_value)) {
835+ std::string string_value;
836+ THROW_IF_BACKEND_MODEL_ERROR (
837+ json_value.MemberAsString (" string_value" , &string_value));
838+ keys.push_back (" max_dead_bytes_per_chunk" );
839+ size_t value;
840+ RETURN_IF_ERROR (ParseUnsignedLongLongValue (string_value, &value));
841+ values.push_back (value);
842+ }
843+ if (params.Find (" max_power_of_two_extend_bytes" , &json_value)) {
844+ std::string string_value;
845+ THROW_IF_BACKEND_MODEL_ERROR (
846+ json_value.MemberAsString (" string_value" , &string_value));
847+ keys.push_back (" max_power_of_two_extend_bytes" );
848+ size_t value;
849+ RETURN_IF_ERROR (ParseUnsignedLongLongValue (string_value, &value));
850+ values.push_back (value);
851+ }
852+ if (!keys.empty ()) {
853+ RETURN_IF_ORT_ERROR (ort_api->CreateArenaCfgV2 (
854+ keys.data (), values.data (), keys.size (), &arena_cfg));
855+
856+ std::ostringstream oss;
857+ for (size_t i = 0 ; i < keys.size (); ++i) {
858+ oss << keys[i] << " =" << values[i] << " , " ;
859+ }
860+ LOG_MESSAGE (
861+ TRITONSERVER_LOG_VERBOSE,
862+ (std::string (" Updated arena config options: " ) + oss.str ())
863+ .c_str ());
864+ }
865+ }
866+ }
867+ std::unique_ptr<OrtArenaCfg, decltype (ort_api->ReleaseArenaCfg )>
868+ rel_arena_cfg (arena_cfg, ort_api->ReleaseArenaCfg );
765869 RETURN_IF_ORT_ERROR (ort_api->UpdateCUDAProviderOptionsWithValue (
766- rel_cuda_options.get (), " default_memory_arena_cfg" , nullptr ));
870+ rel_cuda_options.get (), " default_memory_arena_cfg" ,
871+ rel_arena_cfg.get ()));
872+
767873 {
768874 // Parse CUDA EP configurations directly from the parameters field.
769875 // This is deprecated with adding support for CUDA EP in the
0 commit comments