Skip to content

Commit cb1c07c

Browse files
authored
Add who dataset, examples and tests
Adding a publically-available example for the who dataset #149
1 parent cdb6195 commit cb1c07c

File tree

9 files changed

+3040
-2
lines changed

9 files changed

+3040
-2
lines changed

docs/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ sphinx_rtd_theme
99
tqdm
1010
traitlets>=5.0
1111
jinja2 < 3.1
12+
pandas
Lines changed: 376 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,376 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "2a30aece",
6+
"metadata": {},
7+
"source": [
8+
"# Feature Selection on the WHO Dataset"
9+
]
10+
},
11+
{
12+
"cell_type": "code",
13+
"execution_count": null,
14+
"id": "c6857fae",
15+
"metadata": {},
16+
"outputs": [],
17+
"source": [
18+
"import pandas as pd\n",
19+
"from matplotlib import pyplot as plt\n",
20+
"import numpy as np\n",
21+
"from tqdm.auto import tqdm\n",
22+
"from sklearn.kernel_ridge import KernelRidge\n",
23+
"from sklearn.model_selection import train_test_split\n",
24+
"from skcosmo.preprocessing import StandardFlexibleScaler\n",
25+
"from skcosmo.feature_selection import PCovFPS, PCovCUR, FPS, CUR\n",
26+
"from skcosmo.datasets import load_who_dataset"
27+
]
28+
},
29+
{
30+
"cell_type": "markdown",
31+
"id": "de5f2f17",
32+
"metadata": {},
33+
"source": [
34+
"## Load the Dataset"
35+
]
36+
},
37+
{
38+
"cell_type": "code",
39+
"execution_count": null,
40+
"id": "b816f2fb",
41+
"metadata": {},
42+
"outputs": [],
43+
"source": [
44+
"df = load_who_dataset()['data']\n",
45+
"df"
46+
]
47+
},
48+
{
49+
"cell_type": "code",
50+
"execution_count": null,
51+
"id": "472af9a2",
52+
"metadata": {
53+
"code_folding": []
54+
},
55+
"outputs": [],
56+
"source": [
57+
"columns = np.array([\n",
58+
" \"SP.POP.TOTL\",\n",
59+
" \"SH.TBS.INCD\",\n",
60+
" \"SH.IMM.MEAS\",\n",
61+
" \"SE.XPD.TOTL.GD.ZS\",\n",
62+
" \"SH.DYN.AIDS.ZS\",\n",
63+
" \"SH.IMM.IDPT\",\n",
64+
" \"SH.XPD.CHEX.GD.ZS\",\n",
65+
" \"SN.ITK.DEFC.ZS\",\n",
66+
" \"NY.GDP.PCAP.CD\",\n",
67+
"])\n",
68+
"\n",
69+
"column_names = np.array([\n",
70+
" \"Population\",\n",
71+
" \"Tuberculosis\",\n",
72+
" \"Immunization, measles\",\n",
73+
" \"Educ. Expenditure\",\n",
74+
" \"HIV\",\n",
75+
" \"Immunization, DPT\",\n",
76+
" \"Health Expenditure\",\n",
77+
" \"Undernourishment\",\n",
78+
" \"GDP per capita\",\n",
79+
"])\n",
80+
"\n",
81+
"columns = columns[[8, 4, 5, 6, 1, 0, 7, 3, 2]].tolist()\n",
82+
"column_names = column_names[[8, 4, 5, 6, 1, 0, 7, 3, 2]].tolist()"
83+
]
84+
},
85+
{
86+
"cell_type": "code",
87+
"execution_count": null,
88+
"id": "a06715d8",
89+
"metadata": {
90+
"code_folding": []
91+
},
92+
"outputs": [],
93+
"source": [
94+
"X_raw = np.array(df[columns]) \n",
95+
"\n",
96+
"# We are taking the logarithm of the population and GDP to avoid extreme distributions\n",
97+
"log_scaled = ['SP.POP.TOTL', 'NY.GDP.PCAP.CD']\n",
98+
"for ls in log_scaled:\n",
99+
" print(X_raw[:, columns.index(ls)].min(), X_raw[:, columns.index(ls)].max())\n",
100+
" if ls in columns:\n",
101+
" X_raw[:, columns.index(ls)] = np.log10(\n",
102+
" X_raw[:, columns.index(ls)]\n",
103+
" )\n",
104+
"y_raw = np.array(df[\"SP.DYN.LE00.IN\"]) # [np.where(df['Year']==2000)[0]])\n",
105+
"y_raw = y_raw.reshape(-1, 1)\n",
106+
"X_raw.shape"
107+
]
108+
},
109+
{
110+
"cell_type": "markdown",
111+
"id": "f8cccebd",
112+
"metadata": {},
113+
"source": [
114+
"## Scale and Center the Features and Targets"
115+
]
116+
},
117+
{
118+
"cell_type": "code",
119+
"execution_count": null,
120+
"id": "43241e40",
121+
"metadata": {},
122+
"outputs": [],
123+
"source": [
124+
"x_scaler = StandardFlexibleScaler(column_wise=True)\n",
125+
"X = x_scaler.fit_transform(X_raw)\n",
126+
"\n",
127+
"y_scaler = StandardFlexibleScaler(column_wise=True)\n",
128+
"y = y_scaler.fit_transform(y_raw)\n",
129+
"\n",
130+
"n_components = 2\n",
131+
"\n",
132+
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, shuffle=True, random_state=0)"
133+
]
134+
},
135+
{
136+
"cell_type": "markdown",
137+
"id": "e623dc38",
138+
"metadata": {},
139+
"source": [
140+
"## Provide an estimated target for the feature selector"
141+
]
142+
},
143+
{
144+
"cell_type": "code",
145+
"execution_count": null,
146+
"id": "3d307bdc",
147+
"metadata": {},
148+
"outputs": [],
149+
"source": [
150+
"kernel_params = {\"kernel\": \"rbf\", \"gamma\": 0.08858667904100832}\n",
151+
"krr = KernelRidge(alpha=0.006158482110660267, **kernel_params)\n",
152+
"\n",
153+
"yp_train = krr.fit(X_train, y_train).predict(X_train)"
154+
]
155+
},
156+
{
157+
"cell_type": "markdown",
158+
"id": "bb6adcbb",
159+
"metadata": {},
160+
"source": [
161+
"## Compute the Selections for Each Selector Type"
162+
]
163+
},
164+
{
165+
"cell_type": "code",
166+
"execution_count": null,
167+
"id": "73b012f9",
168+
"metadata": {},
169+
"outputs": [],
170+
"source": [
171+
"n_select = X.shape[1]"
172+
]
173+
},
174+
{
175+
"cell_type": "markdown",
176+
"id": "d54fd7e0",
177+
"metadata": {},
178+
"source": [
179+
"### PCov-CUR"
180+
]
181+
},
182+
{
183+
"cell_type": "code",
184+
"execution_count": null,
185+
"id": "40469566",
186+
"metadata": {
187+
"scrolled": false
188+
},
189+
"outputs": [],
190+
"source": [
191+
"pcur = PCovCUR(n_to_select=n_select, progress_bar=True, mixing=0.0)\n",
192+
"pcur.fit(X_train, yp_train)"
193+
]
194+
},
195+
{
196+
"cell_type": "markdown",
197+
"id": "74feb992",
198+
"metadata": {},
199+
"source": [
200+
"### PCov-FPS"
201+
]
202+
},
203+
{
204+
"cell_type": "code",
205+
"execution_count": null,
206+
"id": "17eb69d7",
207+
"metadata": {},
208+
"outputs": [],
209+
"source": [
210+
"pfps = PCovFPS(n_to_select=n_select, progress_bar=True, mixing=0.0, initialize=pcur.selected_idx_[0])\n",
211+
"pfps.fit(X_train, yp_train)"
212+
]
213+
},
214+
{
215+
"cell_type": "markdown",
216+
"id": "2d7c1762",
217+
"metadata": {},
218+
"source": [
219+
"### CUR"
220+
]
221+
},
222+
{
223+
"cell_type": "code",
224+
"execution_count": null,
225+
"id": "ef80f649",
226+
"metadata": {},
227+
"outputs": [],
228+
"source": [
229+
"cur = CUR(n_to_select=n_select, progress_bar=True)\n",
230+
"cur.fit(X_train, y_train)"
231+
]
232+
},
233+
{
234+
"cell_type": "markdown",
235+
"id": "29536065",
236+
"metadata": {},
237+
"source": [
238+
"### FPS"
239+
]
240+
},
241+
{
242+
"cell_type": "code",
243+
"execution_count": null,
244+
"id": "e4c934cb",
245+
"metadata": {},
246+
"outputs": [],
247+
"source": [
248+
"fps = FPS(n_to_select=n_select, progress_bar=True, initialize=cur.selected_idx_[0])\n",
249+
"fps.fit(X_train, y_train)"
250+
]
251+
},
252+
{
253+
"cell_type": "markdown",
254+
"id": "275587cd",
255+
"metadata": {},
256+
"source": [
257+
"### (For Comparison) Recurisive Feature Addition"
258+
]
259+
},
260+
{
261+
"cell_type": "code",
262+
"execution_count": null,
263+
"id": "1e5510bf",
264+
"metadata": {},
265+
"outputs": [],
266+
"source": [
267+
"class RecursiveFeatureAddition:\n",
268+
" def __init__(self, n_to_select):\n",
269+
" self.n_to_select = n_to_select\n",
270+
" self.selected_idx_ = np.zeros(n_to_select, dtype=int)\n",
271+
" def fit(self, X, y):\n",
272+
" remaining = np.arange(X.shape[1])\n",
273+
" for n in range(self.n_to_select):\n",
274+
" errors = np.zeros(len(remaining))\n",
275+
" for i, pp in enumerate(remaining):\n",
276+
" krr.fit(\n",
277+
" X[:, [*self.selected_idx_[:n], pp]], y\n",
278+
" )\n",
279+
" errors[i] = krr.score(X[:, [*self.selected_idx_[:n], pp]], y)\n",
280+
" self.selected_idx_[n] = remaining[np.argmax(errors)]\n",
281+
" remaining = np.array(np.delete(remaining, np.argmax(errors)), dtype=int)\n",
282+
" return self\n",
283+
"rfa = RecursiveFeatureAddition(n_select).fit(X_train, y_train)"
284+
]
285+
},
286+
{
287+
"cell_type": "markdown",
288+
"id": "5975fde7",
289+
"metadata": {},
290+
"source": [
291+
"## Plot our Results"
292+
]
293+
},
294+
{
295+
"cell_type": "code",
296+
"execution_count": null,
297+
"id": "a6b7a203",
298+
"metadata": {
299+
"code_folding": [],
300+
"scrolled": false
301+
},
302+
"outputs": [],
303+
"source": [
304+
"fig, axes = plt.subplots(2, 1,figsize=(3.75, 5), gridspec_kw=dict(height_ratios=(1,1.5)), sharex=True, dpi=150)\n",
305+
"ns = np.arange(1, n_select, dtype=int)\n",
306+
"\n",
307+
"all_errors = {}\n",
308+
"for selector, color, linestyle, label in zip(\n",
309+
" [cur, fps, pcur, pfps, rfa],\n",
310+
" [\"red\", \"lightcoral\", \"blue\", \"dodgerblue\", \"black\"],\n",
311+
" [\"solid\", \"solid\", \"solid\", \"solid\", \"dashed\"],\n",
312+
" [\n",
313+
" \"CUR\",\n",
314+
" \"FPS\",\n",
315+
" \"PCov-CUR\\n\"+r\"($\\alpha=0.0$)\",\n",
316+
" \"PCov-FPS\\n\"+r\"($\\alpha=0.0$)\",\n",
317+
" \"Recursive\\nFeature\\nSelection\",\n",
318+
" ], \n",
319+
"):\n",
320+
" if label not in all_errors:\n",
321+
" errors = np.zeros(len(ns))\n",
322+
" for i, n in enumerate(ns):\n",
323+
" krr.fit(X_train[:, selector.selected_idx_[:n]], y_train)\n",
324+
" errors[i] = krr.score(X_test[:, selector.selected_idx_[:n]], y_test)\n",
325+
" all_errors[label] = errors\n",
326+
" axes[0].plot(ns, all_errors[label], c=color, label=label, linestyle=linestyle)\n",
327+
" axes[1].plot(ns, selector.selected_idx_[:max(ns)], c=color, marker='.', linestyle=linestyle)\n",
328+
"\n",
329+
"axes[1].set_xlabel(r\"$n_{select}$\")\n",
330+
"axes[1].set_xticks(range(1, n_select))\n",
331+
"axes[0].set_ylabel(r\"R$^2$\")\n",
332+
"axes[1].set_yticks(np.arange(X.shape[1]))\n",
333+
"axes[1].set_yticklabels(column_names, rotation=30, fontsize=10)\n",
334+
"axes[0].legend(ncol=2, fontsize=8, bbox_to_anchor=(0.5, 1.0), loc='lower center')\n",
335+
"axes[1].invert_yaxis()\n",
336+
"axes[1].grid(axis='y', alpha=0.5)\n",
337+
"plt.tight_layout()\n",
338+
"plt.show()"
339+
]
340+
}
341+
],
342+
"metadata": {
343+
"kernelspec": {
344+
"display_name": "Python 3 (ipykernel)",
345+
"language": "python",
346+
"name": "python3"
347+
},
348+
"language_info": {
349+
"codemirror_mode": {
350+
"name": "ipython",
351+
"version": 3
352+
},
353+
"file_extension": ".py",
354+
"mimetype": "text/x-python",
355+
"name": "python",
356+
"nbconvert_exporter": "python",
357+
"pygments_lexer": "ipython3",
358+
"version": "3.10.4"
359+
},
360+
"toc": {
361+
"base_numbering": 1,
362+
"nav_menu": {},
363+
"number_sections": true,
364+
"sideBar": true,
365+
"skip_h1_title": false,
366+
"title_cell": "Table of Contents",
367+
"title_sidebar": "Contents",
368+
"toc_cell": false,
369+
"toc_position": {},
370+
"toc_section_display": true,
371+
"toc_window_display": false
372+
}
373+
},
374+
"nbformat": 4,
375+
"nbformat_minor": 5
376+
}

0 commit comments

Comments
 (0)