Skip to content

Commit f56ea06

Browse files
committed
Implementing Rosy's suggestions to code
1 parent df8fa2e commit f56ea06

File tree

8 files changed

+56
-34
lines changed

8 files changed

+56
-34
lines changed

docs/src/references/pcovr_decomposition.rst renamed to docs/src/references/decomposition.rst

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
Principal Covariates Regression (PCovR)
2-
================================================================
1+
Principal Covariates Regression (PCovR) and Classification (PCovC)
2+
==================================================================
33

44
.. _PCovR-api:
55

@@ -20,6 +20,26 @@ PCovR
2020
.. automethod:: inverse_transform
2121
.. automethod:: score
2222

23+
.. _PCovC-api:
24+
25+
PCovC
26+
-----
27+
28+
.. autoclass:: skmatter.decomposition.PCovC
29+
:show-inheritance:
30+
:special-members:
31+
32+
.. automethod:: fit
33+
34+
.. automethod:: _fit_feature_space
35+
.. automethod:: _fit_sample_space
36+
37+
.. automethod:: transform
38+
.. automethod:: predict
39+
.. automethod:: inverse_transform
40+
.. automethod:: decision_function
41+
.. automethod:: score
42+
2343
.. _KPCovR-api:
2444

2545
Kernel PCovR

docs/src/references/index.rst

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@ API Reference
1111
selection
1212
linear_models
1313
clustering
14-
pcovc_decomposition
15-
pcovr_decomposition
14+
decomposition
1615
metrics
1716
neighbors
1817
datasets

docs/src/references/pcovc_decomposition.rst

Lines changed: 0 additions & 22 deletions
This file was deleted.

examples/pcovc/README.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
PCovC
2+
=====

src/skmatter/decomposition/_pcov.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__(
3838
self.random_state = random_state
3939
self.whiten = whiten
4040

41-
def _fit_utils(self, X):
41+
def fit(self, X):
4242
"""Contains the common functionality for the PCovR and PCovC fit methods,
4343
but leaves the rest of the functionality to the subclass.
4444
"""

src/skmatter/decomposition/_pcovc.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,10 @@ class PCovC(LinearClassifierMixin, _BasePCov):
149149
The linear classifier passed for fitting.
150150
151151
z_classifier_ : estimator object
152-
The linear classifier fit between X and Y.
152+
The linear classifier fit between :math:`\mathbf{X}` and :math:`\mathbf{Y}`.
153153
154154
classifier_ : estimator object
155-
The linear classifier fit between T and Y.
155+
The linear classifier fit between :math:`\mathbf{T}` and :math:`\mathbf{Y}`.
156156
157157
pxt_ : ndarray of size :math:`({n_{features}, n_{components}})`
158158
the projector, or weights, from the input space :math:`\mathbf{X}`
@@ -239,13 +239,28 @@ def fit(self, X, Y, W=None):
239239
W : numpy.ndarray, shape (n_features, n_properties)
240240
Classification weights, optional when classifier=`precomputed`. If
241241
not passed, it is assumed that the weights will be taken from a
242-
linear classifier fit between X and Y
242+
linear classifier fit between :math:`\mathbf{X}` and :math:`\mathbf{Y}`
243+
244+
Notes
245+
-----
246+
Note the relationship between :math:`\mathbf{X}`, :math:`\mathbf{Y}`,
247+
:math:`\mathbf{Z}`, and :math:`\mathbf{W}`. The classification weights
248+
:math:`\mathbf{W}`, obtained through a linear classifier fit between
249+
:math:`\mathbf{X}` and :math:`\mathbf{Y}`, are used to compute:
250+
251+
.. math::
252+
\mathbf{Z} = \mathbf{X} \mathbf{W}
253+
254+
Next, :math:`\mathbf{Z}` is used in either `_fit_feature_space` or
255+
`_fit_sample_space` as our approximation of :math:`\mathbf{Y}`.
256+
Finally, we refit a classifier on :math:`\mathbf{T}` and :math:`\mathbf{Y}`
257+
to obtain :math:`\mathbf{P}_{XZ}` and :math:`\mathbf{P}_{TZ}`
243258
"""
244259
X, Y = validate_data(self, X, Y, y_numeric=False)
245260
check_classification_targets(Y)
246261
self.classes_ = np.unique(Y)
247262

248-
super()._fit_utils(X)
263+
super().fit(X)
249264

250265
compatible_classifiers = (
251266
LogisticRegression,

src/skmatter/decomposition/_pcovr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def fit(self, X, Y, W=None):
229229
passed, it is assumed that `W = np.linalg.lstsq(X, Y, self.tol)[0]`
230230
"""
231231
X, Y = validate_data(self, X, Y, y_numeric=True, multi_output=True)
232-
super()._fit_utils(X)
232+
super().fit(X)
233233

234234
compatible_regressors = (LinearRegression, Ridge, RidgeCV)
235235

tests/test_pcovc.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -461,10 +461,10 @@ def test_prefit_classifier(self):
461461
def test_precomputed_classification(self):
462462
classifier = LogisticRegression()
463463
classifier.fit(self.X, self.Y)
464-
Yhat = classifier.predict(self.X)
464+
465465
W = classifier.coef_.T.reshape(self.X.shape[1], -1)
466466
pcovc1 = self.model(mixing=0.5, classifier="precomputed", n_components=1)
467-
pcovc1.fit(self.X, Yhat, W)
467+
pcovc1.fit(self.X, self.Y, W)
468468
t1 = pcovc1.transform(self.X)
469469

470470
pcovc2 = self.model(mixing=0.5, classifier=classifier, n_components=1)
@@ -473,6 +473,14 @@ def test_precomputed_classification(self):
473473

474474
self.assertTrue(np.linalg.norm(t1 - t2) < self.error_tol)
475475

476+
# Now check for match when W is not passed:
477+
pcovc3 = self.model(mixing=0.5, classifier="precomputed", n_components=1)
478+
pcovc3.fit(self.X, self.Y)
479+
t3 = pcovc3.transform(self.X)
480+
481+
self.assertTrue(np.linalg.norm(t3 - t2) < self.error_tol)
482+
self.assertTrue(np.linalg.norm(t3 - t1) < self.error_tol)
483+
476484
def test_classifier_modifications(self):
477485
classifier = LogisticRegression()
478486
pcovc = self.model(mixing=0.5, classifier=classifier)

0 commit comments

Comments
 (0)