Skip to content

Commit 9aae372

Browse files
committed
chore: add type hint for select_best_model,is_valid_model
Signed-off-by: Sunyanan Choochotkaew <[email protected]>
1 parent 7745a19 commit 9aae372

File tree

3 files changed

+12
-12
lines changed

3 files changed

+12
-12
lines changed

src/kepler_model/estimate/model_server_connector.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,15 +75,15 @@ def list_all_models(energy_source=None, output_type=None, feature_group=None, no
7575
endpoint = get_model_server_list_endpoint()
7676
params= {}
7777
if energy_source:
78-
params[ModelListParam.EnergySource] = energy_source
78+
params[ModelListParam.EnergySource.value] = energy_source
7979
if output_type:
80-
params[ModelListParam.OutputType] = output_type
80+
params[ModelListParam.OutputType.value] = output_type
8181
if feature_group:
82-
params[ModelListParam.FeatureGroup] = feature_group
82+
params[ModelListParam.FeatureGroup.value] = feature_group
8383
if node_type:
84-
params[ModelListParam.NodeType] = node_type
84+
params[ModelListParam.NodeType.value] = node_type
8585
if filter:
86-
params[ModelListParam.Filter] = filter
86+
params[ModelListParam.Filter.value] = filter
8787

8888
response = requests.get(endpoint, params=params)
8989
except Exception as err:

src/kepler_model/server/model_server.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ class ModelListParam(enum.Enum):
101101
"""
102102

103103

104-
def select_best_model(spec, valid_group_path, filters, energy_source, pipeline_name="", trainer_name="", node_type=any_node_type, weight=False):
104+
def select_best_model(spec, valid_group_path: str, filters: dict, energy_source: str, pipeline_name: str="", trainer_name: str="", node_type: int=any_node_type, weight: bool=False):
105105
model_names = [f for f in os.listdir(valid_group_path) if f != CHECKPOINT_FOLDERNAME and not os.path.isfile(os.path.join(valid_group_path, f)) and (trainer_name == "" or trainer_name in f)]
106106
if weight:
107107
model_names = [name for name in model_names if name.split("_")[0] in weight_support_trainers]
@@ -194,11 +194,11 @@ def get_model():
194194
# get_available_models: return name list of best-candidate pipelines
195195
@app.route(MODEL_SERVER_MODEL_LIST_PATH, methods=["GET"])
196196
def get_available_models():
197-
fg = request.args.get(ModelListParam.FeatureGroup)
198-
ot = request.args.get(ModelListParam.OutputType)
199-
energy_source = request.args.get(ModelListParam.EnergySource)
200-
node_type = request.args.get(ModelListParam.NodeType)
201-
filter = request.args.get(ModelListParam.Filter)
197+
fg = request.args.get(ModelListParam.FeatureGroup.value)
198+
ot = request.args.get(ModelListParam.OutputType.value)
199+
energy_source = request.args.get(ModelListParam.EnergySource.value)
200+
node_type = request.args.get(ModelListParam.NodeType.value)
201+
filter = request.args.get(ModelListParam.Filter.value)
202202

203203
try:
204204
if fg is None:

src/kepler_model/util/loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def parse_filters(filter):
173173
return filters
174174

175175

176-
def is_valid_model(metadata, filters):
176+
def is_valid_model(metadata, filters: dict):
177177
for attrb, val in filters.items():
178178
if not hasattr(metadata, attrb) or getattr(metadata, attrb) is None:
179179
print("{} has no {}".format(metadata["model_name"], attrb))

0 commit comments

Comments
 (0)