Skip to content

Commit c7df6c2

Browse files
committed
fix: include resolution in stitched curve if present
1 parent 7239cfe commit c7df6c2

File tree

2 files changed

+48
-2
lines changed

2 files changed

+48
-2
lines changed

src/ess/reflectometry/tools.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def cost(scaling_factors):
229229

230230
def combine_curves(
231231
curves: Sequence[sc.DataArray],
232-
q_bin_edges: sc.Variable | None = None,
232+
q_bin_edges: sc.Variable,
233233
) -> sc.DataArray:
234234
'''Combines the given curves by interpolating them
235235
on a 1d grid defined by :code:`q_bin_edges` and averaging
@@ -268,7 +268,8 @@ def combine_curves(
268268
inv_v = 1.0 / v
269269
r_avg = np.nansum(r * inv_v, axis=0) / np.nansum(inv_v, axis=0)
270270
v_avg = 1 / np.nansum(inv_v, axis=0)
271-
return sc.DataArray(
271+
272+
out = sc.DataArray(
272273
data=sc.array(
273274
dims='Q',
274275
values=r_avg,
@@ -277,6 +278,28 @@ def combine_curves(
277278
),
278279
coords={'Q': q_bin_edges},
279280
)
281+
if any('Q_resolution' in c.coords for c in curves):
282+
# This might need to be revisited. The question about how to combine curves
283+
# with different Q-resolution is not completely resolved.
284+
# However, in practice the difference in Q-resolution between different curves
285+
# is small so it's not likely to make a big difference.
286+
q_res = (
287+
sc.DataArray(
288+
data=c.coords.get(
289+
'Q_resolution', sc.scalar(float('nan')) * sc.values(c.data.copy())
290+
),
291+
coords={'Q': c.coords['Q']},
292+
)
293+
for c in curves
294+
)
295+
qs = _interpolate_on_qgrid(q_res, q_bin_edges).values
296+
qs_avg = np.nansum(qs * inv_v, axis=0) / np.nansum(
297+
~np.isnan(qs) * inv_v, axis=0
298+
)
299+
out.coords['Q_resolution'] = sc.array(
300+
dims='Q', values=qs_avg, unit=next(iter(curves)).coords['Q_resolution'].unit
301+
)
302+
return out
280303

281304

282305
def orso_datasets_from_measurements(

tests/tools_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,29 @@ def test_combined_curves():
146146
)
147147

148148

149+
@pytest.mark.filterwarnings("ignore:invalid value encountered in divide")
150+
def test_combined_curves_resolution():
151+
qgrid = sc.linspace('Q', 0, 1, 26)
152+
data = sc.concat(
153+
(
154+
sc.ones(dims=['Q'], shape=[10], with_variances=True),
155+
0.5 * sc.ones(dims=['Q'], shape=[15], with_variances=True),
156+
),
157+
dim='Q',
158+
)
159+
data.variances[:] = 0.1
160+
curves = (
161+
curve(data, 0, 0.3),
162+
curve(0.5 * data, 0.2, 0.7),
163+
curve(0.25 * data, 0.6, 1.0),
164+
)
165+
curves[0].coords['Q_resolution'] = sc.midpoints(curves[0].coords['Q']) / 5
166+
combined = combine_curves(curves, qgrid)
167+
assert 'Q_resolution' in combined.coords
168+
assert combined.coords['Q_resolution'][0] == curves[0].coords['Q_resolution'][1]
169+
assert sc.isnan(combined.coords['Q_resolution'][-1])
170+
171+
149172
def test_linlogspace_linear():
150173
q_lin = linlogspace(
151174
dim='qz', edges=[0.008, 0.08], scale='linear', num=50, unit='1/angstrom'

0 commit comments

Comments
 (0)