Skip to content

Commit 7761bca

Browse files
committed
limit range of y-axis in convergence plots
1 parent 195447e commit 7761bca

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

batchglm/benchmark/nb_glm/convergence.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,21 +71,25 @@ def plot_stat(val, val_name, name_prefix, scale_y_log10=False):
7171
df = val.to_dataframe(val_name)
7272
df = df.reset_index()
7373

74+
ylim = np.max(df.loc[df.global_step.eq(1), val_name]) * 2
75+
7476
plot = (pn.ggplot(df)
7577
+ pn.aes(x="time_elapsed", y=val_name, group=groupby_col, color=groupby_col)
7678
+ pn.geom_line()
7779
+ pn.geom_vline(xintercept=df.loc[[np.argmin(df[val_name])]].time_elapsed.values[0], color="black")
7880
+ pn.geom_hline(yintercept=np.min(df[val_name]), alpha=0.5)
81+
+ pn.ylim(0, ylim)
7982
)
8083
if scale_y_log10:
8184
plot = plot + pn.scale_y_log10()
8285
plot.save(os.path.join(plot_dir, name_prefix + ".time.svg"), format="svg")
8386

8487
plot = (pn.ggplot(df)
85-
+ pn.aes(x="global_step", y=val_name, group=groupby_col, color=groupby_col)
88+
+ pn.aes(x="global_step", y=val_name, group=groupby_col, color=groupby_col, ymax=ylim)
8689
+ pn.geom_line()
8790
+ pn.geom_vline(xintercept=df.loc[[np.argmin(df[val_name])]].global_step.values[0], color="black")
8891
+ pn.geom_hline(yintercept=np.min(df[val_name]), alpha=0.5)
92+
+ pn.ylim(0, ylim)
8993
)
9094
if scale_y_log10:
9195
plot = plot + pn.scale_y_log10()

0 commit comments

Comments
 (0)