Skip to content

Commit cc6f886

Browse files
committed
cuda bugfixes
1 parent ba26dfb commit cc6f886

File tree

5 files changed

+9
-8
lines changed

5 files changed

+9
-8
lines changed

test/forward.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@
2626
{
2727
"name": "sub",
2828
"index": [[0, 0], [1, 1], [1, 1], [0, 0]],
29-
"input": [[5, 2], [2, 5], [4, 3], [1, 3]],
29+
"input": [[5, 2], [2, 2], [4, 2], [1, 3]],
3030
"dim": 0,
3131
"fill_value": 9,
32-
"expected": [[3, 4], [3, 1]]
32+
"expected": [[3, 4], [3, 5]]
3333
},
3434
{
3535
"name": "mul",

test/test_backward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def test_backward_cpu(tensor, i):
3535
@pytest.mark.parametrize('tensor,i', product(tensors, range(len(data))))
3636
def test_backward_gpu(tensor, i): # pragma: no cover
3737
name = data[i]['name']
38-
index = V(torch.LongTensor(data[i]['index']).cuda())
38+
index = V(torch.cuda.LongTensor(data[i]['index']))
3939
input = V(Tensor(tensor, data[i]['input']).cuda(), requires_grad=True)
4040
dim = data[i]['dim']
4141
fill_value = data[i]['fill_value']

test/test_forward.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def test_forward_cpu(tensor, i):
4444
@pytest.mark.parametrize('tensor,i', product(tensors, range(len(data))))
4545
def test_forward_gpu(tensor, i): # pragma: no cover
4646
name = data[i]['name']
47-
index = torch.LongTensor(data[i]['index']).cuda()
47+
index = torch.cuda.LongTensor(data[i]['index'])
4848
input = Tensor(tensor, data[i]['input']).cuda()
4949
dim = data[i]['dim']
5050
fill_value = data[i]['fill_value']
@@ -57,7 +57,6 @@ def test_forward_gpu(tensor, i): # pragma: no cover
5757
if 'expected_arg' in data[i]:
5858
expected_arg = torch.LongTensor(data[i]['expected_arg'])
5959
assert result[1].cpu().tolist() == expected_arg.tolist()
60-
6160
func = getattr(torch_scatter, 'scatter_{}'.format(name))
6261
result = func(index, input, dim, fill_value=fill_value)
6362
if 'expected_arg' not in data[i]:

torch_scatter/functions/sub.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def scatter_sub_(output, index, input, dim=0):
5151
-2 -4 -4 0 0 0
5252
[torch.FloatTensor of size 2x6]
5353
"""
54-
return output.scatter_add_(dim, index, -1 * input)
54+
return output.scatter_add_(dim, index, -input)
5555

5656

5757
def scatter_sub(index, input, dim=0, size=None, fill_value=0):

torch_scatter/kernel/kernel.cu

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,9 @@ __global__ void argKernel(TensorInfo<Real> output, TensorInfo<int64_t> index, Te
6464
KERNEL_LOOP(i, n) {
6565
int outputOffset = 0; int indexOffset = 0; int inputOffset = 0; int argOffset = 0;
6666
IndexToScatterOffsets4<Real, Real, int64_t, Dims>::compute(i, dim, index, &indexOffset, input, &inputOffset, output, &outputOffset, arg, &argOffset);
67-
if (input.data[inputOffset] == output.data[outputOffset]) arg.data[argOffset] = inputOffset % input.size[dim];
67+
if (input.data[inputOffset] == output.data[outputOffset]) {
68+
arg.data[argOffset] = (inputOffset / input.stride[dim]) % input.size[dim];
69+
}
6870
}
6971
}
7072

@@ -73,7 +75,7 @@ __global__ void indexBackwardKernel(TensorInfo<Real> output, TensorInfo<int64_t>
7375
KERNEL_LOOP(i, n) {
7476
int outputOffset = 0; int indexOffset = 0; int gradOffset = 0; int argOffset = 0;
7577
IndexToScatterOffsets4<Real, Real, int64_t, Dims>::compute(i, dim, index, &indexOffset, output, &outputOffset, grad, &gradOffset, arg, &argOffset);
76-
if (arg.data[argOffset] == outputOffset % output.size[dim]) output.data[outputOffset] = grad.data[gradOffset];
78+
if (arg.data[argOffset] == (outputOffset / output.stride[dim]) % output.size[dim]) output.data[outputOffset] = grad.data[gradOffset];
7779
}
7880
}
7981

0 commit comments

Comments
 (0)