Skip to content

Commit 5ccad84

Browse files
committed
fix: include resolution in stitched curve if present
1 parent ce78cb1 commit 5ccad84

File tree

2 files changed

+48
-1
lines changed

2 files changed

+48
-1
lines changed

src/ess/reflectometry/tools.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,8 @@ def combine_curves(
261261
inv_v = 1.0 / v
262262
r_avg = np.nansum(r * inv_v, axis=0) / np.nansum(inv_v, axis=0)
263263
v_avg = 1 / np.nansum(inv_v, axis=0)
264-
return sc.DataArray(
264+
265+
out = sc.DataArray(
265266
data=sc.array(
266267
dims='Q',
267268
values=r_avg,
@@ -270,3 +271,25 @@ def combine_curves(
270271
),
271272
coords={'Q': qgrid},
272273
)
274+
if any('Q_resolution' in c.coords for c in curves):
275+
# This might need to be revisited. The question about how to combine curves
276+
# with different Q-resolution is not completely resolved.
277+
# However, in practice the difference in Q-resolution between different curves
278+
# is small so it's not likely to make a big difference.
279+
q_res = (
280+
sc.DataArray(
281+
data=c.coords.get(
282+
'Q_resolution', sc.scalar(float('nan')) * sc.values(c.data.copy())
283+
),
284+
coords={'Q': c.coords['Q']},
285+
)
286+
for c in curves
287+
)
288+
qs = _interpolate_on_qgrid(q_res, qgrid).values
289+
qs_avg = np.nansum(qs * inv_v, axis=0) / np.nansum(
290+
~np.isnan(qs) * inv_v, axis=0
291+
)
292+
out.coords['Q_resolution'] = sc.array(
293+
dims='Q', values=qs_avg, unit=next(iter(curves)).coords['Q_resolution'].unit
294+
)
295+
return out

tests/tools_test.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-License-Identifier: BSD-3-Clause
22
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3+
import pytest
34
import scipp as sc
45
from scipp.testing import assert_allclose
56

@@ -126,3 +127,26 @@ def test_combined_curves():
126127
],
127128
),
128129
)
130+
131+
132+
@pytest.mark.filterwarnings("ignore:invalid value encountered in divide")
133+
def test_combined_curves_resolution():
134+
qgrid = sc.linspace('Q', 0, 1, 26)
135+
data = sc.concat(
136+
(
137+
sc.ones(dims=['Q'], shape=[10], with_variances=True),
138+
0.5 * sc.ones(dims=['Q'], shape=[15], with_variances=True),
139+
),
140+
dim='Q',
141+
)
142+
data.variances[:] = 0.1
143+
curves = (
144+
curve(data, 0, 0.3),
145+
curve(0.5 * data, 0.2, 0.7),
146+
curve(0.25 * data, 0.6, 1.0),
147+
)
148+
curves[0].coords['Q_resolution'] = sc.midpoints(curves[0].coords['Q']) / 5
149+
combined = combine_curves(curves, qgrid)
150+
assert 'Q_resolution' in combined.coords
151+
assert combined.coords['Q_resolution'][0] == curves[0].coords['Q_resolution'][1]
152+
assert sc.isnan(combined.coords['Q_resolution'][-1])

0 commit comments

Comments
 (0)