|
| 1 | +# Authors: Scott Huberty <[email protected]> |
| 2 | +# |
| 3 | +# License: BSD-3-Clause |
| 4 | + |
| 5 | +import numpy as np |
| 6 | +from scipy.ndimage import gaussian_filter |
| 7 | + |
| 8 | +from ...utils import _ensure_int, _validate_type, fill_doc, logger |
| 9 | +from ..utils import plt_show |
| 10 | + |
| 11 | + |
| 12 | +@fill_doc |
| 13 | +def plot_gaze( |
| 14 | + epochs, |
| 15 | + width, |
| 16 | + height, |
| 17 | + *, |
| 18 | + sigma=25, |
| 19 | + cmap=None, |
| 20 | + alpha=1.0, |
| 21 | + vlim=(None, None), |
| 22 | + axes=None, |
| 23 | + show=True, |
| 24 | +): |
| 25 | + """Plot a heatmap of eyetracking gaze data. |
| 26 | +
|
| 27 | + Parameters |
| 28 | + ---------- |
| 29 | + epochs : instance of Epochs |
| 30 | + The :class:`~mne.Epochs` object containing eyegaze channels. |
| 31 | + width : int |
| 32 | + The width dimension of the plot canvas. For example, if the eyegaze data units |
| 33 | + are pixels, and the participant screen resolution was 1920x1080, then the width |
| 34 | + should be 1920. |
| 35 | + height : int |
| 36 | + The height dimension of the plot canvas. For example, if the eyegaze data units |
| 37 | + are pixels, and the participant screen resolution was 1920x1080, then the height |
| 38 | + should be 1080. |
| 39 | + sigma : float | None |
| 40 | + The amount of Gaussian smoothing applied to the heatmap data (standard |
| 41 | + deviation in pixels). If ``None``, no smoothing is applied. Default is 25. |
| 42 | + %(cmap)s |
| 43 | + alpha : float |
| 44 | + The opacity of the heatmap (default is 1). |
| 45 | + %(vlim_plot_topomap)s |
| 46 | + %(axes_plot_topomap)s |
| 47 | + %(show)s |
| 48 | +
|
| 49 | + Returns |
| 50 | + ------- |
| 51 | + fig : instance of Figure |
| 52 | + The resulting figure object for the heatmap plot. |
| 53 | +
|
| 54 | + Notes |
| 55 | + ----- |
| 56 | + .. versionadded:: 1.6 |
| 57 | + """ |
| 58 | + from mne import BaseEpochs |
| 59 | + from mne._fiff.pick import _picks_to_idx |
| 60 | + |
| 61 | + _validate_type(epochs, BaseEpochs, "epochs") |
| 62 | + _validate_type(alpha, "numeric", "alpha") |
| 63 | + _validate_type(sigma, ("numeric", None), "sigma") |
| 64 | + width = _ensure_int(width, "width") |
| 65 | + height = _ensure_int(height, "height") |
| 66 | + |
| 67 | + pos_picks = _picks_to_idx(epochs.info, "eyegaze") |
| 68 | + gaze_data = epochs.get_data(picks=pos_picks) |
| 69 | + gaze_ch_loc = np.array([epochs.info["chs"][idx]["loc"] for idx in pos_picks]) |
| 70 | + x_data = gaze_data[:, np.where(gaze_ch_loc[:, 4] == -1)[0], :] |
| 71 | + y_data = gaze_data[:, np.where(gaze_ch_loc[:, 4] == 1)[0], :] |
| 72 | + |
| 73 | + if x_data.shape[1] > 1: # binocular recording. Average across eyes |
| 74 | + logger.info("Detected binocular recording. Averaging positions across eyes.") |
| 75 | + x_data = np.nanmean(x_data, axis=1) # shape (n_epochs, n_samples) |
| 76 | + y_data = np.nanmean(y_data, axis=1) |
| 77 | + canvas = np.vstack((x_data.flatten(), y_data.flatten())) # shape (2, n_samples) |
| 78 | + |
| 79 | + # Create 2D histogram |
| 80 | + # Bin into image-like format |
| 81 | + hist, _, _ = np.histogram2d( |
| 82 | + canvas[1, :], |
| 83 | + canvas[0, :], |
| 84 | + bins=(height, width), |
| 85 | + range=[[0, height], [0, width]], |
| 86 | + ) |
| 87 | + # Convert density from samples to seconds |
| 88 | + hist /= epochs.info["sfreq"] |
| 89 | + # Smooth the heatmap |
| 90 | + if sigma: |
| 91 | + hist = gaussian_filter(hist, sigma=sigma) |
| 92 | + |
| 93 | + return _plot_heatmap_array( |
| 94 | + hist, |
| 95 | + width=width, |
| 96 | + height=height, |
| 97 | + cmap=cmap, |
| 98 | + alpha=alpha, |
| 99 | + vmin=vlim[0], |
| 100 | + vmax=vlim[1], |
| 101 | + axes=axes, |
| 102 | + show=show, |
| 103 | + ) |
| 104 | + |
| 105 | + |
| 106 | +def _plot_heatmap_array( |
| 107 | + data, |
| 108 | + width, |
| 109 | + height, |
| 110 | + cmap=None, |
| 111 | + alpha=None, |
| 112 | + vmin=None, |
| 113 | + vmax=None, |
| 114 | + axes=None, |
| 115 | + show=True, |
| 116 | +): |
| 117 | + """Plot a heatmap of eyetracking gaze data from a numpy array.""" |
| 118 | + import matplotlib.pyplot as plt |
| 119 | + |
| 120 | + # Prepare axes |
| 121 | + if axes is not None: |
| 122 | + from matplotlib.axes import Axes |
| 123 | + |
| 124 | + _validate_type(axes, Axes, "axes") |
| 125 | + ax = axes |
| 126 | + fig = ax.get_figure() |
| 127 | + else: |
| 128 | + fig, ax = plt.subplots(constrained_layout=True) |
| 129 | + |
| 130 | + ax.set_title("Gaze heatmap") |
| 131 | + ax.set_xlabel("X position") |
| 132 | + ax.set_ylabel("Y position") |
| 133 | + |
| 134 | + # Prepare the heatmap |
| 135 | + alphas = 1 if alpha is None else alpha |
| 136 | + vmin = np.nanmin(data) if vmin is None else vmin |
| 137 | + vmax = np.nanmax(data) if vmax is None else vmax |
| 138 | + extent = [0, width, height, 0] # origin is the top left of the screen |
| 139 | + |
| 140 | + # Plot heatmap |
| 141 | + im = ax.imshow( |
| 142 | + data, |
| 143 | + aspect="equal", |
| 144 | + cmap=cmap, |
| 145 | + alpha=alphas, |
| 146 | + extent=extent, |
| 147 | + origin="upper", |
| 148 | + vmin=vmin, |
| 149 | + vmax=vmax, |
| 150 | + ) |
| 151 | + |
| 152 | + # Prepare the colorbar |
| 153 | + fig.colorbar(im, ax=ax, shrink=0.6, label="Dwell time (seconds)") |
| 154 | + plt_show(show) |
| 155 | + return fig |
0 commit comments