Skip to content

Commit a5068f4

Browse files
committed
Allow to change axes from ArrayImagePlot
1 parent 731158f commit a5068f4

File tree

3 files changed

+84
-73
lines changed

3 files changed

+84
-73
lines changed

src/silx/gui/data/DataViews.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import numpy
2929
import os
3030

31+
from silx.gui.data.NXdataWidgets import ArrayImagePlot
3132
import silx.io
3233
from silx.gui import qt, icons
3334
from silx.gui.data.TextFormatter import TextFormatter
@@ -1763,27 +1764,15 @@ def setData(self, data):
17631764

17641765
self._updateColormap(nxd)
17651766

1766-
# last two axes are Y & X
1767-
img_slicing = slice(-2, None) if not isRgba else slice(-3, -1)
1768-
y_axis, x_axis = nxd.axes[img_slicing]
1769-
y_label, x_label = nxd.axes_names[img_slicing]
1770-
y_scale, x_scale = nxd.plot_style.axes_scale_types[img_slicing]
1771-
x_units = get_attr_as_unicode(x_axis, "units") if x_axis else None
1772-
y_units = get_attr_as_unicode(y_axis, "units") if y_axis else None
1773-
1774-
self.getWidget().setImageData(
1767+
widget: ArrayImagePlot = self.getWidget()
1768+
widget.setImageData(
17751769
[nxd.signal] + nxd.auxiliary_signals,
1776-
x_axis=x_axis,
1777-
y_axis=y_axis,
1770+
axes=nxd.axes,
17781771
signals_names=[nxd.signal_name] + nxd.auxiliary_signals_names,
17791772
axes_names=nxd.axes_names,
1780-
xlabel=x_label,
1781-
ylabel=y_label,
1773+
axes_scales=nxd.plot_style.axes_scale_types,
17821774
title=nxd.title,
17831775
isRgba=isRgba,
1784-
xscale=x_scale,
1785-
yscale=y_scale,
1786-
keep_ratio=(x_units == y_units),
17871776
)
17881777

17891778
def getDataPriority(self, data, info: DataInfo):

src/silx/gui/data/NXdataWidgets.py

Lines changed: 71 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727
__date__ = "12/11/2018"
2828

2929
import logging
30+
from typing import Literal
3031
import numpy
32+
import h5py
3133

3234
from silx.gui import qt
3335
from silx.gui.data.NumpyAxesSelector import NumpyAxesSelector
@@ -38,6 +40,7 @@
3840
from silx.gui.colors import Colormap
3941
from silx.gui.data._SignalSelector import SignalSelector
4042

43+
from silx.io.nxdata._utils import get_attr_as_unicode
4144
from silx.math.calibration import ArrayCalibration, NoCalibration, LinearCalibration
4245

4346

@@ -388,10 +391,6 @@ def __init__(self, parent=None):
388391

389392
self.__signals = None
390393
self.__signals_names = None
391-
self.__x_axis = None
392-
self.__x_axis_name = None
393-
self.__y_axis = None
394-
self.__y_axis_name = None
395394

396395
self._plot = Plot2D(self)
397396
self._plot.setDefaultColormap(
@@ -404,10 +403,9 @@ def __init__(self, parent=None):
404403
maskToolWidget = self._plot.getMaskToolsDockWidget().widget()
405404
maskToolWidget.setItemMaskUpdated(True)
406405

407-
# not closable
408406
self._axesSelector = NumpyAxesSelector(self)
409-
self._axesSelector.setNamedAxesSelectorVisibility(False)
410407
self._axesSelector.selectionChanged.connect(self._updateImage)
408+
self._axesSelector.selectedAxisChanged.connect(self._updateImageAxes)
411409

412410
self._signalSelector = SignalSelector(parent=self)
413411
self._signalSelector.selectionChanged.connect(self._signalChanges)
@@ -442,7 +440,7 @@ def _aggregationModeChanged(self):
442440
)
443441

444442
def _signalChanges(self, value):
445-
self._updateImage()
443+
self._updateImageAxes()
446444

447445
def getPlot(self):
448446
"""Returns the plot used for the display
@@ -453,48 +451,33 @@ def getPlot(self):
453451

454452
def setImageData(
455453
self,
456-
signals,
457-
x_axis=None,
458-
y_axis=None,
459-
signals_names=None,
460-
axes_names=None,
461-
xlabel=None,
462-
ylabel=None,
463-
title=None,
464-
isRgba=False,
465-
xscale=None,
466-
yscale=None,
467-
keep_ratio: bool = True,
454+
signals: list[h5py.Dataset],
455+
axes: list[h5py.Dataset] | None = None,
456+
signals_names: list[str] | None = None,
457+
axes_names: list[str] | None = None,
458+
axes_scales: list[Literal["linear", "log"] | None] | None = None,
459+
title: str | None = None,
460+
isRgba: bool = False,
468461
):
469462
"""
470463
471-
:param signals: list of n-D datasets, whose last 2 dimensions are used as the
472-
image's values, or list of 3D datasets interpreted as RGBA image.
473-
:param x_axis: 1-D dataset used as the image's x coordinates. If
474-
provided, its lengths must be equal to the length of the last
475-
dimension of ``signal``.
476-
:param y_axis: 1-D dataset used as the image's y. If provided,
477-
its lengths must be equal to the length of the 2nd to last
478-
dimension of ``signal``.
464+
:param signals: list of n-D datasets or list of 3D datasets interpreted as RGBA image.
465+
:param axes: list of 1D datasets to be used as axes
479466
:param signals_names: Names for each image, used as subtitle and legend.
480-
:param xlabel: Label for X axis
481-
:param ylabel: Label for Y axis
467+
:param axes_names: Names for each axis, used as graph label.
468+
:param axes_scales: Scale of axes in (None, 'linear', 'log')
482469
:param title: Graph title
483470
:param isRgba: True if data is a 3D RGBA image
484-
:param str xscale: Scale of X axis in (None, 'linear', 'log')
485-
:param str yscale: Scale of Y axis in (None, 'linear', 'log')
486-
:param keep_ratio: Toggle plot keep aspect ratio
487471
"""
488472
self._axesSelector.selectionChanged.disconnect(self._updateImage)
473+
self._axesSelector.selectedAxisChanged.disconnect(self._updateImageAxes)
489474
self._signalSelector.selectionChanged.disconnect(self._signalChanges)
490475

491476
self.__signals = signals
492477
self.__signals_names = signals_names
493-
self.__axis_names = axes_names
494-
self.__x_axis = x_axis
495-
self.__x_axis_name = xlabel
496-
self.__y_axis = y_axis
497-
self.__y_axis_name = ylabel
478+
self.__axes = axes
479+
self.__axes_names = axes_names
480+
self.__axes_scales = axes_scales
498481
self.__title = title
499482

500483
self._axesSelector.clear()
@@ -511,8 +494,8 @@ def setImageData(
511494
else:
512495
self._axesSelector.show()
513496

514-
if self.__axis_names:
515-
self._axesSelector.setLabels(self.__axis_names)
497+
if self.__axes_names:
498+
self._axesSelector.setLabels(self.__axes_names)
516499

517500
self._signalSelector.setSignalNames(signals_names)
518501
if len(signals) > 1:
@@ -521,16 +504,14 @@ def setImageData(
521504
self._signalSelector.hide()
522505
self._signalSelector.setSignalIndex(0)
523506

524-
self._axis_scales = xscale, yscale
525-
526507
self._axesSelector.selectionChanged.connect(self._updateImage)
508+
self._axesSelector.selectedAxisChanged.connect(self._updateImageAxes)
527509
self._signalSelector.selectionChanged.connect(self._signalChanges)
528510

529-
self._updateImage()
530-
self._plot.setKeepDataAspectRatio(keep_ratio)
511+
self._updateImageAxes()
531512
self._plot.resetZoom()
532513

533-
def _updateImage(self):
514+
def _updateImageAxes(self):
534515
axes_selection = self._axesSelector.selection()
535516
signal_index = self._signalSelector.getSignalIndex()
536517

@@ -539,8 +520,21 @@ def _updateImage(self):
539520
images = [img[axes_selection] for img in self.__signals]
540521
image = images[signal_index]
541522

542-
x_axis = self.__x_axis
543-
y_axis = self.__y_axis
523+
axis_indices = self._axesSelector.getIndicesOfNamedAxes()
524+
x_axis_index = axis_indices["X"]
525+
y_axis_index = axis_indices["Y"]
526+
527+
if self.__axes:
528+
x_axis = self.__axes[x_axis_index]
529+
y_axis = self.__axes[y_axis_index]
530+
x_units = get_attr_as_unicode(x_axis, "units") if x_axis else None
531+
y_units = get_attr_as_unicode(y_axis, "units") if y_axis else None
532+
else:
533+
x_axis = None
534+
y_axis = None
535+
x_units = None
536+
y_units = None
537+
self._plot.setKeepDataAspectRatio(x_units == y_units)
544538

545539
if x_axis is None and y_axis is None:
546540
xcalib = NoCalibration()
@@ -607,7 +601,12 @@ def _updateImage(self):
607601
self._plot.addItem(imageItem)
608602
self._plot.setActiveImage(imageItem)
609603
else:
610-
xaxisscale, yaxisscale = self._axis_scales
604+
if self.__axes_scales:
605+
xaxisscale = self.__axes_scales[x_axis_index]
606+
yaxisscale = self.__axes_scales[y_axis_index]
607+
else:
608+
xaxisscale = None
609+
yaxisscale = None
611610

612611
if xaxisscale is not None:
613612
self._plot.getXAxis().setScale(
@@ -627,23 +626,38 @@ def _updateImage(self):
627626
legend=legend,
628627
)
629628

630-
if self.__title:
631-
title = self.__title
632-
if len(self.__signals_names) > 1:
633-
# Append dataset name only when there is many datasets
634-
title += "\n" + self.__signals_names[signal_index]
635-
else:
636-
title = self.__signals_names[signal_index]
637-
self._plot.setGraphTitle(title)
638-
self._plot.getXAxis().setLabel(self.__x_axis_name)
639-
self._plot.getYAxis().setLabel(self.__y_axis_name)
629+
self._plot.setGraphTitle(self._graphTitle())
630+
self._plot.getXAxis().setLabel(self.__axes_names[x_axis_index])
631+
self._plot.getYAxis().setLabel(self.__axes_names[y_axis_index])
632+
self._plot.resetZoom()
640633

641634
def clear(self):
642635
old = self._axesSelector.blockSignals(True)
643636
self._axesSelector.clear()
644637
self._axesSelector.blockSignals(old)
645638
self._plot.clear()
646639

640+
def _updateImage(self):
641+
axes_selection = self._axesSelector.selection()
642+
signal_index = self._signalSelector.getSignalIndex()
643+
images = [img[axes_selection] for img in self.__signals]
644+
image = images[signal_index]
645+
646+
self._plot.getActiveImage().setData(image)
647+
648+
def _graphTitle(self):
649+
signal_index = self._signalSelector.getSignalIndex()
650+
if not self.__title:
651+
if not self.__signals_names:
652+
return ""
653+
return self.__signals_names[signal_index]
654+
655+
title = self.__title
656+
if self.__signals_names and len(self.__signals_names) > 1:
657+
# Append dataset name only when there are many datasets
658+
title += "\n" + self.__signals_names[signal_index]
659+
return title
660+
647661

648662
class ArrayComplexImagePlot(qt.QWidget):
649663
"""

src/silx/gui/data/NumpyAxesSelector.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,3 +595,11 @@ def setNamedAxesSelectorVisibility(self, visible):
595595
self.__namedAxesVisibility = visible
596596
for axis in self.__axis:
597597
axis.setNamedAxisSelectorVisibility(visible)
598+
599+
def getIndicesOfNamedAxes(self) -> dict[str, int]:
600+
result: dict[str, int] = {}
601+
for i, axis in enumerate(self.__axis):
602+
name = axis.axisName()
603+
if name:
604+
result[name] = i
605+
return result

0 commit comments

Comments
 (0)