diff --git a/radiospectra/spectrum.py b/radiospectra/spectrum.py index c72fe4f..b8a33b4 100644 --- a/radiospectra/spectrum.py +++ b/radiospectra/spectrum.py @@ -26,6 +26,27 @@ class Spectrum(np.ndarray): def __new__(cls, data, *args, **kwargs): return np.asarray(data).view(cls) + def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): + new_inputs = [] + for x in inputs: + if isinstance(x, Spectrum): + new_inputs.append(x.view(np.ndarray)) + else: + new_inputs.append(x) + + result = getattr(ufunc, method)(*new_inputs, **kwargs) + + if isinstance(result, tuple): + return tuple(self._wrap_result(r) for r in result) + + return self._wrap_result(result) + + def _wrap_result(self, result): + if isinstance(result, np.ndarray): + result = result.view(Spectrum) + if hasattr(self, "freq_axis"): + result.freq_axis = self.freq_axis.copy() + return result def __init__(self, data, freq_axis): if np.shape(data)[0] != np.shape(freq_axis)[0]: diff --git a/radiospectra/tests/test_spectrum.py b/radiospectra/tests/test_spectrum.py index cc3ab40..358dfef 100644 --- a/radiospectra/tests/test_spectrum.py +++ b/radiospectra/tests/test_spectrum.py @@ -7,3 +7,27 @@ def test_spectrum(): spec = Spectrum(np.arange(10), np.arange(10)) np.testing.assert_equal(spec.data, np.arange(10)) np.testing.assert_equal(spec.freq_axis, np.arange(10)) +def test_freq_axis_preserved_after_ufunc(): + data = np.arange(10) + freq = np.linspace(100, 200, 10) + + spec = Spectrum(data, freq) + + new = np.sqrt(spec) + + assert isinstance(new, Spectrum) + assert hasattr(new, "freq_axis") + np.testing.assert_allclose(new.freq_axis, spec.freq_axis) + + +def test_freq_axis_preserved_after_binary_operation(): + data = np.arange(10) + freq = np.linspace(100, 200, 10) + + spec = Spectrum(data, freq) + + new = spec + 5 + + assert isinstance(new, Spectrum) + assert hasattr(new, "freq_axis") + np.testing.assert_allclose(new.freq_axis, spec.freq_axis)