Skip to content

Commit 3c415d2

Browse files
committed
torch.device / torch.dtype args
1 parent 698be79 commit 3c415d2

File tree

4 files changed

+77
-72
lines changed

4 files changed

+77
-72
lines changed

test/test_eye.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,25 +9,27 @@
99

1010
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
1111
def test_eye(dtype, device):
12-
options = torch.tensor(0, dtype=dtype, device=device)
13-
14-
mat = SparseTensor.eye(3, options=options)
12+
mat = SparseTensor.eye(3, dtype=dtype, device=device)
13+
assert mat.storage.col().device == device
1514
assert mat.storage.sparse_sizes() == (3, 3)
1615
assert mat.storage.row().tolist() == [0, 1, 2]
1716
assert mat.storage.rowptr().tolist() == [0, 1, 2, 3]
1817
assert mat.storage.col().tolist() == [0, 1, 2]
1918
assert mat.storage.value().tolist() == [1, 1, 1]
19+
assert mat.storage.value().dtype == dtype
2020
assert mat.storage.num_cached_keys() == 0
2121

22-
mat = SparseTensor.eye(3, options=options, has_value=False)
22+
mat = SparseTensor.eye(3, has_value=False)
23+
assert mat.storage.col().device == device
2324
assert mat.storage.sparse_sizes() == (3, 3)
2425
assert mat.storage.row().tolist() == [0, 1, 2]
2526
assert mat.storage.rowptr().tolist() == [0, 1, 2, 3]
2627
assert mat.storage.col().tolist() == [0, 1, 2]
2728
assert mat.storage.value() is None
2829
assert mat.storage.num_cached_keys() == 0
2930

30-
mat = SparseTensor.eye(3, 4, options=options, fill_cache=True)
31+
mat = SparseTensor.eye(3, 4, fill_cache=True)
32+
assert mat.storage.col().device == device
3133
assert mat.storage.sparse_sizes() == (3, 4)
3234
assert mat.storage.row().tolist() == [0, 1, 2]
3335
assert mat.storage.rowptr().tolist() == [0, 1, 2, 3]
@@ -39,7 +41,8 @@ def test_eye(dtype, device):
3941
assert mat.storage.csr2csc().tolist() == [0, 1, 2]
4042
assert mat.storage.csc2csr().tolist() == [0, 1, 2]
4143

42-
mat = SparseTensor.eye(4, 3, options=options, fill_cache=True)
44+
mat = SparseTensor.eye(4, 3, fill_cache=True)
45+
assert mat.storage.col().device == device
4346
assert mat.storage.sparse_sizes() == (4, 3)
4447
assert mat.storage.row().tolist() == [0, 1, 2]
4548
assert mat.storage.rowptr().tolist() == [0, 1, 2, 3, 3]

torch_sparse/diag.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ def set_diag(src: SparseTensor, values: Optional[torch.Tensor] = None,
7878
return src.from_storage(storage)
7979

8080

81-
def fill_diag(src: SparseTensor, fill_value: int, k: int = 0) -> SparseTensor:
81+
def fill_diag(src: SparseTensor, fill_value: float,
82+
k: int = 0) -> SparseTensor:
8283
num_diag = min(src.sparse_size(0), src.sparse_size(1) - k)
8384
if k < 0:
8485
num_diag = min(src.sparse_size(0) + k, src.sparse_size(1))

torch_sparse/storage.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -459,13 +459,14 @@ def clone(self):
459459
csc2csr = self._csc2csr
460460
if csc2csr is not None:
461461
csc2csr = csc2csr.clone()
462+
462463
return SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
463464
sparse_sizes=self._sparse_sizes,
464465
rowcount=rowcount, colptr=colptr,
465466
colcount=colcount, csr2csc=csr2csc,
466467
csc2csr=csc2csr, is_sorted=True)
467468

468-
def type_as(self, tensor=torch.Tensor):
469+
def type_as(self, tensor: torch.Tensor):
469470
value = self._value
470471
if value is not None:
471472
if tensor.dtype == value.dtype:
@@ -504,12 +505,49 @@ def device_as(self, tensor: torch.Tensor, non_blocking: bool = False):
504505
csc2csr = self._csc2csr
505506
if csc2csr is not None:
506507
csc2csr = csc2csr.to(tensor.device, non_blocking=non_blocking)
508+
507509
return SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
508510
sparse_sizes=self._sparse_sizes,
509511
rowcount=rowcount, colptr=colptr,
510512
colcount=colcount, csr2csc=csr2csc,
511513
csc2csr=csc2csr, is_sorted=True)
512514

