11from textwrap import indent
2- from typing import Optional , List , Tuple , Dict , Union , Any
2+ from typing import Any , Dict , List , Optional , Tuple , Union
33
4- import torch
54import numpy as np
65import scipy .sparse
6+ import torch
77from torch_scatter import segment_csr
88
99from torch_sparse .storage import SparseStorage , get_layout
1313class SparseTensor (object ):
1414 storage : SparseStorage
1515
16- def __init__ (self , row : Optional [torch .Tensor ] = None ,
17- rowptr : Optional [torch .Tensor ] = None ,
18- col : Optional [torch .Tensor ] = None ,
19- value : Optional [torch .Tensor ] = None ,
20- sparse_sizes : Optional [Tuple [Optional [int ],
21- Optional [int ]]] = None ,
22- is_sorted : bool = False ,
23- trust_data : bool = False ):
16+ def __init__ (
17+ self ,
18+ row : Optional [torch .Tensor ] = None ,
19+ rowptr : Optional [torch .Tensor ] = None ,
20+ col : Optional [torch .Tensor ] = None ,
21+ value : Optional [torch .Tensor ] = None ,
22+ sparse_sizes : Optional [Tuple [Optional [int ], Optional [int ]]] = None ,
23+ is_sorted : bool = False ,
24+ trust_data : bool = False ,
25+ ):
2426 self .storage = SparseStorage (
2527 row = row ,
2628 rowptr = rowptr ,
@@ -33,7 +35,8 @@ def __init__(self, row: Optional[torch.Tensor] = None,
3335 csr2csc = None ,
3436 csc2csr = None ,
3537 is_sorted = is_sorted ,
36- trust_data = trust_data )
38+ trust_data = trust_data ,
39+ )
3740
3841 @classmethod
3942 def from_storage (self , storage : SparseStorage ):
@@ -44,7 +47,8 @@ def from_storage(self, storage: SparseStorage):
4447 value = storage ._value ,
4548 sparse_sizes = storage ._sparse_sizes ,
4649 is_sorted = True ,
47- trust_data = True )
50+ trust_data = True ,
51+ )
4852 out .storage ._rowcount = storage ._rowcount
4953 out .storage ._colptr = storage ._colptr
5054 out .storage ._colcount = storage ._colcount
@@ -53,12 +57,14 @@ def from_storage(self, storage: SparseStorage):
5357 return out
5458
5559 @classmethod
56- def from_edge_index (self , edge_index : torch .Tensor ,
57- edge_attr : Optional [torch .Tensor ] = None ,
58- sparse_sizes : Optional [Tuple [Optional [int ],
59- Optional [int ]]] = None ,
60- is_sorted : bool = False ,
61- trust_data : bool = False ):
60+ def from_edge_index (
61+ self ,
62+ edge_index : torch .Tensor ,
63+ edge_attr : Optional [torch .Tensor ] = None ,
64+ sparse_sizes : Optional [Tuple [Optional [int ], Optional [int ]]] = None ,
65+ is_sorted : bool = False ,
66+ trust_data : bool = False ,
67+ ):
6268 return SparseTensor (row = edge_index [0 ], rowptr = None , col = edge_index [1 ],
6369 value = edge_attr , sparse_sizes = sparse_sizes ,
6470 is_sorted = is_sorted , trust_data = trust_data )
@@ -97,6 +103,20 @@ def from_torch_sparse_coo_tensor(self, mat: torch.Tensor,
97103 sparse_sizes = (mat .size (0 ), mat .size (1 )),
98104 is_sorted = True , trust_data = True )
99105
106+ @classmethod
107+ def from_torch_sparse_csr_tensor (self , mat : torch .Tensor ,
108+ has_value : bool = True ):
109+ rowptr = mat .crow_indices ()
110+ col = mat .col_indices ()
111+
112+ value : Optional [torch .Tensor ] = None
113+ if has_value :
114+ value = mat .values ()
115+
116+ return SparseTensor (row = None , rowptr = rowptr , col = col , value = value ,
117+ sparse_sizes = (mat .size (0 ), mat .size (1 )),
118+ is_sorted = True , trust_data = True )
119+
100120 @classmethod
101121 def eye (self , M : int , N : Optional [int ] = None , has_value : bool = True ,
102122 dtype : Optional [int ] = None , device : Optional [torch .device ] = None ,
@@ -140,7 +160,8 @@ def eye(self, M: int, N: Optional[int] = None, has_value: bool = True,
140160 value = value ,
141161 sparse_sizes = (M , N ),
142162 is_sorted = True ,
143- trust_data = True )
163+ trust_data = True ,
164+ )
144165 out .storage ._rowcount = rowcount
145166 out .storage ._colptr = colptr
146167 out .storage ._colcount = colcount
@@ -158,17 +179,17 @@ def type(self, dtype: torch.dtype, non_blocking: bool = False):
158179 value = self .storage .value ()
159180 if value is None or dtype == value .dtype :
160181 return self
161- return self .from_storage (self . storage . type (
162- dtype = dtype , non_blocking = non_blocking ))
182+ return self .from_storage (
183+ self . storage . type ( dtype = dtype , non_blocking = non_blocking ))
163184
164185 def type_as (self , tensor : torch .Tensor , non_blocking : bool = False ):
165186 return self .type (dtype = tensor .dtype , non_blocking = non_blocking )
166187
167188 def to_device (self , device : torch .device , non_blocking : bool = False ):
168189 if device == self .device ():
169190 return self
170- return self .from_storage (self . storage . to_device (
171- device = device , non_blocking = non_blocking ))
191+ return self .from_storage (
192+ self . storage . to_device ( device = device , non_blocking = non_blocking ))
172193
173194 def device_as (self , tensor : torch .Tensor , non_blocking : bool = False ):
174195 return self .to_device (device = tensor .device , non_blocking = non_blocking )
@@ -362,7 +383,8 @@ def to_symmetric(self, reduce: str = "sum"):
362383 value = value ,
363384 sparse_sizes = (N , N ),
364385 is_sorted = True ,
365- trust_data = True )
386+ trust_data = True ,
387+ )
366388 return out
367389
368390 def detach_ (self ):
@@ -479,6 +501,15 @@ def to_torch_sparse_coo_tensor(
479501
480502 return torch .sparse_coo_tensor (index , value , self .sizes ())
481503
504+ def to_torch_sparse_csr_tensor (
505+ self , dtype : Optional [int ] = None ) -> torch .Tensor :
506+ rowptr , col , value = self .csr ()
507+
508+ if value is None :
509+ value = torch .ones (self .nnz (), dtype = dtype , device = self .device ())
510+
511+ return torch .sparse_csr_tensor (rowptr , col , value , self .sizes ())
512+
482513
483514# Python Bindings #############################################################
484515
0 commit comments