@@ -108,7 +108,7 @@ def __init__(self):
108108 # tuple of 2 arrays: producers and attributes
109109 self .step_order = [
110110 ([self .create_entityset , self .set_entityset ], [self .get_entityset ]),
111- ([self .set_labeling_function ], [self .get_labeling_function ]),
111+ # ([self.set_labeling_function], [self.get_labeling_function]),
112112 ([self .generate_label_times ], [self .get_label_times ]),
113113 ([self .generate_feature_matrix_and_labels , self .set_feature_matrix_and_labels ], [self .get_feature_matrix_and_labels ]),
114114 ([self .generate_train_test_split , self .set_train_test_split ], [self .get_train_test_split ]),
@@ -186,37 +186,49 @@ def get_entityset(self):
186186 def get_predefined_labeling_functions (self ):
187187 return get_labeling_functions ()
188188
189- @guide
190- def set_labeling_function (self , name = None , func = None ):
191- if name is not None :
192- labeling_fn_map = get_labeling_functions_map ()
193- if name in labeling_fn_map :
194- self .labeling_function = labeling_fn_map [name ]
195- return
196- else :
197- raise ValueError (
198- f"Unrecognized name argument:{ name } . Call get_predefined_labeling_functions to view predefined labeling functions"
199- )
200- elif func is not None :
201- if callable (func ):
202- self .labeling_function = func
203- return
204- else :
205- raise ValueError (f"Custom function is not callable" )
206- raise ValueError ("No labeling function given." )
189+ # @guide
190+ # def set_labeling_function(self, name=None, func=None):
191+ # if name is not None:
192+ # labeling_fn_map = get_labeling_functions_map()
193+ # if name in labeling_fn_map:
194+ # self.labeling_function = labeling_fn_map[name]
195+ # return
196+ # else:
197+ # raise ValueError(
198+ # f"Unrecognized name argument:{name}. Call get_predefined_labeling_functions to view predefined labeling functions"
199+ # )
200+ # elif func is not None:
201+ # if callable(func):
202+ # self.labeling_function = func
203+ # return
204+ # else:
205+ # raise ValueError(f"Custom function is not callable")
206+ # raise ValueError("No labeling function given.")
207207
208- @guide
209- def get_labeling_function (self ):
210- return self .labeling_function
208+ # @guide
209+ # def get_labeling_function(self):
210+ # return self.labeling_function
211211
212212 @guide
213213 def generate_label_times (
214- self , num_samples = - 1 , subset = None , column_map = {}, verbose = False , ** kwargs
214+ self , labeling_fn , num_samples = - 1 , subset = None , column_map = {}, verbose = False , ** kwargs
215215 ):
216- assert self .entityset is not None
217- assert self .labeling_function is not None
216+ assert self .entityset is not None , "entityset has not been set"
217+
218+ if isinstance (labeling_fn , str ): # get predefined labeling function
219+ labeling_fn_map = get_labeling_functions_map ()
220+ if labeling_fn in labeling_fn_map :
221+ labeling_fn = labeling_fn_map [labeling_fn ]
222+ else :
223+ raise ValueError (
224+ f"Unrecognized name argument:{ labeling_fn } . Call get_predefined_labeling_functions to view predefined labeling functions"
225+ )
218226
219- labeling_function , df , meta = self .labeling_function (self .entityset , column_map )
227+
228+ assert callable (labeling_fn ), "Labeling function is not callable"
229+
230+
231+ labeling_function , df , meta = labeling_fn (self .entityset , column_map )
220232
221233 data = df
222234 if isinstance (subset , float ) or isinstance (subset , int ):
@@ -332,8 +344,6 @@ def set_and_fit_pipeline(
332344 if y is None :
333345 y = self .y_train
334346
335- print (X )
336- print (y )
337347
338348 if visual :
339349 outputs_spec , visual_names = self ._get_outputs_spec (False )
0 commit comments