Skip to content

Commit 56ef16d

Browse files
author
Sunil Thaha
authored
Merge pull request #396 from sunya-ch/server-api-rebase-patch-4
address issue in comment #392, define ModelListParam
2 parents 8525a41 + 9aae372 commit 56ef16d

File tree

3 files changed

+26
-12
lines changed

3 files changed

+26
-12
lines changed

src/kepler_model/estimate/model_server_connector.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
)
1414
from kepler_model.util.loader import get_download_output_path
1515
from kepler_model.util.train_types import ModelOutputType
16-
16+
from kepler_model.server.model_server import ModelListParam
1717

1818
# discover_spec: determine node spec in json format (refer to NodeTypeSpec)
1919
def discover_spec():
@@ -68,16 +68,23 @@ def make_request(power_request):
6868
return unpack(power_request.energy_source, output_type, response)
6969

7070

71-
def list_all_models(energy_source=None, node_type=None):
71+
def list_all_models(energy_source=None, output_type=None, feature_group=None, node_type=None, filter=None):
7272
if not is_model_server_enabled():
7373
return dict()
7474
try:
7575
endpoint = get_model_server_list_endpoint()
7676
params= {}
7777
if energy_source:
78-
params["source"] = energy_source
78+
params[ModelListParam.EnergySource.value] = energy_source
79+
if output_type:
80+
params[ModelListParam.OutputType.value] = output_type
81+
if feature_group:
82+
params[ModelListParam.FeatureGroup.value] = feature_group
7983
if node_type:
80-
params["type"] = node_type
84+
params[ModelListParam.NodeType.value] = node_type
85+
if filter:
86+
params[ModelListParam.Filter.value] = filter
87+
8188
response = requests.get(endpoint, params=params)
8289
except Exception as err:
8390
print(f"cannot list model: {err}")
@@ -86,4 +93,3 @@ def list_all_models(energy_source=None, node_type=None):
8693
return dict()
8794
model_names = json.loads(response.content.decode("utf-8"))
8895
return model_names
89-

src/kepler_model/server/model_server.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import shutil
55
import sys
6+
import enum
67

78
import click
89
import requests
@@ -71,6 +72,13 @@ def __init__(self, metrics, output_type, source="rapl-sysfs", node_type=-1, weig
7172
if spec is not None:
7273
self.spec = NodeTypeSpec(**spec)
7374

75+
# ModelListParams defines parameters for /best-models API
76+
class ModelListParam(enum.Enum):
77+
EnergySource = "energy-source"
78+
OutputType = "output-type"
79+
FeatureGroup = "feature-group"
80+
NodeType = "node-type"
81+
Filter = "filter"
7482

7583
###########################################
7684
MODEL_SERVER_PORT = int(getConfig("MODEL_SERVER_PORT", "8100"))
@@ -93,7 +101,7 @@ def __init__(self, metrics, output_type, source="rapl-sysfs", node_type=-1, weig
93101
"""
94102

95103

96-
def select_best_model(spec, valid_group_path, filters, energy_source, pipeline_name="", trainer_name="", node_type=any_node_type, weight=False):
104+
def select_best_model(spec, valid_group_path: str, filters: dict, energy_source: str, pipeline_name: str="", trainer_name: str="", node_type: int=any_node_type, weight: bool=False):
97105
model_names = [f for f in os.listdir(valid_group_path) if f != CHECKPOINT_FOLDERNAME and not os.path.isfile(os.path.join(valid_group_path, f)) and (trainer_name == "" or trainer_name in f)]
98106
if weight:
99107
model_names = [name for name in model_names if name.split("_")[0] in weight_support_trainers]
@@ -186,11 +194,11 @@ def get_model():
186194
# get_available_models: return name list of best-candidate pipelines
187195
@app.route(MODEL_SERVER_MODEL_LIST_PATH, methods=["GET"])
188196
def get_available_models():
189-
fg = request.args.get("fg")
190-
ot = request.args.get("ot")
191-
energy_source = request.args.get("source")
192-
node_type = request.args.get("type")
193-
filter = request.args.get("filter")
197+
fg = request.args.get(ModelListParam.FeatureGroup.value)
198+
ot = request.args.get(ModelListParam.OutputType.value)
199+
energy_source = request.args.get(ModelListParam.EnergySource.value)
200+
node_type = request.args.get(ModelListParam.NodeType.value)
201+
filter = request.args.get(ModelListParam.Filter.value)
194202

195203
try:
196204
if fg is None:

src/kepler_model/util/loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def parse_filters(filter):
173173
return filters
174174

175175

176-
def is_valid_model(metadata, filters):
176+
def is_valid_model(metadata, filters: dict):
177177
for attrb, val in filters.items():
178178
if not hasattr(metadata, attrb) or getattr(metadata, attrb) is None:
179179
print("{} has no {}".format(metadata["model_name"], attrb))

0 commit comments

Comments
 (0)