@@ -97,6 +97,9 @@ def _call_super_method(self, method_name, *args, **kwargs):
97
97
# self.mask[key] = getattr(other, 'mask', False)
98
98
# return self.data.__setitem__(key, getattr(other, 'data', other))
99
99
100
+ def __iter__ (self ):
101
+ return iter (self .magnitude )
102
+
100
103
## Visualization ##
101
104
def __repr__ (self ):
102
105
return (
@@ -293,11 +296,15 @@ def get_manip_fun(func_str):
293
296
def manip_fun (x , * args , ** kwargs ):
294
297
xp_func = getattr (xp , func_str )
295
298
299
+ one_array = False
296
300
if func_str not in first_arg_arrays :
297
301
x = asarray (x )
298
302
magnitude = xp .asarray (x .magnitude , copy = True )
299
303
units = x .units
300
-
304
+ elif hasattr (x , "__array_namespace__" ):
305
+ magnitude = x
306
+ units = None
307
+ one_array = True
301
308
else :
302
309
x = [asarray (x_i ) for x_i in x ]
303
310
if len (x ) == 0 :
@@ -314,7 +321,7 @@ def manip_fun(x, *args, **kwargs):
314
321
):
315
322
args [0 ] = repeats .magnitude
316
323
317
- if func_str in arbitrary_num_arrays :
324
+ if func_str in arbitrary_num_arrays and not one_array :
318
325
magnitude = xp_func (* magnitude , * args , ** kwargs )
319
326
else :
320
327
magnitude = xp_func (magnitude , * args , ** kwargs )
0 commit comments