Skip to content

Commit 8525a41

Browse files
authored
Merge pull request #391 from sunya-ch/server-api-rebase-patch-3
chore: fix typo valid_groupath to valid_group_path
2 parents a0ad11f + 542b471 commit 8525a41

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

src/kepler_model/server/model_server.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def __init__(self, metrics, output_type, source="rapl-sysfs", node_type=-1, weig
8383

8484
"""
8585
select_best_model:
86-
1. list model_names from valid_grouppath (determined by valid features)
86+
1. list model_names from valid_group_path (determined by valid features)
8787
2. filter weight-supported model if requesting for model weight
8888
3. filter matched type by requesting node_type or node_collection over node spec
8989
4. if no candidate left, list model with largest number of cores
@@ -93,8 +93,8 @@ def __init__(self, metrics, output_type, source="rapl-sysfs", node_type=-1, weig
9393
"""
9494

9595

96-
def select_best_model(spec, valid_groupath, filters, energy_source, pipeline_name="", trainer_name="", node_type=any_node_type, weight=False):
97-
model_names = [f for f in os.listdir(valid_groupath) if f != CHECKPOINT_FOLDERNAME and not os.path.isfile(os.path.join(valid_groupath, f)) and (trainer_name == "" or trainer_name in f)]
96+
def select_best_model(spec, valid_group_path, filters, energy_source, pipeline_name="", trainer_name="", node_type=any_node_type, weight=False):
97+
model_names = [f for f in os.listdir(valid_group_path) if f != CHECKPOINT_FOLDERNAME and not os.path.isfile(os.path.join(valid_group_path, f)) and (trainer_name == "" or trainer_name in f)]
9898
if weight:
9999
model_names = [name for name in model_names if name.split("_")[0] in weight_support_trainers]
100100
# Load metadata of trainers
@@ -113,7 +113,7 @@ def select_best_model(spec, valid_groupath, filters, energy_source, pipeline_nam
113113
logger.warning("no large candidates; selecting from all available")
114114
candidates = model_names
115115
for model_name in candidates:
116-
model_savepath = os.path.join(valid_groupath, model_name)
116+
model_savepath = os.path.join(valid_group_path, model_name)
117117
metadata = load_json(model_savepath, METADATA_FILENAME)
118118
if metadata is None or not is_valid_model(metadata, filters) or ERROR_KEY not in metadata:
119119
# invalid metadata
@@ -126,7 +126,7 @@ def select_best_model(spec, valid_groupath, filters, energy_source, pipeline_nam
126126
logger.warning(f"weight failed: {model_savepath}")
127127
continue
128128
else:
129-
response = get_archived_file(valid_groupath, model_name)
129+
response = get_archived_file(valid_group_path, model_name)
130130
if not os.path.exists(response):
131131
# archived model file does not exists
132132
logger.warning(f"archive failed: {response}")
@@ -160,9 +160,9 @@ def get_model():
160160
best_response = None
161161
# find best model comparing best candidate from each valid feature group complied with filtering conditions
162162
for fg in valid_fgs:
163-
valid_groupath = get_model_group_path(model_toppath, output_type, fg, energy_source, pipeline_name=pipelineName[energy_source])
164-
if os.path.exists(valid_groupath):
165-
best_candidate, response = select_best_model(req.spec, valid_groupath, filters, energy_source, req.pipeline_name, req.trainer_name, req.node_type, req.weight)
163+
valid_group_path = get_model_group_path(model_toppath, output_type, fg, energy_source, pipeline_name=pipelineName[energy_source])
164+
if os.path.exists(valid_group_path):
165+
best_candidate, response = select_best_model(req.spec, valid_group_path, filters, energy_source, req.pipeline_name, req.trainer_name, req.node_type, req.weight)
166166
if best_candidate is None:
167167
continue
168168
if best_model is None or best_model[ERROR_KEY] > best_candidate[ERROR_KEY]:
@@ -222,9 +222,9 @@ def get_available_models():
222222
model_names[output_type.name] = dict()
223223
for fg in valid_fgs:
224224
logger.debug(f"Searching feature group {fg}")
225-
valid_groupath = get_model_group_path(model_toppath, output_type, fg, energy_source, pipeline_name=pipelineName[energy_source])
226-
if os.path.exists(valid_groupath):
227-
best_candidate, _ = select_best_model(None, valid_groupath, filters, energy_source, node_type=node_type)
225+
valid_group_path = get_model_group_path(model_toppath, output_type, fg, energy_source, pipeline_name=pipelineName[energy_source])
226+
if os.path.exists(valid_group_path):
227+
best_candidate, _ = select_best_model(None, valid_group_path, filters, energy_source, node_type=node_type)
228228
if best_candidate is None:
229229
continue
230230
model_names[output_type.name][fg.name] = best_candidate["model_name"]

0 commit comments

Comments
 (0)