Skip to content

Commit 46dac04

Browse files
committed
get diag
1 parent 105a60b commit 46dac04

File tree

3 files changed

+34
-3
lines changed

3 files changed

+34
-3
lines changed

test/test_diag.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,15 @@ def test_fill_diag(dtype, device):
5252

5353
mat = mat.fill_diag(-8, k=-1)
5454
mat = mat.fill_diag(-8, k=1)
55+
56+
57+
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
58+
def test_get_diag(dtype, device):
59+
row, col = tensor([[0, 0, 1, 2], [0, 1, 2, 2]], torch.long, device)
60+
value = tensor([[1, 1], [2, 2], [3, 3], [4, 4]], dtype, device)
61+
mat = SparseTensor(row=row, col=col, value=value)
62+
assert mat.get_diag().tolist() == [[1, 1], [0, 0], [4, 4]]
63+
64+
row, col = tensor([[0, 0, 1, 2], [0, 1, 2, 2]], torch.long, device)
65+
mat = SparseTensor(row=row, col=col)
66+
assert mat.get_diag().tolist() == [1, 0, 1]

torch_sparse/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from .index_select import index_select, index_select_nnz # noqa
4040
from .masked_select import masked_select, masked_select_nnz # noqa
4141
from .permute import permute # noqa
42-
from .diag import remove_diag, set_diag, fill_diag # noqa
42+
from .diag import remove_diag, set_diag, fill_diag, get_diag # noqa
4343
from .add import add, add_, add_nnz, add_nnz_ # noqa
4444
from .mul import mul, mul_, mul_nnz, mul_nnz_ # noqa
4545
from .reduce import sum, mean, min, max # noqa
@@ -75,6 +75,7 @@
7575
'remove_diag',
7676
'set_diag',
7777
'fill_diag',
78+
'get_diag',
7879
'add',
7980
'add_',
8081
'add_nnz',

torch_sparse/diag.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Optional
22

33
import torch
4+
from torch import Tensor
45
from torch_sparse.storage import SparseStorage
56
from torch_sparse.tensor import SparseTensor
67

@@ -31,7 +32,7 @@ def remove_diag(src: SparseTensor, k: int = 0) -> SparseTensor:
3132
return src.from_storage(storage)
3233

3334

34-
def set_diag(src: SparseTensor, values: Optional[torch.Tensor] = None,
35+
def set_diag(src: SparseTensor, values: Optional[Tensor] = None,
3536
k: int = 0) -> SparseTensor:
3637
src = remove_diag(src, k=k)
3738
row, col, value = src.coo()
@@ -51,7 +52,7 @@ def set_diag(src: SparseTensor, values: Optional[torch.Tensor] = None,
5152
new_col[mask] = col
5253
new_col[inv_mask] = diag.add_(k)
5354

54-
new_value: Optional[torch.Tensor] = None
55+
new_value: Optional[Tensor] = None
5556
if value is not None:
5657
new_value = value.new_empty((mask.size(0), ) + value.size()[1:])
5758
new_value[mask] = value
@@ -92,8 +93,25 @@ def fill_diag(src: SparseTensor, fill_value: float,
9293
return set_diag(src, None, k)
9394

9495

96+
def get_diag(src: SparseTensor) -> Tensor:
97+
row, col, value = src.coo()
98+
99+
if value is None:
100+
value = torch.ones(row.size(0))
101+
102+
sizes = list(value.size())
103+
sizes[0] = min(src.size(0), src.size(1))
104+
105+
out = value.new_zeros(sizes)
106+
107+
mask = row == col
108+
out[row[mask]] = value[mask]
109+
return out
110+
111+
95112
SparseTensor.remove_diag = lambda self, k=0: remove_diag(self, k)
96113
SparseTensor.set_diag = lambda self, values=None, k=0: set_diag(
97114
self, values, k)
98115
SparseTensor.fill_diag = lambda self, fill_value, k=0: fill_diag(
99116
self, fill_value, k)
117+
SparseTensor.get_diag = lambda self: get_diag(self)

0 commit comments

Comments
 (0)