1+ import warnings
12import numpy as np
23
34from sklearn import clone
1718from sklearn .linear_model ._base import LinearClassifierMixin
1819from sklearn .utils .multiclass import check_classification_targets , type_of_target
1920
20- from skmatter .preprocessing import KernelNormalizer
21+ from skmatter .preprocessing import KernelNormalizer , StandardFlexibleScaler
2122from skmatter .utils import check_cl_fit
2223from skmatter .decomposition import _BaseKPCov
2324
@@ -99,6 +100,9 @@ class KernelPCovC(LinearClassifierMixin, _BaseKPCov):
99100 constructed, with ``sklearn.linear_model.LogisticRegression()`` models used for each
100101 label.
101102
103+ scale_z: bool, default=False
104+ Whether to scale Z prior to eigendecomposition.
105+
102106 kernel : {"linear", "poly", "rbf", "sigmoid", "precomputed"} or callable, default="linear"
103107 Kernel.
104108
@@ -129,6 +133,14 @@ class KernelPCovC(LinearClassifierMixin, _BaseKPCov):
129133 and for matrix inversions.
130134 Must be of range [0.0, infinity).
131135
136+ z_mean_tol: float, default=1e-12
137+ Tolerance for the column means of Z.
138+ Must be of range [0.0, infinity).
139+
140+ z_var_tol: float, default=1.5
141+ Tolerance for the column variances of Z.
142+ Must be of range [0.0, infinity).
143+
132144 n_jobs : int, default=None
133145 The number of parallel jobs to run.
134146 :obj:`None` means 1 unless in a :obj:`joblib.parallel_backend` context.
@@ -185,14 +197,17 @@ class KernelPCovC(LinearClassifierMixin, _BaseKPCov):
185197 The data used to fit the model. This attribute is used to build kernels
186198 from new data.
187199
200+ scale_z: bool
201+ Whether Z is being scaled prior to eigendecomposition.
202+
188203 Examples
189204 --------
190205 >>> import numpy as np
191206 >>> from skmatter.decomposition import KernelPCovC
192207 >>> from sklearn.preprocessing import StandardScaler
193208 >>> X = np.array([[-2, 3, -1, 0], [2, 0, -3, 1], [3, 0, -1, 3], [2, -2, 1, 0]])
194209 >>> X = StandardScaler().fit_transform(X)
195- >>> Y = np.array([[2], [0], [1], [2] ])
210+ >>> Y = np.array([2, 0, 1, 2 ])
196211 >>> kpcovc = KernelPCovC(
197212 ... mixing=0.1,
198213 ... n_components=2,
@@ -218,6 +233,7 @@ def __init__(
218233 n_components = None ,
219234 svd_solver = "auto" ,
220235 classifier = None ,
236+ scale_z = False ,
221237 kernel = "linear" ,
222238 gamma = None ,
223239 degree = 3 ,
@@ -226,6 +242,8 @@ def __init__(
226242 center = False ,
227243 fit_inverse_transform = False ,
228244 tol = 1e-12 ,
245+ z_mean_tol = 1e-12 ,
246+ z_var_tol = 1.5 ,
229247 n_jobs = None ,
230248 iterated_power = "auto" ,
231249 random_state = None ,
@@ -247,6 +265,9 @@ def __init__(
247265 fit_inverse_transform = fit_inverse_transform ,
248266 )
249267 self .classifier = classifier
268+ self .scale_z = scale_z
269+ self .z_mean_tol = z_mean_tol
270+ self .z_var_tol = z_var_tol
250271
251272 def fit (self , X , Y , W = None ):
252273 r"""Fit the model with X and Y.
@@ -368,6 +389,25 @@ def fit(self, X, Y, W=None):
368389 W = self .z_classifier_ .coef_ .T
369390
370391 Z = K @ W
392+ if self .scale_z :
393+ Z = StandardFlexibleScaler ().fit_transform (Z )
394+
395+ z_means_ = np .mean (Z , axis = 0 )
396+ z_vars_ = np .var (Z , axis = 0 )
397+
398+ if np .max (np .abs (z_means_ )) > self .z_mean_tol :
399+ warnings .warn (
400+ "This class does not automatically center Z, and the column means "
401+ "of Z are greater than the supplied tolerance. We recommend scaling "
402+ "Z (and the weights) by setting `scale_z=True`."
403+ )
404+
405+ if np .max (z_vars_ ) > self .z_var_tol :
406+ warnings .warn (
407+ "This class does not automatically scale Z, and the column variances "
408+ "of Z are greater than the supplied tolerance. We recommend scaling "
409+ "Z (and the weights) by setting `scale_z=True`."
410+ )
371411
372412 self ._fit (K , Z , W )
373413
0 commit comments