Skip to content

Commit 5f4f9c5

Browse files
committed
cpu and cuda methods
1 parent 3dbd228 commit 5f4f9c5

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

torch_sparse/tensor.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,15 @@ def to(self, *args: Optional[List[Any]],
455455
return self
456456

457457

458+
def cpu(self) -> SparseTensor:
459+
return self.device_as(torch.tensor(0., device='cpu'))
460+
461+
462+
def cuda(self, device: Optional[Union[int, str]] = None,
463+
non_blocking: bool = False):
464+
return self.device_as(torch.tensor(0., device=device or 'cuda'))
465+
466+
458467
def __getitem__(self: SparseTensor, index: Any) -> SparseTensor:
459468
index = list(index) if isinstance(index, tuple) else [index]
460469
# More than one `Ellipsis` is not allowed...
@@ -523,6 +532,8 @@ def __repr__(self: SparseTensor) -> str:
523532
SparseTensor.share_memory_ = share_memory_
524533
SparseTensor.is_shared = is_shared
525534
SparseTensor.to = to
535+
SparseTensor.cpu = cpu
536+
SparseTensor.cuda = cuda
526537
SparseTensor.__getitem__ = __getitem__
527538
SparseTensor.__repr__ = __repr__
528539

0 commit comments

Comments
 (0)