Skip to content

Commit 8b274d5

Browse files
committed
logging & update entityset tests
1 parent dd7fc95 commit 8b274d5

File tree

3 files changed

+130
-24
lines changed

3 files changed

+130
-24
lines changed

tests/test_entityset.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pandas as pd
22
import pytest
33

4-
from zephyr_ml import create_pidata_entityset, create_scada_entityset
4+
from zephyr_ml import _create_entityset
55

66

77
@pytest.fixture
@@ -118,7 +118,11 @@ def scada_dfs(base_dfs):
118118
})
119119
return {**base_dfs, 'scada': scada_df}
120120

121+
def create_pidata_entityset(pidata_dfs):
122+
return _create_entityset(pidata_dfs, es_type = "pidata")
121123

124+
def create_scada_entityset(scada_dfs):
125+
return _create_entityset(scada_dfs, es_type = "scada")
122126
def test_create_pidata_missing_entities(pidata_dfs):
123127
error_msg = 'Missing dataframes for entities notifications.'
124128

zephyr_ml/core.py

Lines changed: 125 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from itertools import chain
1919
import logging
2020
import matplotlib.pyplot as plt
21+
from functools import wraps
2122

2223
DEFAULT_METRICS = [
2324
"sklearn.metrics.accuracy_score",
@@ -29,6 +30,63 @@
2930
]
3031

3132
LOGGER = logging.getLogger(__name__)
33+
34+
def guide(method):
35+
36+
@wraps(method)
37+
def guided_step(self, *method_args, **method_kwargs):
38+
expected_next_step = self.current_step + 1
39+
method_name = method.__name__
40+
if method_name in self.producer_to_step_map:
41+
actual_next_step = self.producer_to_step_map[method_name]
42+
if actual_next_step > expected_next_step:
43+
necessary_steps_str = self._get_necessary_steps(actual_next_step)
44+
LOGGER.error(f"Required steps have been skipped! Unable to run {method_name}. Please perform the following steps first {necessary_steps_str}")
45+
return
46+
elif actual_next_step < self.current_step: #regressing, make stale data, warn
47+
try:
48+
res = method(self, *method_args, **method_kwargs)
49+
LOGGER.warning(f"The last run step was {self.current_step}. The following methods will return stale data. Please perform the following steps in order to get up to date.")
50+
self.current_step = actual_next_step
51+
return res
52+
except Exception as e:
53+
LOGGER.error(f"{method_name} threw an exception", exc_info = e)
54+
return
55+
else:
56+
try:
57+
res = method(self, *method_args, **method_kwargs)
58+
self.current_step = actual_next_step
59+
60+
# do logging here
61+
return res
62+
except Exception as e:
63+
LOGGER.error(f"{method_name} threw an exception", exc_info = e)
64+
65+
elif method_name in self.getter_to_step_map:
66+
actual_next_step = self.getter_to_step_map[method_name]
67+
if actual_next_step > expected_next_step:
68+
try:
69+
res = method(self, *method_args, **method_kwargs)
70+
if res is None:
71+
LOGGER.error(f"Required steps have been skipped!. {method_name} does not have a value to return. Please perform the following steps in order before running this method.")
72+
else:
73+
LOGGER.warning(f"This data may be stale. Please perform the following steps in order to ensure the response is up to date.")
74+
return res
75+
except Exception as e:
76+
LOGGER.error(f"{method_name} threw an exception", exc_info = e)
77+
return
78+
else:
79+
try:
80+
res = method(self, *method_args, **method_kwargs)
81+
return res
82+
except Exception as e:
83+
LOGGER.error(f"{method_name} threw an exception", exc_info = e)
84+
else:
85+
print(f"Method {method_name} does not need to be wrapped")
86+
87+
88+
return guided_step
89+
3290
class Zephyr:
3391

3492
def __init__(self):
@@ -42,14 +100,40 @@ def __init__(self):
42100
self.X_test = None
43101
self.y_train = None
44102
self.y_test = None
103+
self.is_fitted = None
45104
self.results = None
46105

