@@ -24,6 +24,7 @@ def __init__(
2424 after_subplot : Optional [Callable [[plt .Axes , str , str ], None ]] = None ,
2525 before_plots : Optional [Callable [[plt .Figure , np .ndarray , int ], None ]] = None ,
2626 after_plots : Optional [Callable [[plt .Figure ], None ]] = None ,
27+ figsize : Optional [Tuple [int , int ]] = None ,
2728 ):
2829 """
2930 Args:
@@ -36,6 +37,7 @@ def __init__(
3637 after_subplot: function which will be called after every subplot
3738 before_plots: function which will be called before all subplots
3839 after_plots: function which will be called after all subplots
40+ figsize: optional tuple to explicitly set figure size (overrides cell_size calculation)
3941 """
4042 self .cell_size = cell_size
4143 self .max_cols = max_cols
@@ -47,6 +49,7 @@ def __init__(
4749 self ._after_subplot = after_subplot if after_subplot else self ._default_after_subplot
4850 self ._before_plots = before_plots if before_plots else self ._default_before_plots
4951 self ._after_plots = after_plots if after_plots else self ._default_after_plots
52+ self .figsize = figsize
5053
5154 def send (self , logger : MainLogger ):
5255 """Draw figures with metrics and show"""
@@ -87,9 +90,12 @@ def _default_before_plots(self, fig: plt.Figure, axes: np.ndarray, num_of_log_gr
8790 num_of_log_groups: number of log groups
8891 """
8992 clear_output (wait = True )
90- figsize_x = self .max_cols * self .cell_size [0 ]
91- figsize_y = ((num_of_log_groups + 1 ) // self .max_cols + 1 ) * self .cell_size [1 ]
92- fig .set_size_inches (figsize_x , figsize_y )
93+ if self .figsize is not None :
94+ fig .set_size_inches (* self .figsize )
95+ else :
96+ figsize_x = self .max_cols * self .cell_size [0 ]
97+ figsize_y = ((num_of_log_groups + 1 ) // self .max_cols + 1 ) * self .cell_size [1 ]
98+ fig .set_size_inches (figsize_x , figsize_y )
9399 if num_of_log_groups < axes .size :
94100 for idx , ax in enumerate (axes [- 1 ]):
95101 if idx >= (num_of_log_groups + len (self .extra_plots )) % self .max_cols :
0 commit comments