Skip to content

Commit d1ae49b

Browse files
implement matmul
1 parent dc918fe commit d1ae49b

File tree

2 files changed

+18
-0
lines changed

2 files changed

+18
-0
lines changed

test/test_value_array.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,3 +293,10 @@ def test_unique() -> None:
293293
unit = tu.MHz
294294
v_arr = xs * unit
295295
assert np.array_equal(v_arr.unique(), np.unique(xs) * unit)
296+
297+
298+
def test_matmul() -> None:
299+
a = np.random.random((3, 4))
300+
b = np.random.random((4, 3)) * tu.ns
301+
assert (a @ b).allclose((a @ b[tu.us]) * tu.us)
302+
assert (b @ a).allclose((b[tu.s] @ a) * tu.s)

tunits/core/cython/with_unit_value_array.pyx

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,10 @@ class ValueArray(WithUnit):
116116
return self ** 2
117117
if ufunc == np.reciprocal:
118118
return self.__rtruediv__(1)
119+
if ufunc == np.matmul:
120+
if isinstance(inputs[0], ValueArray):
121+
return inputs[0].__matmul__(inputs[1])
122+
return inputs[1].__rmatmul__(inputs[0])
119123

120124
if ufunc in [
121125
np.greater,
@@ -161,3 +165,10 @@ class ValueArray(WithUnit):
161165
ret = _ndarray_to_proto(self.value, msg)
162166
ret.units.extend(_units_to_proto(self.display_units))
163167
return ret
168+
169+
def __matmul__(WithUnit self, other: np.ndarray):
170+
return self.__with_value(self.value @ other)
171+
172+
173+
def __rmatmul__(WithUnit self, other: np.ndarray):
174+
return self.__with_value(other @ self.value)

0 commit comments

Comments
 (0)