Skip to content

Commit a01ca5f

Browse files
authored
Merge pull request #163 from sunya-ch/server-api
change node-type check logic
2 parents b4571cd + d323c0e commit a01ca5f

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

src/server/model_server.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from util.train_types import get_valid_feature_groups, ModelOutputType, FeatureGroups, FeatureGroup
1616
from util.config import getConfig, model_toppath, ERROR_KEY, MODEL_SERVER_MODEL_REQ_PATH, MODEL_SERVER_MODEL_LIST_PATH, initial_pipeline_url
17-
from util.loader import parse_filters, is_valid_model, load_json, load_weight, get_model_group_path, get_archived_file, METADATA_FILENAME, CHECKPOINT_FOLDERNAME, get_pipeline_path
17+
from util.loader import parse_filters, is_valid_model, load_json, load_weight, get_model_group_path, get_archived_file, METADATA_FILENAME, CHECKPOINT_FOLDERNAME, get_pipeline_path, any_node_type, is_matched_type
1818

1919
###############################################
2020
# model request
@@ -42,16 +42,17 @@ def __init__(self, metrics, output_type, source='rapl', node_type=-1, weight=Fal
4242
MODEL_SERVER_PORT = getConfig('MODEL_SERVER_PORT', MODEL_SERVER_PORT)
4343
MODEL_SERVER_PORT = int(MODEL_SERVER_PORT)
4444

45-
def select_best_model(valid_groupath, filters, trainer_name="", node_type=-1, weight=False):
45+
def select_best_model(valid_groupath, filters, trainer_name="", node_type=any_node_type, weight=False):
4646
model_names = [f for f in os.listdir(valid_groupath) if \
4747
f != CHECKPOINT_FOLDERNAME \
4848
and not os.path.isfile(os.path.join(valid_groupath,f)) \
49-
and (trainer_name == "" or trainer_name in f) \
50-
and (node_type == -1 or str(node_type) in f) ]
49+
and (trainer_name == "" or trainer_name in f)]
5150
# Load metadata of trainers
5251
best_cadidate = None
5352
best_response = None
5453
for model_name in model_names:
54+
if not is_matched_type(model_name, node_type):
55+
continue
5556
model_savepath = os.path.join(valid_groupath, model_name)
5657
metadata = load_json(model_savepath, METADATA_FILENAME)
5758
if metadata is None or not is_valid_model(metadata, filters) or ERROR_KEY not in metadata:

src/util/loader.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
default_init_pipeline_name = "Linux-4.15.0-213-generic-x86_64_v0.6"
2323
default_trainer_name = "GradientBoostingRegressorTrainer"
2424
default_node_type = "1"
25+
any_node_type = -1
2526
default_feature_group = FeatureGroup.KubeletOnly
2627

2728
def load_json(path, name):
@@ -123,6 +124,11 @@ def is_valid_model(metadata, filters):
123124
def get_model_name(trainer_name, node_type):
124125
return "{}_{}".format(trainer_name, node_type)
125126

127+
def is_matched_type(model_name, node_type):
128+
if node_type == any_node_type:
129+
return True
130+
return model_name.split("_")[-1] == str(node_type)
131+
126132
def get_pipeline_path(model_toppath, pipeline_name=DEFAULT_PIPELINE):
127133
return os.path.join(model_toppath, pipeline_name)
128134

0 commit comments

Comments
 (0)