Skip to content

Commit 9c0419d

Browse files
committed
eye func
1 parent 8daf945 commit 9c0419d

File tree

3 files changed

+31
-0
lines changed

3 files changed

+31
-0
lines changed

test/test_eye.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from torch_sparse import eye
2+
3+
4+
def test_eye():
5+
index, value = eye(3)
6+
assert index.tolist() == [[0, 1, 2], [0, 1, 2]]
7+
assert value.tolist() == [1, 1, 1]

torch_sparse/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .coalesce import coalesce
22
from .transpose import transpose
3+
from .eye import eye
34
from .spmm import spmm
45
from .spspmm import spspmm
56

@@ -9,6 +10,7 @@
910
'__version__',
1011
'coalesce',
1112
'transpose',
13+
'eye',
1214
'spmm',
1315
'spspmm',
1416
]

torch_sparse/eye.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import torch
2+
3+
4+
def eye(m, dtype=None, device=None):
5+
"""Returns a sparse matrix with ones on the diagonal and zeros elsewhere.
6+
7+
Args:
8+
m (int): The first dimension of sparse matrix.
9+
dtype (`torch.dtype`, optional): The desired data type of returned
10+
value vector. (default is set by `torch.set_default_tensor_type()`)
11+
device (`torch.device`, optional): The desired device of returned
12+
tensors. (default is set by `torch.set_default_tensor_type()`)
13+
14+
:rtype: (:class:`LongTensor`, :class:`Tensor`)
15+
"""
16+
17+
row = torch.arange(m, dtype=torch.long, device=device)
18+
index = torch.stack([row, row], dim=0)
19+
20+
value = torch.ones(m, dtype=dtype, device=device)
21+
22+
return index, value

0 commit comments

Comments
 (0)