Skip to content

Commit 7134d45

Browse files
committed
fix index_bug half bug
1 parent 56ec830 commit 7134d45

File tree

3 files changed

+14
-9
lines changed

3 files changed

+14
-9
lines changed

.github/workflows/building-conda.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ jobs:
3434
python-version: ${{ matrix.python-version }}
3535

3636
- name: Free up disk space
37-
if: ${{ runner.os == 'Linux' && matrix.cuda-version == 'cu111' }}
37+
if: ${{ runner.os == 'Linux' }}
3838
run: |
3939
sudo rm -rf /usr/share/dotnet
4040

conda/pytorch-sparse/build_conda.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,4 @@ echo "PyTorch $TORCH_VERSION+$CUDA_VERSION"
3030
echo "- $CONDA_PYTORCH_CONSTRAINT"
3131
echo "- $CONDA_CUDATOOLKIT_CONSTRAINT"
3232

33-
conda build . -c defaults -c nvidia -c pytorch -c conda-forge -c rusty1s --output-folder "$HOME/conda-bld"
33+
conda build . -c pytorch -c nvidia -c rusty1s -c defaults -c conda-forge --output-folder "$HOME/conda-bld"

csrc/spmm.cpp

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ class SPMMSum : public torch::autograd::Function<SPMMSum> {
9797
if (torch::autograd::any_variable_requires_grad({mat})) {
9898
torch::optional<torch::Tensor> opt_value = torch::nullopt;
9999
if (has_value)
100-
opt_value = value.index_select(0, csr2csc);
100+
opt_value = value.view({-1, 1}).index_select(0, csr2csc).view(-1);
101101

102102
grad_mat = std::get<0>(spmm_fw(colptr, row.index_select(0, csr2csc),
103103
opt_value, grad_out, "sum"));
@@ -161,11 +161,12 @@ class SPMMMean : public torch::autograd::Function<SPMMMean> {
161161
auto grad_mat = Variable();
162162
if (torch::autograd::any_variable_requires_grad({mat})) {
163163
row = row.index_select(0, csr2csc);
164-
rowcount = rowcount.toType(mat.scalar_type()).index_select(0, row);
164+
rowcount = rowcount.index_select(0, row).toType(mat.scalar_type());
165165
rowcount.masked_fill_(rowcount < 1, 1);
166166

167167
if (has_value > 0)
168-
rowcount = value.index_select(0, csr2csc).div(rowcount);
168+
rowcount =
169+
value.view({-1, 1}).index_select(0, csr2csc).view(-1).div(rowcount);
169170
else
170171
rowcount.pow_(-1);
171172

@@ -219,8 +220,10 @@ class SPMMMin : public torch::autograd::Function<SPMMMin> {
219220
auto grad_mat = Variable();
220221
if (torch::autograd::any_variable_requires_grad({mat})) {
221222
if (has_value > 0) {
222-
value = value.index_select(0, arg_out.flatten()).view_as(arg_out);
223-
value.mul_(grad_out);
223+
value = value.view({-1, 1})
224+
.index_select(0, arg_out.flatten())
225+
.view_as(arg_out)
226+
.mul_(grad_out);
224227
} else
225228
value = grad_out;
226229

@@ -277,8 +280,10 @@ class SPMMMax : public torch::autograd::Function<SPMMMax> {
277280
auto grad_mat = Variable();
278281
if (torch::autograd::any_variable_requires_grad({mat})) {
279282
if (has_value > 0) {
280-
value = value.index_select(0, arg_out.flatten()).view_as(arg_out);
281-
value.mul_(grad_out);
283+
value = value.view({-1, 1})
284+
.index_select(0, arg_out.flatten())
285+
.view_as(arg_out)
286+
.mul_(grad_out);
282287
} else
283288
value = grad_out;
284289

0 commit comments

Comments
 (0)