Skip to content

Commit 3029480

Browse files
committed
chore: refine figure plotting
1 parent 0cf1ecb commit 3029480

File tree

1 file changed

+166
-14
lines changed

1 file changed

+166
-14
lines changed

bench/bench_plot.py

Lines changed: 166 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,49 @@
3333
default=Path("."),
3434
help="Directory to save plots (default: current directory)",
3535
)
36+
parser.add_argument(
37+
"--font-scale",
38+
type=float,
39+
default=1.6,
40+
help="Scale factor for fonts and markers (default: 1.6)",
41+
)
42+
parser.add_argument(
43+
"--dpi",
44+
type=int,
45+
default=320,
46+
help="PNG export DPI (default: 320)",
47+
)
48+
parser.add_argument(
49+
"--style",
50+
type=str,
51+
choices=["points", "lines", "both"],
52+
default="points",
53+
help="Plot style for modes: points, lines, or both (default: points)",
54+
)
55+
parser.add_argument(
56+
"--max-modes",
57+
type=int,
58+
default=None,
59+
help="If set, plot only the top N modes by mean of the current metric",
60+
)
61+
parser.add_argument(
62+
"--xtick-rotation",
63+
type=float,
64+
default=75.0,
65+
help="Rotation angle for x tick labels (default: 75)",
66+
)
67+
parser.add_argument(
68+
"--side-margin",
69+
type=float,
70+
default=0.0,
71+
help="Shrink x-limits inward by this many x units per side (default: 0)",
72+
)
73+
parser.add_argument(
74+
"--side-expand",
75+
type=float,
76+
default=0.25,
77+
help="Expand x-limits outward by this many x units per side (default: 0.25)",
78+
)
3679
args = parser.parse_args()
3780
summary_path = args.summary
3881

@@ -128,19 +171,57 @@ def _mm(values):
128171

129172

130173
def plot_metric(metric: str, out_path: Path):
131-
fig, ax = plt.subplots(figsize=(14, 6))
174+
fig, ax = plt.subplots(figsize=(18, 8))
132175

133176
x = range(len(cats))
134177

135-
# Overlay each mode as points
178+
# Determine modes to plot, optionally limiting to top-N by mean of metric
136179
all_modes = sorted({m for c in cats for m in cat_by_mode.get(c, {}).keys()})
137180
if len(all_modes) > 0:
181+
def _mean(values):
182+
vals = [v for v in values if v is not None]
183+
return sum(vals) / len(vals) if vals else float("nan")
184+
185+
if args.max_modes is not None and args.max_modes > 0 and len(all_modes) > args.max_modes:
186+
mode_means = []
187+
for mode in all_modes:
188+
vals = [cat_by_mode.get(c, {}).get(mode, {}).get(metric) for c in cats]
189+
mode_means.append((mode, _mean(vals)))
190+
# Accuracy: higher is better; latency/tokens: lower is better
191+
ascending = metric != "accuracy"
192+
mode_means = sorted(
193+
mode_means,
194+
key=lambda kv: (float("inf") if (kv[1] != kv[1]) else kv[1]),
195+
reverse=not ascending,
196+
)
197+
all_modes = [m for m, _ in mode_means[: args.max_modes]]
198+
138199
palette = colormaps.get_cmap("tab10").resampled(len(all_modes))
200+
linestyles = ["-", "--", "-.", ":"]
139201
for i, mode in enumerate(all_modes):
140-
ys = []
141-
for c in cats:
142-
ys.append(cat_by_mode.get(c, {}).get(mode, {}).get(metric))
143-
ax.scatter(x, ys, s=20, color=palette.colors[i], label=mode, alpha=0.8)
202+
ys = [cat_by_mode.get(c, {}).get(mode, {}).get(metric) for c in cats]
203+
if args.style in ("lines", "both"):
204+
ax.plot(
205+
x,
206+
ys,
207+
color=palette.colors[i],
208+
linestyle=linestyles[i % len(linestyles)],
209+
linewidth=1.4 * args.font_scale,
210+
alpha=0.6,
211+
zorder=2,
212+
)
213+
if args.style in ("points", "both"):
214+
ax.scatter(
215+
x,
216+
ys,
217+
s=60 * args.font_scale,
218+
color=palette.colors[i],
219+
label=mode,
220+
alpha=0.85,
221+
edgecolors="white",
222+
linewidths=0.5 * args.font_scale,
223+
zorder=3,
224+
)
144225

