Skip to content
This repository was archived by the owner on Feb 12, 2022. It is now read-only.

Commit e29f8ae

Browse files
vlasenkovjekbradbury
authored andcommitted
add_type_method
1 parent d1f9eb3 commit e29f8ae

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

matchbox/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,14 @@ def shape(self):
8383
def new(self, *sizes):
8484
return self.data.new(*sizes)
8585

86+
def type(self, dtype=None, non_blocking=False, **kwargs):
87+
if dtype:
88+
data = self.data.type(dtype, non_blocking, **kwargs)
89+
mask = self.mask.type(dtype, non_blocking, **kwargs)
90+
return self.__class__(data, mask, self.dims)
91+
else:
92+
return self.data.type()
93+
8694
def __bool__(self):
8795
if self.data.nelement() > 1:
8896
raise ValueError("bool value of MaskedBatch with more than one "

0 commit comments

Comments
 (0)