106+
self.current_step = -1
107+
# tuple of 2 arrays: producers and attributes
108+
self.step_order = [
109+
([self.create_entityset, self.set_entityset], [self.get_entityset]),
110+
([self.set_labeling_function], [self.get_labeling_function]),
111+
([self.generate_label_times], [self.get_label_times]),
112+
([self.generate_feature_matrix_and_labels, self.set_feature_matrix_and_labels], [self.get_feature_matrix_and_labels]),
113+
([self.generate_train_test_split, self.set_train_test_split], [self.get_train_test_split]),
114+
([self.set_pipeline], [self.get_pipeline]),
115+
([self.fit], []),
116+
([self.predict, self.evaluate], [])
117+
]
118+
119+
self.producer_to_step_map = {}
120+
self.getter_to_step_map = {}
121+
for idx, (producers, getters) in enumerate(self.step_order):
122+
for prod in producers:
123+
self.producer_to_step_map[prod.__name__] = idx
124+
for get in getters:
125+
self.getter_to_step_map[get.__name__] = idx
126+
127+
def _get_necessary_steps(self, actual_step):
128+
pass
129+
47130
def get_entityset_types(self):
48131
"""
49-
Returns the supported entityset types (PI/SCADA) and the required dataframes and their columns
132+
Returns the supported entityset types (PI/SCADA/Vibrations) and the required dataframes and their columns
50133
"""
51134
return VALIDATE_DATA_FUNCTIONS.keys()
52135

