@@ -122,3 +122,24 @@ def test_coalesce(dtype, device):
122122 assert storage .row ().tolist () == [0 , 0 , 1 , 1 ]
123123 assert storage .col ().tolist () == [0 , 1 , 0 , 1 ]
124124 assert storage .value ().tolist () == [1 , 2 , 3 , 4 ]
125+
126+
127+ @pytest .mark .parametrize ('dtype,device' , product (dtypes , devices ))
128+ def test_sparse_reshape (dtype , device ):
129+ row , col = tensor ([[0 , 1 , 2 , 3 ], [0 , 1 , 2 , 3 ]], torch .long , device )
130+ storage = SparseStorage (row = row , col = col )
131+
132+ storage = storage .sparse_reshape (2 , 8 )
133+ assert storage .sparse_sizes () == (2 , 8 )
134+ assert storage .row ().tolist () == [0 , 0 , 1 , 1 ]
135+ assert storage .col ().tolist () == [0 , 5 , 2 , 7 ]
136+
137+ storage = storage .sparse_reshape (- 1 , 4 )
138+ assert storage .sparse_sizes () == (4 , 4 )
139+ assert storage .row ().tolist () == [0 , 1 , 2 , 3 ]
140+ assert storage .col ().tolist () == [0 , 1 , 2 , 3 ]
141+
142+ storage = storage .sparse_reshape (2 , - 1 )
143+ assert storage .sparse_sizes () == (2 , 8 )
144+ assert storage .row ().tolist () == [0 , 0 , 1 , 1 ]
145+ assert storage .col ().tolist () == [0 , 5 , 2 , 7 ]
0 commit comments