Skip to content

Commit cea2f94

Browse files
committed
fix bug on same output type of power requests
Signed-off-by: Sunyanan Choochotkaew <[email protected]>
1 parent b6dad1e commit cea2f94

File tree

2 files changed

+15
-11
lines changed

2 files changed

+15
-11
lines changed

src/estimate/estimator.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,12 @@ def handle_request(data):
5454
return {"powers": dict(), "msg": msg}
5555

5656
output_type = ModelOutputType[power_request.output_type]
57+
energy_source = power_request.energy_source
5758

5859
if output_type.name not in loaded_model:
60+
loaded_model[output_type.name] = dict()
61+
62+
if energy_source not in loaded_model[output_type.name]:
5963
output_path = get_download_output_path(download_path, power_request.energy_source, output_type)
6064
if not os.path.exists(output_path):
6165
# try connecting to model server
@@ -71,11 +75,11 @@ def handle_request(data):
7175
print("load model from config: ", output_path)
7276
else:
7377
print("load model from model server: ", output_path)
74-
loaded_model[output_type.name] = load_downloaded_model(power_request.energy_source, output_type)
78+
loaded_model[output_type.name][energy_source] = load_downloaded_model(power_request.energy_source, output_type)
7579
# remove loaded model
7680
shutil.rmtree(output_path)
7781

78-
model = loaded_model[output_type.name]
82+
model = loaded_model[output_type.name][energy_source]
7983
powers, msg = model.get_power(power_request.datapoint)
8084
if msg != "":
8185
print("{} fail to predict, removed".format(model.model_name))

tests/estimator_model_request_test.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@
5454
for fg_name, best_model in valid_fgs.items():
5555
if os.path.exists(output_path):
5656
shutil.rmtree(output_path)
57-
if output_type.name in loaded_model:
58-
del loaded_model[output_type.name]
57+
if output_type.name in loaded_model and energy_source in loaded_model[output_type.name]:
58+
del loaded_model[output_type.name][energy_source]
5959
metrics = FeatureGroups[FeatureGroup[fg_name]]
6060
request_json = generate_request(None, n=10, metrics=metrics, output_type=output_type_name)
6161
data = json.dumps(request_json)
@@ -74,8 +74,8 @@
7474
response = requests.get(url)
7575
if response.status_code == 200:
7676
output_path = get_download_output_path(download_path, energy_source, output_type)
77-
if output_type_name in loaded_model:
78-
del loaded_model[output_type_name]
77+
if output_type_name in loaded_model and energy_source in loaded_model[output_type.name]:
78+
del loaded_model[output_type_name][energy_source]
7979
if os.path.exists(output_path):
8080
shutil.rmtree(output_path)
8181
request_json = generate_request(None, n=10, metrics=FeatureGroups[FeatureGroup.Full], output_type=output_type_name)
@@ -94,8 +94,8 @@
9494
os.environ['MODEL_SERVER_ENABLE'] = "false"
9595
output_type = ModelOutputType[output_type_name]
9696
output_path = get_download_output_path(download_path, energy_source, output_type)
97-
if output_type_name in loaded_model:
98-
del loaded_model[output_type_name]
97+
if output_type_name in loaded_model and energy_source in loaded_model[output_type.name]:
98+
del loaded_model[output_type_name][energy_source]
9999
if os.path.exists(output_path):
100100
shutil.rmtree(output_path)
101101
# valid model
@@ -106,7 +106,7 @@
106106
output = handle_request(data)
107107
assert len(output['powers']) > 0, "cannot get power {}\n {}".format(output['msg'], request_json)
108108
print("result {}/{} from static set: {}".format(output_type_name, FeatureGroup.KubeletOnly.name, output))
109-
del loaded_model[output_type_name]
109+
del loaded_model[output_type_name][energy_source]
110110
# invalid model
111111
os.environ[init_url_key] = get_url(output_type=output_type, feature_group=FeatureGroup.BPFOnly)
112112
print("Requesting from ", os.environ[init_url_key])
@@ -119,8 +119,8 @@
119119
set_env_from_model_config()
120120
print("Requesting from ", os.environ[init_url_key])
121121
reset_failed_list()
122-
if output_type_name in loaded_model:
123-
del loaded_model[output_type_name]
122+
if output_type_name in loaded_model and energy_source in loaded_model[output_type.name]:
123+
del loaded_model[output_type_name][energy_source]
124124
output_path = get_download_output_path(download_path, energy_source, output_type)
125125
if os.path.exists(output_path):
126126
shutil.rmtree(output_path)

0 commit comments

Comments
 (0)