1- from typing import Optional , List
1+ from typing import Optional , List , Tuple
22
33import torch
44from torch_sparse .storage import SparseStorage
55from torch_sparse .tensor import SparseTensor
66
77
8- def cat (tensors : List [SparseTensor ], dim : int ) -> SparseTensor :
8+ @torch .jit ._overload # noqa: F811
9+ def cat (tensors , dim ): # noqa: F811
10+ # type: (List[SparseTensor], int) -> SparseTensor
11+ pass
12+
13+
14+ @torch .jit ._overload # noqa: F811
15+ def cat (tensors , dim ): # noqa: F811
16+ # type: (List[SparseTensor], Tuple[int, int]) -> SparseTensor
17+ pass
18+
19+
20+ @torch .jit ._overload # noqa: F811
21+ def cat (tensors , dim ): # noqa: F811
22+ # type: (List[SparseTensor], List[int]) -> SparseTensor
23+ pass
24+
25+
26+ def cat (tensors , dim ): # noqa: F811
927 assert len (tensors ) > 0
10- if dim < 0 :
11- dim = tensors [0 ].dim () + dim
12-
13- if dim == 0 :
14- rows : List [torch .Tensor ] = []
15- rowptrs : List [torch .Tensor ] = []
16- cols : List [torch .Tensor ] = []
17- values : List [torch .Tensor ] = []
18- sparse_sizes : List [int ] = [0 , 0 ]
19- rowcounts : List [torch .Tensor ] = []
20-
21- nnz : int = 0
22- for tensor in tensors :
23- row = tensor .storage ._row
24- if row is not None :
25- rows .append (row + sparse_sizes [0 ])
26-
27- rowptr = tensor .storage ._rowptr
28- if rowptr is not None :
29- if len (rowptrs ) > 0 :
30- rowptr = rowptr [1 :]
31- rowptrs .append (rowptr + nnz )
32-
33- cols .append (tensor .storage ._col )
34-
35- value = tensor .storage ._value
36- if value is not None :
28+
29+ if isinstance (dim , int ):
30+ dim = tensors [0 ].dim () + dim if dim < 0 else dim
31+
32+ if dim == 0 :
33+ return cat_first (tensors )
34+
35+ elif dim == 1 :
36+ return cat_second (tensors )
37+ pass
38+
39+ elif dim > 1 and dim < tensors [0 ].dim ():
40+ values = []
41+ for tensor in tensors :
42+ value = tensor .storage .value ()
43+ assert value is not None
3744 values .append (value )
45+ value = torch .cat (values , dim = dim - 1 )
46+ return tensors [0 ].set_value (value , layout = 'coo' )
3847
39- rowcount = tensor .storage ._rowcount
40- if rowcount is not None :
41- rowcounts .append (rowcount )
48+ else :
49+ raise IndexError (
50+ (f'Dimension out of range: Expected to be in range of '
51+ f'[{ - tensors [0 ].dim ()} , { tensors [0 ].dim () - 1 } ], but got '
52+ f'{ dim } .' ))
53+ else :
54+ assert isinstance (dim , (tuple , list ))
55+ assert len (dim ) == 2
56+ assert sorted (dim ) == [0 , 1 ]
57+ return cat_diag (tensors )
4258
43- sparse_sizes [0 ] += tensor .sparse_size (0 )
44- sparse_sizes [1 ] = max (sparse_sizes [1 ], tensor .sparse_size (1 ))
45- nnz += tensor .nnz ()
4659
47- row : Optional [torch .Tensor ] = None
48- if len (rows ) == len (tensors ):
49- row = torch .cat (rows , dim = 0 )
60+ def cat_first (tensors : List [SparseTensor ]) -> SparseTensor :
61+ rows : List [torch .Tensor ] = []
62+ rowptrs : List [torch .Tensor ] = []
63+ cols : List [torch .Tensor ] = []
64+ values : List [torch .Tensor ] = []
65+ sparse_sizes : List [int ] = [0 , 0 ]
66+ rowcounts : List [torch .Tensor ] = []
5067
51- rowptr : Optional [torch .Tensor ] = None
52- if len (rowptrs ) == len (tensors ):
53- rowptr = torch .cat (rowptrs , dim = 0 )
68+ nnz : int = 0
69+ for tensor in tensors :
70+ row = tensor .storage ._row
71+ if row is not None :
72+ rows .append (row + sparse_sizes [0 ])
5473
55- col = torch .cat (cols , dim = 0 )
74+ rowptr = tensor .storage ._rowptr
75+ if rowptr is not None :
76+ rowptrs .append (rowptr [1 :] + nnz if len (rowptrs ) > 0 else rowptr )
5677
57- value : Optional [torch .Tensor ] = None
58- if len (values ) == len (tensors ):
59- value = torch .cat (values , dim = 0 )
78+ cols .append (tensor .storage ._col )
6079
61- rowcount : Optional [ torch . Tensor ] = None
62- if len ( rowcounts ) == len ( tensors ) :
63- rowcount = torch . cat ( rowcounts , dim = 0 )
80+ value = tensor . storage . _value
81+ if value is not None :
82+ values . append ( value )
6483
65- storage = SparseStorage (row = row , rowptr = rowptr , col = col , value = value ,
66- sparse_sizes = sparse_sizes , rowcount = rowcount ,
67- colptr = None , colcount = None , csr2csc = None ,
68- csc2csr = None , is_sorted = True )
69- return tensors [0 ].from_storage (storage )
84+ rowcount = tensor .storage ._rowcount
85+ if rowcount is not None :
86+ rowcounts .append (rowcount )
7087
71- elif dim == 1 :
72- rows : List [torch .Tensor ] = []
73- cols : List [torch .Tensor ] = []
74- values : List [torch .Tensor ] = []
75- sparse_sizes : List [int ] = [0 , 0 ]
76- colptrs : List [torch .Tensor ] = []
77- colcounts : List [torch .Tensor ] = []
88+ sparse_sizes [0 ] += tensor .sparse_size (0 )
89+ sparse_sizes [1 ] = max (sparse_sizes [1 ], tensor .sparse_size (1 ))
90+ nnz += tensor .nnz ()
7891
79- nnz : int = 0
80- for tensor in tensors :
81- row , col , value = tensor . coo ( )
92+ row : Optional [ torch . Tensor ] = None
93+ if len ( rows ) == len ( tensors ) :
94+ row = torch . cat ( rows , dim = 0 )
8295
83- rows .append (row )
96+ rowptr : Optional [torch .Tensor ] = None
97+ if len (rowptrs ) == len (tensors ):
98+ rowptr = torch .cat (rowptrs , dim = 0 )
8499
85- cols . append ( tensor . storage . _col + sparse_sizes [ 1 ] )
100+ col = torch . cat ( cols , dim = 0 )
86101
87- if value is not None :
88- values .append (value )
102+ value : Optional [torch .Tensor ] = None
103+ if len (values ) == len (tensors ):
104+ value = torch .cat (values , dim = 0 )
89105
90- colptr = tensor .storage ._colptr
91- if colptr is not None :
92- if len (colptrs ) > 0 :
93- colptr = colptr [1 :]
94- colptrs .append (colptr + nnz )
106+ rowcount : Optional [torch .Tensor ] = None
107+ if len (rowcounts ) == len (tensors ):
108+ rowcount = torch .cat (rowcounts , dim = 0 )
95109
96- colcount = tensor .storage ._colcount
97- if colcount is not None :
98- colcounts .append (colcount )
110+ storage = SparseStorage (row = row , rowptr = rowptr , col = col , value = value ,
111+ sparse_sizes = (sparse_sizes [0 ], sparse_sizes [1 ]),
112+ rowcount = rowcount , colptr = None , colcount = None ,
113+ csr2csc = None , csc2csr = None , is_sorted = True )
114+ return tensors [0 ].from_storage (storage )
99115
100- sparse_sizes [0 ] = max (sparse_sizes [0 ], tensor .sparse_size (0 ))
101- sparse_sizes [1 ] += tensor .sparse_size (1 )
102- nnz += tensor .nnz ()
103116
104- row = torch .cat (rows , dim = 0 )
117+ def cat_second (tensors : List [SparseTensor ]) -> SparseTensor :
118+ rows : List [torch .Tensor ] = []
119+ cols : List [torch .Tensor ] = []
120+ values : List [torch .Tensor ] = []
121+ sparse_sizes : List [int ] = [0 , 0 ]
122+ colptrs : List [torch .Tensor ] = []
123+ colcounts : List [torch .Tensor ] = []
124+
125+ nnz : int = 0
126+ for tensor in tensors :
127+ row , col , value = tensor .coo ()
128+ rows .append (row )
129+ cols .append (tensor .storage ._col + sparse_sizes [1 ])
130+
131+ if value is not None :
132+ values .append (value )
105133
106- col = torch .cat (cols , dim = 0 )
134+ colptr = tensor .storage ._colptr
135+ if colptr is not None :
136+ colptrs .append (colptr [1 :] + nnz if len (colptrs ) > 0 else colptr )
107137
108- value : Optional [ torch . Tensor ] = None
109- if len ( values ) == len ( tensors ) :
110- value = torch . cat ( values , dim = 0 )
138+ colcount = tensor . storage . _colcount
139+ if colcount is not None :
140+ colcounts . append ( colcount )
111141
112- colptr : Optional [ torch . Tensor ] = None
113- if len ( colptrs ) == len ( tensors ):
114- colptr = torch . cat ( colptrs , dim = 0 )
142+ sparse_sizes [ 0 ] = max ( sparse_sizes [ 0 ], tensor . sparse_size ( 0 ))
143+ sparse_sizes [ 1 ] += tensor . sparse_size ( 1 )
144+ nnz += tensor . nnz ( )
115145
116- colcount : Optional [torch .Tensor ] = None
117- if len (colcounts ) == len (tensors ):
118- colcount = torch .cat (colcounts , dim = 0 )
146+ row = torch .cat (rows , dim = 0 )
147+ col = torch .cat (cols , dim = 0 )
119148
120- storage = SparseStorage (row = row , rowptr = None , col = col , value = value ,
121- sparse_sizes = sparse_sizes , rowcount = None ,
122- colptr = colptr , colcount = colcount , csr2csc = None ,
123- csc2csr = None , is_sorted = False )
124- return tensors [0 ].from_storage (storage )
149+ value : Optional [torch .Tensor ] = None
150+ if len (values ) == len (tensors ):
151+ value = torch .cat (values , dim = 0 )
125152
126- elif dim > 1 and dim < tensors [0 ].dim ():
127- values : List [torch .Tensor ] = []
128- for tensor in tensors :
129- value = tensor .storage .value ()
130- if value is not None :
131- values .append (value )
153+ colptr : Optional [torch .Tensor ] = None
154+ if len (colptrs ) == len (tensors ):
155+ colptr = torch .cat (colptrs , dim = 0 )
132156
133- value : Optional [torch .Tensor ] = None
134- if len (values ) == len (tensors ):
135- value = torch .cat (values , dim = dim - 1 )
157+ colcount : Optional [torch .Tensor ] = None
158+ if len (colcounts ) == len (tensors ):
159+ colcount = torch .cat (colcounts , dim = 0 )
136160
137- return tensors [ 0 ]. set_value ( value , layout = 'coo' )
138- else :
139- raise IndexError (
140- ( f'Dimension out of range: Expected to be in range of '
141- f'[ { - tensors [0 ].dim () } , { tensors [ 0 ]. dim () - 1 } ], but got { dim } .' ) )
161+ storage = SparseStorage ( row = row , rowptr = None , col = col , value = value ,
162+ sparse_sizes = ( sparse_sizes [ 0 ], sparse_sizes [ 1 ]),
163+ rowcount = None , colptr = colptr , colcount = colcount ,
164+ csr2csc = None , csc2csr = None , is_sorted = False )
165+ return tensors [0 ].from_storage ( storage )
142166
143167
144168def cat_diag (tensors : List [SparseTensor ]) -> SparseTensor :
@@ -163,9 +187,7 @@ def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
163187
164188 rowptr = tensor .storage ._rowptr
165189 if rowptr is not None :
166- if len (rowptrs ) > 0 :
167- rowptr = rowptr [1 :]
168- rowptrs .append (rowptr + nnz )
190+ rowptrs .append (rowptr [1 :] + nnz if len (rowptrs ) > 0 else rowptr )
169191
170192 cols .append (tensor .storage ._col + sparse_sizes [1 ])
171193
@@ -179,9 +201,7 @@ def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
179201
180202 colptr = tensor .storage ._colptr
181203 if colptr is not None :
182- if len (colptrs ) > 0 :
183- colptr = colptr [1 :]
184- colptrs .append (colptr + nnz )
204+ colptrs .append (colptr [1 :] + nnz if len (colptrs ) > 0 else colptr )
185205
186206 colcount = tensor .storage ._colcount
187207 if colcount is not None :
@@ -234,7 +254,8 @@ def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
234254 csc2csr = torch .cat (csc2csrs , dim = 0 )
235255
236256 storage = SparseStorage (row = row , rowptr = rowptr , col = col , value = value ,
237- sparse_sizes = sparse_sizes , rowcount = rowcount ,
238- colptr = colptr , colcount = colcount , csr2csc = csr2csc ,
257+ sparse_sizes = (sparse_sizes [0 ], sparse_sizes [1 ]),
258+ rowcount = rowcount , colptr = colptr ,
259+ colcount = colcount , csr2csc = csr2csc ,
239260 csc2csr = csc2csr , is_sorted = True )
240261 return tensors [0 ].from_storage (storage )
0 commit comments