Skip to content

Commit a4313e4

Browse files
author
Sergey Feldman
committed
update autogluon as per best practice guidelines
1 parent 1d18442 commit a4313e4

File tree

2 files changed

+251
-7
lines changed

2 files changed

+251
-7
lines changed

03_autogluon.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from utils import load_data
1010

1111

12-
SEC = 120
12+
SEC = 60 * 5
1313

1414

1515
def define_and_evaluate_autogluon_pipeline(X, y, random_state=0):
@@ -31,7 +31,7 @@ def define_and_evaluate_autogluon_pipeline(X, y, random_state=0):
3131
data_df_train,
3232
"y",
3333
time_limits=SEC,
34-
auto_stack=True,
34+
presets="best_quality",
3535
output_directory=".autogluon_temp",
3636
eval_metric=eval_metric,
3737
problem_type=problem_type,

make_figures.ipynb

Lines changed: 249 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@
8181
},
8282
{
8383
"cell_type": "code",
84-
"execution_count": 39,
84+
"execution_count": 10,
8585
"metadata": {},
8686
"outputs": [
8787
{
@@ -91,7 +91,7 @@
9191
"Number of datasets each algorithm does best on:\n",
9292
"Counter({'AutoGluon (sec=120)': 84, 'AutoGluon (sec=60)': 74, 'LightGBM (n_iter=25)': 74, 'LightGBM (n_iter=10)': 68, 'Logistic Regression': 64, 'Random Forest': 64, 'SVC': 35}) \n",
9393
"\n",
94-
"Average performance for each algorithm: model\n",
94+
"Average performance for each model\n",
9595
"AutoGluon (sec=120) 0.887491\n",
9696
"AutoGluon (sec=60) 0.886326\n",
9797
"LightGBM (n_iter=10) 0.886359\n",
@@ -101,7 +101,7 @@
101101
"SVC 0.852368\n",
102102
"Name: mean_auroc, dtype: float64 \n",
103103
"\n",
104-
"Median performance for each algorithm: model\n",
104+
"Median performance for each model\n",
105105
"AutoGluon (sec=120) 0.924359\n",
106106
"AutoGluon (sec=60) 0.925754\n",
107107
"LightGBM (n_iter=10) 0.924920\n",
@@ -124,8 +124,8 @@
124124
"\n",
125125
"print('Number of datasets each algorithm does best on:')\n",
126126
"print(Counter(winning_algorithms), '\\n')\n",
127-
"print('Average performance for each algorithm:', results_df.groupby('model')['mean_auroc'].mean(), '\\n')\n",
128-
"print('Median performance for each algorithm:', results_df.groupby('model')['mean_auroc'].median())"
127+
"print('Average performance for each', results_df.groupby('model')['mean_auroc'].mean(), '\\n')\n",
128+
"print('Median performance for each', results_df.groupby('model')['mean_auroc'].median())"
129129
]
130130
},
131131
{
@@ -242,6 +242,250 @@
242242
"g.set(xscale=\"log\")"
243243
]
244244
},
245+
{
246+
"cell_type": "code",
247+
"execution_count": 13,
248+
"metadata": {},
249+
"outputs": [
250+
{
251+
"data": {
252+
"text/plain": [
253+
"dataset\n",
254+
"iris 0.000000\n",
255+
"robot-nav-sensor-readings-2 0.000000\n",
256+
"robot-nav-sensor-readings-4 0.000000\n",
257+
"hayes-roth 0.000000\n",
258+
"banknote-authentication 0.000000\n",
259+
" ... \n",
260+
"thoracic-surgery 0.014894\n",
261+
"leukemia-haslinger 0.022436\n",
262+
"autoUniv-au7-cpd1-500 0.022964\n",
263+
"planning-relax 0.051938\n",
264+
"meta-data 0.324029\n",
265+
"Name: mean_auroc, Length: 142, dtype: float64"
266+
]
267+
},
268+
"execution_count": 13,
269+
"metadata": {},
270+
"output_type": "execute_result"
271+
}
272+
],
273+
"source": [
274+
"results_df.groupby('dataset')['mean_auroc'].apply(lambda x: np.sort(x)[-1] - np.sort(x)[-2]).sort_values()"
275+
]
276+
},
277+
{
278+
"cell_type": "code",
279+
"execution_count": 24,
280+
"metadata": {},
281+
"outputs": [
282+
{
283+
"data": {
284+
"text/html": [
285+
"<div>\n",
286+
"<style scoped>\n",
287+
" .dataframe tbody tr th:only-of-type {\n",
288+
" vertical-align: middle;\n",
289+
" }\n",
290+
"\n",
291+
" .dataframe tbody tr th {\n",
292+
" vertical-align: top;\n",
293+
" }\n",
294+
"\n",
295+
" .dataframe thead th {\n",
296+
" text-align: right;\n",
297+
" }\n",
298+
"</style>\n",
299+
"<table border=\"1\" class=\"dataframe\">\n",
300+
" <thead>\n",
301+
" <tr style=\"text-align: right;\">\n",
302+
" <th></th>\n",
303+
" <th>auroc_split_1</th>\n",
304+
" <th>auroc_split_2</th>\n",
305+
" <th>auroc_split_3</th>\n",
306+
" <th>auroc_split_4</th>\n",
307+
" <th>model</th>\n",
308+
" <th>nrow</th>\n",
309+
" <th>ncol</th>\n",
310+
" <th>mv</th>\n",
311+
" <th>ir</th>\n",
312+
" <th>class</th>\n",
313+
" <th>mean_auroc</th>\n",
314+
" <th>min_auroc</th>\n",
315+
" <th>max_auroc</th>\n",
316+
" <th>std_auroc</th>\n",
317+
" <th>dataset</th>\n",
318+
" </tr>\n",
319+
" </thead>\n",
320+
" <tbody>\n",
321+
" <tr>\n",
322+
" <th>planning-relax</th>\n",
323+
" <td>0.648019</td>\n",
324+
" <td>0.594406</td>\n",
325+
" <td>0.358173</td>\n",
326+
" <td>0.531250</td>\n",
327+
" <td>SVC</td>\n",
328+
" <td>182.0</td>\n",
329+
" <td>13.0</td>\n",
330+
" <td>0.0</td>\n",
331+
" <td>0.714286</td>\n",
332+
" <td>2.0</td>\n",
333+
" <td>0.532962</td>\n",
334+
" <td>0.358173</td>\n",
335+
" <td>0.648019</td>\n",
336+
" <td>0.125920</td>\n",
337+
" <td>planning-relax</td>\n",
338+
" </tr>\n",
339+
" <tr>\n",
340+
" <th>planning-relax</th>\n",
341+
" <td>0.375291</td>\n",
342+
" <td>0.356643</td>\n",
343+
" <td>0.305288</td>\n",
344+
" <td>0.497596</td>\n",
345+
" <td>Logistic Regression</td>\n",
346+
" <td>182.0</td>\n",
347+
" <td>13.0</td>\n",
348+
" <td>0.0</td>\n",
349+
" <td>0.714286</td>\n",
350+
" <td>2.0</td>\n",
351+
" <td>0.383705</td>\n",
352+
" <td>0.305288</td>\n",
353+
" <td>0.497596</td>\n",
354+
" <td>0.081493</td>\n",
355+
" <td>planning-relax</td>\n",
356+
" </tr>\n",
357+
" <tr>\n",
358+
" <th>planning-relax</th>\n",
359+
" <td>0.341492</td>\n",
360+
" <td>0.403263</td>\n",
361+
" <td>0.268029</td>\n",
362+
" <td>0.413462</td>\n",
363+
" <td>Random Forest</td>\n",
364+
" <td>182.0</td>\n",
365+
" <td>13.0</td>\n",
366+
" <td>0.0</td>\n",
367+
" <td>0.714286</td>\n",
368+
" <td>2.0</td>\n",
369+
" <td>0.356561</td>\n",
370+
" <td>0.268029</td>\n",
371+
" <td>0.413462</td>\n",
372+
" <td>0.067042</td>\n",
373+
" <td>planning-relax</td>\n",
374+
" </tr>\n",
375+
" <tr>\n",
376+
" <th>planning-relax</th>\n",
377+
" <td>0.393939</td>\n",
378+
" <td>0.550117</td>\n",
379+
" <td>0.268029</td>\n",
380+
" <td>0.500000</td>\n",
381+
" <td>LightGBM (n_iter=10)</td>\n",
382+
" <td>182.0</td>\n",
383+
" <td>13.0</td>\n",
384+
" <td>0.0</td>\n",
385+
" <td>0.714286</td>\n",
386+
" <td>2.0</td>\n",
387+
" <td>0.428021</td>\n",
388+
" <td>0.268029</td>\n",
389+
" <td>0.550117</td>\n",
390+
" <td>0.124963</td>\n",
391+
" <td>planning-relax</td>\n",
392+
" </tr>\n",
393+
" <tr>\n",
394+
" <th>planning-relax</th>\n",
395+
" <td>0.493007</td>\n",
396+
" <td>0.589744</td>\n",
397+
" <td>0.341346</td>\n",
398+
" <td>0.500000</td>\n",
399+
" <td>LightGBM (n_iter=25)</td>\n",
400+
" <td>182.0</td>\n",
401+
" <td>13.0</td>\n",
402+
" <td>0.0</td>\n",
403+
" <td>0.714286</td>\n",
404+
" <td>2.0</td>\n",
405+
" <td>0.481024</td>\n",
406+
" <td>0.341346</td>\n",
407+
" <td>0.589744</td>\n",
408+
" <td>0.103011</td>\n",
409+
" <td>planning-relax</td>\n",
410+
" </tr>\n",
411+
" <tr>\n",
412+
" <th>planning-relax</th>\n",
413+
" <td>0.333333</td>\n",
414+
" <td>0.496503</td>\n",
415+
" <td>0.367788</td>\n",
416+
" <td>0.514423</td>\n",
417+
" <td>AutoGluon (sec=60)</td>\n",
418+
" <td>182.0</td>\n",
419+
" <td>13.0</td>\n",
420+
" <td>0.0</td>\n",
421+
" <td>0.714286</td>\n",
422+
" <td>2.0</td>\n",
423+
" <td>0.428012</td>\n",
424+
" <td>0.333333</td>\n",
425+
" <td>0.514423</td>\n",
426+
" <td>0.090827</td>\n",
427+
" <td>planning-relax</td>\n",
428+
" </tr>\n",
429+
" <tr>\n",
430+
" <th>planning-relax</th>\n",
431+
" <td>0.365967</td>\n",
432+
" <td>0.463869</td>\n",
433+
" <td>0.382212</td>\n",
434+
" <td>0.500000</td>\n",
435+
" <td>AutoGluon (sec=120)</td>\n",
436+
" <td>182.0</td>\n",
437+
" <td>13.0</td>\n",
438+
" <td>0.0</td>\n",
439+
" <td>0.714286</td>\n",
440+
" <td>2.0</td>\n",
441+
" <td>0.428012</td>\n",
442+
" <td>0.365967</td>\n",
443+
" <td>0.500000</td>\n",
444+
" <td>0.064331</td>\n",
445+
" <td>planning-relax</td>\n",
446+
" </tr>\n",
447+
" </tbody>\n",
448+
"</table>\n",
449+
"</div>"
450+
],
451+
"text/plain": [
452+
" auroc_split_1 auroc_split_2 auroc_split_3 auroc_split_4 \\\n",
453+
"planning-relax 0.648019 0.594406 0.358173 0.531250 \n",
454+
"planning-relax 0.375291 0.356643 0.305288 0.497596 \n",
455+
"planning-relax 0.341492 0.403263 0.268029 0.413462 \n",
456+
"planning-relax 0.393939 0.550117 0.268029 0.500000 \n",
457+
"planning-relax 0.493007 0.589744 0.341346 0.500000 \n",
458+
"planning-relax 0.333333 0.496503 0.367788 0.514423 \n",
459+
"planning-relax 0.365967 0.463869 0.382212 0.500000 \n",
460+
"\n",
461+
" model nrow ncol mv ir class \\\n",
462+
"planning-relax SVC 182.0 13.0 0.0 0.714286 2.0 \n",
463+
"planning-relax Logistic Regression 182.0 13.0 0.0 0.714286 2.0 \n",
464+
"planning-relax Random Forest 182.0 13.0 0.0 0.714286 2.0 \n",
465+
"planning-relax LightGBM (n_iter=10) 182.0 13.0 0.0 0.714286 2.0 \n",
466+
"planning-relax LightGBM (n_iter=25) 182.0 13.0 0.0 0.714286 2.0 \n",
467+
"planning-relax AutoGluon (sec=60) 182.0 13.0 0.0 0.714286 2.0 \n",
468+
"planning-relax AutoGluon (sec=120) 182.0 13.0 0.0 0.714286 2.0 \n",
469+
"\n",
470+
" mean_auroc min_auroc max_auroc std_auroc dataset \n",
471+
"planning-relax 0.532962 0.358173 0.648019 0.125920 planning-relax \n",
472+
"planning-relax 0.383705 0.305288 0.497596 0.081493 planning-relax \n",
473+
"planning-relax 0.356561 0.268029 0.413462 0.067042 planning-relax \n",
474+
"planning-relax 0.428021 0.268029 0.550117 0.124963 planning-relax \n",
475+
"planning-relax 0.481024 0.341346 0.589744 0.103011 planning-relax \n",
476+
"planning-relax 0.428012 0.333333 0.514423 0.090827 planning-relax \n",
477+
"planning-relax 0.428012 0.365967 0.500000 0.064331 planning-relax "
478+
]
479+
},
480+
"execution_count": 24,
481+
"metadata": {},
482+
"output_type": "execute_result"
483+
}
484+
],
485+
"source": [
486+
"results_df.loc['planning-relax']"
487+
]
488+
},
245489
{
246490
"cell_type": "code",
247491
"execution_count": null,

0 commit comments

Comments
 (0)