Skip to content

Commit 2a18688

Browse files
author
Sergey Feldman
committed
reqs
1 parent aa2c633 commit 2a18688

File tree

2 files changed

+168
-127
lines changed

2 files changed

+168
-127
lines changed

make_figures.ipynb

Lines changed: 167 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -270,12 +270,13 @@
270270
}
271271
],
272272
"source": [
273+
"# datasets where there is a big gap between first largest and second largest performance\n",
273274
"results_df.groupby('dataset')['mean_auroc'].apply(lambda x: np.sort(x)[-1] - np.sort(x)[-2]).sort_values()[-10:]"
274275
]
275276
},
276277
{
277278
"cell_type": "code",
278-
"execution_count": 15,
279+
"execution_count": 22,
279280
"metadata": {},
280281
"outputs": [
281282
{
@@ -318,172 +319,213 @@
318319
" </thead>\n",
319320
" <tbody>\n",
320321
" <tr>\n",
321-
" <th>blogger</th>\n",
322-
" <td>0.669118</td>\n",
323-
" <td>0.838235</td>\n",
324-
" <td>0.783088</td>\n",
325-
" <td>0.647059</td>\n",
322+
" <th>connectionist-vowel</th>\n",
323+
" <td>0.989168</td>\n",
324+
" <td>0.979898</td>\n",
325+
" <td>0.978576</td>\n",
326+
" <td>0.989061</td>\n",
326327
" <td>SVC</td>\n",
327-
" <td>100.0</td>\n",
328-
" <td>6.0</td>\n",
328+
" <td>990.0</td>\n",
329+
" <td>14.0</td>\n",
329330
" <td>0.0</td>\n",
330-
" <td>0.68</td>\n",
331-
" <td>2.0</td>\n",
332-
" <td>0.734375</td>\n",
333-
" <td>0.647059</td>\n",
334-
" <td>0.838235</td>\n",
335-
" <td>0.091365</td>\n",
336-
" <td>blogger</td>\n",
331+
" <td>0.090909</td>\n",
332+
" <td>11.0</td>\n",
333+
" <td>0.984176</td>\n",
334+
" <td>0.978576</td>\n",
335+
" <td>0.989168</td>\n",
336+
" <td>0.005729</td>\n",
337+
" <td>connectionist-vowel</td>\n",
337338
" </tr>\n",
338339
" <tr>\n",
339-
" <th>blogger</th>\n",
340-
" <td>0.691176</td>\n",
341-
" <td>0.794118</td>\n",
342-
" <td>0.753676</td>\n",
343-
" <td>0.617647</td>\n",
340+
" <th>connectionist-vowel</th>\n",
341+
" <td>0.962657</td>\n",
342+
" <td>0.953214</td>\n",
343+
" <td>0.960909</td>\n",
344+
" <td>0.957383</td>\n",
344345
" <td>Logistic Regression</td>\n",
345-
" <td>100.0</td>\n",
346-
" <td>6.0</td>\n",
346+
" <td>990.0</td>\n",
347+
" <td>14.0</td>\n",
347348
" <td>0.0</td>\n",
348-
" <td>0.68</td>\n",
349-
" <td>2.0</td>\n",
350-
" <td>0.714154</td>\n",
351-
" <td>0.617647</td>\n",
352-
" <td>0.794118</td>\n",
353-
" <td>0.077023</td>\n",
354-
" <td>blogger</td>\n",
349+
" <td>0.090909</td>\n",
350+
" <td>11.0</td>\n",
351+
" <td>0.958541</td>\n",
352+
" <td>0.953214</td>\n",
353+
" <td>0.962657</td>\n",
354+
" <td>0.004174</td>\n",
355+
" <td>connectionist-vowel</td>\n",
355356
" </tr>\n",
356357
" <tr>\n",
357-
" <th>blogger</th>\n",
358-
" <td>0.830882</td>\n",
359-
" <td>0.779412</td>\n",
360-
" <td>0.871324</td>\n",
361-
" <td>0.779412</td>\n",
358+
" <th>connectionist-vowel</th>\n",
359+
" <td>0.998838</td>\n",
360+
" <td>0.998621</td>\n",
361+
" <td>0.998387</td>\n",
362+
" <td>0.999406</td>\n",
362363
" <td>Random Forest</td>\n",
363-
" <td>100.0</td>\n",
364-
" <td>6.0</td>\n",
364+
" <td>990.0</td>\n",
365+
" <td>14.0</td>\n",
365366
" <td>0.0</td>\n",
366-
" <td>0.68</td>\n",
367-
" <td>2.0</td>\n",
368-
" <td>0.815257</td>\n",
369-
" <td>0.779412</td>\n",
370-
" <td>0.871324</td>\n",
371-
" <td>0.044562</td>\n",
372-
" <td>blogger</td>\n",
367+
" <td>0.090909</td>\n",
368+
" <td>11.0</td>\n",
369+
" <td>0.998813</td>\n",
370+
" <td>0.998387</td>\n",
371+
" <td>0.999406</td>\n",
372+
" <td>0.000436</td>\n",
373+
" <td>connectionist-vowel</td>\n",
373374
" </tr>\n",
374375
" <tr>\n",
375-
" <th>blogger</th>\n",
376-
" <td>0.720588</td>\n",
377-
" <td>0.816176</td>\n",
378-
" <td>0.915441</td>\n",
379-
" <td>0.676471</td>\n",
376+
" <th>connectionist-vowel</th>\n",
377+
" <td>0.996457</td>\n",
378+
" <td>0.997028</td>\n",
379+
" <td>0.995780</td>\n",
380+
" <td>0.996505</td>\n",
380381
" <td>LightGBM (n_iter=10)</td>\n",
381-
" <td>100.0</td>\n",
382-
" <td>6.0</td>\n",
382+
" <td>990.0</td>\n",
383+
" <td>14.0</td>\n",
383384
" <td>0.0</td>\n",
384-
" <td>0.68</td>\n",
385-
" <td>2.0</td>\n",
386-
" <td>0.782169</td>\n",
387-
" <td>0.676471</td>\n",
388-
" <td>0.915441</td>\n",
389-
" <td>0.106274</td>\n",
390-
" <td>blogger</td>\n",
385+
" <td>0.090909</td>\n",
386+
" <td>11.0</td>\n",
387+
" <td>0.996443</td>\n",
388+
" <td>0.995780</td>\n",
389+
" <td>0.997028</td>\n",
390+
" <td>0.000512</td>\n",
391+
" <td>connectionist-vowel</td>\n",
391392
" </tr>\n",
392393
" <tr>\n",
393-
" <th>blogger</th>\n",
394-
" <td>0.764706</td>\n",
395-
" <td>0.794118</td>\n",
396-
" <td>0.694853</td>\n",
397-
" <td>0.625000</td>\n",
394+
" <th>connectionist-vowel</th>\n",
395+
" <td>0.561180</td>\n",
396+
" <td>0.997013</td>\n",
397+
" <td>0.995780</td>\n",
398+
" <td>0.996505</td>\n",
398399
" <td>LightGBM (n_iter=25)</td>\n",
399-
" <td>100.0</td>\n",
400-
" <td>6.0</td>\n",
400+
" <td>990.0</td>\n",
401+
" <td>14.0</td>\n",
401402
" <td>0.0</td>\n",
402-
" <td>0.68</td>\n",
403-
" <td>2.0</td>\n",
404-
" <td>0.719669</td>\n",
405-
" <td>0.625000</td>\n",
406-
" <td>0.794118</td>\n",
407-
" <td>0.075606</td>\n",
408-
" <td>blogger</td>\n",
403+
" <td>0.090909</td>\n",
404+
" <td>11.0</td>\n",
405+
" <td>0.887619</td>\n",
406+
" <td>0.561180</td>\n",
407+
" <td>0.997013</td>\n",
408+
" <td>0.217627</td>\n",
409+
" <td>connectionist-vowel</td>\n",
409410
" </tr>\n",
410411
" <tr>\n",
411-
" <th>blogger</th>\n",
412-
" <td>0.801471</td>\n",
413-
" <td>0.823529</td>\n",
414-
" <td>0.908088</td>\n",
415-
" <td>0.757353</td>\n",
412+
" <th>connectionist-vowel</th>\n",
413+
" <td>0.999374</td>\n",
414+
" <td>0.999606</td>\n",
415+
" <td>0.999748</td>\n",
416+
" <td>0.999676</td>\n",
416417
" <td>AutoGluon (sec=60)</td>\n",
417-
" <td>100.0</td>\n",
418-
" <td>6.0</td>\n",
418+
" <td>990.0</td>\n",
419+
" <td>14.0</td>\n",
419420
" <td>0.0</td>\n",
420-
" <td>0.68</td>\n",
421-
" <td>2.0</td>\n",
422-
" <td>0.822610</td>\n",
423-
" <td>0.757353</td>\n",
424-
" <td>0.908088</td>\n",
425-
" <td>0.063279</td>\n",
426-
" <td>blogger</td>\n",
421+
" <td>0.090909</td>\n",
422+
" <td>11.0</td>\n",
423+
" <td>0.999601</td>\n",
424+
" <td>0.999374</td>\n",
425+
" <td>0.999748</td>\n",
426+
" <td>0.000162</td>\n",
427+
" <td>connectionist-vowel</td>\n",
427428
" </tr>\n",
428429
" <tr>\n",
429-
" <th>blogger</th>\n",
430-
" <td>0.801471</td>\n",
431-
" <td>0.823529</td>\n",
432-
" <td>0.908088</td>\n",
433-
" <td>0.786765</td>\n",
430+
" <th>connectionist-vowel</th>\n",
431+
" <td>0.999392</td>\n",
432+
" <td>0.999893</td>\n",
433+
" <td>0.999946</td>\n",
434+
" <td>0.999711</td>\n",
434435
" <td>AutoGluon (sec=120)</td>\n",
435-
" <td>100.0</td>\n",
436-
" <td>6.0</td>\n",
436+
" <td>990.0</td>\n",
437+
" <td>14.0</td>\n",
437438
" <td>0.0</td>\n",
438-
" <td>0.68</td>\n",
439-
" <td>2.0</td>\n",
440-
" <td>0.829963</td>\n",
441-
" <td>0.786765</td>\n",
442-
" <td>0.908088</td>\n",
443-
" <td>0.054231</td>\n",
444-
" <td>blogger</td>\n",
439+
" <td>0.090909</td>\n",
440+
" <td>11.0</td>\n",
441+
" <td>0.999735</td>\n",
442+
" <td>0.999392</td>\n",
443+
" <td>0.999946</td>\n",
444+
" <td>0.000250</td>\n",
445+
" <td>connectionist-vowel</td>\n",
445446
" </tr>\n",
446447
" </tbody>\n",
447448
"</table>\n",
448449
"</div>"
449450
],
450451
"text/plain": [
451-
" auroc_split_1 auroc_split_2 auroc_split_3 auroc_split_4 \\\n",
452-
"blogger 0.669118 0.838235 0.783088 0.647059 \n",
453-
"blogger 0.691176 0.794118 0.753676 0.617647 \n",
454-
"blogger 0.830882 0.779412 0.871324 0.779412 \n",
455-
"blogger 0.720588 0.816176 0.915441 0.676471 \n",
456-
"blogger 0.764706 0.794118 0.694853 0.625000 \n",
457-
"blogger 0.801471 0.823529 0.908088 0.757353 \n",
458-
"blogger 0.801471 0.823529 0.908088 0.786765 \n",
452+
" auroc_split_1 auroc_split_2 auroc_split_3 \\\n",
453+
"connectionist-vowel 0.989168 0.979898 0.978576 \n",
454+
"connectionist-vowel 0.962657 0.953214 0.960909 \n",
455+
"connectionist-vowel 0.998838 0.998621 0.998387 \n",
456+
"connectionist-vowel 0.996457 0.997028 0.995780 \n",
457+
"connectionist-vowel 0.561180 0.997013 0.995780 \n",
458+
"connectionist-vowel 0.999374 0.999606 0.999748 \n",
459+
"connectionist-vowel 0.999392 0.999893 0.999946 \n",
459460
"\n",
460-
" model nrow ncol mv ir class mean_auroc \\\n",
461-
"blogger SVC 100.0 6.0 0.0 0.68 2.0 0.734375 \n",
462-
"blogger Logistic Regression 100.0 6.0 0.0 0.68 2.0 0.714154 \n",
463-
"blogger Random Forest 100.0 6.0 0.0 0.68 2.0 0.815257 \n",
464-
"blogger LightGBM (n_iter=10) 100.0 6.0 0.0 0.68 2.0 0.782169 \n",
465-
"blogger LightGBM (n_iter=25) 100.0 6.0 0.0 0.68 2.0 0.719669 \n",
466-
"blogger AutoGluon (sec=60) 100.0 6.0 0.0 0.68 2.0 0.822610 \n",
467-
"blogger AutoGluon (sec=120) 100.0 6.0 0.0 0.68 2.0 0.829963 \n",
461+
" auroc_split_4 model nrow ncol mv \\\n",
462+
"connectionist-vowel 0.989061 SVC 990.0 14.0 0.0 \n",
463+
"connectionist-vowel 0.957383 Logistic Regression 990.0 14.0 0.0 \n",
464+
"connectionist-vowel 0.999406 Random Forest 990.0 14.0 0.0 \n",
465+
"connectionist-vowel 0.996505 LightGBM (n_iter=10) 990.0 14.0 0.0 \n",
466+
"connectionist-vowel 0.996505 LightGBM (n_iter=25) 990.0 14.0 0.0 \n",
467+
"connectionist-vowel 0.999676 AutoGluon (sec=60) 990.0 14.0 0.0 \n",
468+
"connectionist-vowel 0.999711 AutoGluon (sec=120) 990.0 14.0 0.0 \n",
468469
"\n",
469-
" min_auroc max_auroc std_auroc dataset \n",
470-
"blogger 0.647059 0.838235 0.091365 blogger \n",
471-
"blogger 0.617647 0.794118 0.077023 blogger \n",
472-
"blogger 0.779412 0.871324 0.044562 blogger \n",
473-
"blogger 0.676471 0.915441 0.106274 blogger \n",
474-
"blogger 0.625000 0.794118 0.075606 blogger \n",
475-
"blogger 0.757353 0.908088 0.063279 blogger \n",
476-
"blogger 0.786765 0.908088 0.054231 blogger "
470+
" ir class mean_auroc min_auroc max_auroc \\\n",
471+
"connectionist-vowel 0.090909 11.0 0.984176 0.978576 0.989168 \n",
472+
"connectionist-vowel 0.090909 11.0 0.958541 0.953214 0.962657 \n",
473+
"connectionist-vowel 0.090909 11.0 0.998813 0.998387 0.999406 \n",
474+
"connectionist-vowel 0.090909 11.0 0.996443 0.995780 0.997028 \n",
475+
"connectionist-vowel 0.090909 11.0 0.887619 0.561180 0.997013 \n",
476+
"connectionist-vowel 0.090909 11.0 0.999601 0.999374 0.999748 \n",
477+
"connectionist-vowel 0.090909 11.0 0.999735 0.999392 0.999946 \n",
478+
"\n",
479+
" std_auroc dataset \n",
480+
"connectionist-vowel 0.005729 connectionist-vowel \n",
481+
"connectionist-vowel 0.004174 connectionist-vowel \n",
482+
"connectionist-vowel 0.000436 connectionist-vowel \n",
483+
"connectionist-vowel 0.000512 connectionist-vowel \n",
484+
"connectionist-vowel 0.217627 connectionist-vowel \n",
485+
"connectionist-vowel 0.000162 connectionist-vowel \n",
486+
"connectionist-vowel 0.000250 connectionist-vowel "
477487
]
478488
},
479-
"execution_count": 15,
489+
"execution_count": 22,
480490
"metadata": {},
481491
"output_type": "execute_result"
482492
}
483493
],
484494
"source": [
485495
"# interesting datasets: meta-data, planning-relax, blogger, autoUniv-au7-cpd1-500, leukemia-haslinger, thoracic-surgery, hill-valley-without-noise, blogger\n",
486-
"results_df.loc['blogger']"
496+
"results_df.loc['connectionist-vowel']"
497+
]
498+
},
499+
{
500+
"cell_type": "code",
501+
"execution_count": 23,
502+
"metadata": {},
503+
"outputs": [
504+
{
505+
"data": {
506+
"text/plain": [
507+
"dataset \n",
508+
"connectionist-vowel connectionist-vowel -0.108823\n",
509+
"blogger blogger -0.062500\n",
510+
"colon32 colon32 -0.022917\n",
511+
"hill-valley-with-noise hill-valley-with-noise -0.021507\n",
512+
"parkinsons parkinsons -0.010714\n",
513+
"bupa bupa -0.009703\n",
514+
"autoUniv-au7-300-drift-au7-cpd1-800 autoUniv-au7-300-drift-au7-cpd1-800 -0.008352\n",
515+
"volcanoes-b5 volcanoes-b5 -0.005339\n",
516+
"robot-failure-lp5 robot-failure-lp5 -0.005209\n",
517+
"habermans-survival habermans-survival -0.005079\n",
518+
"Name: mean_auroc, dtype: float64"
519+
]
520+
},
521+
"execution_count": 23,
522+
"metadata": {},
523+
"output_type": "execute_result"
524+
}
525+
],
526+
"source": [
527+
"# datasets where lightgbm with 25 evals does worse than 10 evals\n",
528+
"results_df.groupby('dataset').apply(lambda x: x[x.model == 'LightGBM (n_iter=25)']['mean_auroc'] - x[x.model == 'LightGBM (n_iter=10)']['mean_auroc']).sort_values()[0:10]"
487529
]
488530
},
489531
{

requirements.in

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,4 @@ mxnet_cu102 # insert your own cuda 3 digit version OR just 'mxnet' if don't hav
66
autogluon
77
scikit-learn==0.23.2
88
lightgbm>=3.1.1
9-
optuna
10-
scikit-optimize
9+
hyperopt

0 commit comments

Comments
 (0)