Skip to content

Commit 1e39c25

Browse files
author
Sunil Thaha
authored
Merge pull request #443 from sthaha/estimator-improvements
Estimator improvements
2 parents cb69d9a + 21fe291 commit 1e39c25

File tree

4 files changed

+33
-22
lines changed

4 files changed

+33
-22
lines changed

pyproject.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ dependencies = [
4444
"boto3==1.34.155",
4545
"pymarkdownlnt==0.9.22",
4646
"yamllint==1.35.1",
47+
"requests-file==2.1.0",
4748
]
4849

4950
[project.scripts]
@@ -67,8 +68,6 @@ extra-dependencies = [
6768
"coverage[toml]>=6.5",
6869
"ipdb",
6970
"ipython",
70-
"ipdb",
71-
"ipython",
7271
"pytest",
7372
]
7473

src/kepler_model/estimate/archived_model.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
1+
import logging
2+
13
import requests
4+
from requests_file import FileAdapter
25

36
from kepler_model.estimate.model_server_connector import unpack
47
from kepler_model.util.config import get_init_model_url
58
from kepler_model.util.loader import load_metadata
69
from kepler_model.util.train_types import ModelOutputType
710

11+
logger = logging.getLogger(__name__)
12+
813
failed_list = []
914

1015
FILTER_ITEM_DELIMIT = ";"
@@ -38,19 +43,21 @@ def valid_metrics(metrics, features):
3843
def is_valid_model(metrics, metadata, filters):
3944
if not valid_metrics(metrics, metadata["features"]):
4045
return False
46+
4147
for attrb, val in filters.items():
4248
if not hasattr(metadata, attrb) or getattr(metadata, attrb) is None:
43-
print("{} has no {}".format(metadata["model_name"], attrb))
49+
logger.warning(f"{metadata['model_name']} has no {attrb}")
4450
return False
45-
else:
46-
cmp_val = getattr(metadata, attrb)
47-
val = float(val)
48-
if attrb == "abs_max_corr": # higher is better
49-
valid = cmp_val >= val
50-
else: # lower is better
51-
valid = cmp_val <= val
52-
if not valid:
53-
return False
51+
52+
cmp_val = getattr(metadata, attrb)
53+
val = float(val)
54+
if attrb == "abs_max_corr": # higher is better
55+
valid = cmp_val >= val
56+
else: # lower is better
57+
valid = cmp_val <= val
58+
if not valid:
59+
return False
60+
5461
return True
5562

5663

@@ -60,21 +67,25 @@ def reset_failed_list():
6067

6168

6269
def get_achived_model(power_request):
63-
print("get archived model")
6470
global failed_list
6571
output_type_name = power_request.output_type
6672
if output_type_name in failed_list:
6773
return None
6874
output_type = ModelOutputType[power_request.output_type]
6975
url = get_init_model_url(power_request.energy_source, output_type_name)
7076
if url == "":
71-
print("no URL set for ", output_type_name, power_request.energy_source)
77+
logger.warning(f"no URL set for {output_type_name}, {power_request.energy_source}")
7278
return None
73-
print(f"try getting archieved model from URL: {url} for {output_type_name}")
74-
response = requests.get(url)
75-
print(response)
79+
logger.info(f"try getting archieved model from URL: {url} for {output_type_name}")
80+
81+
s = requests.Session()
82+
s.mount("file://", FileAdapter())
83+
response = s.get(url)
84+
logger.debug(f"response: {response}")
85+
7686
if response.status_code != 200:
7787
return None
88+
7889
output_path = unpack(power_request.energy_source, output_type, response, replace=False)
7990
if output_path is not None:
8091
metadata = load_metadata(output_path)
@@ -83,7 +94,8 @@ def get_achived_model(power_request):
8394
if not is_valid_model(power_request.metrics, metadata, filters):
8495
failed_list += [output_type_name]
8596
return None
86-
except:
87-
print("cannot validate the archived model")
97+
except Exception as e:
98+
logger.warning(f"cannot validate the archived model: {e}")
8899
return None
100+
89101
return output_path

src/kepler_model/estimate/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from kepler_model.estimate.model.model import load_downloaded_model
1414
from kepler_model.estimate.model_server_connector import is_model_server_enabled, make_request
1515
from kepler_model.train.profiler.node_type_index import NodeTypeSpec, discover_spec_values, get_machine_spec
16-
from kepler_model.util.config import CONFIG_PATH, SERVE_SOCKET, download_path, set_env_from_model_config, set_config_dir
16+
from kepler_model.util.config import CONFIG_PATH, SERVE_SOCKET, download_path, set_config_dir, set_env_from_model_config
1717
from kepler_model.util.loader import get_download_output_path, load_metadata
1818
from kepler_model.util.train_types import ModelOutputType, convert_enery_source, is_output_type_supported
1919

src/kepler_model/util/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
#
1313
#################################################
1414

15-
import os
1615
import logging
16+
import os
1717

1818
import requests
1919

@@ -165,7 +165,7 @@ def get_init_model_url(energy_source, output_type, model_topurl=model_topurl):
165165
for prefix in modelConfigPrefix:
166166
if get_energy_source(prefix) == energy_source:
167167
modelURL = get_init_url(prefix)
168-
logger.info("get init url", modelURL)
168+
logger.info(f"get init url: {modelURL}")
169169
url = get_url(
170170
feature_group=FeatureGroup.BPFOnly,
171171
output_type=ModelOutputType[output_type],

0 commit comments

Comments
 (0)