515+
def cuda(self):
516+
new_col = self._col.cuda()
517+
if new_col.device == self._col.device:
518+
return self
519+
520+
row = self._row
521+
if row is not None:
522+
row = row.cuda()
523+
rowptr = self._rowptr
524+
if rowptr is not None:
525+
rowptr = rowptr.cuda()
526+
value = self._value
527+
if value is not None:
528+
value = value.cuda()
529+
rowcount = self._rowcount
530+
if rowcount is not None:
531+
rowcount = rowcount.cuda()
532+
colptr = self._colptr
533+
if colptr is not None:
534+
colptr = colptr.cuda()
535+
colcount = self._colcount
536+
if colcount is not None:
537+
colcount = colcount.cuda()
538+
csr2csc = self._csr2csc
539+
if csr2csc is not None:
540+
csr2csc = csr2csc.cuda()
541+
csc2csr = self._csc2csr
542+
if csc2csr is not None:
543+
csc2csr = csc2csr.cuda()
544+
545+
return SparseStorage(row=row, rowptr=rowptr, col=new_col, value=value,
546+
sparse_sizes=self._sparse_sizes,
547+
rowcount=rowcount, colptr=colptr,
548+
colcount=colcount, csr2csc=csr2csc,
549+
csc2csr=csc2csr, is_sorted=True)
550+
513551
def pin_memory(self):
514552
row = self._row
515553
if row is not None:
@@ -536,6 +574,7 @@ def pin_memory(self):
536574
csc2csr = self._csc2csr
537575
if csc2csr is not None:
538576
csc2csr = csc2csr.pin_memory()
577+
539578
return SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
540579
sparse_sizes=self._sparse_sizes,
541580
rowcount=rowcount, colptr=colptr,

torch_sparse/tensor.py

Lines changed: 26 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -73,29 +73,22 @@ def from_torch_sparse_coo_tensor(self, mat: torch.Tensor,
7373
is_sorted=True)
7474

7575
@classmethod
76-
def eye(self, M: int, N: Optional[int] = None,
77-
options: Optional[torch.Tensor] = None, has_value: bool = True,
76+
def eye(self, M: int, N: Optional[int] = None, has_value: bool = True,
77+
dtype: Optional[int] = None, device: Optional[torch.device] = None,
7878
fill_cache: bool = False):
7979

8080
N = M if N is None else N
8181

82-
if options is not None:
83-
row = torch.arange(min(M, N), device=options.device)
84-
else:
85-
row = torch.arange(min(M, N))
82+
row = torch.arange(min(M, N), device=device)
8683
col = row
8784

88-
rowptr = torch.arange(M + 1, dtype=torch.long, device=row.device)
85+
rowptr = torch.arange(M + 1, device=row.device)
8986
if M > N:
9087
rowptr[N + 1:] = N
9188

9289
value: Optional[torch.Tensor] = None
9390
if has_value:
94-
if options is not None:
95-
value = torch.ones(row.numel(), dtype=options.dtype,
96-
device=row.device)
97-
else:
98-
value = torch.ones(row.numel(), device=row.device)
91+
value = torch.ones(row.numel(), dtype=dtype, device=row.device)
9992

10093
rowcount: Optional[torch.Tensor] = None
10194
colptr: Optional[torch.Tensor] = None
@@ -131,7 +124,7 @@ def clone(self):
131124
return self.from_storage(self.storage.clone())
132125

133126
def type_as(self, tensor=torch.Tensor):
134-
value = self.storage._value
127+
value = self.storage.value()
135128
if value is None or tensor.dtype == value.dtype:
136129
return self
137130
return self.from_storage(self.storage.type_as(tensor))
@@ -199,24 +192,14 @@ def clear_cache_(self):
199192

200193
# Utility functions #######################################################
201194

202-
def fill_value_(self, fill_value: float,
203-
options: Optional[torch.Tensor] = None):
204-
if options is not None:
205-
value = torch.full((self.nnz(), ), fill_value, dtype=options.dtype,
206-
device=self.device())
207-
else:
208-
value = torch.full((self.nnz(), ), fill_value,
209-
device=self.device())
195+
def fill_value_(self, fill_value: float, dtype: Optional[int] = None):
196+
value = torch.full((self.nnz(), ), fill_value, dtype=dtype,
197+
device=self.device())
210198
return self.set_value_(value, layout='coo')
211199

212-
def fill_value(self, fill_value: float,
213-
options: Optional[torch.Tensor] = None):
214-
if options is not None:
215-
value = torch.full((self.nnz(), ), fill_value, dtype=options.dtype,
216-
device=self.device())
217-
else:
218-
value = torch.full((self.nnz(), ), fill_value,
219-
device=self.device())
200+
def fill_value(self, fill_value: float, dtype: Optional[int] = None):
201+
value = torch.full((self.nnz(), ), fill_value, dtype=dtype,
202+
device=self.device())
220203
return self.set_value(value, layout='coo')
221204

222205
def sizes(self) -> List[int]:
@@ -320,9 +303,9 @@ def requires_grad(self) -> bool:
320303
return False
321304

322305
def requires_grad_(self, requires_grad: bool = True,
323-
options: Optional[torch.Tensor] = None):
306+
dtype: Optional[int] = None):
324307
if requires_grad and not self.has_value():
325-
self.fill_value_(1., options=options)
308+
self.fill_value_(1., dtype)
326309

