Skip to content

Commit 02ce5d8

Browse files
authored
Merge pull request #8 from wiseodd/override-rcparams
Add defaults to `nrows` and `ncols`; add a way to override `rcParams`
2 parents 1a59953 + db32970 commit 02ce5d8

File tree

8 files changed

+143
-46
lines changed

8 files changed

+143
-46
lines changed

README.md

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,12 @@ import pub_ready_plots
5252
with pub_ready_plots.get_context(
5353
width_frac=1, # between 0 and 1
5454
height_frac=0.15, # between 0 and 1
55-
nrows=1, # depending on your subplots
56-
ncols=2, # depending on your subplots
5755
layout="icml", # or "iclr", "neurips", "poster-portrait", "poster-landscape"
5856
single_col=False, # only works for the "icml" layout
57+
nrows=1, # depending on your subplots, default = 1
58+
ncols=2, # depending on your subplots, default = 1
5959
sharey=True, # Additional keyword args for `plt.subplots`
60+
override_rc_params={"lines.linewidth": 4.123} # Overriding rcParams
6061
) as (fig, axs):
6162
# Do whatever you want with `fig` and `axs`
6263
...
@@ -74,9 +75,14 @@ Then in your LaTeX file, include the plot as follows:
7475
> [!IMPORTANT]
7576
> The argument `width=\linewidth` is **crucial**!
7677
78+
That's it! But you should use TikZ more.
79+
Anyway, see the full, runnable example in [`examples/simple_plot.py`](https://github.com/wiseodd/pub-ready-plots/blob/master/examples/simple_plot.py)
80+
7781
> [!TIP]
78-
> That's it! But you should use TikZ more.
79-
> Anyway, see the full, runnable example in [`examples/simple_plot.py`](https://github.com/wiseodd/pub-ready-plots/blob/master/examples/simple_plot.py)
82+
> I recommend using this library in conjunction with
83+
> [pypalettes]<https://github.com/JosephBARBIERDARNAL/pypalettes>
84+
> to avoid the generic blue-orange Matplotlib colors.
85+
> Distinguish your plots from others!
8086
8187
## Using your own styles
8288

examples/advanced_usage.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,37 @@
22

33
import numpy as np
44
import matplotlib.pyplot as plt
5+
from matplotlib.axes import Axes
56

7+
########################################################################################
8+
# User-specified rcParams override
9+
########################################################################################
10+
with pub_ready_plots.get_context(
11+
width_frac=1,
12+
height_frac=0.15,
13+
layout="iclr",
14+
override_rc_params={"lines.linewidth": 5}, # Pass your style overrides here!
15+
) as (fig, ax):
16+
assert isinstance(ax, Axes)
17+
18+
x = np.linspace(-1, 1, 100)
19+
20+
ax.plot(x, np.sin(x))
21+
ax.set_title("Sine")
22+
ax.set_xlabel(r"$x$")
23+
ax.set_ylabel(r"$\mathrm{sin}(x)$")
24+
25+
fig.savefig("advanced_usage_1.pdf")
26+
27+
28+
########################################################################################
29+
# Manual, most-flexible way to use this library
30+
########################################################################################
631
rc_params, fig_width_in, fig_height_in = pub_ready_plots.get_mpl_rcParams(
7-
width_frac=1, # between 0 and 1
8-
height_frac=0.15, # between 0 and 1
9-
layout="poster-portrait", # or "iclr", "neurips", "poster-portrait", "poster-landscape"
10-
single_col=False, # only works for the "icml" layout
32+
width_frac=1,
33+
height_frac=0.15,
34+
layout="poster-portrait",
35+
single_col=False,
1136
)
1237

1338
# You can update `rc_params` further before feeding it to `plt`, e.g.
@@ -32,4 +57,4 @@
3257
axs[1].set_xlabel(r"$x$")
3358
axs[1].set_ylabel(r"$\mathrm{cos}(x)$")
3459

35-
fig.savefig("advanced_usage.pdf")
60+
fig.savefig("advanced_usage_2.pdf")

examples/simple_plot.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,32 @@
11
import pub_ready_plots
2+
23
import numpy as np
4+
from matplotlib.axes import Axes
35

6+
########################################################################################
7+
# Single plot (i.e. no subplots)
8+
########################################################################################
9+
with pub_ready_plots.get_context(
10+
width_frac=1, # between 0 and 1
11+
height_frac=0.15, # between 0 and 1
12+
layout="iclr", # or "iclr", "neurips", "poster-portrait", "poster-landscape"
13+
) as (fig, ax):
14+
# Just like in `plt.subplots`, `ax` is a matplotlib Axes if
15+
# nrows & ncols are not specified (both default to 1).
16+
assert isinstance(ax, Axes)
17+
18+
x = np.linspace(-1, 1, 100)
19+
20+
ax.plot(x, np.sin(x))
21+
ax.set_title("Sine")
22+
ax.set_xlabel(r"$x$")
23+
ax.set_ylabel(r"$\mathrm{sin}(x)$")
24+
25+
fig.savefig("simple_plot.pdf")
26+
27+
########################################################################################
28+
# Multiple subplots
29+
########################################################################################
430
with pub_ready_plots.get_context(
531
width_frac=1, # between 0 and 1
632
height_frac=0.15, # between 0 and 1
@@ -10,6 +36,9 @@
1036
single_col=False, # only works for the "icml" layout
1137
sharey=True, # Additional keyword args for `plt.subplots`
1238
) as (fig, axs):
39+
# If `nrows` or `ncols` are not 1, `axs` is a NumPy array containing Axes'
40+
assert isinstance(axs, np.ndarray)
41+
1342
x = np.linspace(-1, 1, 100)
1443

1544
axs[0].plot(x, np.sin(x))
@@ -22,4 +51,4 @@
2251
axs[1].set_xlabel(r"$x$")
2352
axs[1].set_ylabel(r"$\mathrm{cos}(x)$")
2453

25-
fig.savefig("simple_plot.pdf")
54+
fig.savefig("simple_subplots.pdf")

pdm.lock

Lines changed: 22 additions & 22 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pub_ready_plots/pub_ready_plots.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,17 @@
1313
def get_context(
1414
width_frac: float,
1515
height_frac: float,
16-
nrows: int,
17-
ncols: int,
1816
layout: str = "neurips",
1917
single_col: bool = False,
18+
nrows: int = 1,
19+
ncols: int = 1,
20+
override_rc_params: dict[str, Any] = dict(),
2021
**kwargs,
2122
) -> Generator[tuple[Figure, Union[Axes, ndarray[Any, Any]]], None, None]:
2223
rc_params, fig_width_in, fig_height_in = get_mpl_rcParams(
2324
width_frac, height_frac, layout, single_col
2425
)
26+
rc_params.update(override_rc_params)
2527

2628
with plt.rc_context(rc_params):
2729
fig, axs = plt.subplots(nrows, ncols, constrained_layout=True, **kwargs)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "pub-ready-plots"
3-
version = "0.2.0"
3+
version = "0.2.1"
44
description = "Easy publication-ready matplotlib plots for ML papers and posters."
55
authors = [{ name = "Agustinus Kristiadi", email = "agustinus@kristia.de" }]
66
classifiers = [

tests/test_context.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,47 @@
1+
from matplotlib.axes import Axes
12
from pub_ready_plots.pub_ready_plots import get_context, get_mpl_rcParams
23
import numpy as np
34

45

5-
def test_correct_func():
6+
def test_single_subplot():
7+
with get_context(width_frac=0.5, height_frac=0.15, layout="iclr") as (fig, axs):
8+
real_rc_params, fig_width_in, fig_height_in = get_mpl_rcParams(
9+
width_frac=0.5, height_frac=0.15, layout="iclr"
10+
)
11+
12+
assert np.allclose(
13+
fig.get_size_inches(), (fig_width_in, fig_height_in), atol=0.001
14+
)
15+
assert isinstance(axs, Axes)
16+
17+
18+
def test_multi_subplots():
619
nrows, ncols = 3, 2
7-
with get_context(0.5, 0.15, nrows, ncols, "iclr") as (fig, axs):
20+
with get_context(
21+
width_frac=0.5, height_frac=0.15, nrows=nrows, ncols=ncols, layout="iclr"
22+
) as (fig, axs):
823
real_rc_params, fig_width_in, fig_height_in = get_mpl_rcParams(
9-
0.5, 0.15, "iclr"
24+
width_frac=0.5, height_frac=0.15, layout="iclr"
1025
)
1126

1227
assert np.allclose(
1328
fig.get_size_inches(), (fig_width_in, fig_height_in), atol=0.001
1429
)
1530
assert isinstance(axs, np.ndarray)
1631
assert axs.shape == (nrows, ncols)
32+
33+
34+
def test_override_rcparams():
35+
LINE_WIDTH: float = 8.2329232
36+
37+
with get_context(
38+
width_frac=0.5,
39+
height_frac=0.15,
40+
layout="iclr",
41+
override_rc_params={"lines.linewidth": LINE_WIDTH},
42+
) as (fig, ax):
43+
assert isinstance(ax, Axes)
44+
45+
x: np.ndarray = np.linspace(-1, 1, 100)
46+
obj = ax.plot(x, np.tanh(x))
47+
assert np.allclose(obj[0].get_linewidth(), LINE_WIDTH)

tests/test_core_func.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,32 +8,36 @@
88

99
@pytest.mark.parametrize("layout", LAYOUTS)
1010
def test_correct_func(layout):
11-
rc_params, width_in, height_in = get_mpl_rcParams(1, 0.15, layout)
11+
rc_params, width_in, height_in = get_mpl_rcParams(
12+
width_frac=1, height_frac=0.15, layout=layout
13+
)
1214
plt.rcParams.update(rc_params)
1315

1416

1517
def test_incorrect_width():
1618
with pytest.raises(ValueError):
17-
_, _, _ = get_mpl_rcParams(12, 0.1, "iclr")
19+
_, _, _ = get_mpl_rcParams(width_frac=12, height_frac=0.1, layout="iclr")
1820

1921
with pytest.raises(ValueError):
20-
_, _, _ = get_mpl_rcParams(0, 0.1, "iclr")
22+
_, _, _ = get_mpl_rcParams(width_frac=0, height_frac=0.1, layout="iclr")
2123

2224
with pytest.raises(ValueError):
23-
_, _, _ = get_mpl_rcParams(-3.2, 0.1, "iclr")
25+
_, _, _ = get_mpl_rcParams(width_frac=-3.2, height_frac=0.1, layout="iclr")
2426

2527

2628
def test_incorrect_height():
2729
with pytest.raises(ValueError):
28-
_, _, _ = get_mpl_rcParams(0.15, 12, "iclr")
30+
_, _, _ = get_mpl_rcParams(width_frac=0.15, height_frac=12, layout="iclr")
2931

3032
with pytest.raises(ValueError):
31-
_, _, _ = get_mpl_rcParams(0.15, 0, "iclr")
33+
_, _, _ = get_mpl_rcParams(width_frac=0.15, height_frac=0, layout="iclr")
3234

3335
with pytest.raises(ValueError):
34-
_, _, _ = get_mpl_rcParams(0.15, -1.2, "iclr")
36+
_, _, _ = get_mpl_rcParams(width_frac=0.15, height_frac=-1.2, layout="iclr")
3537

3638

3739
def test_incorrect_layout():
3840
with pytest.raises(ValueError):
39-
_, _, _ = get_mpl_rcParams(1, 1, "predatory_journal")
41+
_, _, _ = get_mpl_rcParams(
42+
width_frac=0.5, height_frac=0.5, layout="predatory_journal"
43+
)

0 commit comments

Comments
 (0)