@@ -105,19 +105,17 @@ def __repr__(self):
105
105
f" '{ self .units } '\n )>"
106
106
)
107
107
108
- # # # Linear Algebra Methods ##
109
- # def __matmul__(self, other):
110
- # return mod.matmul(self, other)
108
+ ## Linear Algebra Methods ##
109
+ def __matmul__ (self , other ):
110
+ return mod .matmul (self , other )
111
111
112
- # def __imatmul__(self, other):
113
- # res = mod.matmul(self, other)
114
- # self.data[...] = res.data[...]
115
- # self.mask[...] = res.mask[...]
116
- # return
112
+ def __imatmul__ (self , other ):
113
+ res = mod .matmul (self , other )
114
+ self .magnitude [...] = res .magnitude [...]
115
+ self .units = res .units
117
116
118
- # def __rmatmul__(self, other):
119
- # other = MArray(other)
120
- # return mod.matmul(other, self)
117
+ def __rmatmul__ (self , other ):
118
+ return mod .matmul (other , self )
121
119
122
120
## Attributes ##
123
121
@@ -133,9 +131,11 @@ def mT(self):
133
131
def __dlpack_device__ (self ):
134
132
return self .magnitude .__dlpack_device__ ()
135
133
136
- def __dlpack__ (self ):
134
+ def __dlpack__ (self , ** kwargs ):
137
135
# really not sure how to define this
138
- return self .magnitude .__dlpack__ ()
136
+ return self .magnitude .__dlpack__ (** kwargs )
137
+
138
+ __dlpack__ .__signature__ = inspect .signature (xp .empty (0 ).__dlpack__ )
139
139
140
140
def to_device (self , device , / , * , stream = None ):
141
141
_magnitude = self ._magnitude .to_device (device , stream = stream )
@@ -171,34 +171,81 @@ def fun(self, name=name):
171
171
172
172
setattr (ArrayQuantity , name , fun )
173
173
174
- # # Methods that return the result of an elementwise binary operation
175
- # binary_names = ['__add__', '__sub__', '__and__', '__eq__', '__ge__', '__gt__',
176
- # '__le__', '__lshift__', '__lt__', '__mod__', '__mul__', '__ne__',
177
- # '__or__', '__pow__', '__rshift__', '__sub__', '__truediv__',
178
- # '__xor__'] + ['__divmod__', '__floordiv__']
179
- # # Methods that return the result of an elementwise binary operation (reflected)
180
- # rbinary_names = ['__radd__', '__rand__', '__rdivmod__', '__rfloordiv__',
181
- # '__rlshift__', '__rmod__', '__rmul__', '__ror__', '__rpow__',
182
- # '__rrshift__', '__rsub__', '__rtruediv__', '__rxor__']
183
- # for name in binary_names + rbinary_names:
184
- # def fun(self, other, name=name):
185
- # mask = (self.mask | other.mask) if hasattr(other, 'mask') else self.mask
186
- # data = self._call_super_method(name, other)
187
- # return ArrayUnitQuantity(data, mask)
188
- # setattr(ArrayQuantity, name, fun)
189
-
190
- # # In-place methods
191
- # desired_names = ['__iadd__', '__iand__', '__ifloordiv__', '__ilshift__',
192
- # '__imod__', '__imul__', '__ior__', '__ipow__', '__irshift__',
193
- # '__isub__', '__itruediv__', '__ixor__']
194
- # for name in desired_names:
195
- # def fun(self, other, name=name, **kwargs):
196
- # if hasattr(other, 'mask'):
197
- # # self.mask |= other.mask doesn't work because mask has no setter
198
- # self.mask.__ior__(other.mask)
199
- # self._call_super_method(name, other)
200
- # return self
201
- # setattr(ArrayQuantity, name, fun)
174
+ # Methods that return the result of an elementwise binary operation
175
+ binary_names = [
176
+ "__add__" ,
177
+ "__sub__" ,
178
+ "__and__" ,
179
+ "__eq__" ,
180
+ "__ge__" ,
181
+ "__gt__" ,
182
+ "__le__" ,
183
+ "__lshift__" ,
184
+ "__lt__" ,
185
+ "__mod__" ,
186
+ "__mul__" ,
187
+ "__ne__" ,
188
+ "__or__" ,
189
+ "__pow__" ,
190
+ "__rshift__" ,
191
+ "__sub__" ,
192
+ "__truediv__" ,
193
+ "__xor__" ,
194
+ "__divmod__" ,
195
+ "__floordiv__" ,
196
+ ]
197
+ # Methods that return the result of an elementwise binary operation (reflected)
198
+ rbinary_names = [
199
+ "__radd__" ,
200
+ "__rand__" ,
201
+ "__rdivmod__" ,
202
+ "__rfloordiv__" ,
203
+ "__rlshift__" ,
204
+ "__rmod__" ,
205
+ "__rmul__" ,
206
+ "__ror__" ,
207
+ "__rpow__" ,
208
+ "__rrshift__" ,
209
+ "__rsub__" ,
210
+ "__rtruediv__" ,
211
+ "__rxor__" ,
212
+ ]
213
+ for name in binary_names + rbinary_names :
214
+
215
+ def method (self , other , name = name ):
216
+ units = self .units
217
+ magnitude_other = other .m_as (units ) if hasattr (other , "units" ) else other
218
+ magnitude = self ._call_super_method (name , magnitude_other )
219
+ # FIXME: correct units for op
220
+ return ArrayUnitQuantity (magnitude , units )
221
+
222
+ setattr (ArrayQuantity , name , method )
223
+
224
+ # In-place methods
225
+ desired_names = [
226
+ "__iadd__" ,
227
+ "__iand__" ,
228
+ "__ifloordiv__" ,
229
+ "__ilshift__" ,
230
+ "__imod__" ,
231
+ "__imul__" ,
232
+ "__ior__" ,
233
+ "__ipow__" ,
234
+ "__irshift__" ,
235
+ "__isub__" ,
236
+ "__itruediv__" ,
237
+ "__ixor__" ,
238
+ ]
239
+ for name in desired_names :
240
+
241
+ def method (self , other , name = name ):
242
+ units = self .units
243
+ magnitude_other = other .m_as (units ) if hasattr (other , "units" ) else other
244
+ magnitude = self ._call_super_method (name , magnitude_other )
245
+ # FIXME: correct units for op
246
+ return ArrayUnitQuantity (magnitude , units )
247
+
248
+ setattr (ArrayQuantity , name , method )
202
249
203
250
## Constants ##
204
251
constant_names = ["e" , "inf" , "nan" , "newaxis" , "pi" ]
0 commit comments