|
1 | 1 | # SPDX-License-Identifier: BSD-3-Clause |
2 | 2 | # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) |
| 3 | +from collections.abc import Sequence |
| 4 | +from itertools import chain |
| 5 | + |
3 | 6 | import numpy as np |
4 | 7 | import scipp as sc |
| 8 | +import scipy.optimize as opt |
5 | 9 |
|
6 | 10 | _STD_TO_FWHM = sc.scalar(2.0) * sc.sqrt(sc.scalar(2.0) * sc.log(sc.scalar(2.0))) |
7 | 11 |
|
@@ -108,3 +112,161 @@ def linlogspace( |
108 | 112 | grids.append(mesh[dim, start:]) |
109 | 113 |
|
110 | 114 | return sc.concat(grids, dim) |
| 115 | + |
| 116 | + |
| 117 | +def _sort_by(a, by): |
| 118 | + return [x for x, _ in sorted(zip(a, by, strict=True), key=lambda x: x[1])] |
| 119 | + |
| 120 | + |
| 121 | +def _find_interval_overlaps(intervals): |
| 122 | + '''Returns the intervals where at least |
| 123 | + two or more of the provided intervals |
| 124 | + are overlapping.''' |
| 125 | + edges = list(chain.from_iterable(intervals)) |
| 126 | + is_start_edge = list(chain.from_iterable((True, False) for _ in intervals)) |
| 127 | + edges_sorted = sorted(edges) |
| 128 | + is_start_edge_sorted = _sort_by(is_start_edge, edges) |
| 129 | + |
| 130 | + number_overlapping = 0 |
| 131 | + overlap_intervals = [] |
| 132 | + for x, is_start in zip(edges_sorted, is_start_edge_sorted, strict=True): |
| 133 | + if number_overlapping == 1 and is_start: |
| 134 | + start = x |
| 135 | + if number_overlapping == 2 and not is_start: |
| 136 | + overlap_intervals.append((start, x)) |
| 137 | + if is_start: |
| 138 | + number_overlapping += 1 |
| 139 | + else: |
| 140 | + number_overlapping -= 1 |
| 141 | + return overlap_intervals |
| 142 | + |
| 143 | + |
| 144 | +def _searchsorted(a, v): |
| 145 | + for i, e in enumerate(a): |
| 146 | + if e > v: |
| 147 | + return i |
| 148 | + return len(a) |
| 149 | + |
| 150 | + |
| 151 | +def _create_qgrid_where_overlapping(qgrids): |
| 152 | + '''Given a number of Q-grids, construct a new grid |
| 153 | + covering the regions where (any two of the) provided grids overlap.''' |
| 154 | + pieces = [] |
| 155 | + for start, end in _find_interval_overlaps([(q.min(), q.max()) for q in qgrids]): |
| 156 | + interval_sliced_from_qgrids = [ |
| 157 | + q[max(_searchsorted(q, start) - 1, 0) : _searchsorted(q, end) + 1] |
| 158 | + for q in qgrids |
| 159 | + ] |
| 160 | + densest_grid_in_interval = max(interval_sliced_from_qgrids, key=len) |
| 161 | + pieces.append(densest_grid_in_interval) |
| 162 | + return sc.concat(pieces, dim='Q') |
| 163 | + |
| 164 | + |
| 165 | +def _interpolate_on_qgrid(curves, grid): |
| 166 | + return sc.concat( |
| 167 | + [sc.lookup(c, grid.dim)[sc.midpoints(grid)] for c in curves], dim='curves' |
| 168 | + ) |
| 169 | + |
| 170 | + |
| 171 | +def scale_reflectivity_curves_to_overlap( |
| 172 | + curves: Sequence[sc.DataArray], |
| 173 | + return_scaling_factors=False, |
| 174 | +) -> list[sc.DataArray] | list[sc.scalar]: |
| 175 | + '''Make the curves overlap by scaling all except the first by a factor. |
| 176 | + The scaling factors are determined by a maximum likelihood estimate |
| 177 | + (assuming the errors are normal distributed). |
| 178 | +
|
| 179 | + All curves must be have the same unit for data and the Q-coordinate. |
| 180 | +
|
| 181 | + Parameters |
| 182 | + --------- |
| 183 | + curves: |
| 184 | + the reflectivity curves that should be scaled together |
| 185 | + return_scaling_factor: |
| 186 | + If True the return value of the function |
| 187 | + is a list of the scaling factors that should be applied. |
| 188 | + If False (default) the function returns the scaled curves. |
| 189 | +
|
| 190 | + Returns |
| 191 | + --------- |
| 192 | + : |
| 193 | + A list of scaled reflectivity curves or a list of scaling factors. |
| 194 | + ''' |
| 195 | + if len({c.data.unit for c in curves}) != 1: |
| 196 | + raise ValueError('The reflectivity curves must have the same unit') |
| 197 | + if len({c.coords['Q'].unit for c in curves}) != 1: |
| 198 | + raise ValueError('The Q-coordinates must have the same unit for each curve') |
| 199 | + |
| 200 | + qgrid = _create_qgrid_where_overlapping([c.coords['Q'] for c in curves]) |
| 201 | + |
| 202 | + r = _interpolate_on_qgrid(map(sc.values, curves), qgrid).values |
| 203 | + v = _interpolate_on_qgrid(map(sc.variances, curves), qgrid).values |
| 204 | + |
| 205 | + def cost(scaling_factors): |
| 206 | + scaling_factors = np.concatenate([[1.0], scaling_factors])[:, None] |
| 207 | + r_scaled = scaling_factors * r |
| 208 | + v_scaled = scaling_factors**2 * v |
| 209 | + v_scaled[v_scaled == 0] = np.nan |
| 210 | + inv_v_scaled = 1 / v_scaled |
| 211 | + r_avg = np.nansum(r_scaled * inv_v_scaled, axis=0) / np.nansum( |
| 212 | + inv_v_scaled, axis=0 |
| 213 | + ) |
| 214 | + return np.nansum((r_scaled - r_avg) ** 2 * inv_v_scaled) |
| 215 | + |
| 216 | + sol = opt.minimize(cost, [1.0] * (len(curves) - 1)) |
| 217 | + scaling_factors = (1.0, *sol.x) |
| 218 | + if return_scaling_factors: |
| 219 | + return [sc.scalar(x) for x in scaling_factors] |
| 220 | + return [ |
| 221 | + scaling_factor * curve |
| 222 | + for scaling_factor, curve in zip(scaling_factors, curves, strict=True) |
| 223 | + ] |
| 224 | + |
| 225 | + |
| 226 | +def combine_curves( |
| 227 | + curves: Sequence[sc.DataArray], |
| 228 | + qgrid: sc.Variable | None = None, |
| 229 | +) -> sc.DataArray: |
| 230 | + '''Combines the given curves by interpolating them |
| 231 | + on a grid and merging them by the requested method. |
| 232 | + The default method is a weighted mean where the weights |
| 233 | + are proportional to the variances. |
| 234 | +
|
| 235 | + Unless the curves are already scaled correctly they might |
| 236 | + need to be scaled using :func:`scale_reflectivity_curves_to_overlap`. |
| 237 | +
|
| 238 | + All curves must be have the same unit for data and the Q-coordinate. |
| 239 | +
|
| 240 | + Parameters |
| 241 | + ---------- |
| 242 | + curves: |
| 243 | + the reflectivity curves that should be combined |
| 244 | + qgrid: |
| 245 | + the Q-grid of the resulting combined reflectivity curve |
| 246 | +
|
| 247 | + Returns |
| 248 | + --------- |
| 249 | + : |
| 250 | + A data array representing the combined reflectivity curve |
| 251 | + ''' |
| 252 | + if len({c.data.unit for c in curves}) != 1: |
| 253 | + raise ValueError('The reflectivity curves must have the same unit') |
| 254 | + if len({c.coords['Q'].unit for c in curves}) != 1: |
| 255 | + raise ValueError('The Q-coordinates must have the same unit for each curve') |
| 256 | + |
| 257 | + r = _interpolate_on_qgrid(map(sc.values, curves), qgrid).values |
| 258 | + v = _interpolate_on_qgrid(map(sc.variances, curves), qgrid).values |
| 259 | + |
| 260 | + v[v == 0] = np.nan |
| 261 | + inv_v = 1.0 / v |
| 262 | + r_avg = np.nansum(r * inv_v, axis=0) / np.nansum(inv_v, axis=0) |
| 263 | + v_avg = 1 / np.nansum(inv_v, axis=0) |
| 264 | + return sc.DataArray( |
| 265 | + data=sc.array( |
| 266 | + dims='Q', |
| 267 | + values=r_avg, |
| 268 | + variances=v_avg, |
| 269 | + unit=next(iter(curves)).data.unit, |
| 270 | + ), |
| 271 | + coords={'Q': qgrid}, |
| 272 | + ) |
0 commit comments