Skip to content

Commit ab8e7af

Browse files
committed
better
1 parent e90e709 commit ab8e7af

File tree

1 file changed

+26
-20
lines changed

1 file changed

+26
-20
lines changed

app.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,7 @@
2323

2424
# --- Sidebar configuration with hover tooltips ---
2525
st.sidebar.title("Configuration")
26-
# Fullscreen option for plots
27-
fullscreen = st.sidebar.checkbox(
28-
"Fullscreen plots",
29-
False,
30-
help="Toggle to display plots in full container width/height for detail."
31-
)
32-
26+
# --- Sidebar configuration with hover tooltips ---
3327
st.sidebar.title("Configuration")
3428
# Dataset selection
3529
dataset_name = st.sidebar.selectbox(
@@ -216,30 +210,42 @@ def get_data(name):
216210
st.dataframe(metrics_df, use_container_width=True)
217211
# Plot decision boundaries
218212
x_vis = X_train_pre[:, :2]
213+
# Create mesh grid once based on x_vis
214+
x_min, x_max = x_vis[:,0].min() - 1, x_vis[:,0].max() + 1
215+
y_min, y_max = x_vis[:,1].min() - 1, x_vis[:,1].max() + 1
216+
xx, yy = np.meshgrid(
217+
np.linspace(x_min, x_max, 200),
218+
np.linspace(y_min, y_max, 200)
219+
)
219220
for _, row in metrics_df.iterrows():
220221
name = row["Model"]
221222
exp = st.expander(f"Decision Boundary: {name}")
222223
with exp:
223-
est_vis = clone(models[name])
224-
if name in ["Local Outlier Factor"]:
225-
est_vis.fit(x_vis)
224+
# use columns to restrict plot width
225+
col1, _ = st.columns([1, 2])
226+
# zoom toggle
227+
zoom = col1.checkbox("Enlarge plot", key=f"zoom_{name}")
228+
fig_w, fig_h = (6, 4) if zoom else (3, 2)
229+
# train on 2D for visualization
230+
model_vis = clone(models[name])
231+
if name == "Local Outlier Factor":
232+
model_vis.fit(x_vis)
226233
else:
227-
est_vis.fit(x_vis, y_train if name not in ["Isolation Forest", "One-Class SVM"] else None)
228-
xx, yy = np.meshgrid(
229-
np.linspace(x_vis[:,0].min()-1, x_vis[:,0].max()+1, 200),
230-
np.linspace(x_vis[:,1].min()-1, x_vis[:,1].max()+1, 200)
231-
)
232-
Z_raw = est_vis.predict(np.c_[xx.ravel(), yy.ravel()])
234+
fit_args = (x_vis, y_train) if name not in ["Isolation Forest", "One-Class SVM"] else (x_vis, None)
235+
model_vis.fit(*fit_args)
236+
# predict on grid
237+
Z_pred = model_vis.predict(np.c_[xx.ravel(), yy.ravel()])
238+
# map anomalies
233239
if name in ["Isolation Forest", "One-Class SVM", "Local Outlier Factor"]:
234-
Z = (Z_raw>0).astype(int).reshape(xx.shape)
240+
Z = (Z_pred > 0).astype(int).reshape(xx.shape)
235241
else:
236-
Z = Z_raw.reshape(xx.shape)
237-
fig_w, fig_h = (6, 4) if fullscreen else (3, 2)
242+
Z = Z_pred.reshape(xx.shape)
243+
# plot
238244
plt.figure(figsize=(fig_w, fig_h))
239245
plt.contourf(xx, yy, Z, alpha=0.3)
240246
plt.scatter(x_vis[:,0], x_vis[:,1], c=y_train, edgecolor='k', s=20)
241247
plt.title(name)
242248
plt.xlabel("Component 1")
243249
plt.ylabel("Component 2")
244-
st.pyplot(plt)
250+
col1.pyplot(plt)
245251

0 commit comments

Comments
 (0)