66 get_labeling_functions_map ,
77 LABELING_FUNCTIONS ,
88)
9+ from zephyr_ml .feature_engineering import process_signals
910import composeml as cp
1011from inspect import getfullargspec
1112import featuretools as ft
1920import logging
2021import matplotlib .pyplot as plt
2122from functools import wraps
22-
23+ import inspect
2324DEFAULT_METRICS = [
2425 "sklearn.metrics.accuracy_score" ,
2526 "sklearn.metrics.precision_score" ,
@@ -36,15 +37,18 @@ class GuideHandler:
3637
3738 def __init__ (self , producers_and_getters , set_methods ):
3839 self .cur_term = 0
40+ self .current_step = - 1
3941 self .producers_and_getters = producers_and_getters
4042 self .set_methods = set_methods
4143
4244 self .producer_to_step_map = {}
4345 self .getter_to_step_map = {}
46+
4447 self .terms = []
45-
48+ self . skipped = []
4649 for idx , (producers , getters ) in enumerate (self .producers_and_getters ):
4750 self .terms .append (- 1 )
51+ self .skipped .append (False )
4852
4953 for prod in producers :
5054 self .producer_to_step_map [prod .__name__ ] = idx
@@ -98,6 +102,8 @@ def perform_producer_step(self, method, *method_args, **method_kwargs):
98102 def try_log_skipping_steps_warning (self , name , next_step ):
99103 steps_skipped = self .get_steps_in_between (self .current_step , next_step )
100104 if len (steps_skipped ) > 0 :
105+ for step in range (self .current_step + 1 , next_step ):
106+ self .skipped [step ] = True
101107 necc_steps = self .join_steps (steps_skipped )
102108 LOGGER .warning (f"Performing { name } . You are skipping the following steps:\n { necc_steps } " )
103109
@@ -115,7 +121,7 @@ def try_log_using_stale_warning(self, name, next_step):
115121
116122 def try_log_making_stale_warning (self , name , next_step ):
117123 next_next_step = next_step + 1
118- prod_steps = f"{ next_next_step } . { " or " .join (self .producers_and_getters [next_next_step ][0 ])} "
124+ prod_steps = f"{ next_next_step } . { ' or ' .join (self .producers_and_getters [next_next_step ][0 ])} "
119125 # add later set methods
120126 get_steps = self .join_steps (self .get_get_steps_in_between (next_step , self .current_step + 1 ))
121127
@@ -133,7 +139,7 @@ def try_log_inconsistent_warning(self, name, next_step):
133139 starting at or before { latest_up_to_date } " )
134140
135141 def log_get_inconsistent_warning (self , name , next_step ):
136- prod_steps = f"{ next_step } . { " or " .join (self .producers_and_getters [next_step ][0 ])} "
142+ prod_steps = f"{ next_step } . { ' or ' .join (self .producers_and_getters [next_step ][0 ])} "
137143 latest_up_to_date = self .get_last_up_to_date (next_step )
138144 LOGGER .warning (f"Unable to perform { name } because { prod_steps } has not been run yet. Run steps starting at or before { latest_up_to_date } " )
139145
@@ -186,6 +192,10 @@ def try_perform_stale_or_inconsistent_producer_step(self, method, *method_args,
186192 if self .terms [next_step - 1 ] == - 1 : #inconsistent
187193 self .try_log_inconsistent_warning (name , next_step )
188194 else :
195+ # need to include a case where performing using stale data that was skipped in current iteration
196+ # overwrite current iteration's ?
197+ # no not possible b/c if there is a current iteration after this step, it must have updated this step's iteration
198+ #
189199 self .try_log_using_stale_warning (name , next_step )
190200 res = self .perform_producer_step (method , * method_args , ** method_kwargs )
191201 return res
@@ -274,7 +284,6 @@ def __init__(self):
274284 # tuple of 2 arrays: producers and attributes
275285 self .step_order = [
276286 ([self .generate_entityset , self .set_entityset ], [self .get_entityset ]),
277- # ([self.set_labeling_function], [self.get_labeling_function]),
278287 ([self .generate_label_times , self .set_label_times ], [self .get_label_times ]),
279288 ([self .generate_feature_matrix_and_labels , self .set_feature_matrix_and_labels ], [self .get_feature_matrix_and_labels ]),
280289 ([self .generate_train_test_split , self .set_train_test_split ], [self .get_train_test_split ]),
@@ -291,10 +300,26 @@ def GET_ENTITYSET_TYPES(self):
291300 """
292301 Returns the supported entityset types (PI/SCADA/Vibrations) and the required dataframes and their columns
293302 """
294- return VALIDATE_DATA_FUNCTIONS .keys ()
303+ info_map = {}
304+ for es_type , val_fn in VALIDATE_DATA_FUNCTIONS .items ():
305+ info_map [es_type ] = {"obj" : es_type , "desc" : " " .join ((val_fn .__doc__ .split ()))}
306+
307+ return info_map
308+
309+ def GET_LABELING_FUNCTIONS (self ):
310+ return get_labeling_functions ()
311+
312+ def GET_EVALUATION_METRICS (self ):
313+ info_map = {}
314+ for metric in DEFAULT_METRICS :
315+ primitive = self ._get_ml_primitive (metric )
316+ info_map [metric ] = {"obj" : primitive , "desc" : primitive .metadata ["description" ] }
317+ return info_map
295318
296319 @guide
297- def generate_entityset (self , dfs , es_type , custom_kwargs_mapping = None ):
320+ def generate_entityset (self , dfs , es_type , custom_kwargs_mapping = None ,
321+ signal_dataframe_name = None , signal_column = None , signal_transformations = None ,
322+ signal_aggregations = None , signal_window_size = None , signal_replace_dataframe = False , ** sigpro_kwargs ):
298323 """
299324 Generate an entityset
300325
@@ -309,6 +334,16 @@ def generate_entityset(self, dfs, es_type, custom_kwargs_mapping=None):
309334 their relationships
310335 """
311336 entityset = _create_entityset (dfs , es_type , custom_kwargs_mapping )
337+
338+ #perform signal processing
339+ if signal_dataframe_name is not None and signal_column is not None :
340+ if signal_transformations is None :
341+ signal_transformations = []
342+ if signal_aggregations is None :
343+ signal_aggregations = []
344+ process_signals (entityset , signal_dataframe_name , signal_column , signal_transformations ,
345+ signal_aggregations , signal_window_size , signal_replace_dataframe , ** sigpro_kwargs )
346+
312347 self .entityset = entityset
313348 return self .entityset
314349
@@ -335,8 +370,7 @@ def get_entityset(self):
335370 return self .entityset
336371
337372
338- def GET_LABELING_FUNCTIONS (self ):
339- return get_labeling_functions ()
373+
340374
341375 # @guide
342376 # def set_labeling_function(self, name=None, func=None):
@@ -425,9 +459,37 @@ def get_label_times(self, visualize = True):
425459 return self .label_times
426460
427461 @guide
428- def generate_feature_matrix_and_labels (self , ** kwargs ):
462+ def generate_feature_matrix_and_labels (self , target_dataframe_name = None , instance_ids = None ,
463+ agg_primitives = None , trans_primitives = None , groupby_trans_primitives = None ,
464+ allowed_paths = None , max_depth = 2 , ignore_dataframes = None , ignore_columns = None ,
465+ primitive_options = None , seed_features = None ,
466+ drop_contains = None , drop_exact = None , where_primitives = None , max_features = - 1 ,
467+ cutoff_time_in_index = False , save_progress = None , features_only = False , training_window = None ,
468+ approximate = None , chunk_size = None , n_jobs = 1 , dask_kwargs = None , verbose = False , return_types = None ,
469+ progress_callback = None , include_cutoff_time = True ,
470+
471+ signal_dataframe_name = None , signal_column = None , signal_transformations = None ,
472+ signal_aggregations = None , signal_window_size = None , signal_replace_dataframe = False , ** sigpro_kwargs ):
473+
474+ # perform signal processing
475+ if signal_dataframe_name is not None and signal_column is not None :
476+ if signal_transformations is None :
477+ signal_transformations = []
478+ if signal_aggregations is None :
479+ signal_aggregations = []
480+ process_signals (self .entityset , signal_dataframe_name , signal_column , signal_transformations ,
481+ signal_aggregations , signal_window_size , signal_replace_dataframe , ** sigpro_kwargs )
482+
429483 feature_matrix , features = ft .dfs (
430- entityset = self .entityset , cutoff_time = self .label_times , ** kwargs
484+ entityset = self .entityset , cutoff_time = self .label_times ,
485+ target_dataframe_name = target_dataframe_name , instance_ids = instance_ids ,
486+ agg_primitives = agg_primitives , trans_primitives = trans_primitives , groupby_trans_primitives = groupby_trans_primitives ,
487+ allowed_paths = allowed_paths , max_depth = max_depth , ignore_dataframes = ignore_dataframes , ignore_columns = ignore_columns ,
488+ primitive_options = primitive_options , seed_features = seed_features ,
489+ drop_contains = drop_contains , drop_exact = drop_exact , where_primitives = where_primitives , max_features = max_features ,
490+ cutoff_time_in_index = cutoff_time_in_index , save_progress = save_progress , features_only = features_only , training_window = training_window ,
491+ approximate = approximate , chunk_size = chunk_size , n_jobs = n_jobs , dask_kwargs = dask_kwargs , verbose = verbose , return_types = return_types ,
492+ progress_callback = progress_callback , include_cutoff_time = include_cutoff_time ,
431493 )
432494 self .feature_matrix_and_labels = self ._clean_feature_matrix (feature_matrix )
433495 self .features = features
@@ -546,7 +608,8 @@ def predict(self, X=None, visual=False, **kwargs):
546608
547609 return outputs
548610
549-
611+
612+
550613
551614 @guide
552615 def evaluate (self , X = None , y = None ,metrics = None , global_args = None , local_args = None , global_mapping = None , local_mapping = None ):
@@ -656,6 +719,7 @@ def _get_outputs_spec(self, default=True):
656719
657720if __name__ == "__main__" :
658721 obj = Zephyr ()
722+ print (obj .GET_EVALUATION_METRICS ())
659723 alarms_df = pd .DataFrame (
660724 {
661725 "COD_ELEMENT" : [0 , 0 ],
@@ -791,40 +855,40 @@ def _get_outputs_spec(self, default=True):
791855 }
792856 )
793857
794- obj .create_entityset (
795- {
796- "alarms" : alarms_df ,
797- "stoppages" : stoppages_df ,
798- "notifications" : notifications_df ,
799- "work_orders" : work_orders_df ,
800- "turbines" : turbines_df ,
801- "pidata" : pidata_df ,
802- },
803- "pidata" ,
804- )
858+ # obj.create_entityset(
859+ # {
860+ # "alarms": alarms_df,
861+ # "stoppages": stoppages_df,
862+ # "notifications": notifications_df,
863+ # "work_orders": work_orders_df,
864+ # "turbines": turbines_df,
865+ # "pidata": pidata_df,
866+ # },
867+ # "pidata",
868+ # )
805869
806870 # obj.set_entityset(entityset_path = "/Users/raymondpan/zephyr/Zephyr-repo/brake_pad_es", es_type = 'scada')
807871
808872 # obj.set_labeling_function(name="brake_pad_presence")
809873
810- obj .generate_label_times (labeling_fn = "brake_pad_presence" , num_samples = 10 , gap = "20d" )
811- # print(obj.get_label_times())
874+ # obj.generate_label_times(labeling_fn="brake_pad_presence", num_samples=10, gap="20d")
875+ # # print(obj.get_label_times())
812876
813877
814- obj .generate_feature_matrix_and_labels (
815- target_dataframe_name = "turbines" ,
816- cutoff_time_in_index = True ,
817- agg_primitives = ["count" , "sum" , "max" ],
818- verbose = True
819- )
878+ # obj.generate_feature_matrix_and_labels(
879+ # target_dataframe_name="turbines",
880+ # cutoff_time_in_index=True,
881+ # agg_primitives=["count", "sum", "max"],
882+ # verbose = True
883+ # )
820884
821- print (obj .get_feature_matrix_and_labels )
885+ # print(obj.get_feature_matrix_and_labels)
822886
823- obj .generate_train_test_split ()
824- add_primitives_path (
825- path = "/Users/raymondpan/zephyr/Zephyr-repo/zephyr_ml/primitives/jsons"
826- )
827- obj .set_and_fit_pipeline ()
887+ # obj.generate_train_test_split()
888+ # add_primitives_path(
889+ # path="/Users/raymondpan/zephyr/Zephyr-repo/zephyr_ml/primitives/jsons"
890+ # )
891+ # obj.set_and_fit_pipeline()
828892
829893
830- obj .evaluate ()
894+ # obj.evaluate()
0 commit comments