Skip to content

Commit 8346518

Browse files
committed
magic methods
1 parent b9adfc7 commit 8346518

File tree

1 file changed

+88
-41
lines changed

1 file changed

+88
-41
lines changed

src/pint_array/__init__.py

Lines changed: 88 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -105,19 +105,17 @@ def __repr__(self):
105105
f" '{self.units}'\n)>"
106106
)
107107

108-
# ## Linear Algebra Methods ##
109-
# def __matmul__(self, other):
110-
# return mod.matmul(self, other)
108+
## Linear Algebra Methods ##
109+
def __matmul__(self, other):
110+
return mod.matmul(self, other)
111111

112-
# def __imatmul__(self, other):
113-
# res = mod.matmul(self, other)
114-
# self.data[...] = res.data[...]
115-
# self.mask[...] = res.mask[...]
116-
# return
112+
def __imatmul__(self, other):
113+
res = mod.matmul(self, other)
114+
self.magnitude[...] = res.magnitude[...]
115+
self.units = res.units
117116

118-
# def __rmatmul__(self, other):
119-
# other = MArray(other)
120-
# return mod.matmul(other, self)
117+
def __rmatmul__(self, other):
118+
return mod.matmul(other, self)
121119

122120
## Attributes ##
123121

@@ -133,9 +131,11 @@ def mT(self):
133131
def __dlpack_device__(self):
134132
return self.magnitude.__dlpack_device__()
135133

136-
def __dlpack__(self):
134+
def __dlpack__(self, **kwargs):
137135
# really not sure how to define this
138-
return self.magnitude.__dlpack__()
136+
return self.magnitude.__dlpack__(**kwargs)
137+
138+
__dlpack__.__signature__ = inspect.signature(xp.empty(0).__dlpack__)
139139

140140
def to_device(self, device, /, *, stream=None):
141141
_magnitude = self._magnitude.to_device(device, stream=stream)
@@ -171,34 +171,81 @@ def fun(self, name=name):
171171

172172
setattr(ArrayQuantity, name, fun)
173173

174-
# # Methods that return the result of an elementwise binary operation
175-
# binary_names = ['__add__', '__sub__', '__and__', '__eq__', '__ge__', '__gt__',
176-
# '__le__', '__lshift__', '__lt__', '__mod__', '__mul__', '__ne__',
177-
# '__or__', '__pow__', '__rshift__', '__sub__', '__truediv__',
178-
# '__xor__'] + ['__divmod__', '__floordiv__']
179-
# # Methods that return the result of an elementwise binary operation (reflected)
180-
# rbinary_names = ['__radd__', '__rand__', '__rdivmod__', '__rfloordiv__',
181-
# '__rlshift__', '__rmod__', '__rmul__', '__ror__', '__rpow__',
182-
# '__rrshift__', '__rsub__', '__rtruediv__', '__rxor__']
183-
# for name in binary_names + rbinary_names:
184-
# def fun(self, other, name=name):
185-
# mask = (self.mask | other.mask) if hasattr(other, 'mask') else self.mask
186-
# data = self._call_super_method(name, other)
187-
# return ArrayUnitQuantity(data, mask)
188-
# setattr(ArrayQuantity, name, fun)
189-
190-
# # In-place methods
191-
# desired_names = ['__iadd__', '__iand__', '__ifloordiv__', '__ilshift__',
192-
# '__imod__', '__imul__', '__ior__', '__ipow__', '__irshift__',
193-
# '__isub__', '__itruediv__', '__ixor__']
194-
# for name in desired_names:
195-
# def fun(self, other, name=name, **kwargs):
196-
# if hasattr(other, 'mask'):
197-
# # self.mask |= other.mask doesn't work because mask has no setter
198-
# self.mask.__ior__(other.mask)
199-
# self._call_super_method(name, other)
200-
# return self
201-
# setattr(ArrayQuantity, name, fun)
174+
# Methods that return the result of an elementwise binary operation
175+
binary_names = [
176+
"__add__",
177+
"__sub__",
178+
"__and__",
179+
"__eq__",
180+
"__ge__",
181+
"__gt__",
182+
"__le__",
183+
"__lshift__",
184+
"__lt__",
185+
"__mod__",
186+
"__mul__",
187+
"__ne__",
188+
"__or__",
189+
"__pow__",
190+
"__rshift__",
191+
"__sub__",
192+
"__truediv__",
193+
"__xor__",
194+
"__divmod__",
195+
"__floordiv__",
196+
]
197+
# Methods that return the result of an elementwise binary operation (reflected)
198+
rbinary_names = [
199+
"__radd__",
200+
"__rand__",
201+
"__rdivmod__",
202+
"__rfloordiv__",
203+
"__rlshift__",
204+
"__rmod__",
205+
"__rmul__",
206+
"__ror__",
207+
"__rpow__",
208+
"__rrshift__",
209+
"__rsub__",
210+
"__rtruediv__",
211+
"__rxor__",
212+
]
213+
for name in binary_names + rbinary_names:
214+
215+
def method(self, other, name=name):
216+
units = self.units
217+
magnitude_other = other.m_as(units) if hasattr(other, "units") else other
218+
magnitude = self._call_super_method(name, magnitude_other)
219+
# FIXME: correct units for op
220+
return ArrayUnitQuantity(magnitude, units)
221+
222+
setattr(ArrayQuantity, name, method)
223+
224+
# In-place methods
225+
desired_names = [
226+
"__iadd__",
227+
"__iand__",
228+
"__ifloordiv__",
229+
"__ilshift__",
230+
"__imod__",
231+
"__imul__",
232+
"__ior__",
233+
"__ipow__",
234+
"__irshift__",
235+
"__isub__",
236+
"__itruediv__",
237+
"__ixor__",
238+
]
239+
for name in desired_names:
240+
241+
def method(self, other, name=name):
242+
units = self.units
243+
magnitude_other = other.m_as(units) if hasattr(other, "units") else other
244+
magnitude = self._call_super_method(name, magnitude_other)
245+
# FIXME: correct units for op
246+
return ArrayUnitQuantity(magnitude, units)
247+
248+
setattr(ArrayQuantity, name, method)
202249

203250
## Constants ##
204251
constant_names = ["e", "inf", "nan", "newaxis", "pi"]

0 commit comments

Comments
 (0)