3333    default = Path ("." ),
3434    help = "Directory to save plots (default: current directory)" ,
3535)
36+ parser .add_argument (
37+     "--font-scale" ,
38+     type = float ,
39+     default = 1.6 ,
40+     help = "Scale factor for fonts and markers (default: 1.6)" ,
41+ )
42+ parser .add_argument (
43+     "--dpi" ,
44+     type = int ,
45+     default = 320 ,
46+     help = "PNG export DPI (default: 320)" ,
47+ )
48+ parser .add_argument (
49+     "--style" ,
50+     type = str ,
51+     choices = ["points" , "lines" , "both" ],
52+     default = "points" ,
53+     help = "Plot style for modes: points, lines, or both (default: points)" ,
54+ )
55+ parser .add_argument (
56+     "--max-modes" ,
57+     type = int ,
58+     default = None ,
59+     help = "If set, plot only the top N modes by mean of the current metric" ,
60+ )
61+ parser .add_argument (
62+     "--xtick-rotation" ,
63+     type = float ,
64+     default = 75.0 ,
65+     help = "Rotation angle for x tick labels (default: 75)" ,
66+ )
67+ parser .add_argument (
68+     "--side-margin" ,
69+     type = float ,
70+     default = 0.0 ,
71+     help = "Shrink x-limits inward by this many x units per side (default: 0)" ,
72+ )
73+ parser .add_argument (
74+     "--side-expand" ,
75+     type = float ,
76+     default = 0.25 ,
77+     help = "Expand x-limits outward by this many x units per side (default: 0.25)" ,
78+ )
3679args  =  parser .parse_args ()
3780summary_path  =  args .summary 
3881
@@ -128,19 +171,57 @@ def _mm(values):
128171
129172
130173def  plot_metric (metric : str , out_path : Path ):
131-     fig , ax  =  plt .subplots (figsize = (14 ,  6 ))
174+     fig , ax  =  plt .subplots (figsize = (18 ,  8 ))
132175
133176    x  =  range (len (cats ))
134177
135-     # Overlay each mode as points  
178+     # Determine modes to plot, optionally limiting to top-N by mean of metric  
136179    all_modes  =  sorted ({m  for  c  in  cats  for  m  in  cat_by_mode .get (c , {}).keys ()})
137180    if  len (all_modes ) >  0 :
181+         def  _mean (values ):
182+             vals  =  [v  for  v  in  values  if  v  is  not None ]
183+             return  sum (vals ) /  len (vals ) if  vals  else  float ("nan" )
184+ 
185+         if  args .max_modes  is  not None  and  args .max_modes  >  0  and  len (all_modes ) >  args .max_modes :
186+             mode_means  =  []
187+             for  mode  in  all_modes :
188+                 vals  =  [cat_by_mode .get (c , {}).get (mode , {}).get (metric ) for  c  in  cats ]
189+                 mode_means .append ((mode , _mean (vals )))
190+             # Accuracy: higher is better; latency/tokens: lower is better 
191+             ascending  =  metric  !=  "accuracy" 
192+             mode_means  =  sorted (
193+                 mode_means ,
194+                 key = lambda  kv : (float ("inf" ) if  (kv [1 ] !=  kv [1 ]) else  kv [1 ]),
195+                 reverse = not  ascending ,
196+             )
197+             all_modes  =  [m  for  m , _  in  mode_means [: args .max_modes ]]
198+ 
138199        palette  =  colormaps .get_cmap ("tab10" ).resampled (len (all_modes ))
200+         linestyles  =  ["-" , "--" , "-." , ":" ]
139201        for  i , mode  in  enumerate (all_modes ):
140-             ys  =  []
141-             for  c  in  cats :
142-                 ys .append (cat_by_mode .get (c , {}).get (mode , {}).get (metric ))
143-             ax .scatter (x , ys , s = 20 , color = palette .colors [i ], label = mode , alpha = 0.8 )
202+             ys  =  [cat_by_mode .get (c , {}).get (mode , {}).get (metric ) for  c  in  cats ]
203+             if  args .style  in  ("lines" , "both" ):
204+                 ax .plot (
205+                     x ,
206+                     ys ,
207+                     color = palette .colors [i ],
208+                     linestyle = linestyles [i  %  len (linestyles )],
209+                     linewidth = 1.4  *  args .font_scale ,
210+                     alpha = 0.6 ,
211+                     zorder = 2 ,
212+                 )
213+             if  args .style  in  ("points" , "both" ):
214+                 ax .scatter (
215+                     x ,
216+                     ys ,
217+                     s = 60  *  args .font_scale ,
218+                     color = palette .colors [i ],
219+                     label = mode ,
220+                     alpha = 0.85 ,
221+                     edgecolors = "white" ,
222+                     linewidths = 0.5  *  args .font_scale ,
223+                     zorder = 3 ,
224+                 )
144225
145226    # Overlay router per-category metric as diamonds, if provided 
146227    if  s_router  is  not None :
@@ -153,24 +234,95 @@ def plot_metric(metric: str, out_path: Path):
153234                router_x .append (idx )
154235                router_vals .append (v )
155236        if  router_vals :
237+             # Connect router points with a line and draw larger diamond markers 
238+             ax .plot (
239+                 router_x ,
240+                 router_vals ,
241+                 color = "tab:red" ,
242+                 linestyle = "-" ,
243+                 linewidth = 2.0  *  args .font_scale ,
244+                 alpha = 0.85 ,
245+                 zorder = 4 ,
246+             )
156247            ax .scatter (
157248                router_x ,
158249                router_vals ,
159-                 s = 50 ,
250+                 s = 90   *   args . font_scale ,
160251                color = "tab:red" ,
161252                marker = "D" ,
162253                label = "router" ,
163-                 zorder = 4 ,
254+                 zorder = 5 ,
255+                 edgecolors = "white" ,
256+                 linewidths = 0.6  *  args .font_scale ,
164257            )
165258
166259    ax .set_xticks (list (x ))
167-     ax .set_xticklabels (cats , rotation = 60 , ha = "right" )
260+     ax .set_xticklabels (
261+         cats ,
262+         rotation = args .xtick_rotation ,
263+         ha = "right" ,
264+         fontsize = int (14  *  args .font_scale ),
265+     )
266+     # Control horizontal fit by expanding/shrinking x-limits around the first/last category 
267+     if  len (cats ) >  0 :
268+         n  =  len (cats )
269+         # Base categorical extents 
270+         base_left  =  - 0.5 
271+         base_right  =  n  -  0.5 
272+         # Apply outward expansion first, then inward margin 
273+         expand  =  max (0.0 , float (args .side_expand ))
274+         max_margin  =  0.49 
275+         margin  =  max (0.0 , min (float (args .side_margin ), max_margin ))
276+         left_xlim  =  base_left  -  expand  +  margin 
277+         right_xlim  =  base_right  +  expand  -  margin 
278+         if  right_xlim  >  left_xlim :
279+             ax .set_xlim (left_xlim , right_xlim )
168280    ylabel  =  metric .replace ("_" , " " )
169-     ax .set_ylabel (ylabel )
170-     ax .set_title (f"Per-category { ylabel }  )
171-     ax .legend (ncol = 3 , fontsize = 8 )
172-     plt .tight_layout ()
173-     plt .savefig (out_path , dpi = 200 , bbox_inches = "tight" )
281+     ax .set_ylabel (ylabel , fontsize = int (18  *  args .font_scale ))
282+     ax .set_title (f"Per-category { ylabel }  , fontsize = int (22  *  args .font_scale ))
283+     ax .tick_params (axis = "both" , which = "major" , labelsize = int (14  *  args .font_scale ))
284+ 
285+     # Build a figure-level legend below the axes and reserve space to prevent overlap 
286+     handles , labels  =  ax .get_legend_handles_labels ()
287+     if  handles :
288+         num_series  =  len (handles )
289+         # Force exactly 2 legend rows; compute columns accordingly 
290+         legend_rows  =  2 
291+         legend_ncol  =  max (1 , (num_series  +  legend_rows  -  1 ) //  legend_rows )
292+         num_rows  =  legend_rows 
293+         scale  =  (args .font_scale  /  1.6 )
294+         # Reserve generous space for long rotated tick labels and multi-row legend 
295+         bottom_reserved  =  (0.28  *  scale ) +  (0.12  *  num_rows  *  scale )
296+         bottom_reserved  =  max (0.24 , min (0.60 , bottom_reserved ))
297+         fig .subplots_adjust (left = 0.01 , right = 0.999 , top = 0.92 , bottom = bottom_reserved )
298+         # Align the legend box width with the axes width 
299+         pos  =  ax .get_position ()
300+         fig .legend (
301+             handles ,
302+             labels ,
303+             loc = "lower left" ,
304+             bbox_to_anchor = (pos .x0 , 0.02 , pos .width , 0.001 ),
305+             bbox_transform = fig .transFigure ,
306+             ncol = legend_ncol ,
307+             mode = "expand" ,
308+             fontsize = int (14  *  args .font_scale ),
309+             markerscale = 1.6  *  args .font_scale ,
310+             frameon = False ,
311+             borderaxespad = 0.0 ,
312+             columnspacing = 0.8  *  args .font_scale ,
313+             handlelength = 2.2 ,
314+         )
315+     else :
316+         fig .subplots_adjust (left = 0.01 , right = 0.999 , top = 0.92 , bottom = 0.14 )
317+     ax .grid (axis = "y" , linestyle = ":" , linewidth = 0.8 , alpha = 0.3 )
318+     # Eliminate additional automatic horizontal padding 
319+     ax .margins (x = 0.0 )
320+     # Layout handled via subplots_adjust above to avoid legend overlap 
321+     # Save both PNG and PDF variants 
322+     png_path  =  out_path .with_suffix (".png" )
323+     pdf_path  =  out_path .with_suffix (".pdf" )
324+     plt .savefig (png_path , dpi = int (args .dpi ), bbox_inches = "tight" )
325+     plt .savefig (pdf_path , bbox_inches = "tight" )
174326    plt .close (fig )
175327
176328
0 commit comments