Skip to content
This repository was archived by the owner on Feb 12, 2022. It is now read-only.

Commit d1f9eb3

Browse files
vlasenkovjekbradbury
authored andcommitted
fix_neg
1 parent d8ec789 commit d1f9eb3

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

matchbox/functional/elementwise.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ def inner(batch, *args, **kwargs):
4040
MaskedBatch.tanh = tanh = _elementwise_unary(F.tanh)
4141
MaskedBatch.sigmoid = sigmoid = _elementwise_unary(F.sigmoid)
4242

43+
MaskedBatch.__neg__ = _elementwise_unary(TENSOR_TYPE.__neg__)
44+
4345
def _elementwise_binary(fn):
4446
def inner(batch1, batch2, **kwargs):
4547
if not isinstance(batch1, MaskedBatch) and not isinstance(batch2, MaskedBatch):
@@ -55,7 +57,6 @@ def inner(batch1, batch2, **kwargs):
5557
return MaskedBatch(data, mask, dims)
5658
return inner
5759

58-
MaskedBatch.__neg__ = _elementwise_binary(TENSOR_TYPE.__neg__)
5960
MaskedBatch.__add__ = _elementwise_binary(TENSOR_TYPE.__add__)
6061
MaskedBatch.__sub__ = _elementwise_binary(TENSOR_TYPE.__sub__)
6162
MaskedBatch.__mul__ = _elementwise_binary(TENSOR_TYPE.__mul__)

0 commit comments

Comments
 (0)