@@ -83,7 +83,7 @@ def __init__(self, metrics, output_type, source="rapl-sysfs", node_type=-1, weig
8383
8484"""
8585select_best_model:
86- 1. list model_names from valid_grouppath (determined by valid features)
86+ 1. list model_names from valid_group_path (determined by valid features)
87872. filter weight-supported model if requesting for model weight
88883. filter matched type by requesting node_type or node_collection over node spec
89894. if no candidate left, list model with largest number of cores
@@ -93,8 +93,8 @@ def __init__(self, metrics, output_type, source="rapl-sysfs", node_type=-1, weig
9393"""
9494
9595
96- def select_best_model (spec , valid_groupath , filters , energy_source , pipeline_name = "" , trainer_name = "" , node_type = any_node_type , weight = False ):
97- model_names = [f for f in os .listdir (valid_groupath ) if f != CHECKPOINT_FOLDERNAME and not os .path .isfile (os .path .join (valid_groupath , f )) and (trainer_name == "" or trainer_name in f )]
96+ def select_best_model (spec , valid_group_path , filters , energy_source , pipeline_name = "" , trainer_name = "" , node_type = any_node_type , weight = False ):
97+ model_names = [f for f in os .listdir (valid_group_path ) if f != CHECKPOINT_FOLDERNAME and not os .path .isfile (os .path .join (valid_group_path , f )) and (trainer_name == "" or trainer_name in f )]
9898 if weight :
9999 model_names = [name for name in model_names if name .split ("_" )[0 ] in weight_support_trainers ]
100100 # Load metadata of trainers
@@ -113,7 +113,7 @@ def select_best_model(spec, valid_groupath, filters, energy_source, pipeline_nam
113113 logger .warning ("no large candidates; selecting from all available" )
114114 candidates = model_names
115115 for model_name in candidates :
116- model_savepath = os .path .join (valid_groupath , model_name )
116+ model_savepath = os .path .join (valid_group_path , model_name )
117117 metadata = load_json (model_savepath , METADATA_FILENAME )
118118 if metadata is None or not is_valid_model (metadata , filters ) or ERROR_KEY not in metadata :
119119 # invalid metadata
@@ -126,7 +126,7 @@ def select_best_model(spec, valid_groupath, filters, energy_source, pipeline_nam
126126 logger .warning (f"weight failed: { model_savepath } " )
127127 continue
128128 else :
129- response = get_archived_file (valid_groupath , model_name )
129+ response = get_archived_file (valid_group_path , model_name )
130130 if not os .path .exists (response ):
131131 # archived model file does not exists
132132 logger .warning (f"archive failed: { response } " )
@@ -160,9 +160,9 @@ def get_model():
160160 best_response = None
161161 # find best model comparing best candidate from each valid feature group complied with filtering conditions
162162 for fg in valid_fgs :
163- valid_groupath = get_model_group_path (model_toppath , output_type , fg , energy_source , pipeline_name = pipelineName [energy_source ])
164- if os .path .exists (valid_groupath ):
165- best_candidate , response = select_best_model (req .spec , valid_groupath , filters , energy_source , req .pipeline_name , req .trainer_name , req .node_type , req .weight )
163+ valid_group_path = get_model_group_path (model_toppath , output_type , fg , energy_source , pipeline_name = pipelineName [energy_source ])
164+ if os .path .exists (valid_group_path ):
165+ best_candidate , response = select_best_model (req .spec , valid_group_path , filters , energy_source , req .pipeline_name , req .trainer_name , req .node_type , req .weight )
166166 if best_candidate is None :
167167 continue
168168 if best_model is None or best_model [ERROR_KEY ] > best_candidate [ERROR_KEY ]:
@@ -222,9 +222,9 @@ def get_available_models():
222222 model_names [output_type .name ] = dict ()
223223 for fg in valid_fgs :
224224 logger .debug (f"Searching feature group { fg } " )
225- valid_groupath = get_model_group_path (model_toppath , output_type , fg , energy_source , pipeline_name = pipelineName [energy_source ])
226- if os .path .exists (valid_groupath ):
227- best_candidate , _ = select_best_model (None , valid_groupath , filters , energy_source , node_type = node_type )
225+ valid_group_path = get_model_group_path (model_toppath , output_type , fg , energy_source , pipeline_name = pipelineName [energy_source ])
226+ if os .path .exists (valid_group_path ):
227+ best_candidate , _ = select_best_model (None , valid_group_path , filters , energy_source , node_type = node_type )
228228 if best_candidate is None :
229229 continue
230230 model_names [output_type .name ][fg .name ] = best_candidate ["model_name" ]
0 commit comments