1- // Copyright 2018-2023 , NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+ // Copyright 2018-2024 , NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22//
33// Redistribution and use in source and binary forms, with or without
44// modification, are permitted provided that the following conditions
@@ -182,10 +182,39 @@ class LocalizeRepoAgent : public TritonRepoAgent {
182182 }
183183};
184184
185+ // / Get the model config path to load for the model.
186+ const std::string
187+ GetModelConfigFullPath (
188+ const std::string& model_dir_path, const std::string& custom_config_name)
189+ {
190+ // "--model-config-name" is set. Select custom config from
191+ // "<model_dir_path>/configs" folder if config file exists.
192+ if (!custom_config_name.empty ()) {
193+ bool custom_config_exists = false ;
194+ const std::string custom_config_path = JoinPath (
195+ {model_dir_path, kModelConfigFolder ,
196+ custom_config_name + kPbTxtExtension });
197+
198+ Status status = FileExists (custom_config_path, &custom_config_exists);
199+ if (!status.IsOk ()) {
200+ LOG_ERROR << " Failed to get model configuration full path for '"
201+ << model_dir_path << " ': " << status.AsString ();
202+ return " " ;
203+ }
204+
205+ if (custom_config_exists) {
206+ return custom_config_path;
207+ }
208+ }
209+ // "--model-config-name" is not set or custom config file does not exist.
210+ return JoinPath ({model_dir_path, kModelConfigPbTxt });
211+ }
212+
185213Status
186214CreateAgentModelListWithLoadAction (
187215 const inference::ModelConfig& original_model_config,
188216 const std::string& original_model_path,
217+ const std::string& model_config_name,
189218 std::shared_ptr<TritonRepoAgentModelList>* agent_model_list)
190219{
191220 if (original_model_config.has_model_repository_agents ()) {
@@ -218,7 +247,8 @@ CreateAgentModelListWithLoadAction(
218247 std::unique_ptr<TritonRepoAgentModel> agent_model;
219248 if (lagent_model_list->Size () != 0 ) {
220249 lagent_model_list->Back ()->Location (&artifact_type, &location);
221- const auto config_path = JoinPath ({location, kModelConfigPbTxt });
250+ const auto config_path =
251+ GetModelConfigFullPath (location, model_config_name);
222252 if (!ReadTextProto (config_path, &model_config).IsOk ()) {
223253 model_config.Clear ();
224254 }
@@ -283,10 +313,12 @@ GetModifiedTime(const std::string& path)
283313}
284314// Return the latest modification time in ns for '<config.pbtxt, model files>'
285315// in a model directory path. The time for "config.pbtxt" will be 0 if not
286- // found at "[model_dir_path]/config.pbtxt". The time for "model files" includes
287- // the time for 'model_dir_path'.
316+ // found at "[model_dir_path]/config.pbtxt" or "[model_dir_path]/configs/
317+ // <custom-config-name>.pbtxt" if "--model-config-name" is set. The time for
318+ // "model files" includes the time for 'model_dir_path'.
288319std::pair<int64_t , int64_t >
289- GetDetailedModifiedTime (const std::string& model_dir_path)
320+ GetDetailedModifiedTime (
321+ const std::string& model_dir_path, const std::string& model_config_path)
290322{
291323 // Check if 'model_dir_path' is a directory.
292324 bool is_dir;
@@ -322,12 +354,10 @@ GetDetailedModifiedTime(const std::string& model_dir_path)
322354 }
323355 // Get latest modification time for each files/folders, and place it at the
324356 // correct category.
325- const std::string model_config_full_path (
326- JoinPath ({model_dir_path, kModelConfigPbTxt }));
327357 for (const auto & child : contents) {
328358 const auto full_path = JoinPath ({model_dir_path, child});
329- if (full_path == model_config_full_path ) {
330- // config.pbtxt
359+ if (full_path == model_config_path ) {
360+ // config.pbtxt or customized config file in configs folder
331361 mtime.first = GetModifiedTime (full_path);
332362 } else {
333363 // model files
@@ -343,9 +373,10 @@ GetDetailedModifiedTime(const std::string& model_dir_path)
343373// modified time.
344374bool
345375IsModified (
346- const std::string& model_dir_path, std::pair<int64_t , int64_t >* last_ns)
376+ const std::string& model_dir_path, const std::string& model_config_path,
377+ std::pair<int64_t , int64_t >* last_ns)
347378{
348- auto new_ns = GetDetailedModifiedTime (model_dir_path);
379+ auto new_ns = GetDetailedModifiedTime (model_dir_path, model_config_path );
349380 bool modified = std::max (new_ns.first , new_ns.second ) >
350381 std::max (last_ns->first , last_ns->second );
351382 last_ns->swap (new_ns);
@@ -356,10 +387,12 @@ IsModified(
356387
357388ModelRepositoryManager::ModelRepositoryManager (
358389 const std::set<std::string>& repository_paths, const bool autofill,
359- const bool polling_enabled, const bool model_control_enabled,
360- const double min_compute_capability, const bool enable_model_namespacing,
390+ const std::string& model_config_name, const bool polling_enabled,
391+ const bool model_control_enabled, const double min_compute_capability,
392+ const bool enable_model_namespacing,
361393 std::unique_ptr<ModelLifeCycle> life_cycle)
362- : autofill_(autofill), polling_enabled_(polling_enabled),
394+ : autofill_(autofill), model_config_name_(model_config_name),
395+ polling_enabled_ (polling_enabled),
363396 model_control_enabled_(model_control_enabled),
364397 min_compute_capability_(min_compute_capability),
365398 dependency_graph_(&global_map_),
@@ -385,7 +418,8 @@ ModelRepositoryManager::Create(
385418 InferenceServer* server, const std::string& server_version,
386419 const std::set<std::string>& repository_paths,
387420 const std::set<std::string>& startup_models, const bool strict_model_config,
388- const bool polling_enabled, const bool model_control_enabled,
421+ const std::string& model_config_name, const bool polling_enabled,
422+ const bool model_control_enabled,
389423 const ModelLifeCycleOptions& life_cycle_options,
390424 const bool enable_model_namespacing,
391425 std::unique_ptr<ModelRepositoryManager>* model_repository_manager)
@@ -414,9 +448,10 @@ ModelRepositoryManager::Create(
414448 // Not setting the smart pointer directly to simplify clean up
415449 std::unique_ptr<ModelRepositoryManager> local_manager (
416450 new ModelRepositoryManager (
417- repository_paths, !strict_model_config, polling_enabled,
418- model_control_enabled, life_cycle_options.min_compute_capability ,
419- enable_model_namespacing, std::move (life_cycle)));
451+ repository_paths, !strict_model_config, model_config_name,
452+ polling_enabled, model_control_enabled,
453+ life_cycle_options.min_compute_capability , enable_model_namespacing,
454+ std::move (life_cycle)));
420455 *model_repository_manager = std::move (local_manager);
421456
422457 // Support loading all models on startup in explicit model control mode with
@@ -549,7 +584,7 @@ ModelRepositoryManager::LoadModelByDependency(
549584 // encapsulate the interaction:
550585 // Each iteration:
551586 // - Check dependency graph for nodes that are ready for lifecycle changes:
552- // - load if all dependencies are satisfied and the node is 'heathy '
587+ // - load if all dependencies are satisfied and the node is 'healthy '
553588 // - unload otherwise (should revisit this, logically will only happen in
554589 // ensemble, the ensemble is requested to be re-loaded, at this point
555590 // it is too late to revert model changes so the ensemble will not be
@@ -1298,10 +1333,11 @@ ModelRepositoryManager::Poll(
12981333 // its state will fallback to the state before the polling.
12991334 for (const auto & pair : model_to_path) {
13001335 std::unique_ptr<ModelInfo> model_info;
1336+ const auto & model_name = pair.first .name_ ;
13011337 // Load with parameters will be appiled to all models with the same
13021338 // name (namespace can be different), unless namespace is specified
13031339 // in the future.
1304- const auto & mit = models.find (pair. first . name_ );
1340+ const auto & mit = models.find (model_name );
13051341 static std::vector<const InferenceParameter*> empty_params;
13061342 auto status = InitializeModelInfo (
13071343 pair.first , pair.second ,
@@ -1401,17 +1437,22 @@ ModelRepositoryManager::InitializeModelInfo(
14011437 // the override while the local files may still be unchanged.
14021438 linfo->mtime_nsec_ = std::make_pair (0 , 0 );
14031439 linfo->model_path_ = location;
1440+ linfo->model_config_path_ = JoinPath ({location, kModelConfigPbTxt });
14041441 linfo->agent_model_list_ .reset (new TritonRepoAgentModelList ());
14051442 linfo->agent_model_list_ ->AddAgentModel (std::move (localize_agent_model));
14061443 } else {
1444+ linfo->model_config_path_ =
1445+ GetModelConfigFullPath (linfo->model_path_ , model_config_name_);
1446+ // Model is not loaded.
14071447 if (iitr == infos_.end ()) {
1408- linfo->mtime_nsec_ = GetDetailedModifiedTime (linfo->model_path_ );
1448+ linfo->mtime_nsec_ = GetDetailedModifiedTime (
1449+ linfo->model_path_ , linfo->model_config_path_ );
14091450 } else {
14101451 // Check the current timestamps to determine if model actually has been
14111452 // modified
14121453 linfo->mtime_nsec_ = linfo->prev_mtime_ns_ ;
1413- unmodified =
1414- ! IsModified ( std::string ( linfo->model_path_ ) , &linfo->mtime_nsec_ );
1454+ unmodified = ! IsModified (
1455+ linfo->model_path_ , linfo-> model_config_path_ , &linfo->mtime_nsec_ );
14151456 }
14161457 }
14171458
@@ -1461,7 +1502,7 @@ ModelRepositoryManager::InitializeModelInfo(
14611502 // this must be done before normalizing model config as agents might
14621503 // redirect to use the model config at a different location
14631504 if (!parsed_config) {
1464- const auto config_path = JoinPath ({ linfo->model_path_ , kModelConfigPbTxt }) ;
1505+ const auto config_path = linfo->model_config_path_ ;
14651506 bool model_config_exists = false ;
14661507 RETURN_IF_ERROR (FileExists (config_path, &model_config_exists));
14671508 // model config can be missing if auto fill is set
@@ -1474,7 +1515,8 @@ ModelRepositoryManager::InitializeModelInfo(
14741515 }
14751516 if (parsed_config) {
14761517 RETURN_IF_ERROR (CreateAgentModelListWithLoadAction (
1477- linfo->model_config_ , linfo->model_path_ , &linfo->agent_model_list_ ));
1518+ linfo->model_config_ , linfo->model_path_ , model_config_name_,
1519+ &linfo->agent_model_list_ ));
14781520 if (linfo->agent_model_list_ != nullptr ) {
14791521 // Get the latest repository path
14801522 const char * location;
0 commit comments