Skip to content

Commit e008a65

Browse files
author
Christian Jorgensen
committed
Adding example
1 parent 02a9edf commit e008a65

File tree

1 file changed

+136
-0
lines changed

1 file changed

+136
-0
lines changed
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
#!/usr/bin/env python
2+
# coding: utf-8
3+
4+
"""
5+
Multioutput PCovC
6+
=================
7+
"""
8+
# %%
9+
#
10+
11+
import numpy as np
12+
import matplotlib.pyplot as plt
13+
14+
from sklearn.datasets import load_digits
15+
from sklearn.preprocessing import StandardScaler
16+
from sklearn.decomposition import PCA
17+
from sklearn.linear_model import LogisticRegressionCV
18+
from sklearn.multioutput import MultiOutputClassifier
19+
20+
from skmatter.decomposition import PCovC
21+
22+
plt.rcParams["image.cmap"] = "tab10"
23+
plt.rcParams["scatter.edgecolors"] = "k"
24+
# %%
25+
#
26+
#
27+
X, y = load_digits(return_X_y=True)
28+
x_scaler = StandardScaler()
29+
X_scaled = StandardScaler().fit_transform(X)
30+
31+
np.unique(y)
32+
# %%
33+
# Let's begin by trying to make a PCovC map to separate the digits.
34+
# This is a one-label, ten-class classification problem.
35+
pca = PCA(n_components=2)
36+
T_pca = pca.fit_transform(X_scaled, y)
37+
38+
pcovc = PCovC(n_components=2, mixing=0.5)
39+
T_pcovc = pcovc.fit_transform(X_scaled, y)
40+
41+
fig, axs = plt.subplots(1, 2, figsize=(10, 6))
42+
43+
scat_pca = axs[0].scatter(T_pca[:, 0], T_pca[:, 1], c=y)
44+
scat_pcovc = axs[1].scatter(T_pcovc[:, 0], T_pcovc[:, 1], c=y)
45+
fig.colorbar(scat_pca, ax=axs, orientation="horizontal")
46+
47+
# %%
48+
# Next, let's try a two-label classification problem, with both labels
49+
# being binary classification tasks.
50+
51+
is_even = (y % 2).reshape(-1, 1)
52+
is_less_than_five = (y < 5).reshape(-1, 1)
53+
54+
y2 = np.hstack([is_even, is_less_than_five])
55+
y2.shape
56+
# %%
57+
# Here, we can build a map that considers both of these labels simultaneously.
58+
59+
clf = MultiOutputClassifier(estimator=LogisticRegressionCV())
60+
pcovc = PCovC(n_components=2, mixing=0.5, classifier=clf)
61+
62+
T_pcovc = pcovc.fit_transform(X_scaled, y2)
63+
64+
fig, axs = plt.subplots(2, 3, figsize=(15, 10))
65+
cmap1 = "Set1"
66+
cmap2 = "Set2"
67+
cmap3 = "tab10"
68+
69+
labels_list = [["Even", "Odd"], [">= 5", "< 5"]]
70+
71+
for i, c, cmap in zip(range(3), [is_even, is_less_than_five, y], [cmap1, cmap2, cmap3]):
72+
73+
scat_pca = axs[0, i].scatter(T_pca[:, 0], T_pca[:, 1], c=c, cmap=cmap)
74+
axs[1, i].scatter(T_pcovc[:, 0], T_pcovc[:, 1], c=c, cmap=cmap)
75+
76+
if i == 0 or i == 1:
77+
handles, _ = scat_pca.legend_elements()
78+
labels = labels_list[i]
79+
axs[0, i].legend(handles, labels)
80+
print(labels)
81+
print(i)
82+
print(handles)
83+
84+
85+
axs[0, 0].set_title("Even/Odd")
86+
axs[0, 1].set_title("Greater/Less than 5")
87+
axs[0, 2].set_title("Digit")
88+
89+
axs[0, 0].set_ylabel("PCA")
90+
axs[1, 0].set_ylabel("PCovC")
91+
fig.colorbar(scat_pca, ax=axs, orientation="horizontal")
92+
# %%
93+
# Let's try a more complicated example:
94+
95+
num_holes = np.array(
96+
[0 if i in [1, 2, 3, 5, 7] else 1 if i in [0, 4, 6, 9] else 2 for i in y]
97+
).reshape(-1, 1)
98+
99+
y3 = np.hstack([is_even, num_holes])
100+
# %%
101+
# Now, we have a two-label classification
102+
# problem, with one binary label and one label with three
103+
# possible classes
104+
clf = MultiOutputClassifier(estimator=LogisticRegressionCV())
105+
pcovc = PCovC(n_components=2, mixing=0.5, classifier=clf)
106+
107+
T_pcovc = pcovc.fit_transform(X_scaled, y3)
108+
109+
fig, axs = plt.subplots(2, 3, figsize=(15, 10))
110+
cmap1 = "Set1"
111+
cmap2 = "Set3"
112+
cmap3 = "tab10"
113+
114+
labels_list = [["Even", "Odd"], ["0", "1", "2"]]
115+
116+
for i, c, cmap in zip(range(3), [is_even, num_holes, y], [cmap1, cmap2, cmap3]):
117+
118+
scat_pca = axs[0, i].scatter(T_pca[:, 0], T_pca[:, 1], c=c, cmap=cmap)
119+
axs[1, i].scatter(T_pcovc[:, 0], T_pcovc[:, 1], c=c, cmap=cmap)
120+
121+
if i == 0 or i == 1:
122+
handles, _ = scat_pca.legend_elements()
123+
labels = labels_list[i]
124+
axs[0, i].legend(handles, labels)
125+
print(labels)
126+
print(i)
127+
print(handles)
128+
129+
130+
axs[0, 0].set_title("Even/Odd")
131+
axs[0, 1].set_title("Number of Holes")
132+
axs[0, 2].set_title("Digit")
133+
134+
axs[0, 0].set_ylabel("PCA")
135+
axs[1, 0].set_ylabel("PCovC")
136+
fig.colorbar(scat_pca, ax=axs, orientation="horizontal")

0 commit comments

Comments
 (0)