@@ -78,7 +78,7 @@ def __init__(
7878 classifier_abbreviation = "LCPL" ,
7979 )
8080
81- def fit (self , X , y ):
81+ def fit (self , X , y , sample_weight = None ):
8282 """
8383 Fit a local classifier per level.
8484
@@ -90,14 +90,17 @@ def fit(self, X, y):
9090 converted into a sparse ``csc_matrix``.
9191 y : array-like of shape (n_samples, n_levels)
9292 The target values, i.e., hierarchical class labels for classification.
93+ sample_weight : array-like of shape (n_samples,), default=None
94+ Array of weights that are assigned to individual samples.
95+ If not provided, then each sample is given unit weight.
9396
9497 Returns
9598 -------
9699 self : object
97100 Fitted estimator.
98101 """
99102 # Execute common methods necessary before fitting
100- super ()._pre_fit (X , y )
103+ super ()._pre_fit (X , y , sample_weight )
101104
102105 # Fit local classifiers in DAG
103106 super ().fit (X , y )
@@ -232,17 +235,22 @@ def _fit_digraph(self, local_mode: bool = False, use_joblib: bool = False):
232235 def _fit_classifier (self , level , separator ):
233236 classifier = self .local_classifiers_ [level ]
234237
235- X , y = self ._remove_empty_leaves (separator , self .X_ , self .y_ [:, level ])
238+ X , y , sample_weight = self ._remove_empty_leaves (
239+ separator , self .X_ , self .y_ [:, level ], self .sample_weight_
240+ )
236241
237242 unique_y = np .unique (y )
238243 if len (unique_y ) == 1 and self .replace_classifiers :
239244 classifier = ConstantClassifier ()
240- classifier .fit (X , y )
245+ classifier .fit (X , y , sample_weight )
241246 return classifier
242247
243248 @staticmethod
244- def _remove_empty_leaves (separator , X , y ):
249+ def _remove_empty_leaves (separator , X , y , sample_weight ):
245250 # Detect rows where leaves are not empty
246251 leaves = np .array ([str (i ).split (separator )[- 1 ] for i in y ])
247252 mask = leaves != ""
248- return X [mask ], y [mask ]
253+ X = X [mask ]
254+ y = y [mask ]
255+ sample_weight = sample_weight [mask ] if sample_weight is not None else None
256+ return X , y , sample_weight
0 commit comments