@@ -41,7 +41,8 @@ def __init__(self, row: Optional[torch.Tensor] = None,
4141 colcount : Optional [torch .Tensor ] = None ,
4242 csr2csc : Optional [torch .Tensor ] = None ,
4343 csc2csr : Optional [torch .Tensor ] = None ,
44- is_sorted : bool = False ):
44+ is_sorted : bool = False ,
45+ trust_data : bool = False ):
4546
4647 assert row is not None or rowptr is not None
4748 assert col is not None
@@ -62,7 +63,7 @@ def __init__(self, row: Optional[torch.Tensor] = None,
6263 if rowptr is not None :
6364 assert rowptr .numel () - 1 == M
6465 elif row is not None and row .numel () > 0 :
65- assert int (row .max ()) < M
66+ assert trust_data or int (row .max ()) < M
6667
6768 N : int = 0
6869 if sparse_sizes is None or sparse_sizes [1 ] is None :
@@ -73,7 +74,7 @@ def __init__(self, row: Optional[torch.Tensor] = None,
7374 assert _N is not None
7475 N = _N
7576 if col .numel () > 0 :
76- assert int (col .max ()) < N
77+ assert trust_data or int (col .max ()) < N
7778
7879 sparse_sizes = (M , N )
7980
@@ -163,7 +164,7 @@ def empty(self):
163164 return SparseStorage (row = row , rowptr = None , col = col , value = None ,
164165 sparse_sizes = (0 , 0 ), rowcount = None , colptr = None ,
165166 colcount = None , csr2csc = None , csc2csr = None ,
166- is_sorted = True )
167+ is_sorted = True , trust_data = True )
167168
168169 def has_row (self ) -> bool :
169170 return self ._row is not None
@@ -227,11 +228,19 @@ def set_value(self, value: Optional[torch.Tensor],
227228 assert value .device == self ._col .device
228229 assert value .size (0 ) == self ._col .numel ()
229230
230- return SparseStorage (row = self ._row , rowptr = self ._rowptr , col = self ._col ,
231- value = value , sparse_sizes = self ._sparse_sizes ,
232- rowcount = self ._rowcount , colptr = self ._colptr ,
233- colcount = self ._colcount , csr2csc = self ._csr2csc ,
234- csc2csr = self ._csc2csr , is_sorted = True )
231+ return SparseStorage (
232+ row = self ._row ,
233+ rowptr = self ._rowptr ,
234+ col = self ._col ,
235+ value = value ,
236+ sparse_sizes = self ._sparse_sizes ,
237+ rowcount = self ._rowcount ,
238+ colptr = self ._colptr ,
239+ colcount = self ._colcount ,
240+ csr2csc = self ._csr2csc ,
241+ csc2csr = self ._csc2csr ,
242+ is_sorted = True ,
243+ trust_data = True )
235244
236245 def sparse_sizes (self ) -> Tuple [int , int ]:
237246 return self ._sparse_sizes
@@ -269,11 +278,19 @@ def sparse_resize(self, sparse_sizes: Tuple[int, int]):
269278 if colcount is not None :
270279 colcount = colcount [:- diff_1 ]
271280
272- return SparseStorage (row = self ._row , rowptr = rowptr , col = self ._col ,
273- value = self ._value , sparse_sizes = sparse_sizes ,
274- rowcount = rowcount , colptr = colptr ,
275- colcount = colcount , csr2csc = self ._csr2csc ,
276- csc2csr = self ._csc2csr , is_sorted = True )
281+ return SparseStorage (
282+ row = self ._row ,
283+ rowptr = rowptr ,
284+ col = self ._col ,
285+ value = self ._value ,
286+ sparse_sizes = sparse_sizes ,
287+ rowcount = rowcount ,
288+ colptr = colptr ,
289+ colcount = colcount ,
290+ csr2csc = self ._csr2csc ,
291+ csc2csr = self ._csc2csr ,
292+ is_sorted = True ,
293+ trust_data = True )
277294
278295 def sparse_reshape (self , num_rows : int , num_cols : int ):
279296 assert num_rows > 0 or num_rows == - 1
@@ -299,7 +316,7 @@ def sparse_reshape(self, num_rows: int, num_cols: int):
299316 return SparseStorage (row = row , rowptr = None , col = col , value = self ._value ,
300317 sparse_sizes = (num_rows , num_cols ), rowcount = None ,
301318 colptr = None , colcount = None , csr2csc = None ,
302- csc2csr = None , is_sorted = True )
319+ csc2csr = None , is_sorted = True , trust_data = True )
303320
304321 def has_rowcount (self ) -> bool :
305322 return self ._rowcount is not None
@@ -399,7 +416,7 @@ def coalesce(self, reduce: str = "add"):
399416 return SparseStorage (row = row , rowptr = None , col = col , value = value ,
400417 sparse_sizes = self ._sparse_sizes , rowcount = None ,
401418 colptr = None , colcount = None , csr2csc = None ,
402- csc2csr = None , is_sorted = True )
419+ csc2csr = None , is_sorted = True , trust_data = True )
403420
404421 def fill_cache_ (self ):
405422 self .row ()
@@ -437,12 +454,19 @@ def num_cached_keys(self) -> int:
437454 return len (self .cached_keys ())
438455
439456 def copy (self ):
440- return SparseStorage (row = self ._row , rowptr = self ._rowptr , col = self ._col ,
441- value = self ._value ,
442- sparse_sizes = self ._sparse_sizes ,
443- rowcount = self ._rowcount , colptr = self ._colptr ,
444- colcount = self ._colcount , csr2csc = self ._csr2csc ,
445- csc2csr = self ._csc2csr , is_sorted = True )
457+ return SparseStorage (
458+ row = self ._row ,
459+ rowptr = self ._rowptr ,
460+ col = self ._col ,
461+ value = self ._value ,
462+ sparse_sizes = self ._sparse_sizes ,
463+ rowcount = self ._rowcount ,
464+ colptr = self ._colptr ,
465+ colcount = self ._colcount ,
466+ csr2csc = self ._csr2csc ,
467+ csc2csr = self ._csc2csr ,
468+ is_sorted = True ,
469+ trust_data = True )
446470
447471 def clone (self ):
448472 row = self ._row
@@ -475,53 +499,63 @@ def clone(self):
475499 sparse_sizes = self ._sparse_sizes ,
476500 rowcount = rowcount , colptr = colptr ,
477501 colcount = colcount , csr2csc = csr2csc ,
478- csc2csr = csc2csr , is_sorted = True )
502+ csc2csr = csc2csr , is_sorted = True , trust_data = True )
479503
480- def type_as (self , tensor : torch .Tensor ):
504+ def type (self , dtype : torch .dtype , non_blocking : bool = False ):
481505 value = self ._value
482506 if value is not None :
483- if tensor . dtype == value .dtype :
507+ if dtype == value .dtype :
484508 return self
485509 else :
486- return self .set_value (value .type_as (tensor ), layout = 'coo' )
510+ return self .set_value (
511+ value .to (
512+ dtype = dtype ,
513+ non_blocking = non_blocking ),
514+ layout = 'coo' )
487515 else :
488516 return self
489517
490- def device_as (self , tensor : torch .Tensor , non_blocking : bool = False ):
491- if tensor .device == self ._col .device :
518+ def type_as (self , tensor : torch .Tensor , non_blocking : bool = False ):
519+ return self .type (dtype = tensor .dtype , non_blocking = non_blocking )
520+
521+ def to_device (self , device : torch .device , non_blocking : bool = False ):
522+ if device == self ._col .device :
492523 return self
493524
494525 row = self ._row
495526 if row is not None :
496- row = row .to (tensor . device , non_blocking = non_blocking )
527+ row = row .to (device , non_blocking = non_blocking )
497528 rowptr = self ._rowptr
498529 if rowptr is not None :
499- rowptr = rowptr .to (tensor . device , non_blocking = non_blocking )
500- col = self ._col .to (tensor . device , non_blocking = non_blocking )
530+ rowptr = rowptr .to (device , non_blocking = non_blocking )
531+ col = self ._col .to (device , non_blocking = non_blocking )
501532 value = self ._value
502533 if value is not None :
503- value = value .to (tensor . device , non_blocking = non_blocking )
534+ value = value .to (device , non_blocking = non_blocking )
504535 rowcount = self ._rowcount
505536 if rowcount is not None :
506- rowcount = rowcount .to (tensor . device , non_blocking = non_blocking )
537+ rowcount = rowcount .to (device , non_blocking = non_blocking )
507538 colptr = self ._colptr
508539 if colptr is not None :
509- colptr = colptr .to (tensor . device , non_blocking = non_blocking )
540+ colptr = colptr .to (device , non_blocking = non_blocking )
510541 colcount = self ._colcount
511542 if colcount is not None :
512- colcount = colcount .to (tensor . device , non_blocking = non_blocking )
543+ colcount = colcount .to (device , non_blocking = non_blocking )
513544 csr2csc = self ._csr2csc
514545 if csr2csc is not None :
515- csr2csc = csr2csc .to (tensor . device , non_blocking = non_blocking )
546+ csr2csc = csr2csc .to (device , non_blocking = non_blocking )
516547 csc2csr = self ._csc2csr
517548 if csc2csr is not None :
518- csc2csr = csc2csr .to (tensor . device , non_blocking = non_blocking )
549+ csc2csr = csc2csr .to (device , non_blocking = non_blocking )
519550
520551 return SparseStorage (row = row , rowptr = rowptr , col = col , value = value ,
521552 sparse_sizes = self ._sparse_sizes ,
522553 rowcount = rowcount , colptr = colptr ,
523554 colcount = colcount , csr2csc = csr2csc ,
524- csc2csr = csc2csr , is_sorted = True )
555+ csc2csr = csc2csr , is_sorted = True , trust_data = True )
556+
557+ def device_as (self , tensor : torch .Tensor , non_blocking : bool = False ):
558+ return self .to_device (device = tensor .device , non_blocking = non_blocking )
525559
526560 def cuda (self ):
527561 new_col = self ._col .cuda ()
@@ -557,7 +591,7 @@ def cuda(self):
557591 sparse_sizes = self ._sparse_sizes ,
558592 rowcount = rowcount , colptr = colptr ,
559593 colcount = colcount , csr2csc = csr2csc ,
560- csc2csr = csc2csr , is_sorted = True )
594+ csc2csr = csc2csr , is_sorted = True , trust_data = True )
561595
562596 def pin_memory (self ):
563597 row = self ._row
@@ -590,7 +624,7 @@ def pin_memory(self):
590624 sparse_sizes = self ._sparse_sizes ,
591625 rowcount = rowcount , colptr = colptr ,
592626 colcount = colcount , csr2csc = csr2csc ,
593- csc2csr = csc2csr , is_sorted = True )
627+ csc2csr = csc2csr , is_sorted = True , trust_data = True )
594628
595629 def is_pinned (self ) -> bool :
596630 is_pinned = True
0 commit comments