Skip to content

Commit 3726df1

Browse files
committed
xp-tests passing!
1 parent 2aed51b commit 3726df1

File tree

1 file changed

+17
-18
lines changed

1 file changed

+17
-18
lines changed

src/pint_array/__init__.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -77,25 +77,24 @@ def _call_super_method(self, method_name, *args, **kwargs):
7777
args = [getattr(arg, "magnitude", arg) for arg in args]
7878
return method(*args, **kwargs)
7979

80+
def _validate_key(self, key):
81+
if isinstance(key, tuple):
82+
return tuple(self._validate_key(key_i) for key_i in key)
83+
if hasattr(key, "units"):
84+
key = key.magnitude
85+
return key
86+
8087
## Indexing ##
81-
# def __getitem__(self, key):
82-
# if hasattr(key, 'mask') and xp.any(key.mask):
83-
# message = ("Correct behavior for indexing with a masked array is "
84-
# "ambiguous, and no convention is supported at this time.")
85-
# raise NotImplementedError(message)
86-
# elif hasattr(key, 'mask'):
87-
# key = key.data
88-
# return MArray(self.data[key], self.mask[key])
89-
90-
# def __setitem__(self, key, other):
91-
# if hasattr(key, 'mask') and xp.any(key.mask):
92-
# message = ("Correct behavior for indexing with a masked array is "
93-
# "ambiguous, and no convention is supported at this time.")
94-
# raise NotImplementedError(message)
95-
# elif hasattr(key, 'mask'):
96-
# key = key.data
97-
# self.mask[key] = getattr(other, 'mask', False)
98-
# return self.data.__setitem__(key, getattr(other, 'data', other))
88+
def __getitem__(self, key):
89+
key = self._validate_key(key)
90+
return ArrayUnitQuantity(self.magnitude[key], self.units)
91+
92+
def __setitem__(self, key, other):
93+
key = self._validate_key(key)
94+
magnitude_other = (
95+
other.m_as(self.units) if hasattr(other, "units") else other
96+
)
97+
return self.magnitude.__setitem__(key, magnitude_other)
9998

10099
def __iter__(self):
101100
return iter(self.magnitude)

0 commit comments

Comments
 (0)