Skip to content

Commit fe8c3ce

Browse files
authored
Skip unnecessary assertions and enable non-blocking data transfers (#195)
* Uses the `trust_data` invariant to skip blocking assertions, when unnecessary, during construction of `SparseStorage` objects. * Refactors the dtype and device transfer APIs to align with `torch.Tensor` while maintaining backward compatibility. * No longer constructs dummy tensors when changing dtype or device.
1 parent 88c6ceb commit fe8c3ce

File tree

3 files changed

+148
-87
lines changed

3 files changed

+148
-87
lines changed

setup.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
WITH_METIS = True if os.getenv('WITH_METIS', '0') == '1' else False
2525
WITH_MTMETIS = True if os.getenv('WITH_MTMETIS', '0') == '1' else False
2626

27+
WITH_SYMBOLS = True if os.getenv('WITH_SYMBOLS', '0') == '1' else False
28+
2729

2830
def get_extensions():
2931
extensions = []
@@ -47,7 +49,7 @@ def get_extensions():
4749
extra_compile_args = {'cxx': ['-O2']}
4850
if not os.name == 'nt': # Not on Windows:
4951
extra_compile_args['cxx'] += ['-Wno-sign-compare']
50-
extra_link_args = ['-s']
52+
extra_link_args = [] if WITH_SYMBOLS else ['-s']
5153

5254
info = parallel_info()
5355
if ('backend: OpenMP' in info and 'OpenMP not found' not in info

torch_sparse/storage.py

Lines changed: 74 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ def __init__(self, row: Optional[torch.Tensor] = None,
4141
colcount: Optional[torch.Tensor] = None,
4242
csr2csc: Optional[torch.Tensor] = None,
4343
csc2csr: Optional[torch.Tensor] = None,
44-
is_sorted: bool = False):
44+
is_sorted: bool = False,
45+
trust_data: bool = False):
4546

4647
assert row is not None or rowptr is not None
4748
assert col is not None
@@ -62,7 +63,7 @@ def __init__(self, row: Optional[torch.Tensor] = None,
6263
if rowptr is not None:
6364
assert rowptr.numel() - 1 == M
6465
elif row is not None and row.numel() > 0:
65-
assert int(row.max()) < M
66+
assert trust_data or int(row.max()) < M
6667

6768
N: int = 0
6869
if sparse_sizes is None or sparse_sizes[1] is None:
@@ -73,7 +74,7 @@ def __init__(self, row: Optional[torch.Tensor] = None,
7374
assert _N is not None
7475
N = _N
7576
if col.numel() > 0:
76-
assert int(col.max()) < N
77+
assert trust_data or int(col.max()) < N
7778

7879
sparse_sizes = (M, N)
7980

@@ -163,7 +164,7 @@ def empty(self):
163164
return SparseStorage(row=row, rowptr=None, col=col, value=None,
164165
sparse_sizes=(0, 0), rowcount=None, colptr=None,
165166
colcount=None, csr2csc=None, csc2csr=None,
166-
is_sorted=True)
167+
is_sorted=True, trust_data=True)
167168

168169
def has_row(self) -> bool:
169170
return self._row is not None
@@ -227,11 +228,19 @@ def set_value(self, value: Optional[torch.Tensor],
227228
assert value.device == self._col.device
228229
assert value.size(0) == self._col.numel()
229230

230-
return SparseStorage(row=self._row, rowptr=self._rowptr, col=self._col,
231-
value=value, sparse_sizes=self._sparse_sizes,
232-
rowcount=self._rowcount, colptr=self._colptr,
233-
colcount=self._colcount, csr2csc=self._csr2csc,
234-
csc2csr=self._csc2csr, is_sorted=True)
231+
return SparseStorage(
232+
row=self._row,
233+
rowptr=self._rowptr,
234+
col=self._col,
235+
value=value,
236+
sparse_sizes=self._sparse_sizes,
237+
rowcount=self._rowcount,
238+
colptr=self._colptr,
239+
colcount=self._colcount,
240+
csr2csc=self._csr2csc,
241+
csc2csr=self._csc2csr,
242+
is_sorted=True,
243+
trust_data=True)
235244

236245
def sparse_sizes(self) -> Tuple[int, int]:
237246
return self._sparse_sizes
@@ -269,11 +278,19 @@ def sparse_resize(self, sparse_sizes: Tuple[int, int]):
269278
if colcount is not None:
270279
colcount = colcount[:-diff_1]
271280

272-
return SparseStorage(row=self._row, rowptr=rowptr, col=self._col,
273-
value=self._value, sparse_sizes=sparse_sizes,
274-
rowcount=rowcount, colptr=colptr,
275-
colcount=colcount, csr2csc=self._csr2csc,
276-
csc2csr=self._csc2csr, is_sorted=True)
281+
return SparseStorage(
282+
row=self._row,
283+
rowptr=rowptr,
284+
col=self._col,
285+
value=self._value,
286+
sparse_sizes=sparse_sizes,
287+
rowcount=rowcount,
288+
colptr=colptr,
289+
colcount=colcount,
290+
csr2csc=self._csr2csc,
291+
csc2csr=self._csc2csr,
292+
is_sorted=True,
293+
trust_data=True)
277294

278295
def sparse_reshape(self, num_rows: int, num_cols: int):
279296
assert num_rows > 0 or num_rows == -1
@@ -299,7 +316,7 @@ def sparse_reshape(self, num_rows: int, num_cols: int):
299316
return SparseStorage(row=row, rowptr=None, col=col, value=self._value,
300317
sparse_sizes=(num_rows, num_cols), rowcount=None,
301318
colptr=None, colcount=None, csr2csc=None,
302-
csc2csr=None, is_sorted=True)
319+
csc2csr=None, is_sorted=True, trust_data=True)
303320

304321
def has_rowcount(self) -> bool:
305322
return self._rowcount is not None
@@ -399,7 +416,7 @@ def coalesce(self, reduce: str = "add"):
399416
return SparseStorage(row=row, rowptr=None, col=col, value=value,
400417
sparse_sizes=self._sparse_sizes, rowcount=None,
401418
colptr=None, colcount=None, csr2csc=None,
402-
csc2csr=None, is_sorted=True)
419+
csc2csr=None, is_sorted=True, trust_data=True)
403420

404421
def fill_cache_(self):
405422
self.row()
@@ -437,12 +454,19 @@ def num_cached_keys(self) -> int:
437454
return len(self.cached_keys())
438455

439456
def copy(self):
440-
return SparseStorage(row=self._row, rowptr=self._rowptr, col=self._col,
441-
value=self._value,
442-
sparse_sizes=self._sparse_sizes,
443-
rowcount=self._rowcount, colptr=self._colptr,
444-
colcount=self._colcount, csr2csc=self._csr2csc,
445-
csc2csr=self._csc2csr, is_sorted=True)
457+
return SparseStorage(
458+
row=self._row,
459+
rowptr=self._rowptr,
460+
col=self._col,
461+
value=self._value,
462+
sparse_sizes=self._sparse_sizes,
463+
rowcount=self._rowcount,
464+
colptr=self._colptr,
465+
colcount=self._colcount,
466+
csr2csc=self._csr2csc,
467+
csc2csr=self._csc2csr,
468+
is_sorted=True,
469+
trust_data=True)
446470

447471
def clone(self):
448472
row = self._row
@@ -475,53 +499,63 @@ def clone(self):
475499
sparse_sizes=self._sparse_sizes,
476500
rowcount=rowcount, colptr=colptr,
477501
colcount=colcount, csr2csc=csr2csc,
478-
csc2csr=csc2csr, is_sorted=True)
502+
csc2csr=csc2csr, is_sorted=True, trust_data=True)
479503

