Skip to content

Commit 2f5e925

Browse files
committed
Add static help and signal processing
1 parent 102229e commit 2f5e925

File tree

4 files changed

+116
-40
lines changed

4 files changed

+116
-40
lines changed

zephyr_ml/core.py

Lines changed: 102 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
get_labeling_functions_map,
77
LABELING_FUNCTIONS,
88
)
9+
from zephyr_ml.feature_engineering import process_signals
910
import composeml as cp
1011
from inspect import getfullargspec
1112
import featuretools as ft
@@ -19,7 +20,7 @@
1920
import logging
2021
import matplotlib.pyplot as plt
2122
from functools import wraps
22-
23+
import inspect
2324
DEFAULT_METRICS = [
2425
"sklearn.metrics.accuracy_score",
2526
"sklearn.metrics.precision_score",
@@ -36,15 +37,18 @@ class GuideHandler:
3637

3738
def __init__(self, producers_and_getters, set_methods):
3839
self.cur_term = 0
40+
self.current_step = -1
3941
self.producers_and_getters = producers_and_getters
4042
self.set_methods = set_methods
4143

4244
self.producer_to_step_map = {}
4345
self.getter_to_step_map = {}
46+
4447
self.terms = []
45-
48+
self.skipped = []
4649
for idx, (producers, getters) in enumerate(self.producers_and_getters):
4750
self.terms.append(-1)
51+
self.skipped.append(False)
4852

4953
for prod in producers:
5054
self.producer_to_step_map[prod.__name__] = idx
@@ -98,6 +102,8 @@ def perform_producer_step(self, method, *method_args, **method_kwargs):
98102
def try_log_skipping_steps_warning(self, name, next_step):
99103
steps_skipped = self.get_steps_in_between(self.current_step, next_step)
100104
if len(steps_skipped) > 0:
105+
for step in range(self.current_step + 1, next_step):
106+
self.skipped[step] = True
101107
necc_steps = self.join_steps(steps_skipped)
102108
LOGGER.warning(f"Performing {name}. You are skipping the following steps:\n{necc_steps}")
103109

@@ -115,7 +121,7 @@ def try_log_using_stale_warning(self, name, next_step):
115121

116122
def try_log_making_stale_warning(self, name, next_step):
117123
next_next_step = next_step + 1
118-
prod_steps = f"{next_next_step}. {" or ".join(self.producers_and_getters[next_next_step][0])}"
124+
prod_steps = f"{next_next_step}. {' or '.join(self.producers_and_getters[next_next_step][0])}"
119125
# add later set methods
120126
get_steps = self.join_steps(self.get_get_steps_in_between(next_step, self.current_step + 1))
121127

@@ -133,7 +139,7 @@ def try_log_inconsistent_warning(self, name, next_step):
133139
starting at or before {latest_up_to_date}")
134140

135141
def log_get_inconsistent_warning(self, name, next_step):
136-
prod_steps = f"{next_step}. {" or ".join(self.producers_and_getters[next_step][0])}"
142+
prod_steps = f"{next_step}. {' or '.join(self.producers_and_getters[next_step][0])}"
137143
latest_up_to_date = self.get_last_up_to_date(next_step)
138144
LOGGER.warning(f"Unable to perform {name} because {prod_steps} has not been run yet. Run steps starting at or before {latest_up_to_date} ")
139145

@@ -186,6 +192,10 @@ def try_perform_stale_or_inconsistent_producer_step(self, method, *method_args,
186192
if self.terms[next_step-1] == -1: #inconsistent
187193
self.try_log_inconsistent_warning(name, next_step)
188194
else:
195+
# need to include a case where performing using stale data that was skipped in current iteration
196+
# overwrite current iteration's ?
197+
# no not possible b/c if there is a current iteration after this step, it must have updated this step's iteration
198+
#
189199
self.try_log_using_stale_warning(name, next_step)
190200
res = self.perform_producer_step(method, *method_args, **method_kwargs)
191201
return res
@@ -274,7 +284,6 @@ def __init__(self):
274284
# tuple of 2 arrays: producers and attributes
275285
self.step_order = [
276286
([self.generate_entityset, self.set_entityset], [self.get_entityset]),
277-
# ([self.set_labeling_function], [self.get_labeling_function]),
278287
([self.generate_label_times, self.set_label_times], [self.get_label_times]),
279288
([self.generate_feature_matrix_and_labels, self.set_feature_matrix_and_labels], [self.get_feature_matrix_and_labels]),
280289
([self.generate_train_test_split, self.set_train_test_split], [self.get_train_test_split]),
@@ -291,10 +300,26 @@ def GET_ENTITYSET_TYPES(self):
291300
"""
292301
Returns the supported entityset types (PI/SCADA/Vibrations) and the required dataframes and their columns
293302
"""
294-
return VALIDATE_DATA_FUNCTIONS.keys()
303+
info_map = {}
304+
for es_type, val_fn in VALIDATE_DATA_FUNCTIONS.items():
305+
info_map[es_type] = {"obj": es_type, "desc": " ".join((val_fn.__doc__.split()))}
306+
307+
return info_map
308+
309+
def GET_LABELING_FUNCTIONS(self):
310+
return get_labeling_functions()
311+
312+
def GET_EVALUATION_METRICS(self):
313+
info_map = {}
314+
for metric in DEFAULT_METRICS:
315+
primitive = self._get_ml_primitive(metric)
316+
info_map[metric] = {"obj": primitive, "desc": primitive.metadata["description"] }
317+
return info_map
295318

296319
@guide
297-
def generate_entityset(self, dfs, es_type, custom_kwargs_mapping=None):
320+
def generate_entityset(self, dfs, es_type, custom_kwargs_mapping=None,
321+
signal_dataframe_name = None, signal_column = None, signal_transformations = None,
322+
signal_aggregations = None, signal_window_size = None, signal_replace_dataframe = False, **sigpro_kwargs):
298323
"""
299324
Generate an entityset
300325
@@ -309,6 +334,16 @@ def generate_entityset(self, dfs, es_type, custom_kwargs_mapping=None):
309334
their relationships
310335
"""
311336
entityset = _create_entityset(dfs, es_type, custom_kwargs_mapping)
337+
338+
#perform signal processing
339+
if signal_dataframe_name is not None and signal_column is not None:
340+
if signal_transformations is None:
341+
signal_transformations = []
342+
if signal_aggregations is None:
343+
signal_aggregations = []
344+
process_signals(entityset, signal_dataframe_name, signal_column, signal_transformations,
345+
signal_aggregations, signal_window_size, signal_replace_dataframe, **sigpro_kwargs)
346+
312347
self.entityset = entityset
313348
return self.entityset
314349

@@ -335,8 +370,7 @@ def get_entityset(self):
335370
return self.entityset
336371

337372

338-
def GET_LABELING_FUNCTIONS(self):
339-
return get_labeling_functions()
373+
340374

341375
# @guide
342376
# def set_labeling_function(self, name=None, func=None):
@@ -425,9 +459,37 @@ def get_label_times(self, visualize = True):
425459
return self.label_times
426460

427461
@guide
428-
def generate_feature_matrix_and_labels(self, **kwargs):
462+
def generate_feature_matrix_and_labels(self, target_dataframe_name = None, instance_ids = None,
463+
agg_primitives = None, trans_primitives = None, groupby_trans_primitives = None,
464+
allowed_paths = None, max_depth = 2, ignore_dataframes = None, ignore_columns=None,
465+
primitive_options=None, seed_features=None,
466+
drop_contains=None, drop_exact=None, where_primitives=None, max_features=-1,
467+
cutoff_time_in_index=False, save_progress=None, features_only=False, training_window=None,
468+
approximate=None, chunk_size=None, n_jobs=1, dask_kwargs=None, verbose=False, return_types=None,
469+
progress_callback=None, include_cutoff_time=True,
470+
471+
signal_dataframe_name = None, signal_column = None, signal_transformations = None,
472+
signal_aggregations = None, signal_window_size = None, signal_replace_dataframe = False, **sigpro_kwargs):
473+
474+
# perform signal processing
475+
if signal_dataframe_name is not None and signal_column is not None:
476+
if signal_transformations is None:
477+
signal_transformations = []
478+
if signal_aggregations is None:
479+
signal_aggregations = []
480+
process_signals(self.entityset, signal_dataframe_name, signal_column, signal_transformations,
481+
signal_aggregations, signal_window_size, signal_replace_dataframe, **sigpro_kwargs)
482+
429483
feature_matrix, features = ft.dfs(
430-
entityset=self.entityset, cutoff_time=self.label_times, **kwargs
484+
entityset=self.entityset, cutoff_time=self.label_times,
485+
target_dataframe_name = target_dataframe_name, instance_ids =instance_ids,
486+
agg_primitives = agg_primitives, trans_primitives = trans_primitives, groupby_trans_primitives = groupby_trans_primitives,
487+
allowed_paths = allowed_paths, max_depth = max_depth, ignore_dataframes = ignore_dataframes, ignore_columns=ignore_columns,
488+
primitive_options=primitive_options, seed_features=seed_features,
489+
drop_contains=drop_contains, drop_exact=drop_exact, where_primitives=where_primitives, max_features=max_features,
490+
cutoff_time_in_index=cutoff_time_in_index, save_progress=save_progress, features_only=features_only, training_window=training_window,
491+
approximate=approximate, chunk_size=chunk_size, n_jobs=n_jobs, dask_kwargs=dask_kwargs, verbose=verbose, return_types=return_types,
492+
progress_callback=progress_callback, include_cutoff_time=include_cutoff_time,
431493
)
432494
self.feature_matrix_and_labels = self._clean_feature_matrix(feature_matrix)
433495
self.features = features
@@ -546,7 +608,8 @@ def predict(self, X=None, visual=False, **kwargs):
546608

547609
return outputs
548610

549-
611+
612+
550613

551614
@guide
552615
def evaluate(self, X=None, y=None,metrics=None, global_args = None, local_args = None, global_mapping = None, local_mapping = None):
@@ -656,6 +719,7 @@ def _get_outputs_spec(self, default=True):
656719

657720
if __name__ == "__main__":
658721
obj = Zephyr()
722+
print(obj.GET_EVALUATION_METRICS())
659723
alarms_df = pd.DataFrame(
660724
{
661725
"COD_ELEMENT": [0, 0],
@@ -791,40 +855,40 @@ def _get_outputs_spec(self, default=True):
791855
}
792856
)
793857

794-
obj.create_entityset(
795-
{
796-
"alarms": alarms_df,
797-
"stoppages": stoppages_df,
798-
"notifications": notifications_df,
799-
"work_orders": work_orders_df,
800-
"turbines": turbines_df,
801-
"pidata": pidata_df,
802-
},
803-
"pidata",
804-
)
858+
# obj.create_entityset(
859+
# {
860+
# "alarms": alarms_df,
861+
# "stoppages": stoppages_df,
862+
# "notifications": notifications_df,
863+
# "work_orders": work_orders_df,
864+
# "turbines": turbines_df,
865+
# "pidata": pidata_df,
866+
# },
867+
# "pidata",
868+
# )
805869

806870
# obj.set_entityset(entityset_path = "/Users/raymondpan/zephyr/Zephyr-repo/brake_pad_es", es_type = 'scada')
807871

808872
# obj.set_labeling_function(name="brake_pad_presence")
809873

810-
obj.generate_label_times(labeling_fn="brake_pad_presence", num_samples=10, gap="20d")
811-
# print(obj.get_label_times())
874+
# obj.generate_label_times(labeling_fn="brake_pad_presence", num_samples=10, gap="20d")
875+
# # print(obj.get_label_times())
812876

813877

814-
obj.generate_feature_matrix_and_labels(
815-
target_dataframe_name="turbines",
816-
cutoff_time_in_index=True,
817-
agg_primitives=["count", "sum", "max"],
818-
verbose = True
819-
)
878+
# obj.generate_feature_matrix_and_labels(
879+
# target_dataframe_name="turbines",
880+
# cutoff_time_in_index=True,
881+
# agg_primitives=["count", "sum", "max"],
882+
# verbose = True
883+
# )
820884

821-
print(obj.get_feature_matrix_and_labels)
885+
# print(obj.get_feature_matrix_and_labels)
822886

823-
obj.generate_train_test_split()
824-
add_primitives_path(
825-
path="/Users/raymondpan/zephyr/Zephyr-repo/zephyr_ml/primitives/jsons"
826-
)
827-
obj.set_and_fit_pipeline()
887+
# obj.generate_train_test_split()
888+
# add_primitives_path(
889+
# path="/Users/raymondpan/zephyr/Zephyr-repo/zephyr_ml/primitives/jsons"
890+
# )
891+
# obj.set_and_fit_pipeline()
828892

829893

830-
obj.evaluate()
894+
# obj.evaluate()

zephyr_ml/entityset.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,18 +208,28 @@ def _validate_data(dfs, es_type, es_kwargs):
208208

209209

210210
def validate_scada_data(dfs, new_kwargs_mapping=None):
211+
"""
212+
SCADA data is signal data from the Original Equipment Manufacturer Supervisory Control
213+
And Data Acquisition (OEM-SCADA) system, a signal data source.
214+
"""
211215
entity_kwargs = get_mapped_kwargs("scada", new_kwargs_mapping)
212216
_validate_data(dfs, "scada", entity_kwargs)
213217
return entity_kwargs
214218

215219

216220
def validate_pidata_data(dfs, new_kwargs_mapping=None):
221+
"""
222+
PI data is signal data from the operator's historical Plant Information (PI) system.
223+
"""
217224
entity_kwargs = get_mapped_kwargs("pidata", new_kwargs_mapping)
218225
_validate_data(dfs, "pidata", entity_kwargs)
219226
return entity_kwargs
220227

221228

222229
def validate_vibrations_data(dfs, new_kwargs_mapping=None):
230+
"""
231+
Vibrations data is vibrations data collected on Planetary gearboxes in turbines.
232+
"""
223233
entities = ["vibrations"]
224234

225235
pidata_kwargs, scada_kwargs = {}, {}

zephyr_ml/feature_engineering.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33

44
def process_signals(es, signal_dataframe_name, signal_column, transformations, aggregations,
5-
window_size, replace_dataframe=False, **kwargs):
5+
window_size = None, replace_dataframe=False, **kwargs):
66
'''
77
Process signals using SigPro.
88

zephyr_ml/labeling/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def get_labeling_functions():
2828
functions = {}
2929
for function in LABELING_FUNCTIONS:
3030
name = function.__name__
31-
functions[name] = function.__doc__.split("\n")[0]
31+
functions[name] = {"obj": function, "desc": function.__doc__.split("\n")[0]}
3232

3333
return functions
3434

@@ -41,6 +41,8 @@ def get_labeling_functions_map():
4141
return functions
4242

4343

44+
45+
4446
def get_helper_functions():
4547
functions = {}
4648
for function in UTIL_FUNCTIONS:

0 commit comments

Comments
 (0)