|
23 | 23 |
|
24 | 24 | # --- Sidebar configuration with hover tooltips --- |
25 | 25 | st.sidebar.title("Configuration") |
26 | | -# Fullscreen option for plots |
27 | | -fullscreen = st.sidebar.checkbox( |
28 | | - "Fullscreen plots", |
29 | | - False, |
30 | | - help="Toggle to display plots in full container width/height for detail." |
31 | | -) |
32 | | - |
| 26 | +# --- Sidebar configuration with hover tooltips --- |
33 | 27 | st.sidebar.title("Configuration") |
34 | 28 | # Dataset selection |
35 | 29 | dataset_name = st.sidebar.selectbox( |
@@ -216,30 +210,42 @@ def get_data(name): |
216 | 210 | st.dataframe(metrics_df, use_container_width=True) |
217 | 211 | # Plot decision boundaries |
218 | 212 | x_vis = X_train_pre[:, :2] |
| 213 | +# Create mesh grid once based on x_vis |
| 214 | +x_min, x_max = x_vis[:,0].min() - 1, x_vis[:,0].max() + 1 |
| 215 | +y_min, y_max = x_vis[:,1].min() - 1, x_vis[:,1].max() + 1 |
| 216 | +xx, yy = np.meshgrid( |
| 217 | + np.linspace(x_min, x_max, 200), |
| 218 | + np.linspace(y_min, y_max, 200) |
| 219 | +) |
219 | 220 | for _, row in metrics_df.iterrows(): |
220 | 221 | name = row["Model"] |
221 | 222 | exp = st.expander(f"Decision Boundary: {name}") |
222 | 223 | with exp: |
223 | | - est_vis = clone(models[name]) |
224 | | - if name in ["Local Outlier Factor"]: |
225 | | - est_vis.fit(x_vis) |
| 224 | + # use columns to restrict plot width |
| 225 | + col1, _ = st.columns([1, 2]) |
| 226 | + # zoom toggle |
| 227 | + zoom = col1.checkbox("Enlarge plot", key=f"zoom_{name}") |
| 228 | + fig_w, fig_h = (6, 4) if zoom else (3, 2) |
| 229 | + # train on 2D for visualization |
| 230 | + model_vis = clone(models[name]) |
| 231 | + if name == "Local Outlier Factor": |
| 232 | + model_vis.fit(x_vis) |
226 | 233 | else: |
227 | | - est_vis.fit(x_vis, y_train if name not in ["Isolation Forest", "One-Class SVM"] else None) |
228 | | - xx, yy = np.meshgrid( |
229 | | - np.linspace(x_vis[:,0].min()-1, x_vis[:,0].max()+1, 200), |
230 | | - np.linspace(x_vis[:,1].min()-1, x_vis[:,1].max()+1, 200) |
231 | | - ) |
232 | | - Z_raw = est_vis.predict(np.c_[xx.ravel(), yy.ravel()]) |
| 234 | + fit_args = (x_vis, y_train) if name not in ["Isolation Forest", "One-Class SVM"] else (x_vis, None) |
| 235 | + model_vis.fit(*fit_args) |
| 236 | + # predict on grid |
| 237 | + Z_pred = model_vis.predict(np.c_[xx.ravel(), yy.ravel()]) |
| 238 | + # map anomalies |
233 | 239 | if name in ["Isolation Forest", "One-Class SVM", "Local Outlier Factor"]: |
234 | | - Z = (Z_raw>0).astype(int).reshape(xx.shape) |
| 240 | + Z = (Z_pred > 0).astype(int).reshape(xx.shape) |
235 | 241 | else: |
236 | | - Z = Z_raw.reshape(xx.shape) |
237 | | - fig_w, fig_h = (6, 4) if fullscreen else (3, 2) |
| 242 | + Z = Z_pred.reshape(xx.shape) |
| 243 | + # plot |
238 | 244 | plt.figure(figsize=(fig_w, fig_h)) |
239 | 245 | plt.contourf(xx, yy, Z, alpha=0.3) |
240 | 246 | plt.scatter(x_vis[:,0], x_vis[:,1], c=y_train, edgecolor='k', s=20) |
241 | 247 | plt.title(name) |
242 | 248 | plt.xlabel("Component 1") |
243 | 249 | plt.ylabel("Component 2") |
244 | | - st.pyplot(plt) |
| 250 | + col1.pyplot(plt) |
245 | 251 |
|
0 commit comments