480-
def type_as(self, tensor: torch.Tensor):
504+
def type(self, dtype: torch.dtype, non_blocking: bool = False):
481505
value = self._value
482506
if value is not None:
483-
if tensor.dtype == value.dtype:
507+
if dtype == value.dtype:
484508
return self
485509
else:
486-
return self.set_value(value.type_as(tensor), layout='coo')
510+
return self.set_value(
511+
value.to(
512+
dtype=dtype,
513+
non_blocking=non_blocking),
514+
layout='coo')
487515
else:
488516
return self
489517

490-
def device_as(self, tensor: torch.Tensor, non_blocking: bool = False):
491-
if tensor.device == self._col.device:
518+
def type_as(self, tensor: torch.Tensor, non_blocking: bool = False):
519+
return self.type(dtype=tensor.dtype, non_blocking=non_blocking)
520+
521+
def to_device(self, device: torch.device, non_blocking: bool = False):
522+
if device == self._col.device:
492523
return self
493524

494525
row = self._row
495526
if row is not None:
496-
row = row.to(tensor.device, non_blocking=non_blocking)
527+
row = row.to(device, non_blocking=non_blocking)
497528
rowptr = self._rowptr
498529
if rowptr is not None:
499-
rowptr = rowptr.to(tensor.device, non_blocking=non_blocking)
500-
col = self._col.to(tensor.device, non_blocking=non_blocking)
530+
rowptr = rowptr.to(device, non_blocking=non_blocking)
531+
col = self._col.to(device, non_blocking=non_blocking)
501532
value = self._value
502533
if value is not None:
503-
value = value.to(tensor.device, non_blocking=non_blocking)
534+
value = value.to(device, non_blocking=non_blocking)
504535
rowcount = self._rowcount
505536
if rowcount is not None:
506-
rowcount = rowcount.to(tensor.device, non_blocking=non_blocking)
537+
rowcount = rowcount.to(device, non_blocking=non_blocking)
507538
colptr = self._colptr
508539
if colptr is not None:
509-
colptr = colptr.to(tensor.device, non_blocking=non_blocking)
540+
colptr = colptr.to(device, non_blocking=non_blocking)
510541
colcount = self._colcount
511542
if colcount is not None:
512-
colcount = colcount.to(tensor.device, non_blocking=non_blocking)
543+
colcount = colcount.to(device, non_blocking=non_blocking)
513544
csr2csc = self._csr2csc
514545
if csr2csc is not None:
515-
csr2csc = csr2csc.to(tensor.device, non_blocking=non_blocking)
546+
csr2csc = csr2csc.to(device, non_blocking=non_blocking)
516547
csc2csr = self._csc2csr
517548
if csc2csr is not None:
518-
csc2csr = csc2csr.to(tensor.device, non_blocking=non_blocking)
549+
csc2csr = csc2csr.to(device, non_blocking=non_blocking)
519550

