Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/dev/13494.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix bug where :meth:`mne.channels.DigMontage.plot` would error when ``axes`` was passed by `Christian O'Reilly`_.
14 changes: 8 additions & 6 deletions mne/viz/montage.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,17 +101,19 @@ def plot_montage(
)

if scale != 1.0:
axes = axes if axes else fig.axes[0]

# scale points
collection = fig.axes[0].collections[0]
collection = axes.collections[0]
collection.set_sizes([scale * 10])

# scale labels
labels = fig.findobj(match=plt.Text)
x_label, y_label = fig.axes[0].xaxis.label, fig.axes[0].yaxis.label
z_label = fig.axes[0].zaxis.label if kind == "3d" else None
tick_labels = fig.axes[0].get_xticklabels() + fig.axes[0].get_yticklabels()
labels = axes.findobj(match=plt.Text)
x_label, y_label = axes.xaxis.label, axes.yaxis.label
z_label = axes.zaxis.label if kind == "3d" else None
tick_labels = axes.get_xticklabels() + axes.get_yticklabels()
if kind == "3d":
tick_labels += fig.axes[0].get_zticklabels()
tick_labels += axes.get_zticklabels()
for label in labels:
if label not in [x_label, y_label, z_label] + tick_labels:
label.set_fontsize(label.get_fontsize() * scale)
Expand Down
15 changes: 15 additions & 0 deletions mne/viz/tests/test_montage.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
import numpy as np
import pytest

from mne import create_info
from mne.channels import make_dig_montage, make_standard_montage, read_dig_fif
from mne.io import RawArray

p_dir = Path(__file__).parents[2] / "io" / "kit" / "tests" / "data"
elp = p_dir / "test_elp.txt"
Expand Down Expand Up @@ -86,3 +88,16 @@ def test_plot_digmontage():
)
montage.plot()
plt.close("all")


def test_plot_montage_scale():
"""Test montage.plot with non-default scale using subplot axes."""
montage = make_standard_montage("GSN-HydroCel-129")
ax = plt.subplots(2, 1)[1][1]
picks = montage.ch_names
info = create_info(montage.ch_names, sfreq=256, ch_types="eeg")
raw = RawArray(
np.zeros((len(montage.ch_names), 1)), info, copy=None, verbose=False
).set_montage(montage)
# test for gh-13438
raw.pick(picks).get_montage().plot(axes=ax, show_names=False, scale=0.1)
Loading