Skip to content

Commit 1a4bdd3

Browse files
committed
sparse matrix ops request floating point data types
1 parent 9fb0794 commit 1a4bdd3

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

README.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -163,17 +163,17 @@ from torch_sparse import spmm
163163

164164
index = torch.tensor([[0, 0, 1, 2, 2],
165165
[0, 2, 1, 0, 1]])
166-
value = torch.tensor([1, 2, 4, 1, 3])
167-
matrix = torch.tensor([[1, 4], [2, 5], [3, 6]])
166+
value = torch.tensor([1, 2, 4, 1, 3], dtype=torch.float)
167+
matrix = torch.tensor([[1, 4], [2, 5], [3, 6]], dtype=torch.float)
168168

169169
out = spmm(index, value, 3, matrix)
170170
```
171171

172172
```
173173
print(out)
174-
tensor([[7, 16],
175-
[8, 20],
176-
[7, 19]])
174+
tensor([[7.0, 16.0],
175+
[8.0, 20.0],
176+
[7.0, 19.0]])
177177
```
178178

179179
## Sparse Sparse Matrix Multiplication
@@ -206,10 +206,10 @@ Both input sparse matrices need to be **coalesced**.
206206
from torch_sparse import spspmm
207207

208208
indexA = torch.tensor([[0, 0, 1, 2, 2], [1, 2, 0, 0, 1]])
209-
valueA = torch.tensor([1, 2, 3, 4, 5])
209+
valueA = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float)
210210

211211
indexB = torch.tensor([[0, 2], [1, 0]])
212-
valueB = torch.tensor([2, 4])
212+
valueB = torch.tensor([2, 4], dtype=torch.float)
213213

214214
indexC, valueC = spspmm(indexA, valueA, indexB, valueB, 3, 3, 2)
215215
```
@@ -219,7 +219,7 @@ print(index)
219219
tensor([[0, 1, 2],
220220
[0, 1, 1]])
221221
print(value)
222-
tensor([8, 6, 8])
222+
tensor([8.0, 6.0, 8.0])
223223
```
224224

225225
## Running tests

0 commit comments

Comments
 (0)