Skip to content

Commit 609bbcc

Browse files
committed
combined set_labeling_function and generate_labeling_times
1 parent 2db555a commit 609bbcc

File tree

1 file changed

+38
-28
lines changed

1 file changed

+38
-28
lines changed

zephyr_ml/core.py

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def __init__(self):
108108
# tuple of 2 arrays: producers and attributes
109109
self.step_order = [
110110
([self.create_entityset, self.set_entityset], [self.get_entityset]),
111-
([self.set_labeling_function], [self.get_labeling_function]),
111+
# ([self.set_labeling_function], [self.get_labeling_function]),
112112
([self.generate_label_times], [self.get_label_times]),
113113
([self.generate_feature_matrix_and_labels, self.set_feature_matrix_and_labels], [self.get_feature_matrix_and_labels]),
114114
([self.generate_train_test_split, self.set_train_test_split], [self.get_train_test_split]),
@@ -186,37 +186,49 @@ def get_entityset(self):
186186
def get_predefined_labeling_functions(self):
187187
return get_labeling_functions()
188188

189-
@guide
190-
def set_labeling_function(self, name=None, func=None):
191-
if name is not None:
192-
labeling_fn_map = get_labeling_functions_map()
193-
if name in labeling_fn_map:
194-
self.labeling_function = labeling_fn_map[name]
195-
return
196-
else:
197-
raise ValueError(
198-
f"Unrecognized name argument:{name}. Call get_predefined_labeling_functions to view predefined labeling functions"
199-
)
200-
elif func is not None:
201-
if callable(func):
202-
self.labeling_function = func
203-
return
204-
else:
205-
raise ValueError(f"Custom function is not callable")
206-
raise ValueError("No labeling function given.")
189+
# @guide
190+
# def set_labeling_function(self, name=None, func=None):
191+
# if name is not None:
192+
# labeling_fn_map = get_labeling_functions_map()
193+
# if name in labeling_fn_map:
194+
# self.labeling_function = labeling_fn_map[name]
195+
# return
196+
# else:
197+
# raise ValueError(
198+
# f"Unrecognized name argument:{name}. Call get_predefined_labeling_functions to view predefined labeling functions"
199+
# )
200+
# elif func is not None:
201+
# if callable(func):
202+
# self.labeling_function = func
203+
# return
204+
# else:
205+
# raise ValueError(f"Custom function is not callable")
206+
# raise ValueError("No labeling function given.")
207207

208-
@guide
209-
def get_labeling_function(self):
210-
return self.labeling_function
208+
# @guide
209+
# def get_labeling_function(self):
210+
# return self.labeling_function
211211

212212
@guide
213213
def generate_label_times(
214-
self, num_samples=-1, subset=None, column_map={}, verbose=False, **kwargs
214+
self, labeling_fn, num_samples=-1, subset=None, column_map={}, verbose=False, **kwargs
215215
):
216-
assert self.entityset is not None
217-
assert self.labeling_function is not None
216+
assert self.entityset is not None, "entityset has not been set"
217+
218+
if isinstance(labeling_fn, str): # get predefined labeling function
219+
labeling_fn_map = get_labeling_functions_map()
220+
if labeling_fn in labeling_fn_map:
221+
labeling_fn = labeling_fn_map[labeling_fn]
222+
else:
223+
raise ValueError(
224+
f"Unrecognized name argument:{labeling_fn}. Call get_predefined_labeling_functions to view predefined labeling functions"
225+
)
218226

219-
labeling_function, df, meta = self.labeling_function(self.entityset, column_map)
227+
228+
assert callable(labeling_fn), "Labeling function is not callable"
229+
230+
231+
labeling_function, df, meta = labeling_fn(self.entityset, column_map)
220232

221233
data = df
222234
if isinstance(subset, float) or isinstance(subset, int):
@@ -332,8 +344,6 @@ def set_and_fit_pipeline(
332344
if y is None:
333345
y = self.y_train
334346

335-
print(X)
336-
print(y)
337347

338348
if visual:
339349
outputs_spec, visual_names = self._get_outputs_spec(False)

0 commit comments

Comments
 (0)