327310
value = self.storage.value()
328311
if value is not None:
@@ -335,36 +318,25 @@ def pin_memory(self):
335318
def is_pinned(self) -> bool:
336319
return self.storage.is_pinned()
337320

338-
def options(self) -> torch.Tensor:
339-
value = self.storage.value()
340-
if value is not None:
341-
return value
342-
else:
343-
return torch.tensor(0., dtype=torch.float,
344-
device=self.storage.col().device)
345-
346321
def device(self):
347322
return self.storage.col().device
348323

349324
def cpu(self):
350-
return self.device_as(torch.tensor(0.), non_blocking=False)
325+
return self.device_as(torch.tensor(0), non_blocking=False)
351326

352-
def cuda(self, options: Optional[torch.Tensor] = None,
353-
non_blocking: bool = False):
354-
if options is not None:
355-
return self.device_as(options, non_blocking)
356-
else:
357-
options = torch.tensor(0.).cuda()
358-
return self.device_as(options, non_blocking)
327+
def cuda(self):
328+
return self.from_storage(self.storage.cuda())
359329

360330
def is_cuda(self) -> bool:
361331
return self.storage.col().is_cuda
362332

363333
def dtype(self):
364-
return self.options().dtype
334+
value = self.storage.value()
335+
return value.dtype if value is not None else torch.float
365336

366337
def is_floating_point(self) -> bool:
367-
return torch.is_floating_point(self.options())
338+
value = self.storage.value()
339+
return torch.is_floating_point(value) if value is not None else True
368340

369341
def bfloat16(self):
370342
return self.type_as(
@@ -408,17 +380,14 @@ def long(self):
408380

409381
# Conversions #############################################################
410382

411-
def to_dense(self, options: Optional[torch.Tensor] = None) -> torch.Tensor:
383+
def to_dense(self, dtype: Optional[int] = None) -> torch.Tensor:
412384
row, col, value = self.coo()
413385

414386
if value is not None:
415387
mat = torch.zeros(self.sizes(), dtype=value.dtype,
416388
device=self.device())
417-
elif options is not None:
418-
mat = torch.zeros(self.sizes(), dtype=options.dtype,
419-
device=self.device())
420389
else:
421-
mat = torch.zeros(self.sizes(), device=self.device())
390+
mat = torch.zeros(self.sizes(), dtype=dtype, device=self.device())
422391

423392
if value is not None:
424393
mat[row, col] = value
@@ -428,25 +397,18 @@ def to_dense(self, options: Optional[torch.Tensor] = None) -> torch.Tensor:
428397

429398
return mat
430399

431-
def to_torch_sparse_coo_tensor(self,
432-
options: Optional[torch.Tensor] = None):
400+
def to_torch_sparse_coo_tensor(self, dtype: Optional[int] = None):
433401
row, col, value = self.coo()
434402
index = torch.stack([row, col], dim=0)
403+
435404
if value is None:
436-
if options is not None:
437-
value = torch.ones(self.nnz(), dtype=options.dtype,
438-
device=self.device())
439-
else:
440-
value = torch.ones(self.nnz(), device=self.device())
405+
value = torch.ones(self.nnz(), dtype=dtype, device=self.device())
441406

442407
return torch.sparse_coo_tensor(index, value, self.sizes())
443408

444409

445410
# Python Bindings #############################################################
446411

447-
Dtype = Optional[torch.dtype]
448-
Device = Optional[Union[torch.device, str]]
449-
450412

451413
def share_memory_(self: SparseTensor) -> SparseTensor:
452414
self.storage.share_memory_()

0 commit comments

Comments
 (0)