1313from kepler_model .estimate .model .model import load_downloaded_model
1414from kepler_model .estimate .model_server_connector import is_model_server_enabled , make_request
1515from kepler_model .train .profiler .node_type_index import NodeTypeSpec , discover_spec_values , get_machine_spec
16- from kepler_model .util .config import SERVE_SOCKET , download_path , set_env_from_model_config
16+ from kepler_model .util .config import CONFIG_PATH , SERVE_SOCKET , download_path , set_env_from_model_config , set_config_dir
1717from kepler_model .util .loader import get_download_output_path , load_metadata
1818from kepler_model .util .train_types import ModelOutputType , convert_enery_source , is_output_type_supported
1919
@@ -185,7 +185,14 @@ def sig_handler(signum, frame) -> None:
185185 type = click .Path (exists = True ),
186186 required = False ,
187187)
188- def run (log_level : str , machine_spec : str ):
188+ @click .option (
189+ "--config-dir" ,
190+ "-c" ,
191+ type = click .Path (exists = False , dir_okay = True , file_okay = False ),
192+ default = CONFIG_PATH ,
193+ required = False ,
194+ )
195+ def run (log_level : str , machine_spec : str , config_dir : str ) -> int :
189196 level = getattr (logging , log_level .upper ())
190197 logging .basicConfig (
191198 level = level ,
@@ -194,6 +201,8 @@ def run(log_level: str, machine_spec: str):
194201 )
195202
196203 logger .info ("starting estimator" )
204+ set_config_dir (config_dir )
205+
197206 set_env_from_model_config ()
198207 clean_socket ()
199208 signal .signal (signal .SIGTERM , sig_handler )
0 commit comments