Skip to content

Commit 39c5c06

Browse files
committed
FIX: legend inside plot in tuto + rename group into partition in tuto
1 parent 7672b2e commit 39c5c06

File tree

1 file changed

+30
-28
lines changed

1 file changed

+30
-28
lines changed

examples/mondrian/1-quickstart/plot_main-tutorial-mondrian-regression.py

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,10 @@
4444
np.random.seed(0)
4545
X = np.linspace(0, 10, n_points).reshape(-1, 1)
4646
group_size = n_points // 10
47-
groups_list = []
47+
partition_list = []
4848
for i in range(10):
49-
groups_list.append(np.array([i] * group_size))
50-
groups = np.concatenate(groups_list)
49+
partition_list.append(np.array([i] * group_size))
50+
partition = np.concatenate(partition_list)
5151

5252
noise_0_1 = np.random.normal(0, 0.1, group_size)
5353
noise_1_2 = np.random.normal(0, 0.5, group_size)
@@ -62,25 +62,25 @@
6262

6363
y = np.concatenate(
6464
[
65-
np.sin(X[groups == 0, 0] * 2) + noise_0_1,
66-
np.sin(X[groups == 1, 0] * 2) + noise_1_2,
67-
np.sin(X[groups == 2, 0] * 2) + noise_2_3,
68-
np.sin(X[groups == 3, 0] * 2) + noise_3_4,
69-
np.sin(X[groups == 4, 0] * 2) + noise_4_5,
70-
np.sin(X[groups == 5, 0] * 2) + noise_5_6,
71-
np.sin(X[groups == 6, 0] * 2) + noise_6_7,
72-
np.sin(X[groups == 7, 0] * 2) + noise_7_8,
73-
np.sin(X[groups == 8, 0] * 2) + noise_8_9,
74-
np.sin(X[groups == 9, 0] * 2) + noise_9_10,
65+
np.sin(X[partition == 0, 0] * 2) + noise_0_1,
66+
np.sin(X[partition == 1, 0] * 2) + noise_1_2,
67+
np.sin(X[partition == 2, 0] * 2) + noise_2_3,
68+
np.sin(X[partition == 3, 0] * 2) + noise_3_4,
69+
np.sin(X[partition == 4, 0] * 2) + noise_4_5,
70+
np.sin(X[partition == 5, 0] * 2) + noise_5_6,
71+
np.sin(X[partition == 6, 0] * 2) + noise_6_7,
72+
np.sin(X[partition == 7, 0] * 2) + noise_7_8,
73+
np.sin(X[partition == 8, 0] * 2) + noise_8_9,
74+
np.sin(X[partition == 9, 0] * 2) + noise_9_10,
7575
], axis=0
7676
)
7777

7878

7979
##############################################################################
80-
# We plot the dataset with the groups as colors.
80+
# We plot the dataset with the partition as colors.
8181

8282

83-
plt.scatter(X, y, c=groups)
83+
plt.scatter(X, y, c=partition)
8484
plt.show()
8585

8686

@@ -91,14 +91,14 @@
9191
X_train_temp, X_test, y_train_temp, y_test = train_test_split(
9292
X, y, test_size=0.2, random_state=0
9393
)
94-
groups_train_temp, groups_test, _, _ = train_test_split(
95-
groups, y, test_size=0.2, random_state=0
94+
partition_train_temp, partition_test, _, _ = train_test_split(
95+
partition, y, test_size=0.2, random_state=0
9696
)
9797
X_cal, X_train, y_cal, y_train = train_test_split(
9898
X_train_temp, y_train_temp, test_size=0.5, random_state=0
9999
)
100-
groups_cal, groups_train, _, _ = train_test_split(
101-
groups_train_temp, y_train_temp, test_size=0.5, random_state=0
100+
partition_cal, partition_train, _, _ = train_test_split(
101+
partition_train_temp, y_train_temp, test_size=0.5, random_state=0
102102
)
103103

104104

@@ -107,11 +107,11 @@
107107

108108

109109
f, ax = plt.subplots(1, 3, figsize=(15, 5))
110-
ax[0].scatter(X_train, y_train, c=groups_train)
110+
ax[0].scatter(X_train, y_train, c=partition_train)
111111
ax[0].set_title("Train set")
112-
ax[1].scatter(X_cal, y_cal, c=groups_cal)
112+
ax[1].scatter(X_cal, y_cal, c=partition_cal)
113113
ax[1].set_title("Calibration set")
114-
ax[2].scatter(X_test, y_test, c=groups_test)
114+
ax[2].scatter(X_test, y_test, c=partition_test)
115115
ax[2].set_title("Test set")
116116
plt.show()
117117

@@ -131,7 +131,7 @@
131131
mapie_regressor = MapieRegressor(rf, cv="prefit")
132132
mondrian_regressor = MondrianCP(MapieRegressor(rf, cv="prefit"))
133133
mapie_regressor.fit(X_cal, y_cal)
134-
mondrian_regressor.fit(X_cal, y_cal, groups=groups_cal)
134+
mondrian_regressor.fit(X_cal, y_cal, partition=partition_cal)
135135

136136

137137
##############################################################################
@@ -140,22 +140,23 @@
140140

141141
_, y_pss_split = mapie_regressor.predict(X_test, alpha=.1)
142142
_, y_pss_mondrian = mondrian_regressor.predict(
143-
X_test, groups=groups_test, alpha=.1
143+
X_test, partition=partition_test, alpha=.1
144144
)
145145

146146

147147
##############################################################################
148-
# 6. Compare the coverage by groups, plot both methods side by side.
148+
# 6. Compare the coverage by partition, plot both methods side by side.
149149

150150

151151
coverages = {}
152-
for group in np.unique(groups_test):
152+
for group in np.unique(partition_test):
153153
coverages[group] = {}
154154
coverages[group]["split"] = regression_coverage_score_v2(
155-
y_test[groups_test == group], y_pss_split[groups_test == group]
155+
y_test[partition_test == group], y_pss_split[partition_test == group]
156156
)
157157
coverages[group]["mondrian"] = regression_coverage_score_v2(
158-
y_test[groups_test == group], y_pss_mondrian[groups_test == group]
158+
y_test[partition_test == group],
159+
y_pss_mondrian[partition_test == group]
159160
)
160161

161162

@@ -178,4 +179,5 @@
178179
plt.hlines(0.9, -1, 21, label="90% coverage", color="black", linestyle="--")
179180
plt.ylabel("Coverage")
180181
plt.legend(loc='upper left', bbox_to_anchor=(1, 1))
182+
plt.tight_layout()
181183
plt.show()

0 commit comments

Comments
 (0)