Skip to content

Commit 3b7ddad

Browse files
author
Sunil Thaha
committed
feat(estimator): allow config dir to be overridden
Signed-off-by: Sunil Thaha <[email protected]>
1 parent c7b9b47 commit 3b7ddad

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

src/kepler_model/estimate/estimator.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from kepler_model.estimate.model.model import load_downloaded_model
1414
from kepler_model.estimate.model_server_connector import is_model_server_enabled, make_request
1515
from kepler_model.train.profiler.node_type_index import NodeTypeSpec, discover_spec_values, get_machine_spec
16-
from kepler_model.util.config import SERVE_SOCKET, download_path, set_env_from_model_config
16+
from kepler_model.util.config import CONFIG_PATH, SERVE_SOCKET, download_path, set_env_from_model_config, set_config_dir
1717
from kepler_model.util.loader import get_download_output_path, load_metadata
1818
from kepler_model.util.train_types import ModelOutputType, convert_enery_source, is_output_type_supported
1919

@@ -185,7 +185,14 @@ def sig_handler(signum, frame) -> None:
185185
type=click.Path(exists=True),
186186
required=False,
187187
)
188-
def run(log_level: str, machine_spec: str):
188+
@click.option(
189+
"--config-dir",
190+
"-c",
191+
type=click.Path(exists=False, dir_okay=True, file_okay=False),
192+
default=CONFIG_PATH,
193+
required=False,
194+
)
195+
def run(log_level: str, machine_spec: str, config_dir: str) -> int:
189196
level = getattr(logging, log_level.upper())
190197
logging.basicConfig(
191198
level=level,
@@ -194,6 +201,8 @@ def run(log_level: str, machine_spec: str):
194201
)
195202

196203
logger.info("starting estimator")
204+
set_config_dir(config_dir)
205+
197206
set_env_from_model_config()
198207
clean_socket()
199208
signal.signal(signal.SIGTERM, sig_handler)

src/kepler_model/util/config.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@
4848
SERVE_SOCKET = "/tmp/estimator.sock"
4949

5050

51+
def set_config_dir(config_dir: str):
52+
global CONFIG_PATH
53+
CONFIG_PATH = config_dir
54+
55+
5156
def getConfig(key: str, default):
5257
# check configmap path
5358
file = os.path.join(CONFIG_PATH, key)
@@ -74,8 +79,6 @@ def getPath(subpath):
7479
# use local path if not exists or cannot write
7580
MNT_PATH = os.path.join(os.path.dirname(__file__), "..")
7681

77-
CONFIG_PATH = getConfig("CONFIG_PATH", CONFIG_PATH)
78-
7982
model_topurl = getConfig("MODEL_TOPURL", base_model_url)
8083
initial_pipeline_urls = getConfig("INITIAL_PIPELINE_URL", "")
8184
if initial_pipeline_urls == "":

0 commit comments

Comments
 (0)