Skip to content

Commit 1b8159d

Browse files
add plot method and other minor edits (#206)
1 parent 7a2621a commit 1b8159d

File tree

1 file changed

+55
-22
lines changed

1 file changed

+55
-22
lines changed

src/trustyai/utils/extras/metrics_service.py

Lines changed: 55 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import datetime as dt
66
import pandas as pd
77
import requests
8+
import matplotlib.pyplot as plt
89

910
from trustyai.utils.api.api import TrustyAIApi
1011

@@ -128,7 +129,7 @@ def print_name_mapping(self):
128129
f"{self.trusty_url}/info/names",
129130
json=payload,
130131
headers=self.headers,
131-
verify=True,
132+
verify=self.verify,
132133
timeout=timeout,
133134
)
134135
if response.status_code == 200:
@@ -182,27 +183,59 @@ def upload_data_to_model(self, model_name: str, json_file: str, timeout=5):
182183
return response.text
183184
raise RuntimeError(f"Error {response.status_code}: {response.reason}")
184185

185-
def get_metric_data(
186-
self, namespace: str, metric: str, time_interval: List[str], timeout=5
187-
):
186+
def get_metric_data(self, metric: str, time_interval: List[str], timeout=5):
188187
"""
189-
Retrives metric data for a specific range in time
188+
Retrives metric data for a specific range in time for each subcategory in data field
190189
"""
191-
params = {"query": f"{metric}{{namespace='{namespace}'}}{time_interval}"}
192-
response = requests.get(
193-
f"{self.thanos_url}/api/v1/query?",
194-
params=params,
195-
headers=self.headers,
196-
verify=self.verify,
197-
timeout=timeout,
198-
)
199-
if response.status_code == 200:
200-
data_dict = json.loads(response.text)["data"]["result"][0]["values"]
201-
metric_df = pd.DataFrame(data_dict, columns=["timestamp", metric])
202-
metric_df["timestamp"] = metric_df["timestamp"].apply(
203-
lambda epoch: dt.datetime.fromtimestamp(epoch).strftime(
204-
"%Y-%m-%d %H:%M:%S"
205-
)
190+
metric_df = pd.DataFrame()
191+
for subcategory in list(
192+
self.get_model_metadata()[0]["data"]["inputSchema"]["nameMapping"].values()
193+
):
194+
params = {
195+
"query": f"{metric}{{subcategory='{subcategory}'}}{time_interval}"
196+
}
197+
198+
response = requests.get(
199+
f"{self.thanos_url}/api/v1/query?",
200+
params=params,
201+
headers=self.headers,
202+
verify=self.verify,
203+
timeout=timeout,
206204
)
207-
return metric_df
208-
raise RuntimeError(f"Error {response.status_code}: {response.reason}")
205+
if response.status_code == 200:
206+
if "timestamp" in metric_df.columns:
207+
pass
208+
else:
209+
metric_df["timestamp"] = [
210+
item[0]
211+
for item in json.loads(response.text)["data"]["result"][0][
212+
"values"
213+
]
214+
]
215+
metric_df[subcategory] = [
216+
item[1]
217+
for item in json.loads(response.text)["data"]["result"][0]["values"]
218+
]
219+
else:
220+
raise RuntimeError(f"Error {response.status_code}: {response.reason}")
221+
222+
metric_df["timestamp"] = metric_df["timestamp"].apply(
223+
lambda epoch: dt.datetime.fromtimestamp(epoch).strftime("%Y-%m-%d %H:%M:%S")
224+
)
225+
return metric_df
226+
227+
@staticmethod
228+
def plot_metric(metric_df: pd.DataFrame, metric: str):
229+
"""
230+
Plots a line for each subcategory in the pandas DataFrame returned by get_metric_request
231+
with the timestamp on x-axis and specified metric on the y-axis
232+
"""
233+
plt.figure(figsize=(12, 5))
234+
for col in metric_df.columns[1:]:
235+
plt.plot(metric_df["timestamp"], metric_df[col])
236+
plt.xlabel("timestamp")
237+
plt.ylabel(metric)
238+
plt.xticks(rotation=45)
239+
plt.legend(metric_df.columns[1:])
240+
plt.tight_layout()
241+
plt.show()

0 commit comments

Comments
 (0)