Skip to content

Commit b0ff709

Browse files
committed
torch_csr_tensor
1 parent fcf1565 commit b0ff709

File tree

3 files changed

+58
-27
lines changed

3 files changed

+58
-27
lines changed

test/test_matmul.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@
22

33
import pytest
44
import torch
5-
5+
import torch_scatter
66
from torch_sparse.matmul import matmul
77
from torch_sparse.tensor import SparseTensor
8-
import torch_scatter
98

10-
from .utils import reductions, devices, grad_dtypes
9+
from .utils import devices, grad_dtypes, reductions
1110

1211

1312
@pytest.mark.parametrize('dtype,device,reduce',

torch_sparse/matmul.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Tuple
22

33
import torch
4+
45
from torch_sparse.tensor import SparseTensor
56

67

torch_sparse/tensor.py

Lines changed: 55 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from textwrap import indent
2-
from typing import Optional, List, Tuple, Dict, Union, Any
2+
from typing import Any, Dict, List, Optional, Tuple, Union
33

4-
import torch
54
import numpy as np
65
import scipy.sparse
6+
import torch
77
from torch_scatter import segment_csr
88

99
from torch_sparse.storage import SparseStorage, get_layout
@@ -13,14 +13,16 @@
1313
class SparseTensor(object):
1414
storage: SparseStorage
1515

16-
def __init__(self, row: Optional[torch.Tensor] = None,
17-
rowptr: Optional[torch.Tensor] = None,
18-
col: Optional[torch.Tensor] = None,
19-
value: Optional[torch.Tensor] = None,
20-
sparse_sizes: Optional[Tuple[Optional[int],
21-
Optional[int]]] = None,
22-
is_sorted: bool = False,
23-
trust_data: bool = False):
16+
def __init__(
17+
self,
18+
row: Optional[torch.Tensor] = None,
19+
rowptr: Optional[torch.Tensor] = None,
20+
col: Optional[torch.Tensor] = None,
21+
value: Optional[torch.Tensor] = None,
22+
sparse_sizes: Optional[Tuple[Optional[int], Optional[int]]] = None,
23+
is_sorted: bool = False,
24+
trust_data: bool = False,
25+
):
2426
self.storage = SparseStorage(
2527
row=row,
2628
rowptr=rowptr,
@@ -33,7 +35,8 @@ def __init__(self, row: Optional[torch.Tensor] = None,
3335
csr2csc=None,
3436
csc2csr=None,
3537
is_sorted=is_sorted,
36-
trust_data=trust_data)
38+
trust_data=trust_data,
39+
)
3740

3841
@classmethod
3942
def from_storage(self, storage: SparseStorage):
@@ -44,7 +47,8 @@ def from_storage(self, storage: SparseStorage):
4447
value=storage._value,
4548
sparse_sizes=storage._sparse_sizes,
4649
is_sorted=True,
47-
trust_data=True)
50+
trust_data=True,
51+
)
4852
out.storage._rowcount = storage._rowcount
4953
out.storage._colptr = storage._colptr
5054
out.storage._colcount = storage._colcount
@@ -53,12 +57,14 @@ def from_storage(self, storage: SparseStorage):
5357
return out
5458

5559
@classmethod
56-
def from_edge_index(self, edge_index: torch.Tensor,
57-
edge_attr: Optional[torch.Tensor] = None,
58-
sparse_sizes: Optional[Tuple[Optional[int],
59-
Optional[int]]] = None,
60-
is_sorted: bool = False,
61-
trust_data: bool = False):
60+
def from_edge_index(
61+
self,
62+
edge_index: torch.Tensor,
63+
edge_attr: Optional[torch.Tensor] = None,
64+
sparse_sizes: Optional[Tuple[Optional[int], Optional[int]]] = None,
65+
is_sorted: bool = False,
66+
trust_data: bool = False,
67+
):
6268
return SparseTensor(row=edge_index[0], rowptr=None, col=edge_index[1],
6369
value=edge_attr, sparse_sizes=sparse_sizes,
6470
is_sorted=is_sorted, trust_data=trust_data)
@@ -97,6 +103,20 @@ def from_torch_sparse_coo_tensor(self, mat: torch.Tensor,
97103
sparse_sizes=(mat.size(0), mat.size(1)),
98104
is_sorted=True, trust_data=True)
99105

106+
@classmethod
107+
def from_torch_sparse_csr_tensor(self, mat: torch.Tensor,
108+
has_value: bool = True):
109+
rowptr = mat.crow_indices()
110+
col = mat.col_indices()
111+
112+
value: Optional[torch.Tensor] = None
113+
if has_value:
114+
value = mat.values()
115+
116+
return SparseTensor(row=None, rowptr=rowptr, col=col, value=value,
117+
sparse_sizes=(mat.size(0), mat.size(1)),
118+
is_sorted=True, trust_data=True)
119+
100120
@classmethod
101121
def eye(self, M: int, N: Optional[int] = None, has_value: bool = True,
102122
dtype: Optional[int] = None, device: Optional[torch.device] = None,
@@ -140,7 +160,8 @@ def eye(self, M: int, N: Optional[int] = None, has_value: bool = True,
140160
value=value,
141161
sparse_sizes=(M, N),
142162
is_sorted=True,
143-
trust_data=True)
163+
trust_data=True,
164+
)
144165
out.storage._rowcount = rowcount
145166
out.storage._colptr = colptr
146167
out.storage._colcount = colcount
@@ -158,17 +179,17 @@ def type(self, dtype: torch.dtype, non_blocking: bool = False):
158179
value = self.storage.value()
159180
if value is None or dtype == value.dtype:
160181
return self
161-
return self.from_storage(self.storage.type(
162-
dtype=dtype, non_blocking=non_blocking))
182+
return self.from_storage(
183+
self.storage.type(dtype=dtype, non_blocking=non_blocking))
163184

164185
def type_as(self, tensor: torch.Tensor, non_blocking: bool = False):
165186
return self.type(dtype=tensor.dtype, non_blocking=non_blocking)
166187

167188
def to_device(self, device: torch.device, non_blocking: bool = False):
168189
if device == self.device():
169190
return self
170-
return self.from_storage(self.storage.to_device(
171-
device=device, non_blocking=non_blocking))
191+
return self.from_storage(
192+
self.storage.to_device(device=device, non_blocking=non_blocking))
172193

173194
def device_as(self, tensor: torch.Tensor, non_blocking: bool = False):
174195
return self.to_device(device=tensor.device, non_blocking=non_blocking)
@@ -362,7 +383,8 @@ def to_symmetric(self, reduce: str = "sum"):
362383
value=value,
363384
sparse_sizes=(N, N),
364385
is_sorted=True,
365-
trust_data=True)
386+
trust_data=True,
387+
)
366388
return out
367389

368390
def detach_(self):
@@ -479,6 +501,15 @@ def to_torch_sparse_coo_tensor(
479501

480502
return torch.sparse_coo_tensor(index, value, self.sizes())
481503

504+
def to_torch_sparse_csr_tensor(
505+
self, dtype: Optional[int] = None) -> torch.Tensor:
506+
rowptr, col, value = self.csr()
507+
508+
if value is None:
509+
value = torch.ones(self.nnz(), dtype=dtype, device=self.device())
510+
511+
return torch.sparse_csr_tensor(rowptr, col, value, self.sizes())
512+
482513

483514
# Python Bindings #############################################################
484515

0 commit comments

Comments
 (0)