145226
# Overlay router per-category metric as diamonds, if provided
146227
if s_router is not None:
@@ -153,24 +234,95 @@ def plot_metric(metric: str, out_path: Path):
153234
router_x.append(idx)
154235
router_vals.append(v)
155236
if router_vals:
237+
# Connect router points with a line and draw larger diamond markers
238+
ax.plot(
239+
router_x,
240+
router_vals,
241+
color="tab:red",
242+
linestyle="-",
243+
linewidth=2.0 * args.font_scale,
244+
alpha=0.85,
245+
zorder=4,
246+
)
156247
ax.scatter(
157248
router_x,
158249
router_vals,
159-
s=50,
250+
s=90 * args.font_scale,
160251
color="tab:red",
161252
marker="D",
162253
label="router",
163-
zorder=4,
254+
zorder=5,
255+
edgecolors="white",
256+
linewidths=0.6 * args.font_scale,
164257
)
165258

166259
ax.set_xticks(list(x))
167-
ax.set_xticklabels(cats, rotation=60, ha="right")
260+
ax.set_xticklabels(
261+
cats,
262+
rotation=args.xtick_rotation,
263+
ha="right",
264+
fontsize=int(14 * args.font_scale),
265+
)
266+
# Control horizontal fit by expanding/shrinking x-limits around the first/last category
267+
if len(cats) > 0:
268+
n = len(cats)
269+
# Base categorical extents
270+
base_left = -0.5
271+
base_right = n - 0.5
272+
# Apply outward expansion first, then inward margin
273+
expand = max(0.0, float(args.side_expand))
274+
max_margin = 0.49
275+
margin = max(0.0, min(float(args.side_margin), max_margin))
276+
left_xlim = base_left - expand + margin
277+
right_xlim = base_right + expand - margin
278+
if right_xlim > left_xlim:
279+
ax.set_xlim(left_xlim, right_xlim)
168280
ylabel = metric.replace("_", " ")
169-
ax.set_ylabel(ylabel)
170-
ax.set_title(f"Per-category {ylabel} per-mode values")
171-
ax.legend(ncol=3, fontsize=8)
172-
plt.tight_layout()
173-
plt.savefig(out_path, dpi=200, bbox_inches="tight")
281+
ax.set_ylabel(ylabel, fontsize=int(18 * args.font_scale))
282+
ax.set_title(f"Per-category {ylabel} per-mode values", fontsize=int(22 * args.font_scale))
283+
ax.tick_params(axis="both", which="major", labelsize=int(14 * args.font_scale))
284+
285+
# Build a figure-level legend below the axes and reserve space to prevent overlap
286+
handles, labels = ax.get_legend_handles_labels()
287+
if handles:
288+
num_series = len(handles)
289+
# Force exactly 2 legend rows; compute columns accordingly
290+
legend_rows = 2
291+
legend_ncol = max(1, (num_series + legend_rows - 1) // legend_rows)
292+
num_rows = legend_rows
293+
scale = (args.font_scale / 1.6)
294+
# Reserve generous space for long rotated tick labels and multi-row legend
295+
bottom_reserved = (0.28 * scale) + (0.12 * num_rows * scale)
296+
bottom_reserved = max(0.24, min(0.60, bottom_reserved))
297+
fig.subplots_adjust(left=0.01, right=0.999, top=0.92, bottom=bottom_reserved)
298+
# Align the legend box width with the axes width
299+
pos = ax.get_position()
300+
fig.legend(
301+
handles,
302+
labels,
303+
loc="lower left",
304+
bbox_to_anchor=(pos.x0, 0.02, pos.width, 0.001),
305+
bbox_transform=fig.transFigure,
306+
ncol=legend_ncol,
307+
mode="expand",
308+
fontsize=int(14 * args.font_scale),
309+
markerscale=1.6 * args.font_scale,
310+
frameon=False,
311+
borderaxespad=0.0,
312+
columnspacing=0.8 * args.font_scale,
313+
handlelength=2.2,
314+
)
315+
else:
316+
fig.subplots_adjust(left=0.01, right=0.999, top=0.92, bottom=0.14)
317+
ax.grid(axis="y", linestyle=":", linewidth=0.8, alpha=0.3)
318+
# Eliminate additional automatic horizontal padding
319+
ax.margins(x=0.0)
320+
# Layout handled via subplots_adjust above to avoid legend overlap
321+
# Save both PNG and PDF variants
322+
png_path = out_path.with_suffix(".png")
323+
pdf_path = out_path.with_suffix(".pdf")
324+
plt.savefig(png_path, dpi=int(args.dpi), bbox_inches="tight")
325+
plt.savefig(pdf_path, bbox_inches="tight")
174326
plt.close(fig)
175327

176328

0 commit comments

Comments
 (0)