@@ -77,25 +77,24 @@ def _call_super_method(self, method_name, *args, **kwargs):
77
77
args = [getattr (arg , "magnitude" , arg ) for arg in args ]
78
78
return method (* args , ** kwargs )
79
79
80
+ def _validate_key (self , key ):
81
+ if isinstance (key , tuple ):
82
+ return tuple (self ._validate_key (key_i ) for key_i in key )
83
+ if hasattr (key , "units" ):
84
+ key = key .magnitude
85
+ return key
86
+
80
87
## Indexing ##
81
- # def __getitem__(self, key):
82
- # if hasattr(key, 'mask') and xp.any(key.mask):
83
- # message = ("Correct behavior for indexing with a masked array is "
84
- # "ambiguous, and no convention is supported at this time.")
85
- # raise NotImplementedError(message)
86
- # elif hasattr(key, 'mask'):
87
- # key = key.data
88
- # return MArray(self.data[key], self.mask[key])
89
-
90
- # def __setitem__(self, key, other):
91
- # if hasattr(key, 'mask') and xp.any(key.mask):
92
- # message = ("Correct behavior for indexing with a masked array is "
93
- # "ambiguous, and no convention is supported at this time.")
94
- # raise NotImplementedError(message)
95
- # elif hasattr(key, 'mask'):
96
- # key = key.data
97
- # self.mask[key] = getattr(other, 'mask', False)
98
- # return self.data.__setitem__(key, getattr(other, 'data', other))
88
+ def __getitem__ (self , key ):
89
+ key = self ._validate_key (key )
90
+ return ArrayUnitQuantity (self .magnitude [key ], self .units )
91
+
92
+ def __setitem__ (self , key , other ):
93
+ key = self ._validate_key (key )
94
+ magnitude_other = (
95
+ other .m_as (self .units ) if hasattr (other , "units" ) else other
96
+ )
97
+ return self .magnitude .__setitem__ (key , magnitude_other )
99
98
100
99
def __iter__ (self ):
101
100
return iter (self .magnitude )
0 commit comments