@@ -254,19 +254,19 @@ def combine_curves(
254254 if len ({c .coords ['Q' ].unit for c in curves }) != 1 :
255255 raise ValueError ('The Q-coordinates must have the same unit for each curve' )
256256
257- r = _interpolate_on_qgrid (map (sc .values , curves ), qgrid ). values
258- v = _interpolate_on_qgrid (map (sc .variances , curves ), qgrid ). values
257+ r = _interpolate_on_qgrid (map (sc .values , curves ), qgrid )
258+ v = _interpolate_on_qgrid (map (sc .variances , curves ), qgrid )
259259
260- v [ v == 0 ] = np .nan
260+ v = sc . where ( v == 0 , sc . scalar ( np .nan , unit = v . unit ), v )
261261 inv_v = 1.0 / v
262- r_avg = np .nansum (r * inv_v , axis = 0 ) / np .nansum (inv_v , axis = 0 )
263- v_avg = 1 / np .nansum (inv_v , axis = 0 )
262+ r_avg = sc .nansum (r * inv_v , dim = 'curves' ) / sc .nansum (inv_v , dim = 'curves' )
263+ v_avg = 1 / sc .nansum (inv_v , dim = 'curves' )
264264
265265 out = sc .DataArray (
266266 data = sc .array (
267267 dims = 'Q' ,
268- values = r_avg ,
269- variances = v_avg ,
268+ values = r_avg . values ,
269+ variances = v_avg . values ,
270270 unit = next (iter (curves )).data .unit ,
271271 ),
272272 coords = {'Q' : qgrid },
@@ -279,17 +279,14 @@ def combine_curves(
279279 q_res = (
280280 sc .DataArray (
281281 data = c .coords .get (
282- 'Q_resolution' , sc .scalar ( float ( 'nan' )) * sc . values ( c . data . copy () )
282+ 'Q_resolution' , sc .full_like ( c . coords [ 'Q' ], value = np . nan )
283283 ),
284284 coords = {'Q' : c .coords ['Q' ]},
285285 )
286286 for c in curves
287287 )
288- qs = _interpolate_on_qgrid (q_res , qgrid ).values
289- qs_avg = np .nansum (qs * inv_v , axis = 0 ) / np .nansum (
290- ~ np .isnan (qs ) * inv_v , axis = 0
291- )
292- out .coords ['Q_resolution' ] = sc .array (
293- dims = 'Q' , values = qs_avg , unit = next (iter (curves )).coords ['Q_resolution' ].unit
288+ qs = _interpolate_on_qgrid (q_res , qgrid )
289+ out .coords ['Q_resolution' ] = sc .nansum (qs * inv_v , dim = 'curves' ) / sc .nansum (
290+ sc .where (sc .isnan (qs ), sc .scalar (0.0 , unit = inv_v .unit ), inv_v ), dim = 'curves'
294291 )
295292 return out
0 commit comments