Skip to content

Commit c9318ba

Browse files
authored
Merge pull request #95 from scipp/scale-critical-edge
feat: scale critical edge to 1
2 parents b4a9968 + 8c05477 commit c9318ba

File tree

3 files changed

+40
-20
lines changed

3 files changed

+40
-20
lines changed

docs/user-guide/amor/amor-reduction.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@
172172
"from ess.reflectometry.tools import scale_reflectivity_curves_to_overlap\n",
173173
"results_scaled = dict(zip(\n",
174174
" results.keys(),\n",
175-
" scale_reflectivity_curves_to_overlap(results.values()),\n",
175+
" scale_reflectivity_curves_to_overlap(results.values())[0],\n",
176176
" strict=True\n",
177177
"))\n",
178178
"sc.plot(results_scaled, norm='log', vmin=1e-5)"

src/ess/reflectometry/tools.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -170,28 +170,43 @@ def _interpolate_on_qgrid(curves, grid):
170170

171171
def scale_reflectivity_curves_to_overlap(
172172
curves: Sequence[sc.DataArray],
173-
return_scaling_factors=False,
174-
) -> list[sc.DataArray] | list[sc.scalar]:
173+
critical_edge_interval: tuple[sc.Variable, sc.Variable] | None = None,
174+
) -> tuple[list[sc.DataArray], list[sc.Variable]]:
175175
'''Make the curves overlap by scaling all except the first by a factor.
176176
The scaling factors are determined by a maximum likelihood estimate
177177
(assuming the errors are normal distributed).
178178
179+
If :code:`critical_edge_interval` is provided then all curves are scaled.
180+
179181
All curves must be have the same unit for data and the Q-coordinate.
180182
181183
Parameters
182184
---------
183185
curves:
184186
the reflectivity curves that should be scaled together
185-
return_scaling_factor:
186-
If True the return value of the function
187-
is a list of the scaling factors that should be applied.
188-
If False (default) the function returns the scaled curves.
187+
critical_edge_interval:
188+
a tuple denoting an interval that is known to belong
189+
to the critical edge, i.e. where the reflectivity is
190+
known to be 1.
189191
190192
Returns
191193
---------
192194
:
193-
A list of scaled reflectivity curves or a list of scaling factors.
195+
A list of scaled reflectivity curves and a list of the scaling factors.
194196
'''
197+
if critical_edge_interval is not None:
198+
q = next(iter(curves)).coords['Q']
199+
N = (
200+
((q >= critical_edge_interval[0]) & (q < critical_edge_interval[1]))
201+
.sum()
202+
.value
203+
)
204+
edge = sc.DataArray(
205+
data=sc.ones(dims=('Q',), shape=(N,), with_variances=True),
206+
coords={'Q': sc.linspace('Q', *critical_edge_interval, N + 1)},
207+
)
208+
curves, factors = scale_reflectivity_curves_to_overlap([edge, *curves])
209+
return curves[1:], factors[1:]
195210
if len({c.data.unit for c in curves}) != 1:
196211
raise ValueError('The reflectivity curves must have the same unit')
197212
if len({c.coords['Q'].unit for c in curves}) != 1:
@@ -214,13 +229,11 @@ def cost(scaling_factors):
214229
return np.nansum((r_scaled - r_avg) ** 2 * inv_v_scaled)
215230

216231
sol = opt.minimize(cost, [1.0] * (len(curves) - 1))
217-
scaling_factors = (1.0, *sol.x)
218-
if return_scaling_factors:
219-
return [sc.scalar(x) for x in scaling_factors]
232+
scaling_factors = (1.0, *map(float, sol.x))
220233
return [
221234
scaling_factor * curve
222235
for scaling_factor, curve in zip(scaling_factors, curves, strict=True)
223-
]
236+
], scaling_factors
224237

225238

226239
def combine_curves(

tests/tools_test.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: BSD-3-Clause
22
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
33
import scipp as sc
4+
from numpy.testing import assert_allclose as np_assert_allclose
45
from scipp.testing import assert_allclose
56

67
from ess.reflectometry.tools import combine_curves, scale_reflectivity_curves_to_overlap
@@ -20,16 +21,17 @@ def test_reflectivity_curve_scaling():
2021
)
2122
data.variances[:] = 0.1
2223

23-
curves = scale_reflectivity_curves_to_overlap(
24+
curves, factors = scale_reflectivity_curves_to_overlap(
2425
(curve(data, 0, 0.3), curve(0.8 * data, 0.2, 0.7), curve(0.1 * data, 0.6, 1.0)),
2526
)
2627

2728
assert_allclose(curves[0].data, data, rtol=sc.scalar(1e-5))
2829
assert_allclose(curves[1].data, 0.5 * data, rtol=sc.scalar(1e-5))
2930
assert_allclose(curves[2].data, 0.25 * data, rtol=sc.scalar(1e-5))
31+
np_assert_allclose((1, 0.5 / 0.8, 0.25 / 0.1), factors, 1e-4)
3032

3133

32-
def test_reflectivity_curve_scaling_return_factors():
34+
def test_reflectivity_curve_scaling_with_critical_edge():
3335
data = sc.concat(
3436
(
3537
sc.ones(dims=['Q'], shape=[10], with_variances=True),
@@ -39,14 +41,19 @@ def test_reflectivity_curve_scaling_return_factors():
3941
)
4042
data.variances[:] = 0.1
4143

42-
factors = scale_reflectivity_curves_to_overlap(
43-
(curve(data, 0, 0.3), curve(0.8 * data, 0.2, 0.7), curve(0.1 * data, 0.6, 1.0)),
44-
return_scaling_factors=True,
44+
curves, factors = scale_reflectivity_curves_to_overlap(
45+
(
46+
2 * curve(data, 0, 0.3),
47+
curve(0.8 * data, 0.2, 0.7),
48+
curve(0.1 * data, 0.6, 1.0),
49+
),
50+
critical_edge_interval=(sc.scalar(0.01), sc.scalar(0.05)),
4551
)
4652

47-
assert_allclose(factors[0], sc.scalar(1.0), rtol=sc.scalar(1e-5))
48-
assert_allclose(factors[1], sc.scalar(0.5 / 0.8), rtol=sc.scalar(1e-5))
49-
assert_allclose(factors[2], sc.scalar(0.25 / 0.1), rtol=sc.scalar(1e-5))
53+
assert_allclose(curves[0].data, data, rtol=sc.scalar(1e-5))
54+
assert_allclose(curves[1].data, 0.5 * data, rtol=sc.scalar(1e-5))
55+
assert_allclose(curves[2].data, 0.25 * data, rtol=sc.scalar(1e-5))
56+
np_assert_allclose((0.5, 0.5 / 0.8, 0.25 / 0.1), factors, 1e-4)
5057

5158

5259
def test_combined_curves():

0 commit comments

Comments
 (0)