136+
@guide
53137
def create_entityset(self, data_paths, es_type, new_kwargs_mapping=None):
54138
"""
55139
Generate an entityset
@@ -68,12 +152,7 @@ def create_entityset(self, data_paths, es_type, new_kwargs_mapping=None):
68152
self.entityset = entityset
69153
return self.entityset
70154

71-
def get_entityset(self):
72-
if self.entityset is None:
73-
raise ValueError("No entityset has been created or set in this instance.")
74-
75-
return self.entityset
76-
155+
@guide
77156
def set_entityset(self, entityset, es_type, new_kwargs_mapping=None):
78157
dfs = entityset.to_dictionary()
79158

@@ -82,9 +161,18 @@ def set_entityset(self, entityset, es_type, new_kwargs_mapping=None):
82161

83162
self.entityset = entityset
84163

164+
@guide
165+
def get_entityset(self):
166+
if self.entityset is None:
167+
raise ValueError("No entityset has been created or set in this instance.")
168+
169+
return self.entityset
170+
171+
85172
def get_predefined_labeling_functions(self):
86173
return get_labeling_functions()
87174

175+
@guide
88176
def set_labeling_function(self, name=None, func=None):
89177
print(f"labeling fucntion name {name}")
90178
if name is not None:
@@ -103,7 +191,12 @@ def set_labeling_function(self, name=None, func=None):
103191
else:
104192
raise ValueError(f"Custom function is not callable")
105193
raise ValueError("No labeling function given.")
106-
194+
195+
@guide
196+
def get_labeling_function(self):
197+
return self.labeling_function
198+
199+
@guide
107200
def generate_label_times(
108201
self, num_samples=-1, subset=None, column_map={}, verbose=False, **kwargs
109202
):
@@ -143,12 +236,14 @@ def generate_label_times(
143236

144237
return label_times, meta
145238

146-
def plot_label_times(self):
147-
assert self.label_times is not None
148-
cp.label_times.plots.LabelPlots(self.label_times).distribution()
149-
150-
def generate_features(self, **kwargs):
239+
@guide
240+
def get_label_times(self, visualize = True):
241+
if visualize:
242+
cp.label_times.plots.LabelPlots(self.label_times).distribution()
243+
return self.label_times
151244

245+
@guide
246+
def generate_feature_matrix_and_labels(self, **kwargs):
152247
feature_matrix, features = ft.dfs(
153248
entityset=self.entityset, cutoff_time=self.label_times, **kwargs
154249
)
@@ -157,15 +252,18 @@ def generate_features(self, **kwargs):
157252
print(feature_matrix)
158253
return feature_matrix, features
159254

255+
@guide
160256
def get_feature_matrix_and_labels(self):
161257
return self.feature_matrix_and_labels
162258

259+
@guide
163260
def set_feature_matrix_and_labels(self, feature_matrix, label_col_name="label"):
164261
assert label_col_name in feature_matrix.columns
165262
self.feature_matrix_and_labels = self._clean_feature_matrix(
166263
feature_matrix, label_col_name=label_col_name
167264
)
168265

266+
@guide
169267
def generate_train_test_split(
170268
self,
171269
test_size=None,
@@ -191,25 +289,32 @@ def generate_train_test_split(
191289

192290
return
193291

292+
@guide
194293
def set_train_test_split(self, X_train, X_test, y_train, y_test):
195294
self.X_train = X_train
196295
self.X_test = X_test
197296
self.y_train = y_train
198297
self.y_test = y_test
199298

299+
@guide
200300
def get_train_test_split(self):
301+
if self.X_train is None or self.X_test is None or self.y_train is None or self.y_test is None:
302+
return None
201303
return self.X_train, self.X_test, self.y_train, self.y_test
202304

203305
def get_predefined_pipelines(self):
204306
pass
205307

308+
@guide
206309
def set_pipeline(self, pipeline, pipeline_hyperparameters=None):
207310
self.pipeline = self._get_mlpipeline(pipeline, pipeline_hyperparameters)
208311
self.pipeline_hyperparameters = pipeline_hyperparameters
209312

313+
@guide
210314
def get_pipeline(self):
211315
return self.pipeline
212-
316+
317+
@guide
213318
def fit(
214319
self, X=None, y=None, visual=False, **kwargs
215320
): # kwargs indicate the parameters of the current pipeline
@@ -228,6 +333,7 @@ def fit(
228333
if visual and outputs is not None:
229334
return dict(zip(visual_names, outputs))
230335

336+
@guide
231337
def predict(self, X=None, visual=False, **kwargs):
232338
if X is None:
233339
X = self.X_test
@@ -244,6 +350,7 @@ def predict(self, X=None, visual=False, **kwargs):
244350

245351
return outputs
246352

353+
@guide
247354
def evaluate(self, X=None, y=None, metrics=None, show_plots = True):
248355
if X is None:
249356
X = self.X_test
@@ -271,10 +378,6 @@ def evaluate(self, X=None, y=None, metrics=None, show_plots = True):
271378
return results
272379

273380

274-
def _validate_step(self, **kwargs):
275-
for key, value in kwargs:
276-
assert (value is not None, f"{key} has not been set or created")
277-
278381
def _clean_feature_matrix(self, feature_matrix, label_col_name="label"):
279382
labels = feature_matrix.pop(label_col_name)
280383

@@ -473,9 +576,11 @@ def _get_outputs_spec(self, default=True):
473576
obj.set_labeling_function(name="brake_pad_presence")
474577

475578
obj.generate_label_times(num_samples=35, gap="20d")
476-
obj.plot_label_times()
579+
obj.get_label_times()
580+
581+
obj.generate_train_test_split()
477582

478-
obj.generate_features(
583+
obj.generate_feature_matrix_and_labels(
479584
target_dataframe_name="turbines",
480585
cutoff_time_in_index=True,
481586
agg_primitives=["count", "sum", "max"],

zephyr_ml/entityset.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -249,9 +249,6 @@ def _create_entityset(entities, es_type, new_kwargs_mapping=None):
249249

250250
validate_func = VALIDATE_DATA_FUNCTIONS[es_type]
251251
es_kwargs = validate_func(entities, new_kwargs_mapping)
252-
print(entities)
253-
print(es_type)
254-
print(es_kwargs)
255252

256253
# filter out stated logical types for missing columns
257254
for entity, df in entities.items():

0 commit comments

Comments
 (0)