3333 default = Path ("." ),
3434 help = "Directory to save plots (default: current directory)" ,
3535)
36+ parser .add_argument (
37+ "--font-scale" ,
38+ type = float ,
39+ default = 1.6 ,
40+ help = "Scale factor for fonts and markers (default: 1.6)" ,
41+ )
42+ parser .add_argument (
43+ "--dpi" ,
44+ type = int ,
45+ default = 320 ,
46+ help = "PNG export DPI (default: 320)" ,
47+ )
48+ parser .add_argument (
49+ "--style" ,
50+ type = str ,
51+ choices = ["points" , "lines" , "both" ],
52+ default = "points" ,
53+ help = "Plot style for modes: points, lines, or both (default: points)" ,
54+ )
55+ parser .add_argument (
56+ "--max-modes" ,
57+ type = int ,
58+ default = None ,
59+ help = "If set, plot only the top N modes by mean of the current metric" ,
60+ )
61+ parser .add_argument (
62+ "--xtick-rotation" ,
63+ type = float ,
64+ default = 75.0 ,
65+ help = "Rotation angle for x tick labels (default: 75)" ,
66+ )
67+ parser .add_argument (
68+ "--side-margin" ,
69+ type = float ,
70+ default = 0.0 ,
71+ help = "Shrink x-limits inward by this many x units per side (default: 0)" ,
72+ )
73+ parser .add_argument (
74+ "--side-expand" ,
75+ type = float ,
76+ default = 0.25 ,
77+ help = "Expand x-limits outward by this many x units per side (default: 0.25)" ,
78+ )
3679args = parser .parse_args ()
3780summary_path = args .summary
3881
@@ -128,19 +171,57 @@ def _mm(values):
128171
129172
130173def plot_metric (metric : str , out_path : Path ):
131- fig , ax = plt .subplots (figsize = (14 , 6 ))
174+ fig , ax = plt .subplots (figsize = (18 , 8 ))
132175
133176 x = range (len (cats ))
134177
135- # Overlay each mode as points
178+ # Determine modes to plot, optionally limiting to top-N by mean of metric
136179 all_modes = sorted ({m for c in cats for m in cat_by_mode .get (c , {}).keys ()})
137180 if len (all_modes ) > 0 :
181+ def _mean (values ):
182+ vals = [v for v in values if v is not None ]
183+ return sum (vals ) / len (vals ) if vals else float ("nan" )
184+
185+ if args .max_modes is not None and args .max_modes > 0 and len (all_modes ) > args .max_modes :
186+ mode_means = []
187+ for mode in all_modes :
188+ vals = [cat_by_mode .get (c , {}).get (mode , {}).get (metric ) for c in cats ]
189+ mode_means .append ((mode , _mean (vals )))
190+ # Accuracy: higher is better; latency/tokens: lower is better
191+ ascending = metric != "accuracy"
192+ mode_means = sorted (
193+ mode_means ,
194+ key = lambda kv : (float ("inf" ) if (kv [1 ] != kv [1 ]) else kv [1 ]),
195+ reverse = not ascending ,
196+ )
197+ all_modes = [m for m , _ in mode_means [: args .max_modes ]]
198+
138199 palette = colormaps .get_cmap ("tab10" ).resampled (len (all_modes ))
200+ linestyles = ["-" , "--" , "-." , ":" ]
139201 for i , mode in enumerate (all_modes ):
140- ys = []
141- for c in cats :
142- ys .append (cat_by_mode .get (c , {}).get (mode , {}).get (metric ))
143- ax .scatter (x , ys , s = 20 , color = palette .colors [i ], label = mode , alpha = 0.8 )
202+ ys = [cat_by_mode .get (c , {}).get (mode , {}).get (metric ) for c in cats ]
203+ if args .style in ("lines" , "both" ):
204+ ax .plot (
205+ x ,
206+ ys ,
207+ color = palette .colors [i ],
208+ linestyle = linestyles [i % len (linestyles )],
209+ linewidth = 1.4 * args .font_scale ,
210+ alpha = 0.6 ,
211+ zorder = 2 ,
212+ )
213+ if args .style in ("points" , "both" ):
214+ ax .scatter (
215+ x ,
216+ ys ,
217+ s = 60 * args .font_scale ,
218+ color = palette .colors [i ],
219+ label = mode ,
220+ alpha = 0.85 ,
221+ edgecolors = "white" ,
222+ linewidths = 0.5 * args .font_scale ,
223+ zorder = 3 ,
224+ )
144225
145226 # Overlay router per-category metric as diamonds, if provided
146227 if s_router is not None :
@@ -153,24 +234,95 @@ def plot_metric(metric: str, out_path: Path):
153234 router_x .append (idx )
154235 router_vals .append (v )
155236 if router_vals :
237+ # Connect router points with a line and draw larger diamond markers
238+ ax .plot (
239+ router_x ,
240+ router_vals ,
241+ color = "tab:red" ,
242+ linestyle = "-" ,
243+ linewidth = 2.0 * args .font_scale ,
244+ alpha = 0.85 ,
245+ zorder = 4 ,
246+ )
156247 ax .scatter (
157248 router_x ,
158249 router_vals ,
159- s = 50 ,
250+ s = 90 * args . font_scale ,
160251 color = "tab:red" ,
161252 marker = "D" ,
162253 label = "router" ,
163- zorder = 4 ,
254+ zorder = 5 ,
255+ edgecolors = "white" ,
256+ linewidths = 0.6 * args .font_scale ,
164257 )
165258
166259 ax .set_xticks (list (x ))
167- ax .set_xticklabels (cats , rotation = 60 , ha = "right" )
260+ ax .set_xticklabels (
261+ cats ,
262+ rotation = args .xtick_rotation ,
263+ ha = "right" ,
264+ fontsize = int (14 * args .font_scale ),
265+ )
266+ # Control horizontal fit by expanding/shrinking x-limits around the first/last category
267+ if len (cats ) > 0 :
268+ n = len (cats )
269+ # Base categorical extents
270+ base_left = - 0.5
271+ base_right = n - 0.5
272+ # Apply outward expansion first, then inward margin
273+ expand = max (0.0 , float (args .side_expand ))
274+ max_margin = 0.49
275+ margin = max (0.0 , min (float (args .side_margin ), max_margin ))
276+ left_xlim = base_left - expand + margin
277+ right_xlim = base_right + expand - margin
278+ if right_xlim > left_xlim :
279+ ax .set_xlim (left_xlim , right_xlim )
168280 ylabel = metric .replace ("_" , " " )
169- ax .set_ylabel (ylabel )
170- ax .set_title (f"Per-category { ylabel } per-mode values" )
171- ax .legend (ncol = 3 , fontsize = 8 )
172- plt .tight_layout ()
173- plt .savefig (out_path , dpi = 200 , bbox_inches = "tight" )
281+ ax .set_ylabel (ylabel , fontsize = int (18 * args .font_scale ))
282+ ax .set_title (f"Per-category { ylabel } per-mode values" , fontsize = int (22 * args .font_scale ))
283+ ax .tick_params (axis = "both" , which = "major" , labelsize = int (14 * args .font_scale ))
284+
285+ # Build a figure-level legend below the axes and reserve space to prevent overlap
286+ handles , labels = ax .get_legend_handles_labels ()
287+ if handles :
288+ num_series = len (handles )
289+ # Force exactly 2 legend rows; compute columns accordingly
290+ legend_rows = 2
291+ legend_ncol = max (1 , (num_series + legend_rows - 1 ) // legend_rows )
292+ num_rows = legend_rows
293+ scale = (args .font_scale / 1.6 )
294+ # Reserve generous space for long rotated tick labels and multi-row legend
295+ bottom_reserved = (0.28 * scale ) + (0.12 * num_rows * scale )
296+ bottom_reserved = max (0.24 , min (0.60 , bottom_reserved ))
297+ fig .subplots_adjust (left = 0.01 , right = 0.999 , top = 0.92 , bottom = bottom_reserved )
298+ # Align the legend box width with the axes width
299+ pos = ax .get_position ()
300+ fig .legend (
301+ handles ,
302+ labels ,
303+ loc = "lower left" ,
304+ bbox_to_anchor = (pos .x0 , 0.02 , pos .width , 0.001 ),
305+ bbox_transform = fig .transFigure ,
306+ ncol = legend_ncol ,
307+ mode = "expand" ,
308+ fontsize = int (14 * args .font_scale ),
309+ markerscale = 1.6 * args .font_scale ,
310+ frameon = False ,
311+ borderaxespad = 0.0 ,
312+ columnspacing = 0.8 * args .font_scale ,
313+ handlelength = 2.2 ,
314+ )
315+ else :
316+ fig .subplots_adjust (left = 0.01 , right = 0.999 , top = 0.92 , bottom = 0.14 )
317+ ax .grid (axis = "y" , linestyle = ":" , linewidth = 0.8 , alpha = 0.3 )
318+ # Eliminate additional automatic horizontal padding
319+ ax .margins (x = 0.0 )
320+ # Layout handled via subplots_adjust above to avoid legend overlap
321+ # Save both PNG and PDF variants
322+ png_path = out_path .with_suffix (".png" )
323+ pdf_path = out_path .with_suffix (".pdf" )
324+ plt .savefig (png_path , dpi = int (args .dpi ), bbox_inches = "tight" )
325+ plt .savefig (pdf_path , bbox_inches = "tight" )
174326 plt .close (fig )
175327
176328
0 commit comments