1818from itertools import chain
1919import logging
2020import matplotlib .pyplot as plt
21+ from functools import wraps
2122
2223DEFAULT_METRICS = [
2324 "sklearn.metrics.accuracy_score" ,
2930]
3031
3132LOGGER = logging .getLogger (__name__ )
33+
34+ def guide (method ):
35+
36+ @wraps (method )
37+ def guided_step (self , * method_args , ** method_kwargs ):
38+ expected_next_step = self .current_step + 1
39+ method_name = method .__name__
40+ if method_name in self .producer_to_step_map :
41+ actual_next_step = self .producer_to_step_map [method_name ]
42+ if actual_next_step > expected_next_step :
43+ necessary_steps_str = self ._get_necessary_steps (actual_next_step )
44+ LOGGER .error (f"Required steps have been skipped! Unable to run { method_name } . Please perform the following steps first { necessary_steps_str } " )
45+ return
46+ elif actual_next_step < self .current_step : #regressing, make stale data, warn
47+ try :
48+ res = method (self , * method_args , ** method_kwargs )
49+ LOGGER .warning (f"The last run step was { self .current_step } . The following methods will return stale data. Please perform the following steps in order to get up to date." )
50+ self .current_step = actual_next_step
51+ return res
52+ except Exception as e :
53+ LOGGER .error (f"{ method_name } threw an exception" , exc_info = e )
54+ return
55+ else :
56+ try :
57+ res = method (self , * method_args , ** method_kwargs )
58+ self .current_step = actual_next_step
59+
60+ # do logging here
61+ return res
62+ except Exception as e :
63+ LOGGER .error (f"{ method_name } threw an exception" , exc_info = e )
64+
65+ elif method_name in self .getter_to_step_map :
66+ actual_next_step = self .getter_to_step_map [method_name ]
67+ if actual_next_step > expected_next_step :
68+ try :
69+ res = method (self , * method_args , ** method_kwargs )
70+ if res is None :
71+ LOGGER .error (f"Required steps have been skipped!. { method_name } does not have a value to return. Please perform the following steps in order before running this method." )
72+ else :
73+ LOGGER .warning (f"This data may be stale. Please perform the following steps in order to ensure the response is up to date." )
74+ return res
75+ except Exception as e :
76+ LOGGER .error (f"{ method_name } threw an exception" , exc_info = e )
77+ return
78+ else :
79+ try :
80+ res = method (self , * method_args , ** method_kwargs )
81+ return res
82+ except Exception as e :
83+ LOGGER .error (f"{ method_name } threw an exception" , exc_info = e )
84+ else :
85+ print (f"Method { method_name } does not need to be wrapped" )
86+
87+
88+ return guided_step
89+
3290class Zephyr :
3391
3492 def __init__ (self ):
@@ -42,14 +100,40 @@ def __init__(self):
42100 self .X_test = None
43101 self .y_train = None
44102 self .y_test = None
103+ self .is_fitted = None
45104 self .results = None
46105
106+ self .current_step = - 1
107+ # tuple of 2 arrays: producers and attributes
108+ self .step_order = [
109+ ([self .create_entityset , self .set_entityset ], [self .get_entityset ]),
110+ ([self .set_labeling_function ], [self .get_labeling_function ]),
111+ ([self .generate_label_times ], [self .get_label_times ]),
112+ ([self .generate_feature_matrix_and_labels , self .set_feature_matrix_and_labels ], [self .get_feature_matrix_and_labels ]),
113+ ([self .generate_train_test_split , self .set_train_test_split ], [self .get_train_test_split ]),
114+ ([self .set_pipeline ], [self .get_pipeline ]),
115+ ([self .fit ], []),
116+ ([self .predict , self .evaluate ], [])
117+ ]
118+
119+ self .producer_to_step_map = {}
120+ self .getter_to_step_map = {}
121+ for idx , (producers , getters ) in enumerate (self .step_order ):
122+ for prod in producers :
123+ self .producer_to_step_map [prod .__name__ ] = idx
124+ for get in getters :
125+ self .getter_to_step_map [get .__name__ ] = idx
126+
127+ def _get_necessary_steps (self , actual_step ):
128+ pass
129+
47130 def get_entityset_types (self ):
48131 """
49- Returns the supported entityset types (PI/SCADA) and the required dataframes and their columns
132+ Returns the supported entityset types (PI/SCADA/Vibrations ) and the required dataframes and their columns
50133 """
51134 return VALIDATE_DATA_FUNCTIONS .keys ()
52135
136+ @guide
53137 def create_entityset (self , data_paths , es_type , new_kwargs_mapping = None ):
54138 """
55139 Generate an entityset
@@ -68,12 +152,7 @@ def create_entityset(self, data_paths, es_type, new_kwargs_mapping=None):
68152 self .entityset = entityset
69153 return self .entityset
70154
71- def get_entityset (self ):
72- if self .entityset is None :
73- raise ValueError ("No entityset has been created or set in this instance." )
74-
75- return self .entityset
76-
155+ @guide
77156 def set_entityset (self , entityset , es_type , new_kwargs_mapping = None ):
78157 dfs = entityset .to_dictionary ()
79158
@@ -82,9 +161,18 @@ def set_entityset(self, entityset, es_type, new_kwargs_mapping=None):
82161
83162 self .entityset = entityset
84163
164+ @guide
165+ def get_entityset (self ):
166+ if self .entityset is None :
167+ raise ValueError ("No entityset has been created or set in this instance." )
168+
169+ return self .entityset
170+
171+
85172 def get_predefined_labeling_functions (self ):
86173 return get_labeling_functions ()
87174
175+ @guide
88176 def set_labeling_function (self , name = None , func = None ):
89177 print (f"labeling fucntion name { name } " )
90178 if name is not None :
@@ -103,7 +191,12 @@ def set_labeling_function(self, name=None, func=None):
103191 else :
104192 raise ValueError (f"Custom function is not callable" )
105193 raise ValueError ("No labeling function given." )
106-
194+
195+ @guide
196+ def get_labeling_function (self ):
197+ return self .labeling_function
198+
199+ @guide
107200 def generate_label_times (
108201 self , num_samples = - 1 , subset = None , column_map = {}, verbose = False , ** kwargs
109202 ):
@@ -143,12 +236,14 @@ def generate_label_times(
143236
144237 return label_times , meta
145238
146- def plot_label_times ( self ):
147- assert self . label_times is not None
148- cp . label_times . plots . LabelPlots ( self . label_times ). distribution ()
149-
150- def generate_features ( self , ** kwargs ):
239+ @ guide
240+ def get_label_times ( self , visualize = True ):
241+ if visualize :
242+ cp . label_times . plots . LabelPlots ( self . label_times ). distribution ()
243+ return self . label_times
151244
245+ @guide
246+ def generate_feature_matrix_and_labels (self , ** kwargs ):
152247 feature_matrix , features = ft .dfs (
153248 entityset = self .entityset , cutoff_time = self .label_times , ** kwargs
154249 )
@@ -157,15 +252,18 @@ def generate_features(self, **kwargs):
157252 print (feature_matrix )
158253 return feature_matrix , features
159254
255+ @guide
160256 def get_feature_matrix_and_labels (self ):
161257 return self .feature_matrix_and_labels
162258
259+ @guide
163260 def set_feature_matrix_and_labels (self , feature_matrix , label_col_name = "label" ):
164261 assert label_col_name in feature_matrix .columns
165262 self .feature_matrix_and_labels = self ._clean_feature_matrix (
166263 feature_matrix , label_col_name = label_col_name
167264 )
168265
266+ @guide
169267 def generate_train_test_split (
170268 self ,
171269 test_size = None ,
@@ -191,25 +289,32 @@ def generate_train_test_split(
191289
192290 return
193291
292+ @guide
194293 def set_train_test_split (self , X_train , X_test , y_train , y_test ):
195294 self .X_train = X_train
196295 self .X_test = X_test
197296 self .y_train = y_train
198297 self .y_test = y_test
199298
299+ @guide
200300 def get_train_test_split (self ):
301+ if self .X_train is None or self .X_test is None or self .y_train is None or self .y_test is None :
302+ return None
201303 return self .X_train , self .X_test , self .y_train , self .y_test
202304
203305 def get_predefined_pipelines (self ):
204306 pass
205307
308+ @guide
206309 def set_pipeline (self , pipeline , pipeline_hyperparameters = None ):
207310 self .pipeline = self ._get_mlpipeline (pipeline , pipeline_hyperparameters )
208311 self .pipeline_hyperparameters = pipeline_hyperparameters
209312
313+ @guide
210314 def get_pipeline (self ):
211315 return self .pipeline
212-
316+
317+ @guide
213318 def fit (
214319 self , X = None , y = None , visual = False , ** kwargs
215320 ): # kwargs indicate the parameters of the current pipeline
@@ -228,6 +333,7 @@ def fit(
228333 if visual and outputs is not None :
229334 return dict (zip (visual_names , outputs ))
230335
336+ @guide
231337 def predict (self , X = None , visual = False , ** kwargs ):
232338 if X is None :
233339 X = self .X_test
@@ -244,6 +350,7 @@ def predict(self, X=None, visual=False, **kwargs):
244350
245351 return outputs
246352
353+ @guide
247354 def evaluate (self , X = None , y = None , metrics = None , show_plots = True ):
248355 if X is None :
249356 X = self .X_test
@@ -271,10 +378,6 @@ def evaluate(self, X=None, y=None, metrics=None, show_plots = True):
271378 return results
272379
273380
274- def _validate_step (self , ** kwargs ):
275- for key , value in kwargs :
276- assert (value is not None , f"{ key } has not been set or created" )
277-
278381 def _clean_feature_matrix (self , feature_matrix , label_col_name = "label" ):
279382 labels = feature_matrix .pop (label_col_name )
280383
@@ -473,9 +576,11 @@ def _get_outputs_spec(self, default=True):
473576 obj .set_labeling_function (name = "brake_pad_presence" )
474577
475578 obj .generate_label_times (num_samples = 35 , gap = "20d" )
476- obj .plot_label_times ()
579+ obj .get_label_times ()
580+
581+ obj .generate_train_test_split ()
477582
478- obj .generate_features (
583+ obj .generate_feature_matrix_and_labels (
479584 target_dataframe_name = "turbines" ,
480585 cutoff_time_in_index = True ,
481586 agg_primitives = ["count" , "sum" , "max" ],
0 commit comments