Skip to content

Commit 84937cc

Browse files
committed
added figsize
1 parent 93aa155 commit 84937cc

File tree

2 files changed

+17
-4
lines changed

2 files changed

+17
-4
lines changed

livelossplot/outputs/matplotlib_plot.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def __init__(
2424
after_subplot: Optional[Callable[[plt.Axes, str, str], None]] = None,
2525
before_plots: Optional[Callable[[plt.Figure, np.ndarray, int], None]] = None,
2626
after_plots: Optional[Callable[[plt.Figure], None]] = None,
27+
figsize: Optional[Tuple[int, int]] = None,
2728
):
2829
"""
2930
Args:
@@ -36,6 +37,7 @@ def __init__(
3637
after_subplot: function which will be called after every subplot
3738
before_plots: function which will be called before all subplots
3839
after_plots: function which will be called after all subplots
40+
figsize: optional tuple to explicitly set figure size (overrides cell_size calculation)
3941
"""
4042
self.cell_size = cell_size
4143
self.max_cols = max_cols
@@ -47,6 +49,7 @@ def __init__(
4749
self._after_subplot = after_subplot if after_subplot else self._default_after_subplot
4850
self._before_plots = before_plots if before_plots else self._default_before_plots
4951
self._after_plots = after_plots if after_plots else self._default_after_plots
52+
self.figsize = figsize
5053

5154
def send(self, logger: MainLogger):
5255
"""Draw figures with metrics and show"""
@@ -87,9 +90,12 @@ def _default_before_plots(self, fig: plt.Figure, axes: np.ndarray, num_of_log_gr
8790
num_of_log_groups: number of log groups
8891
"""
8992
clear_output(wait=True)
90-
figsize_x = self.max_cols * self.cell_size[0]
91-
figsize_y = ((num_of_log_groups + 1) // self.max_cols + 1) * self.cell_size[1]
92-
fig.set_size_inches(figsize_x, figsize_y)
93+
if self.figsize is not None:
94+
fig.set_size_inches(*self.figsize)
95+
else:
96+
figsize_x = self.max_cols * self.cell_size[0]
97+
figsize_y = ((num_of_log_groups + 1) // self.max_cols + 1) * self.cell_size[1]
98+
fig.set_size_inches(figsize_x, figsize_y)
9399
if num_of_log_groups < axes.size:
94100
for idx, ax in enumerate(axes[-1]):
95101
if idx >= (num_of_log_groups + len(self.extra_plots)) % self.max_cols:

livelossplot/plot_losses.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import warnings
2-
from typing import Type, TypeVar, List, Union
2+
from typing import Type, TypeVar, List, Union, Optional, Tuple
33

44
import livelossplot
55
from livelossplot.main_logger import MainLogger
66
from livelossplot import outputs
7+
from livelossplot.outputs.matplotlib_plot import MatplotlibPlot
78

89
BO = TypeVar('BO', bound=outputs.BaseOutput)
910

@@ -12,10 +13,12 @@ class PlotLosses:
1213
"""
1314
Class collect metrics from the training engine and send it to plugins, when send is called
1415
"""
16+
1517
def __init__(
1618
self,
1719
outputs: List[Union[Type[BO], str]] = ['MatplotlibPlot', 'ExtremaPrinter'],
1820
mode: str = 'notebook',
21+
figsize: Optional[Tuple[int, int]] = None,
1922
**kwargs
2023
):
2124
"""
@@ -24,12 +27,16 @@ def __init__(
2427
or strings for livelossplot built-in output methods with default parameters
2528
mode: Options: 'notebook' or 'script' - some of outputs need to change some behaviors,
2629
depending on the working environment
30+
figsize: tuple of (width, height) in inches for the figure
2731
**kwargs: key-arguments which are passed to MainLogger constructor
2832
"""
2933
self.logger = MainLogger(**kwargs)
3034
self.outputs = [getattr(livelossplot.outputs, out)() if isinstance(out, str) else out for out in outputs]
3135
for out in self.outputs:
3236
out.set_output_mode(mode)
37+
if figsize is not None and isinstance(out, MatplotlibPlot):
38+
print(f"Setting figsize to {figsize}")
39+
out.figsize = figsize
3340

3441
def update(self, *args, **kwargs):
3542
"""update logs with arguments that will be passed to main logger"""

0 commit comments

Comments
 (0)