@@ -163,17 +163,17 @@ from torch_sparse import spmm
163163
164164index = torch.tensor([[0 , 0 , 1 , 2 , 2 ],
165165 [0 , 2 , 1 , 0 , 1 ]])
166- value = torch.tensor([1 , 2 , 4 , 1 , 3 ])
167- matrix = torch.tensor([[1 , 4 ], [2 , 5 ], [3 , 6 ]])
166+ value = torch.tensor([1 , 2 , 4 , 1 , 3 ], dtype = torch.float )
167+ matrix = torch.tensor([[1 , 4 ], [2 , 5 ], [3 , 6 ]], dtype = torch.float )
168168
169169out = spmm(index, value, 3 , matrix)
170170```
171171
172172```
173173print(out)
174- tensor([[7, 16],
175- [8, 20],
176- [7, 19]])
174+ tensor([[7.0 , 16.0 ],
175+ [8.0 , 20.0 ],
176+ [7.0 , 19.0 ]])
177177```
178178
179179## Sparse Sparse Matrix Multiplication
@@ -206,10 +206,10 @@ Both input sparse matrices need to be **coalesced**.
206206from torch_sparse import spspmm
207207
208208indexA = torch.tensor([[0 , 0 , 1 , 2 , 2 ], [1 , 2 , 0 , 0 , 1 ]])
209- valueA = torch.tensor([1 , 2 , 3 , 4 , 5 ])
209+ valueA = torch.tensor([1 , 2 , 3 , 4 , 5 ], dtype = torch.float )
210210
211211indexB = torch.tensor([[0 , 2 ], [1 , 0 ]])
212- valueB = torch.tensor([2 , 4 ])
212+ valueB = torch.tensor([2 , 4 ], dtype = torch.float )
213213
214214indexC, valueC = spspmm(indexA, valueA, indexB, valueB, 3 , 3 , 2 )
215215```
@@ -219,7 +219,7 @@ print(index)
219219tensor([[0, 1, 2],
220220 [0, 1, 1]])
221221print(value)
222- tensor([8, 6, 8])
222+ tensor([8.0 , 6.0 , 8.0 ])
223223```
224224
225225## Running tests
0 commit comments