@@ -231,50 +231,54 @@ def fun(*args, func_str=func_str, units=None, **kwargs):
231
231
setattr (mod , func_str , fun )
232
232
233
233
## Manipulation Functions ##
234
- # first_arg_arrays = {'broadcast_arrays', 'concat', 'stack', 'meshgrid'}
235
- # output_arrays = {'broadcast_arrays', 'unstack', 'meshgrid'}
236
-
237
- # def get_manip_fun(name):
238
- # def manip_fun(x, *args, **kwargs):
239
- # x = (asarray(x) if name not in first_arg_arrays
240
- # else [asarray(xi) for xi in x])
241
- # mask = (x.mask if name not in first_arg_arrays
242
- # else [xi.mask for xi in x])
243
- # data = (x.data if name not in first_arg_arrays
244
- # else [xi.data for xi in x])
245
-
246
- # fun = getattr(xp, name)
247
-
248
- # if name == "repeat":
249
- # args = list(args)
250
- # repeats = args[0]
251
- # if hasattr(repeats, 'mask') and xp.any(repeats.mask):
252
- # message = (
253
- # "Correct behavior when `repeats` is a masked array is "
254
- # "ambiguous, and no convention is supported at this time.")
255
- # raise NotImplementedError(message)
256
- # elif hasattr(repeats, 'mask'):
257
- # repeats = repeats.data
258
- # args[0] = repeats
259
-
260
- # if name in {'broadcast_arrays', 'meshgrid'}:
261
- # res = fun(*data, *args, **kwargs)
262
- # mask = fun(*mask, *args, **kwargs)
263
- # else:
264
- # res = fun(data, *args, **kwargs)
265
- # mask = fun(mask, *args, **kwargs)
266
-
267
- # out = (MArray(res, mask) if name not in output_arrays
268
- # else tuple(MArray(resi, maski) for resi, maski in zip(res, mask)))
269
- # return out
270
- # return manip_fun
271
-
272
- # creation_manip_functions = ['tril', 'triu', 'meshgrid']
273
- # manip_names = ['broadcast_arrays', 'broadcast_to', 'concat', 'expand_dims',
274
- # 'flip', 'moveaxis', 'permute_dims', 'repeat', 'reshape',
275
- # 'roll', 'squeeze', 'stack', 'tile', 'unstack']
276
- # for name in manip_names + creation_manip_functions:
277
- # setattr(mod, name, get_manip_fun(name))
234
+ first_arg_arrays = {"broadcast_arrays" , "concat" , "stack" , "meshgrid" }
235
+ output_arrays = {"broadcast_arrays" , "unstack" , "meshgrid" }
236
+
237
+ def get_manip_fun (func_str ):
238
+ def manip_fun (x , * args , ** kwargs ):
239
+ xp_func = getattr (xp , func_str )
240
+
241
+ if func_str not in first_arg_arrays :
242
+ x = asarray (x )
243
+ magnitude = xp .asarray (x .magnitude , copy = True )
244
+ units = x .units
245
+ magnitude = xp_func (magnitude , * args , ** kwargs )
246
+
247
+ else :
248
+ x = [asarray (x_i ) for x_i in x ]
249
+ units = x [0 ].units
250
+ magnitude = [xp .asarray (x [0 ].magnitude , copy = True )]
251
+ for x_i in x [1 :]:
252
+ magnitude .append (x_i .m_as (units ))
253
+ magnitude = xp_func (* magnitude , * args , ** kwargs )
254
+
255
+ if name in output_arrays :
256
+ return tuple (
257
+ ArrayUnitQuantity (magnitude_i , units ) for magnitude_i in magnitude
258
+ )
259
+ return ArrayUnitQuantity (magnitude , units )
260
+
261
+ return manip_fun
262
+
263
+ creation_manip_functions = ["tril" , "triu" , "meshgrid" ]
264
+ manip_names = [
265
+ "broadcast_arrays" ,
266
+ "broadcast_to" ,
267
+ "concat" ,
268
+ "expand_dims" ,
269
+ "flip" ,
270
+ "moveaxis" ,
271
+ "permute_dims" ,
272
+ "repeat" ,
273
+ "reshape" ,
274
+ "roll" ,
275
+ "squeeze" ,
276
+ "stack" ,
277
+ "tile" ,
278
+ "unstack" ,
279
+ ]
280
+ for name in manip_names + creation_manip_functions :
281
+ setattr (mod , name , get_manip_fun (name ))
278
282
279
283
## Data Type Functions and Data Types ##
280
284
dtype_fun_names = ["can_cast" , "finfo" , "iinfo" , "isdtype" ]
0 commit comments