Skip to content

Commit 9c7e949

Browse files
authored
Merge pull request #75 from scipp/stitch-curves
feat: add basic stitching procedure
2 parents bf074db + abc1786 commit 9c7e949

File tree

6 files changed

+319
-3
lines changed

6 files changed

+319
-3
lines changed

docs/api-reference/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
orso
3333
supermirror
3434
types
35+
tools
3536
```
3637

3738
## Amor

docs/user-guide/amor/amor-reduction.ipynb

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,31 @@
163163
"sc.plot(results, norm='log', vmin=1e-4)"
164164
]
165165
},
166+
{
167+
"cell_type": "code",
168+
"execution_count": null,
169+
"metadata": {},
170+
"outputs": [],
171+
"source": [
172+
"from ess.reflectometry.tools import scale_reflectivity_curves_to_overlap\n",
173+
"results_scaled = dict(zip(\n",
174+
" results.keys(),\n",
175+
" scale_reflectivity_curves_to_overlap(results.values()),\n",
176+
" strict=True\n",
177+
"))\n",
178+
"sc.plot(results_scaled, norm='log', vmin=1e-5)"
179+
]
180+
},
181+
{
182+
"cell_type": "code",
183+
"execution_count": null,
184+
"metadata": {},
185+
"outputs": [],
186+
"source": [
187+
"from ess.reflectometry.tools import combine_curves\n",
188+
"combine_curves(results_scaled.values(), workflow.compute(QBins)).plot(norm='log')"
189+
]
190+
},
166191
{
167192
"cell_type": "markdown",
168193
"metadata": {},
@@ -354,7 +379,8 @@
354379
"mimetype": "text/x-python",
355380
"name": "python",
356381
"nbconvert_exporter": "python",
357-
"pygments_lexer": "ipython3"
382+
"pygments_lexer": "ipython3",
383+
"version": "3.10.14"
358384
}
359385
},
360386
"nbformat": 4,

requirements/base.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ jedi==0.19.1
5353
# via ipython
5454
jupyterlab-widgets==3.0.13
5555
# via ipywidgets
56-
kiwisolver==1.4.6
56+
kiwisolver==1.4.7
5757
# via matplotlib
5858
locket==1.0.0
5959
# via partd

requirements/nightly.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ jedi==0.19.1
5252
# via ipython
5353
jupyterlab-widgets==3.0.13
5454
# via ipywidgets
55-
kiwisolver==1.4.6
55+
kiwisolver==1.4.7
5656
# via matplotlib
5757
locket==1.0.0
5858
# via partd

src/ess/reflectometry/tools.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
# SPDX-License-Identifier: BSD-3-Clause
22
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3+
from collections.abc import Sequence
4+
from itertools import chain
5+
36
import numpy as np
47
import scipp as sc
8+
import scipy.optimize as opt
59

610
_STD_TO_FWHM = sc.scalar(2.0) * sc.sqrt(sc.scalar(2.0) * sc.log(sc.scalar(2.0)))
711

@@ -108,3 +112,161 @@ def linlogspace(
108112
grids.append(mesh[dim, start:])
109113

110114
return sc.concat(grids, dim)
115+
116+
117+
def _sort_by(a, by):
118+
return [x for x, _ in sorted(zip(a, by, strict=True), key=lambda x: x[1])]
119+
120+
121+
def _find_interval_overlaps(intervals):
122+
'''Returns the intervals where at least
123+
two or more of the provided intervals
124+
are overlapping.'''
125+
edges = list(chain.from_iterable(intervals))
126+
is_start_edge = list(chain.from_iterable((True, False) for _ in intervals))
127+
edges_sorted = sorted(edges)
128+
is_start_edge_sorted = _sort_by(is_start_edge, edges)
129+
130+
number_overlapping = 0
131+
overlap_intervals = []
132+
for x, is_start in zip(edges_sorted, is_start_edge_sorted, strict=True):
133+
if number_overlapping == 1 and is_start:
134+
start = x
135+
if number_overlapping == 2 and not is_start:
136+
overlap_intervals.append((start, x))
137+
if is_start:
138+
number_overlapping += 1
139+
else:
140+
number_overlapping -= 1
141+
return overlap_intervals
142+
143+
144+
def _searchsorted(a, v):
145+
for i, e in enumerate(a):
146+
if e > v:
147+
return i
148+
return len(a)
149+
150+
151+
def _create_qgrid_where_overlapping(qgrids):
152+
'''Given a number of Q-grids, construct a new grid
153+
covering the regions where (any two of the) provided grids overlap.'''
154+
pieces = []
155+
for start, end in _find_interval_overlaps([(q.min(), q.max()) for q in qgrids]):
156+
interval_sliced_from_qgrids = [
157+
q[max(_searchsorted(q, start) - 1, 0) : _searchsorted(q, end) + 1]
158+
for q in qgrids
159+
]
160+
densest_grid_in_interval = max(interval_sliced_from_qgrids, key=len)
161+
pieces.append(densest_grid_in_interval)
162+
return sc.concat(pieces, dim='Q')
163+
164+
165+
def _interpolate_on_qgrid(curves, grid):
166+
return sc.concat(
167+
[sc.lookup(c, grid.dim)[sc.midpoints(grid)] for c in curves], dim='curves'
168+
)
169+
170+
171+
def scale_reflectivity_curves_to_overlap(
172+
curves: Sequence[sc.DataArray],
173+
return_scaling_factors=False,
174+
) -> list[sc.DataArray] | list[sc.scalar]:
175+
'''Make the curves overlap by scaling all except the first by a factor.
176+
The scaling factors are determined by a maximum likelihood estimate
177+
(assuming the errors are normal distributed).
178+
179+
All curves must be have the same unit for data and the Q-coordinate.
180+
181+
Parameters
182+
---------
183+
curves:
184+
the reflectivity curves that should be scaled together
185+
return_scaling_factor:
186+
If True the return value of the function
187+
is a list of the scaling factors that should be applied.
188+
If False (default) the function returns the scaled curves.
189+
190+
Returns
191+
---------
192+
:
193+
A list of scaled reflectivity curves or a list of scaling factors.
194+
'''
195+
if len({c.data.unit for c in curves}) != 1:
196+
raise ValueError('The reflectivity curves must have the same unit')
197+
if len({c.coords['Q'].unit for c in curves}) != 1:
198+
raise ValueError('The Q-coordinates must have the same unit for each curve')
199+
200+
qgrid = _create_qgrid_where_overlapping([c.coords['Q'] for c in curves])
201+
202+
r = _interpolate_on_qgrid(map(sc.values, curves), qgrid).values
203+
v = _interpolate_on_qgrid(map(sc.variances, curves), qgrid).values
204+
205+
def cost(scaling_factors):
206+
scaling_factors = np.concatenate([[1.0], scaling_factors])[:, None]
207+
r_scaled = scaling_factors * r
208+
v_scaled = scaling_factors**2 * v
209+
v_scaled[v_scaled == 0] = np.nan
210+
inv_v_scaled = 1 / v_scaled
211+
r_avg = np.nansum(r_scaled * inv_v_scaled, axis=0) / np.nansum(
212+
inv_v_scaled, axis=0
213+
)
214+
return np.nansum((r_scaled - r_avg) ** 2 * inv_v_scaled)
215+
216+
sol = opt.minimize(cost, [1.0] * (len(curves) - 1))
217+
scaling_factors = (1.0, *sol.x)
218+
if return_scaling_factors:
219+
return [sc.scalar(x) for x in scaling_factors]
220+
return [
221+
scaling_factor * curve
222+
for scaling_factor, curve in zip(scaling_factors, curves, strict=True)
223+
]
224+
225+
226+
def combine_curves(
227+
curves: Sequence[sc.DataArray],
228+
qgrid: sc.Variable | None = None,
229+
) -> sc.DataArray:
230+
'''Combines the given curves by interpolating them
231+
on a grid and merging them by the requested method.
232+
The default method is a weighted mean where the weights
233+
are proportional to the variances.
234+
235+
Unless the curves are already scaled correctly they might
236+
need to be scaled using :func:`scale_reflectivity_curves_to_overlap`.
237+
238+
All curves must be have the same unit for data and the Q-coordinate.
239+
240+
Parameters
241+
----------
242+
curves:
243+
the reflectivity curves that should be combined
244+
qgrid:
245+
the Q-grid of the resulting combined reflectivity curve
246+
247+
Returns
248+
---------
249+
:
250+
A data array representing the combined reflectivity curve
251+
'''
252+
if len({c.data.unit for c in curves}) != 1:
253+
raise ValueError('The reflectivity curves must have the same unit')
254+
if len({c.coords['Q'].unit for c in curves}) != 1:
255+
raise ValueError('The Q-coordinates must have the same unit for each curve')
256+
257+
r = _interpolate_on_qgrid(map(sc.values, curves), qgrid).values
258+
v = _interpolate_on_qgrid(map(sc.variances, curves), qgrid).values
259+
260+
v[v == 0] = np.nan
261+
inv_v = 1.0 / v
262+
r_avg = np.nansum(r * inv_v, axis=0) / np.nansum(inv_v, axis=0)
263+
v_avg = 1 / np.nansum(inv_v, axis=0)
264+
return sc.DataArray(
265+
data=sc.array(
266+
dims='Q',
267+
values=r_avg,
268+
variances=v_avg,
269+
unit=next(iter(curves)).data.unit,
270+
),
271+
coords={'Q': qgrid},
272+
)

tests/tools_test.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# SPDX-License-Identifier: BSD-3-Clause
2+
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3+
import scipp as sc
4+
from ess.reflectometry.tools import combine_curves, scale_reflectivity_curves_to_overlap
5+
from scipp.testing import assert_allclose
6+
7+
8+
def curve(d, qmin, qmax):
9+
return sc.DataArray(data=d, coords={'Q': sc.linspace('Q', qmin, qmax, len(d) + 1)})
10+
11+
12+
def test_reflectivity_curve_scaling():
13+
data = sc.concat(
14+
(
15+
sc.ones(dims=['Q'], shape=[10], with_variances=True),
16+
0.5 * sc.ones(dims=['Q'], shape=[15], with_variances=True),
17+
),
18+
dim='Q',
19+
)
20+
data.variances[:] = 0.1
21+
22+
curves = scale_reflectivity_curves_to_overlap(
23+
(curve(data, 0, 0.3), curve(0.8 * data, 0.2, 0.7), curve(0.1 * data, 0.6, 1.0)),
24+
)
25+
26+
assert_allclose(curves[0].data, data, rtol=sc.scalar(1e-5))
27+
assert_allclose(curves[1].data, 0.5 * data, rtol=sc.scalar(1e-5))
28+
assert_allclose(curves[2].data, 0.25 * data, rtol=sc.scalar(1e-5))
29+
30+
31+
def test_reflectivity_curve_scaling_return_factors():
32+
data = sc.concat(
33+
(
34+
sc.ones(dims=['Q'], shape=[10], with_variances=True),
35+
0.5 * sc.ones(dims=['Q'], shape=[15], with_variances=True),
36+
),
37+
dim='Q',
38+
)
39+
data.variances[:] = 0.1
40+
41+
factors = scale_reflectivity_curves_to_overlap(
42+
(curve(data, 0, 0.3), curve(0.8 * data, 0.2, 0.7), curve(0.1 * data, 0.6, 1.0)),
43+
return_scaling_factors=True,
44+
)
45+
46+
assert_allclose(factors[0], sc.scalar(1.0), rtol=sc.scalar(1e-5))
47+
assert_allclose(factors[1], sc.scalar(0.5 / 0.8), rtol=sc.scalar(1e-5))
48+
assert_allclose(factors[2], sc.scalar(0.25 / 0.1), rtol=sc.scalar(1e-5))
49+
50+
51+
def test_combined_curves():
52+
qgrid = sc.linspace('Q', 0, 1, 26)
53+
data = sc.concat(
54+
(
55+
sc.ones(dims=['Q'], shape=[10], with_variances=True),
56+
0.5 * sc.ones(dims=['Q'], shape=[15], with_variances=True),
57+
),
58+
dim='Q',
59+
)
60+
data.variances[:] = 0.1
61+
curves = (
62+
curve(data, 0, 0.3),
63+
curve(0.5 * data, 0.2, 0.7),
64+
curve(0.25 * data, 0.6, 1.0),
65+
)
66+
67+
combined = combine_curves(curves, qgrid)
68+
assert_allclose(
69+
combined.data,
70+
sc.array(
71+
dims='Q',
72+
values=[
73+
1.0,
74+
1,
75+
1,
76+
0.5,
77+
0.5,
78+
0.5,
79+
0.5,
80+
0.5,
81+
0.5,
82+
0.5,
83+
0.25,
84+
0.25,
85+
0.25,
86+
0.25,
87+
0.25,
88+
0.25,
89+
0.25,
90+
0.25,
91+
0.25,
92+
0.125,
93+
0.125,
94+
0.125,
95+
0.125,
96+
0.125,
97+
0.125,
98+
],
99+
variances=[
100+
0.1,
101+
0.1,
102+
0.1,
103+
0.1,
104+
0.1,
105+
0.02,
106+
0.02,
107+
0.025,
108+
0.025,
109+
0.025,
110+
0.025,
111+
0.025,
112+
0.025,
113+
0.025,
114+
0.025,
115+
0.005,
116+
0.005,
117+
0.00625,
118+
0.00625,
119+
0.00625,
120+
0.00625,
121+
0.00625,
122+
0.00625,
123+
0.00625,
124+
0.00625,
125+
],
126+
),
127+
)

0 commit comments

Comments
 (0)