@@ -65,9 +65,15 @@ def from_edge_index(
6565 is_sorted : bool = False ,
6666 trust_data : bool = False ,
6767 ):
68- return SparseTensor (row = edge_index [0 ], rowptr = None , col = edge_index [1 ],
69- value = edge_attr , sparse_sizes = sparse_sizes ,
70- is_sorted = is_sorted , trust_data = trust_data )
68+ return SparseTensor (
69+ row = edge_index [0 ],
70+ rowptr = None ,
71+ col = edge_index [1 ],
72+ value = edge_attr ,
73+ sparse_sizes = sparse_sizes ,
74+ is_sorted = is_sorted ,
75+ trust_data = trust_data ,
76+ )
7177
7278 @classmethod
7379 def from_dense (self , mat : torch .Tensor , has_value : bool = True ):
@@ -84,13 +90,22 @@ def from_dense(self, mat: torch.Tensor, has_value: bool = True):
8490 if has_value :
8591 value = mat [row , col ]
8692
87- return SparseTensor (row = row , rowptr = None , col = col , value = value ,
88- sparse_sizes = (mat .size (0 ), mat .size (1 )),
89- is_sorted = True , trust_data = True )
93+ return SparseTensor (
94+ row = row ,
95+ rowptr = None ,
96+ col = col ,
97+ value = value ,
98+ sparse_sizes = (mat .size (0 ), mat .size (1 )),
99+ is_sorted = True ,
100+ trust_data = True ,
101+ )
90102
91103 @classmethod
92- def from_torch_sparse_coo_tensor (self , mat : torch .Tensor ,
93- has_value : bool = True ):
104+ def from_torch_sparse_coo_tensor (
105+ self ,
106+ mat : torch .Tensor ,
107+ has_value : bool = True ,
108+ ):
94109 mat = mat .coalesce ()
95110 index = mat ._indices ()
96111 row , col = index [0 ], index [1 ]
@@ -99,27 +114,46 @@ def from_torch_sparse_coo_tensor(self, mat: torch.Tensor,
99114 if has_value :
100115 value = mat .values ()
101116
102- return SparseTensor (row = row , rowptr = None , col = col , value = value ,
103- sparse_sizes = (mat .size (0 ), mat .size (1 )),
104- is_sorted = True , trust_data = True )
117+ return SparseTensor (
118+ row = row ,
119+ rowptr = None ,
120+ col = col ,
121+ value = value ,
122+ sparse_sizes = (mat .size (0 ), mat .size (1 )),
123+ is_sorted = True ,
124+ trust_data = True ,
125+ )
105126
106127 @classmethod
107- def from_torch_sparse_csr_tensor (self , mat : torch .Tensor ,
108- has_value : bool = True ):
128+ def from_torch_sparse_csr_tensor (
129+ self ,
130+ mat : torch .Tensor ,
131+ has_value : bool = True ,
132+ ):
109133 rowptr = mat .crow_indices ()
110134 col = mat .col_indices ()
111135
112136 value : Optional [torch .Tensor ] = None
113137 if has_value :
114138 value = mat .values ()
115139
116- return SparseTensor (row = None , rowptr = rowptr , col = col , value = value ,
117- sparse_sizes = (mat .size (0 ), mat .size (1 )),
118- is_sorted = True , trust_data = True )
140+ return SparseTensor (
141+ row = None ,
142+ rowptr = rowptr ,
143+ col = col ,
144+ value = value ,
145+ sparse_sizes = (mat .size (0 ), mat .size (1 )),
146+ is_sorted = True ,
147+ trust_data = True ,
148+ )
119149
120150 @classmethod
121- def eye (self , M : int , N : Optional [int ] = None , has_value : bool = True ,
122- dtype : Optional [int ] = None , device : Optional [torch .device ] = None ,
151+ def eye (self ,
152+ M : int ,
153+ N : Optional [int ] = None ,
154+ has_value : bool = True ,
155+ dtype : Optional [int ] = None ,
156+ device : Optional [torch .device ] = None ,
123157 fill_cache : bool = False ):
124158
125159 N = M if N is None else N
@@ -214,13 +248,19 @@ def csc(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
214248 def has_value (self ) -> bool :
215249 return self .storage .has_value ()
216250
217- def set_value_ (self , value : Optional [torch .Tensor ],
218- layout : Optional [str ] = None ):
251+ def set_value_ (
252+ self ,
253+ value : Optional [torch .Tensor ],
254+ layout : Optional [str ] = None ,
255+ ):
219256 self .storage .set_value_ (value , layout )
220257 return self
221258
222- def set_value (self , value : Optional [torch .Tensor ],
223- layout : Optional [str ] = None ):
259+ def set_value (
260+ self ,
261+ value : Optional [torch .Tensor ],
262+ layout : Optional [str ] = None ,
263+ ):
224264 return self .from_storage (self .storage .set_value (value , layout ))
225265
226266 def sparse_sizes (self ) -> Tuple [int , int ]:
@@ -275,13 +315,21 @@ def __eq__(self, other) -> bool:
275315 # Utility functions #######################################################
276316
277317 def fill_value_ (self , fill_value : float , dtype : Optional [int ] = None ):
278- value = torch .full ((self .nnz (), ), fill_value , dtype = dtype ,
279- device = self .device ())
318+ value = torch .full (
319+ (self .nnz (), ),
320+ fill_value ,
321+ dtype = dtype ,
322+ device = self .device (),
323+ )
280324 return self .set_value_ (value , layout = 'coo' )
281325
282326 def fill_value (self , fill_value : float , dtype : Optional [int ] = None ):
283- value = torch .full ((self .nnz (), ), fill_value , dtype = dtype ,
284- device = self .device ())
327+ value = torch .full (
328+ (self .nnz (), ),
329+ fill_value ,
330+ dtype = dtype ,
331+ device = self .device (),
332+ )
285333 return self .set_value (value , layout = 'coo' )
286334
287335 def sizes (self ) -> List [int ]:
@@ -373,8 +421,8 @@ def to_symmetric(self, reduce: str = "sum"):
373421 value = torch .cat ([value , value ])[perm ]
374422 value = segment_csr (value , ptr , reduce = reduce )
375423
376- new_row = torch .cat ([row , col ], dim = 0 , out = perm )[idx ]
377- new_col = torch .cat ([col , row ], dim = 0 , out = perm )[idx ]
424+ new_row = torch .cat ([row , col ], dim = 0 )[idx ]
425+ new_col = torch .cat ([col , row ], dim = 0 )[idx ]
378426
379427 out = SparseTensor (
380428 row = new_row ,
@@ -406,8 +454,11 @@ def requires_grad(self) -> bool:
406454 else :
407455 return False
408456
409- def requires_grad_ (self , requires_grad : bool = True ,
410- dtype : Optional [int ] = None ):
457+ def requires_grad_ (
458+ self ,
459+ requires_grad : bool = True ,
460+ dtype : Optional [int ] = None ,
461+ ):
411462 if requires_grad and not self .has_value ():
412463 self .fill_value_ (1. , dtype )
413464
@@ -478,21 +529,29 @@ def to_dense(self, dtype: Optional[int] = None) -> torch.Tensor:
478529 row , col , value = self .coo ()
479530
480531 if value is not None :
481- mat = torch .zeros (self .sizes (), dtype = value .dtype ,
482- device = self .device ())
532+ mat = torch .zeros (
533+ self .sizes (),
534+ dtype = value .dtype ,
535+ device = self .device (),
536+ )
483537 else :
484538 mat = torch .zeros (self .sizes (), dtype = dtype , device = self .device ())
485539
486540 if value is not None :
487541 mat [row , col ] = value
488542 else :
489- mat [row , col ] = torch .ones (self .nnz (), dtype = mat .dtype ,
490- device = mat .device )
543+ mat [row , col ] = torch .ones (
544+ self .nnz (),
545+ dtype = mat .dtype ,
546+ device = mat .device ,
547+ )
491548
492549 return mat
493550
494551 def to_torch_sparse_coo_tensor (
495- self , dtype : Optional [int ] = None ) -> torch .Tensor :
552+ self ,
553+ dtype : Optional [int ] = None ,
554+ ) -> torch .Tensor :
496555 row , col , value = self .coo ()
497556 index = torch .stack ([row , col ], dim = 0 )
498557
@@ -502,7 +561,9 @@ def to_torch_sparse_coo_tensor(
502561 return torch .sparse_coo_tensor (index , value , self .sizes ())
503562
504563 def to_torch_sparse_csr_tensor (
505- self , dtype : Optional [int ] = None ) -> torch .Tensor :
564+ self ,
565+ dtype : Optional [int ] = None ,
566+ ) -> torch .Tensor :
506567 rowptr , col , value = self .csr ()
507568
508569 if value is None :
@@ -511,7 +572,9 @@ def to_torch_sparse_csr_tensor(
511572 return torch .sparse_csr_tensor (rowptr , col , value , self .sizes ())
512573
513574 def to_torch_sparse_csc_tensor (
514- self , dtype : Optional [int ] = None ) -> torch .Tensor :
575+ self ,
576+ dtype : Optional [int ] = None ,
577+ ) -> torch .Tensor :
515578 colptr , row , value = self .csc ()
516579
517580 if value is None :
@@ -548,8 +611,11 @@ def cpu(self) -> SparseTensor:
548611 return self .device_as (torch .tensor (0. , device = 'cpu' ))
549612
550613
551- def cuda (self , device : Optional [Union [int , str ]] = None ,
552- non_blocking : bool = False ):
614+ def cuda (
615+ self ,
616+ device : Optional [Union [int , str ]] = None ,
617+ non_blocking : bool = False ,
618+ ):
553619 return self .device_as (torch .tensor (0. , device = device or 'cuda' ))
554620
555621
@@ -654,17 +720,29 @@ def from_scipy(mat: ScipySparseMatrix, has_value: bool = True) -> SparseTensor:
654720 value = torch .from_numpy (mat .data )
655721 sparse_sizes = mat .shape [:2 ]
656722
657- storage = SparseStorage (row = row , rowptr = rowptr , col = col , value = value ,
658- sparse_sizes = sparse_sizes , rowcount = None ,
659- colptr = colptr , colcount = None , csr2csc = None ,
660- csc2csr = None , is_sorted = True )
723+ storage = SparseStorage (
724+ row = row ,
725+ rowptr = rowptr ,
726+ col = col ,
727+ value = value ,
728+ sparse_sizes = sparse_sizes ,
729+ rowcount = None ,
730+ colptr = colptr ,
731+ colcount = None ,
732+ csr2csc = None ,
733+ csc2csr = None ,
734+ is_sorted = True ,
735+ )
661736
662737 return SparseTensor .from_storage (storage )
663738
664739
665740@torch .jit .ignore
666- def to_scipy (self : SparseTensor , layout : Optional [str ] = None ,
667- dtype : Optional [torch .dtype ] = None ) -> ScipySparseMatrix :
741+ def to_scipy (
742+ self : SparseTensor ,
743+ layout : Optional [str ] = None ,
744+ dtype : Optional [torch .dtype ] = None ,
745+ ) -> ScipySparseMatrix :
668746 assert self .dim () == 2
669747 layout = get_layout (layout )
670748
0 commit comments