33import os
44import shutil
55import sys
6+ import enum
67
78import click
89import requests
@@ -71,6 +72,13 @@ def __init__(self, metrics, output_type, source="rapl-sysfs", node_type=-1, weig
7172 if spec is not None :
7273 self .spec = NodeTypeSpec (** spec )
7374
75+ # ModelListParams defines parameters for /best-models API
76+ class ModelListParam (enum .Enum ):
77+ EnergySource = "energy-source"
78+ OutputType = "output-type"
79+ FeatureGroup = "feature-group"
80+ NodeType = "node-type"
81+ Filter = "filter"
7482
7583###########################################
7684MODEL_SERVER_PORT = int (getConfig ("MODEL_SERVER_PORT" , "8100" ))
@@ -93,7 +101,7 @@ def __init__(self, metrics, output_type, source="rapl-sysfs", node_type=-1, weig
93101"""
94102
95103
96- def select_best_model (spec , valid_group_path , filters , energy_source , pipeline_name = "" , trainer_name = "" , node_type = any_node_type , weight = False ):
104+ def select_best_model (spec , valid_group_path : str , filters : dict , energy_source : str , pipeline_name : str = "" , trainer_name : str = "" , node_type : int = any_node_type , weight : bool = False ):
97105 model_names = [f for f in os .listdir (valid_group_path ) if f != CHECKPOINT_FOLDERNAME and not os .path .isfile (os .path .join (valid_group_path , f )) and (trainer_name == "" or trainer_name in f )]
98106 if weight :
99107 model_names = [name for name in model_names if name .split ("_" )[0 ] in weight_support_trainers ]
@@ -186,11 +194,11 @@ def get_model():
186194# get_available_models: return name list of best-candidate pipelines
187195@app .route (MODEL_SERVER_MODEL_LIST_PATH , methods = ["GET" ])
188196def get_available_models ():
189- fg = request .args .get ("fg" )
190- ot = request .args .get ("ot" )
191- energy_source = request .args .get ("source" )
192- node_type = request .args .get ("type" )
193- filter = request .args .get ("filter" )
197+ fg = request .args .get (ModelListParam . FeatureGroup . value )
198+ ot = request .args .get (ModelListParam . OutputType . value )
199+ energy_source = request .args .get (ModelListParam . EnergySource . value )
200+ node_type = request .args .get (ModelListParam . NodeType . value )
201+ filter = request .args .get (ModelListParam . Filter . value )
194202
195203 try :
196204 if fg is None :
0 commit comments