520551
return SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
521552
sparse_sizes=self._sparse_sizes,
522553
rowcount=rowcount, colptr=colptr,
523554
colcount=colcount, csr2csc=csr2csc,
524-
csc2csr=csc2csr, is_sorted=True)
555+
csc2csr=csc2csr, is_sorted=True, trust_data=True)
556+
557+
def device_as(self, tensor: torch.Tensor, non_blocking: bool = False):
558+
return self.to_device(device=tensor.device, non_blocking=non_blocking)
525559

526560
def cuda(self):
527561
new_col = self._col.cuda()
@@ -557,7 +591,7 @@ def cuda(self):
557591
sparse_sizes=self._sparse_sizes,
558592
rowcount=rowcount, colptr=colptr,
559593
colcount=colcount, csr2csc=csr2csc,
560-
csc2csr=csc2csr, is_sorted=True)
594+
csc2csr=csc2csr, is_sorted=True, trust_data=True)
561595

562596
def pin_memory(self):
563597
row = self._row
@@ -590,7 +624,7 @@ def pin_memory(self):
590624
sparse_sizes=self._sparse_sizes,
591625
rowcount=rowcount, colptr=colptr,
592626
colcount=colcount, csr2csc=csr2csc,
593-
csc2csr=csc2csr, is_sorted=True)
627+
csc2csr=csc2csr, is_sorted=True, trust_data=True)
594628

595629
def is_pinned(self) -> bool:
596630
is_pinned = True

0 commit comments

Comments
 (0)