|
14 | 14 |
|
15 | 15 | from util.train_types import get_valid_feature_groups, ModelOutputType, FeatureGroups, FeatureGroup |
16 | 16 | from util.config import getConfig, model_toppath, ERROR_KEY, MODEL_SERVER_MODEL_REQ_PATH, MODEL_SERVER_MODEL_LIST_PATH, initial_pipeline_url |
17 | | -from util.loader import parse_filters, is_valid_model, load_json, load_weight, get_model_group_path, get_archived_file, METADATA_FILENAME, CHECKPOINT_FOLDERNAME, get_pipeline_path |
| 17 | +from util.loader import parse_filters, is_valid_model, load_json, load_weight, get_model_group_path, get_archived_file, METADATA_FILENAME, CHECKPOINT_FOLDERNAME, get_pipeline_path, any_node_type, is_matched_type |
18 | 18 |
|
19 | 19 | ############################################### |
20 | 20 | # model request |
@@ -42,16 +42,17 @@ def __init__(self, metrics, output_type, source='rapl', node_type=-1, weight=Fal |
42 | 42 | MODEL_SERVER_PORT = getConfig('MODEL_SERVER_PORT', MODEL_SERVER_PORT) |
43 | 43 | MODEL_SERVER_PORT = int(MODEL_SERVER_PORT) |
44 | 44 |
|
45 | | -def select_best_model(valid_groupath, filters, trainer_name="", node_type=-1, weight=False): |
| 45 | +def select_best_model(valid_groupath, filters, trainer_name="", node_type=any_node_type, weight=False): |
46 | 46 | model_names = [f for f in os.listdir(valid_groupath) if \ |
47 | 47 | f != CHECKPOINT_FOLDERNAME \ |
48 | 48 | and not os.path.isfile(os.path.join(valid_groupath,f)) \ |
49 | | - and (trainer_name == "" or trainer_name in f) \ |
50 | | - and (node_type == -1 or str(node_type) in f) ] |
| 49 | + and (trainer_name == "" or trainer_name in f)] |
51 | 50 | # Load metadata of trainers |
52 | 51 | best_cadidate = None |
53 | 52 | best_response = None |
54 | 53 | for model_name in model_names: |
| 54 | + if not is_matched_type(model_name, node_type): |
| 55 | + continue |
55 | 56 | model_savepath = os.path.join(valid_groupath, model_name) |
56 | 57 | metadata = load_json(model_savepath, METADATA_FILENAME) |
57 | 58 | if metadata is None or not is_valid_model(metadata, filters) or ERROR_KEY not in metadata: |
|
0 commit comments