@@ -101,7 +101,7 @@ class ModelListParam(enum.Enum):
101101"""
102102
103103
104- def select_best_model (spec , valid_group_path , filters , energy_source , pipeline_name = "" , trainer_name = "" , node_type = any_node_type , weight = False ):
104+ def select_best_model (spec , valid_group_path : str , filters : dict , energy_source : str , pipeline_name : str = "" , trainer_name : str = "" , node_type : int = any_node_type , weight : bool = False ):
105105 model_names = [f for f in os .listdir (valid_group_path ) if f != CHECKPOINT_FOLDERNAME and not os .path .isfile (os .path .join (valid_group_path , f )) and (trainer_name == "" or trainer_name in f )]
106106 if weight :
107107 model_names = [name for name in model_names if name .split ("_" )[0 ] in weight_support_trainers ]
@@ -194,11 +194,11 @@ def get_model():
194194# get_available_models: return name list of best-candidate pipelines
195195@app .route (MODEL_SERVER_MODEL_LIST_PATH , methods = ["GET" ])
196196def get_available_models ():
197- fg = request .args .get (ModelListParam .FeatureGroup )
198- ot = request .args .get (ModelListParam .OutputType )
199- energy_source = request .args .get (ModelListParam .EnergySource )
200- node_type = request .args .get (ModelListParam .NodeType )
201- filter = request .args .get (ModelListParam .Filter )
197+ fg = request .args .get (ModelListParam .FeatureGroup . value )
198+ ot = request .args .get (ModelListParam .OutputType . value )
199+ energy_source = request .args .get (ModelListParam .EnergySource . value )
200+ node_type = request .args .get (ModelListParam .NodeType . value )
201+ filter = request .args .get (ModelListParam .Filter . value )
202202
203203 try :
204204 if fg is None :
0 commit comments