@@ -73,29 +73,22 @@ def from_torch_sparse_coo_tensor(self, mat: torch.Tensor,
7373 is_sorted = True )
7474
7575 @classmethod
76- def eye (self , M : int , N : Optional [int ] = None ,
77- options : Optional [torch . Tensor ] = None , has_value : bool = True ,
76+ def eye (self , M : int , N : Optional [int ] = None , has_value : bool = True ,
77+ dtype : Optional [int ] = None , device : Optional [ torch . device ] = None ,
7878 fill_cache : bool = False ):
7979
8080 N = M if N is None else N
8181
82- if options is not None :
83- row = torch .arange (min (M , N ), device = options .device )
84- else :
85- row = torch .arange (min (M , N ))
82+ row = torch .arange (min (M , N ), device = device )
8683 col = row
8784
88- rowptr = torch .arange (M + 1 , dtype = torch . long , device = row .device )
85+ rowptr = torch .arange (M + 1 , device = row .device )
8986 if M > N :
9087 rowptr [N + 1 :] = N
9188
9289 value : Optional [torch .Tensor ] = None
9390 if has_value :
94- if options is not None :
95- value = torch .ones (row .numel (), dtype = options .dtype ,
96- device = row .device )
97- else :
98- value = torch .ones (row .numel (), device = row .device )
91+ value = torch .ones (row .numel (), dtype = dtype , device = row .device )
9992
10093 rowcount : Optional [torch .Tensor ] = None
10194 colptr : Optional [torch .Tensor ] = None
@@ -131,7 +124,7 @@ def clone(self):
131124 return self .from_storage (self .storage .clone ())
132125
133126 def type_as (self , tensor = torch .Tensor ):
134- value = self .storage ._value
127+ value = self .storage .value ()
135128 if value is None or tensor .dtype == value .dtype :
136129 return self
137130 return self .from_storage (self .storage .type_as (tensor ))
@@ -199,24 +192,14 @@ def clear_cache_(self):
199192
200193 # Utility functions #######################################################
201194
202- def fill_value_ (self , fill_value : float ,
203- options : Optional [torch .Tensor ] = None ):
204- if options is not None :
205- value = torch .full ((self .nnz (), ), fill_value , dtype = options .dtype ,
206- device = self .device ())
207- else :
208- value = torch .full ((self .nnz (), ), fill_value ,
209- device = self .device ())
195+ def fill_value_ (self , fill_value : float , dtype : Optional [int ] = None ):
196+ value = torch .full ((self .nnz (), ), fill_value , dtype = dtype ,
197+ device = self .device ())
210198 return self .set_value_ (value , layout = 'coo' )
211199
212- def fill_value (self , fill_value : float ,
213- options : Optional [torch .Tensor ] = None ):
214- if options is not None :
215- value = torch .full ((self .nnz (), ), fill_value , dtype = options .dtype ,
216- device = self .device ())
217- else :
218- value = torch .full ((self .nnz (), ), fill_value ,
219- device = self .device ())
200+ def fill_value (self , fill_value : float , dtype : Optional [int ] = None ):
201+ value = torch .full ((self .nnz (), ), fill_value , dtype = dtype ,
202+ device = self .device ())
220203 return self .set_value (value , layout = 'coo' )
221204
222205 def sizes (self ) -> List [int ]:
@@ -320,9 +303,9 @@ def requires_grad(self) -> bool:
320303 return False
321304
322305 def requires_grad_ (self , requires_grad : bool = True ,
323- options : Optional [torch . Tensor ] = None ):
306+ dtype : Optional [int ] = None ):
324307 if requires_grad and not self .has_value ():
325- self .fill_value_ (1. , options = options )
308+ self .fill_value_ (1. , dtype )
326309
327310 value = self .storage .value ()
328311 if value is not None :
@@ -335,36 +318,25 @@ def pin_memory(self):
335318 def is_pinned (self ) -> bool :
336319 return self .storage .is_pinned ()
337320
338- def options (self ) -> torch .Tensor :
339- value = self .storage .value ()
340- if value is not None :
341- return value
342- else :
343- return torch .tensor (0. , dtype = torch .float ,
344- device = self .storage .col ().device )
345-
346321 def device (self ):
347322 return self .storage .col ().device
348323
349324 def cpu (self ):
350- return self .device_as (torch .tensor (0. ), non_blocking = False )
325+ return self .device_as (torch .tensor (0 ), non_blocking = False )
351326
352- def cuda (self , options : Optional [torch .Tensor ] = None ,
353- non_blocking : bool = False ):
354- if options is not None :
355- return self .device_as (options , non_blocking )
356- else :
357- options = torch .tensor (0. ).cuda ()
358- return self .device_as (options , non_blocking )
327+ def cuda (self ):
328+ return self .from_storage (self .storage .cuda ())
359329
360330 def is_cuda (self ) -> bool :
361331 return self .storage .col ().is_cuda
362332
363333 def dtype (self ):
364- return self .options ().dtype
334+ value = self .storage .value ()
335+ return value .dtype if value is not None else torch .float
365336
366337 def is_floating_point (self ) -> bool :
367- return torch .is_floating_point (self .options ())
338+ value = self .storage .value ()
339+ return torch .is_floating_point (value ) if value is not None else True
368340
369341 def bfloat16 (self ):
370342 return self .type_as (
@@ -408,17 +380,14 @@ def long(self):
408380
409381 # Conversions #############################################################
410382
411- def to_dense (self , options : Optional [torch . Tensor ] = None ) -> torch .Tensor :
383+ def to_dense (self , dtype : Optional [int ] = None ) -> torch .Tensor :
412384 row , col , value = self .coo ()
413385
414386 if value is not None :
415387 mat = torch .zeros (self .sizes (), dtype = value .dtype ,
416388 device = self .device ())
417- elif options is not None :
418- mat = torch .zeros (self .sizes (), dtype = options .dtype ,
419- device = self .device ())
420389 else :
421- mat = torch .zeros (self .sizes (), device = self .device ())
390+ mat = torch .zeros (self .sizes (), dtype = dtype , device = self .device ())
422391
423392 if value is not None :
424393 mat [row , col ] = value
@@ -428,25 +397,18 @@ def to_dense(self, options: Optional[torch.Tensor] = None) -> torch.Tensor:
428397
429398 return mat
430399
431- def to_torch_sparse_coo_tensor (self ,
432- options : Optional [torch .Tensor ] = None ):
400+ def to_torch_sparse_coo_tensor (self , dtype : Optional [int ] = None ):
433401 row , col , value = self .coo ()
434402 index = torch .stack ([row , col ], dim = 0 )
403+
435404 if value is None :
436- if options is not None :
437- value = torch .ones (self .nnz (), dtype = options .dtype ,
438- device = self .device ())
439- else :
440- value = torch .ones (self .nnz (), device = self .device ())
405+ value = torch .ones (self .nnz (), dtype = dtype , device = self .device ())
441406
442407 return torch .sparse_coo_tensor (index , value , self .sizes ())
443408
444409
445410# Python Bindings #############################################################
446411
447- Dtype = Optional [torch .dtype ]
448- Device = Optional [Union [torch .device , str ]]
449-
450412
451413def share_memory_ (self : SparseTensor ) -> SparseTensor :
452414 self .storage .share_memory_ ()
0 commit comments