11from typing import Optional
22
33import torch
4+ from torch import Tensor
45from torch_sparse .storage import SparseStorage
56from torch_sparse .tensor import SparseTensor
67
@@ -31,7 +32,7 @@ def remove_diag(src: SparseTensor, k: int = 0) -> SparseTensor:
3132 return src .from_storage (storage )
3233
3334
34- def set_diag (src : SparseTensor , values : Optional [torch . Tensor ] = None ,
35+ def set_diag (src : SparseTensor , values : Optional [Tensor ] = None ,
3536 k : int = 0 ) -> SparseTensor :
3637 src = remove_diag (src , k = k )
3738 row , col , value = src .coo ()
@@ -51,7 +52,7 @@ def set_diag(src: SparseTensor, values: Optional[torch.Tensor] = None,
5152 new_col [mask ] = col
5253 new_col [inv_mask ] = diag .add_ (k )
5354
54- new_value : Optional [torch . Tensor ] = None
55+ new_value : Optional [Tensor ] = None
5556 if value is not None :
5657 new_value = value .new_empty ((mask .size (0 ), ) + value .size ()[1 :])
5758 new_value [mask ] = value
@@ -92,8 +93,25 @@ def fill_diag(src: SparseTensor, fill_value: float,
9293 return set_diag (src , None , k )
9394
9495
96+ def get_diag (src : SparseTensor ) -> Tensor :
97+ row , col , value = src .coo ()
98+
99+ if value is None :
100+ value = torch .ones (row .size (0 ))
101+
102+ sizes = list (value .size ())
103+ sizes [0 ] = min (src .size (0 ), src .size (1 ))
104+
105+ out = value .new_zeros (sizes )
106+
107+ mask = row == col
108+ out [row [mask ]] = value [mask ]
109+ return out
110+
111+
95112SparseTensor .remove_diag = lambda self , k = 0 : remove_diag (self , k )
96113SparseTensor .set_diag = lambda self , values = None , k = 0 : set_diag (
97114 self , values , k )
98115SparseTensor .fill_diag = lambda self , fill_value , k = 0 : fill_diag (
99116 self , fill_value , k )
117+ SparseTensor .get_diag = lambda self : get_diag (self )
0 commit comments