Skip to content

Commit fb2ed72

Browse files
Merge pull request #350 from RawnH/pytaco_negation
Adds negation to pytaco tensor interface
2 parents 81ebb64 + e48e4ec commit fb2ed72

File tree

2 files changed

+41
-1
lines changed

2 files changed

+41
-1
lines changed

python_bindings/pytaco/pytensor/taco_tensor.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,10 @@ def __pow__(self, power, modulo=None):
374374
return tensor_pow(self, power, default_mode)
375375

376376
def __abs__(self):
377-
return tensor_abs(self, default_mode)
377+
return tensor_abs(self, self.format)
378+
379+
def __neg__(self):
380+
return tensor_neg(self, self.format)
378381

379382
def __array__(self):
380383
if not _cm.is_dense(self.format):
@@ -1482,6 +1485,39 @@ def tensor_logical_not(t1, out_format, dtype=None):
14821485
"""
14831486
return _compute_unary_elt_eise_op(_cm.logical_not, t1, out_format, dtype)
14841487

1488+
def tensor_neg(t1, out_format, dtype=None):
1489+
"""
1490+
Negates every value in the tensor.
1491+
1492+
The tensor class implements ``__neg__`` using this method.
1493+
1494+
Parameters
1495+
------------
1496+
t1: tensor, array_like
1497+
input tensor or array_like object
1498+
1499+
out_format: format, mode_format, optional
1500+
* If a :class:`format` is specified, the result tensor is stored in the format out_format.
1501+
* If a :class:`mode_format` is specified, the result the result tensor has a with all of the dimensions
1502+
stored in the :class:`mode_format` passed in.
1503+
1504+
dtype: Datatype
1505+
The datatype of the output tensor.
1506+
1507+
1508+
Examples
1509+
----------
1510+
>>> import pytaco as pt
1511+
>>> pt.tensor_neg([1, -2, 0], out_format=pt.dense).toarray()
1512+
array([-1, 2, 0], dtype=int64)
1513+
1514+
Returns
1515+
--------
1516+
neg: tensor
1517+
The element wise negation of the input tensor.
1518+
1519+
"""
1520+
return _compute_unary_elt_eise_op(_cm.neg, t1, out_format, dtype)
14851521

14861522
def tensor_abs(t1, out_format, dtype=None):
14871523
"""

python_bindings/unit_tests.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,10 @@ def test_mod(self):
251251
t1[i, j] = pt.remainder(t[i, j], 2)
252252
self.assertEqual(t1, arr % 2)
253253

254+
def test_neg(self):
255+
arr = np.arange(1, 5).reshape([2, 2])
256+
t = pt.from_array(arr)
257+
self.assertEqual(-t, -arr)
254258

255259
class testParsers(unittest.TestCase):
256260

0 commit comments

Comments
 (0)