@@ -40,6 +40,7 @@ def __init__(
4040 edge_list : str = None ,
4141 replace_classifiers : bool = True ,
4242 n_jobs : int = 1 ,
43+ bert : bool = False ,
4344 ):
4445 """
4546 Initialize a local classifier per parent node.
@@ -61,6 +62,8 @@ def __init__(
6162 n_jobs : int, default=1
6263 The number of jobs to run in parallel. Only :code:`fit` is parallelized.
6364 If :code:`Ray` is installed it is used, otherwise it defaults to :code:`Joblib`.
65+ bert : bool, default=False
66+ If True, skip scikit-learn's checks and sample_weight passing for BERT.
6467 """
6568 super ().__init__ (
6669 local_classifier = local_classifier ,
@@ -69,6 +72,7 @@ def __init__(
6972 replace_classifiers = replace_classifiers ,
7073 n_jobs = n_jobs ,
7174 classifier_abbreviation = "LCPPN" ,
75+ bert = bert ,
7276 )
7377
7478 def fit (self , X , y , sample_weight = None ):
@@ -128,7 +132,10 @@ def predict(self, X):
128132 check_is_fitted (self )
129133
130134 # Input validation
131- X = check_array (X , accept_sparse = "csr" )
135+ if not self .bert :
136+ X = check_array (X , accept_sparse = "csr" )
137+ else :
138+ X = np .array (X )
132139
133140 # Initialize array that holds predictions
134141 y = np .empty ((X .shape [0 ], self .max_levels_ ), dtype = self .dtype_ )
@@ -203,7 +210,10 @@ def _fit_classifier(self, node):
203210 unique_y = np .unique (y )
204211 if len (unique_y ) == 1 and self .replace_classifiers :
205212 classifier = ConstantClassifier ()
206- classifier .fit (X , y , sample_weight )
213+ if not self .bert :
214+ classifier .fit (X , y , sample_weight )
215+ else :
216+ classifier .fit (X , y )
207217 return classifier
208218
209219 def _fit_digraph (self , local_mode : bool = False , use_joblib : bool = False ):
0 commit comments