Skip to content

Commit e645f2b

Browse files
committed
fix: use scipp nansum
1 parent 5ccad84 commit e645f2b

File tree

1 file changed

+11
-14
lines changed

1 file changed

+11
-14
lines changed

src/ess/reflectometry/tools.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -254,19 +254,19 @@ def combine_curves(
254254
if len({c.coords['Q'].unit for c in curves}) != 1:
255255
raise ValueError('The Q-coordinates must have the same unit for each curve')
256256

257-
r = _interpolate_on_qgrid(map(sc.values, curves), qgrid).values
258-
v = _interpolate_on_qgrid(map(sc.variances, curves), qgrid).values
257+
r = _interpolate_on_qgrid(map(sc.values, curves), qgrid)
258+
v = _interpolate_on_qgrid(map(sc.variances, curves), qgrid)
259259

260-
v[v == 0] = np.nan
260+
v = sc.where(v == 0, sc.scalar(np.nan, unit=v.unit), v)
261261
inv_v = 1.0 / v
262-
r_avg = np.nansum(r * inv_v, axis=0) / np.nansum(inv_v, axis=0)
263-
v_avg = 1 / np.nansum(inv_v, axis=0)
262+
r_avg = sc.nansum(r * inv_v, dim='curves') / sc.nansum(inv_v, dim='curves')
263+
v_avg = 1 / sc.nansum(inv_v, dim='curves')
264264

265265
out = sc.DataArray(
266266
data=sc.array(
267267
dims='Q',
268-
values=r_avg,
269-
variances=v_avg,
268+
values=r_avg.values,
269+
variances=v_avg.values,
270270
unit=next(iter(curves)).data.unit,
271271
),
272272
coords={'Q': qgrid},
@@ -279,17 +279,14 @@ def combine_curves(
279279
q_res = (
280280
sc.DataArray(
281281
data=c.coords.get(
282-
'Q_resolution', sc.scalar(float('nan')) * sc.values(c.data.copy())
282+
'Q_resolution', sc.full_like(c.coords['Q'], value=np.nan)
283283
),
284284
coords={'Q': c.coords['Q']},
285285
)
286286
for c in curves
287287
)
288-
qs = _interpolate_on_qgrid(q_res, qgrid).values
289-
qs_avg = np.nansum(qs * inv_v, axis=0) / np.nansum(
290-
~np.isnan(qs) * inv_v, axis=0
291-
)
292-
out.coords['Q_resolution'] = sc.array(
293-
dims='Q', values=qs_avg, unit=next(iter(curves)).coords['Q_resolution'].unit
288+
qs = _interpolate_on_qgrid(q_res, qgrid)
289+
out.coords['Q_resolution'] = sc.nansum(qs * inv_v, dim='curves') / sc.nansum(
290+
sc.where(sc.isnan(qs), sc.scalar(0.0, unit=inv_v.unit), inv_v), dim='curves'
294291
)
295292
return out

0 commit comments

Comments
 (0)