Skip to content

Commit 27a13ed

Browse files
author
Sunil Thaha
authored
Merge pull request #456 from sthaha/feat-config-dir-as-arg
feat: support --config-dir arg to point to the configuration directory
2 parents c7b9b47 + 38436a0 commit 27a13ed

File tree

4 files changed

+42
-33
lines changed

4 files changed

+42
-33
lines changed

src/kepler_model/estimate/estimator.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from kepler_model.estimate.model.model import load_downloaded_model
1414
from kepler_model.estimate.model_server_connector import is_model_server_enabled, make_request
1515
from kepler_model.train.profiler.node_type_index import NodeTypeSpec, discover_spec_values, get_machine_spec
16-
from kepler_model.util.config import SERVE_SOCKET, download_path, set_env_from_model_config
16+
from kepler_model.util.config import CONFIG_PATH, SERVE_SOCKET, download_path, set_env_from_model_config, set_config_dir
1717
from kepler_model.util.loader import get_download_output_path, load_metadata
1818
from kepler_model.util.train_types import ModelOutputType, convert_enery_source, is_output_type_supported
1919

@@ -185,7 +185,14 @@ def sig_handler(signum, frame) -> None:
185185
type=click.Path(exists=True),
186186
required=False,
187187
)
188-
def run(log_level: str, machine_spec: str):
188+
@click.option(
189+
"--config-dir",
190+
"-c",
191+
type=click.Path(exists=False, dir_okay=True, file_okay=False),
192+
default=CONFIG_PATH,
193+
required=False,
194+
)
195+
def run(log_level: str, machine_spec: str, config_dir: str) -> int:
189196
level = getattr(logging, log_level.upper())
190197
logging.basicConfig(
191198
level=level,
@@ -194,6 +201,8 @@ def run(log_level: str, machine_spec: str):
194201
)
195202

196203
logger.info("starting estimator")
204+
set_config_dir(config_dir)
205+
197206
set_env_from_model_config()
198207
clean_socket()
199208
signal.signal(signal.SIGTERM, sig_handler)

src/kepler_model/server/model_server.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@
1111

1212
from kepler_model.train import NodeTypeIndexCollection, NodeTypeSpec
1313
from kepler_model.util.config import (
14+
CONFIG_PATH,
1415
ERROR_KEY,
1516
MODEL_SERVER_MODEL_LIST_PATH,
1617
MODEL_SERVER_MODEL_REQ_PATH,
1718
download_path,
1819
getConfig,
1920
initial_pipeline_urls,
2021
model_toppath,
22+
set_config_dir,
2123
)
2224
from kepler_model.util.loader import (
2325
CHECKPOINT_FOLDERNAME,
@@ -430,11 +432,25 @@ def fill_machine_spec():
430432
default="info",
431433
required=False,
432434
)
433-
def run(log_level: str):
435+
@click.option(
436+
"--config-dir",
437+
"-c",
438+
type=click.Path(exists=False, dir_okay=True, file_okay=False),
439+
default=CONFIG_PATH,
440+
required=False,
441+
)
442+
def run(log_level: str, config_dir: str) -> int:
434443
level = getattr(logging, log_level.upper())
435-
logging.basicConfig(level=level)
444+
logging.basicConfig(
445+
level=level,
446+
format="%(asctime)s %(levelname)s %(filename)s:%(lineno)s: %(message)s",
447+
datefmt="%Y-%m-%d %H:%M:%S",
448+
)
449+
450+
set_config_dir(config_dir)
436451
load_init_pipeline()
437452
app.run(host="0.0.0.0", port=MODEL_SERVER_PORT)
453+
return 0
438454

439455

440456
if __name__ == "__main__":

src/kepler_model/util/config.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@
4848
SERVE_SOCKET = "/tmp/estimator.sock"
4949

5050

51+
def set_config_dir(config_dir: str):
52+
global CONFIG_PATH
53+
CONFIG_PATH = config_dir
54+
55+
5156
def getConfig(key: str, default):
5257
# check configmap path
5358
file = os.path.join(CONFIG_PATH, key)
@@ -74,8 +79,6 @@ def getPath(subpath):
7479
# use local path if not exists or cannot write
7580
MNT_PATH = os.path.join(os.path.dirname(__file__), "..")
7681

77-
CONFIG_PATH = getConfig("CONFIG_PATH", CONFIG_PATH)
78-
7982
model_topurl = getConfig("MODEL_TOPURL", base_model_url)
8083
initial_pipeline_urls = getConfig("INITIAL_PIPELINE_URL", "")
8184
if initial_pipeline_urls == "":
@@ -123,10 +126,16 @@ def set_env_from_model_config():
123126
return
124127

125128
for line in model_config.splitlines():
126-
splits = line.split("=")
129+
line = line.strip()
130+
# ignore comments and blanks
131+
if not line or line.startswith("#"):
132+
continue
133+
134+
# pick only the first part until # and ignore the rest
135+
splits = line.split("#")[0].strip().split("=")
127136
if len(splits) > 1:
128137
os.environ[splits[0].strip()] = splits[1].strip()
129-
logging.info(f"set {splits[0]} to {splits[1]}.")
138+
logging.info(f"set env {splits[0]} to '{splits[1]}'.")
130139

131140

132141
def is_estimator_enable(prefix):

tests/query_test.py

Lines changed: 0 additions & 25 deletions
This file was deleted.

0 commit comments

Comments
 (0)