|
9 | 9 | import pandas as pd
|
10 | 10 | from scipy.io import arff
|
11 | 11 | from sklearn.svm import SVC
|
12 |
| -from sklearn.linear_model import RidgeClassifier, LogisticRegression |
| 12 | +from sklearn.linear_model import LogisticRegression |
| 13 | +from sklearn.ensemble import RandomForestClassifier |
13 | 14 | from sklearn.model_selection import GridSearchCV, cross_val_score, StratifiedKFold
|
14 | 15 | from sklearn.preprocessing import MinMaxScaler
|
15 | 16 | from sklearn.ensemble import BaggingClassifier
|
@@ -74,12 +75,10 @@ def define_and_evaluate_pipelines(X, y, random_state=0):
|
74 | 75 | "logistic__C": [1e-4, 1e-3, 5e-3, 1e-2, 5e-2, 1e-1, 1e1, 1e2],
|
75 | 76 | }
|
76 | 77 |
|
77 |
| - # bagged ridge |
78 |
| - pipeline3 = BaggingClassifier( |
79 |
| - Pipeline([("scaler", MinMaxScaler()), ("ridge", RidgeClassifier(random_state=random_state)),]) |
80 |
| - ) |
| 78 | + # random forest |
| 79 | + pipeline3 = RandomForestClassifier(random_state=random_state) |
81 | 80 | param_grid3 = {
|
82 |
| - "base_estimator__ridge__alpha": [1e-4, 1e-3, 5e-3, 1e-2, 5e-2, 1e-1, 1e1, 1e2], |
| 81 | + "max_depth": [1, 2, 4, 8, 16, 32, None], |
83 | 82 | }
|
84 | 83 |
|
85 | 84 | nested_scores1 = evaluate_pipeline_helper(X, y, pipeline1, param_grid1, random_state=random_state)
|
@@ -117,3 +116,20 @@ def define_and_evaluate_pipelines(X, y, random_state=0):
|
117 | 116 | times.append(elapsed)
|
118 | 117 | print("done. elapsed:", elapsed)
|
119 | 118 |
|
| 119 | +# |
| 120 | +results1 = np.array(results1) |
| 121 | +results2 = np.array(results2) |
| 122 | +results3 = np.array(results3) |
| 123 | +evaluated_datasets = np.array(evaluated_datasets) |
| 124 | +times = np.array(times) |
| 125 | + |
| 126 | +# remove things with exactly 1.0 score as it means it's not interesting |
| 127 | + |
| 128 | + |
| 129 | +# save everything to disk so we can make plots elsewhere |
| 130 | +with open("results/01_compare_baseline_models.pickle", "wb") as f: |
| 131 | + pickle.dump((results1, results2, results3, evaluated_datasets, times), f) |
| 132 | + |
| 133 | + |
| 134 | +# find all the datasets |
| 135 | + |
0 commit comments