Skip to content

Commit 8fb5428

Browse files
committed
fixes
1 parent 78d9af4 commit 8fb5428

File tree

2 files changed

+8
-26
lines changed

2 files changed

+8
-26
lines changed

torch_sparse/sample.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,16 @@ def sample(src: SparseTensor, num_neighbors: int,
2525
def sample_adj(src: SparseTensor, subset: torch.Tensor, num_neighbors: int,
2626
replace: bool = False) -> Tuple[SparseTensor, torch.Tensor]:
2727

28-
rowptr, col, _ = src.csr()
28+
rowptr, col, value = src.csr()
2929
rowcount = src.storage.rowcount()
3030

3131
rowptr, col, n_id, e_id = torch.ops.torch_sparse.sample_adj(
3232
rowptr, col, rowcount, subset, num_neighbors, replace)
3333

34-
out = SparseTensor(rowptr=rowptr, row=None, col=col, value=e_id,
34+
if value is not None:
35+
value = value[e_id]
36+
37+
out = SparseTensor(rowptr=rowptr, row=None, col=col, value=value,
3538
sparse_sizes=(subset.size(0), n_id.size(0)),
3639
is_sorted=True)
3740

torch_sparse/tensor.py

Lines changed: 3 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ def long(self):
409409

410410
# Conversions #############################################################
411411

412-
def to_dense(self, options: Optional[torch.Tensor] = None):
412+
def to_dense(self, options: Optional[torch.Tensor] = None) -> torch.Tensor:
413413
row, col, value = self.coo()
414414

415415
if value is not None:
@@ -541,8 +541,8 @@ def __repr__(self: SparseTensor) -> str:
541541

542542
# Scipy Conversions ###########################################################
543543

544-
ScipySparseMatrix = Union[scipy.sparse.coo_matrix, scipy.sparse.csr_matrix,
545-
scipy.sparse.csc_matrix]
544+
ScipySparseMatrix = Union[scipy.sparse.coo_matrix, scipy.sparse.
545+
csr_matrix, scipy.sparse.csc_matrix]
546546

547547

548548
@torch.jit.ignore
@@ -600,24 +600,3 @@ def to_scipy(self: SparseTensor, layout: Optional[str] = None,
600600

601601
SparseTensor.from_scipy = from_scipy
602602
SparseTensor.to_scipy = to_scipy
603-
604-
# Hacky fixes #################################################################
605-
606-
# Fix standard operators of `torch.Tensor` for PyTorch<=1.3.
607-
# https://github.com/pytorch/pytorch/pull/31769
608-
TORCH_MAJOR = int(torch.__version__.split('.')[0])
609-
TORCH_MINOR = int(torch.__version__.split('.')[1])
610-
if (TORCH_MAJOR < 1) or (TORCH_MAJOR == 1 and TORCH_MINOR <= 3):
611-
612-
def add(self, other):
613-
if torch.is_tensor(other) or is_scalar(other):
614-
return self.add(other)
615-
return NotImplemented
616-
617-
def mul(self, other):
618-
if torch.is_tensor(other) or is_scalar(other):
619-
return self.mul(other)
620-
return NotImplemented
621-
622-
torch.Tensor.__add__ = add
623-
torch.Tensor.__mul__ = mul

0 commit comments

Comments
 (0)