11from typing import Optional
22
33import torch
4+ from torch import Tensor
45from torch_scatter import gather_csr
6+
57from torch_sparse .tensor import SparseTensor
68
79
8- def mul (src : SparseTensor , other : torch .Tensor ) -> SparseTensor :
9- rowptr , col , value = src .csr ()
10- if other .size (0 ) == src .size (0 ) and other .size (1 ) == 1 : # Row-wise...
11- other = gather_csr (other .squeeze (1 ), rowptr )
12- pass
13- elif other .size (0 ) == 1 and other .size (1 ) == src .size (1 ): # Col-wise...
14- other = other .squeeze (0 )[col ]
15- else :
16- raise ValueError (
17- f'Size mismatch: Expected size ({ src .size (0 )} , 1, ...) or '
18- f'(1, { src .size (1 )} , ...), but got size { other .size ()} .' )
10+ @torch .jit ._overload # noqa: F811
11+ def mul (src , other ): # noqa: F811
12+ # type: (SparseTensor, Tensor) -> SparseTensor
13+ pass
1914
20- if value is not None :
21- value = other .to (value .dtype ).mul_ (value )
15+
16+ @torch .jit ._overload # noqa: F811
17+ def mul (src , other ): # noqa: F811
18+ # type: (SparseTensor, SparseTensor) -> SparseTensor
19+ pass
20+
21+
22+ def mul (src , other ): # noqa: F811
23+ if isinstance (other , Tensor ):
24+ rowptr , col , value = src .csr ()
25+ if other .size (0 ) == src .size (0 ) and other .size (1 ) == 1 : # Row-wise...
26+ other = gather_csr (other .squeeze (1 ), rowptr )
27+ pass
28+ # Col-wise...
29+ elif other .size (0 ) == 1 and other .size (1 ) == src .size (1 ):
30+ other = other .squeeze (0 )[col ]
31+ else :
32+ raise ValueError (
33+ f'Size mismatch: Expected size ({ src .size (0 )} , 1, ...) or '
34+ f'(1, { src .size (1 )} , ...), but got size { other .size ()} .' )
35+
36+ if value is not None :
37+ value = other .to (value .dtype ).mul_ (value )
38+ else :
39+ value = other
40+ return src .set_value (value , layout = 'coo' )
41+
42+ assert isinstance (other , SparseTensor )
43+
44+ if not src .is_coalesced ():
45+ raise ValueError ("The `src` tensor is not coalesced" )
46+ if not other .is_coalesced ():
47+ raise ValueError ("The `other` tensor is not coalesced" )
48+
49+ rowA , colA , valueA = src .coo ()
50+ rowB , colB , valueB = other .coo ()
51+
52+ row = torch .cat ([rowA , rowB ], dim = 0 )
53+ col = torch .cat ([colA , colB ], dim = 0 )
54+
55+ if valueA is not None and valueB is not None :
56+ value = torch .cat ([valueA , valueB ], dim = 0 )
2257 else :
23- value = other
24- return src .set_value (value , layout = 'coo' )
58+ raise ValueError ('Both sparse tensors must contain values' )
59+
60+ M = max (src .size (0 ), other .size (0 ))
61+ N = max (src .size (1 ), other .size (1 ))
62+ sparse_sizes = (M , N )
63+
64+ # Sort indices:
65+ idx = col .new_full ((col .numel () + 1 , ), - 1 )
66+ idx [1 :] = row * sparse_sizes [1 ] + col
67+ perm = idx [1 :].argsort ()
68+ idx [1 :] = idx [1 :][perm ]
69+
70+ row , col , value = row [perm ], col [perm ], value [perm ]
71+
72+ valid_mask = idx [1 :] == idx [:- 1 ]
73+ valid_idx = valid_mask .nonzero ().view (- 1 )
74+
75+ return SparseTensor (
76+ row = row [valid_mask ],
77+ col = col [valid_mask ],
78+ value = value [valid_idx - 1 ] * value [valid_idx ],
79+ sparse_sizes = sparse_sizes ,
80+ )
2581
2682
2783def mul_ (src : SparseTensor , other : torch .Tensor ) -> SparseTensor :
@@ -43,8 +99,11 @@ def mul_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
4399 return src .set_value_ (value , layout = 'coo' )
44100
45101
46- def mul_nnz (src : SparseTensor , other : torch .Tensor ,
47- layout : Optional [str ] = None ) -> SparseTensor :
102+ def mul_nnz (
103+ src : SparseTensor ,
104+ other : torch .Tensor ,
105+ layout : Optional [str ] = None ,
106+ ) -> SparseTensor :
48107 value = src .storage .value ()
49108 if value is not None :
50109 value = value .mul (other .to (value .dtype ))
@@ -53,8 +112,11 @@ def mul_nnz(src: SparseTensor, other: torch.Tensor,
53112 return src .set_value (value , layout = layout )
54113
55114
56- def mul_nnz_ (src : SparseTensor , other : torch .Tensor ,
57- layout : Optional [str ] = None ) -> SparseTensor :
115+ def mul_nnz_ (
116+ src : SparseTensor ,
117+ other : torch .Tensor ,
118+ layout : Optional [str ] = None ,
119+ ) -> SparseTensor :
58120 value = src .storage .value ()
59121 if value is not None :
60122 value = value .mul_ (other .to (value .dtype ))
0 commit comments