@@ -115,10 +115,20 @@ def timing(dataset):
115115 dim_size = rowptr .size (0 ) - 1
116116 avg_row_len = row .size (0 ) / dim_size
117117
118- def sca_row (x ):
118+ def sca1_row (x ):
119+ out = x .new_zeros (dim_size , * x .size ()[1 :])
120+ row_tmp = row .view (- 1 , 1 ).expand_as (x ) if x .dim () > 1 else row
121+ return out .scatter_add_ (0 , row_tmp , x )
122+
123+ def sca1_col (x ):
124+ out = x .new_zeros (dim_size , * x .size ()[1 :])
125+ row2_tmp = row2 .view (- 1 , 1 ).expand_as (x ) if x .dim () > 1 else row2
126+ return out .scatter_add_ (0 , row2_tmp , x )
127+
128+ def sca2_row (x ):
119129 return scatter (x , row , dim = 0 , dim_size = dim_size , reduce = args .reduce )
120130
121- def sca_col (x ):
131+ def sca2_col (x ):
122132 return scatter (x , row2 , dim = 0 , dim_size = dim_size , reduce = args .reduce )
123133
124134 def seg_coo (x ):
@@ -133,75 +143,81 @@ def dense1(x):
133143 def dense2 (x ):
134144 return getattr (torch , args .reduce )(x , dim = - 1 )
135145
136- t1 , t2 , t3 , t4 , t5 , t6 = [], [], [], [], [], []
146+ t1 , t2 , t3 , t4 , t5 , t6 , t7 , t8 = [], [], [], [], [], [], [], []
137147
138148 for size in sizes :
139149 try :
140150 x = torch .randn ((row .size (0 ), size ), device = args .device )
141151 x = x .squeeze (- 1 ) if size == 1 else x
142152
143- t1 += [time_func (sca_row , x )]
144- t2 += [time_func (sca_col , x )]
145- t3 += [time_func (seg_coo , x )]
146- t4 += [time_func (seg_csr , x )]
153+ t1 += [time_func (sca1_row , x )]
154+ t2 += [time_func (sca1_col , x )]
155+ t3 += [time_func (sca2_row , x )]
156+ t4 += [time_func (sca2_col , x )]
157+ t5 += [time_func (seg_coo , x )]
158+ t6 += [time_func (seg_csr , x )]
147159
148160 del x
149161
150162 except RuntimeError as e :
151163 if 'out of memory' not in str (e ):
152164 raise RuntimeError (e )
153165 torch .cuda .empty_cache ()
154- for t in (t1 , t2 , t3 , t4 ):
166+ for t in (t1 , t2 , t3 , t4 , t5 , t6 ):
155167 t .append (float ('inf' ))
156168
157169 try :
158170 x = torch .randn ((dim_size , int (avg_row_len + 1 ), size ),
159171 device = args .device )
160172
161- t5 += [time_func (dense1 , x )]
173+ t7 += [time_func (dense1 , x )]
162174 x = x .view (dim_size , size , int (avg_row_len + 1 ))
163- t6 += [time_func (dense2 , x )]
175+ t8 += [time_func (dense2 , x )]
164176
165177 del x
166178
167179 except RuntimeError as e :
168180 if 'out of memory' not in str (e ):
169181 raise RuntimeError (e )
170182 torch .cuda .empty_cache ()
171- for t in (t5 , t6 ):
183+ for t in (t7 , t8 ):
172184 t .append (float ('inf' ))
173185
174- ts = torch .tensor ([t1 , t2 , t3 , t4 , t5 , t6 ])
186+ ts = torch .tensor ([t1 , t2 , t3 , t4 , t5 , t6 , t7 , t8 ])
175187 winner = torch .zeros_like (ts , dtype = torch .bool )
176188 winner [ts .argmin (dim = 0 ), torch .arange (len (sizes ))] = 1
177189 winner = winner .tolist ()
178190
179191 name = f'{ group } /{ name } '
180192 print (f'{ bold (name )} (avg row length: { avg_row_len :.2f} ):' )
181193 print ('\t ' .join ([' ' ] + [f'{ size :>5} ' for size in sizes ]))
182- print ('\t ' .join ([bold ('SCA_ROW ' )] +
194+ print ('\t ' .join ([bold ('SCA1_R ' )] +
183195 [bold (f'{ t :.5f} ' , f ) for t , f in zip (t1 , winner [0 ])]))
184- print ('\t ' .join ([bold ('SCA_COL ' )] +
196+ print ('\t ' .join ([bold ('SCA1_C ' )] +
185197 [bold (f'{ t :.5f} ' , f ) for t , f in zip (t2 , winner [1 ])]))
186- print ('\t ' .join ([bold ('SEG_COO ' )] +
198+ print ('\t ' .join ([bold ('SCA2_R ' )] +
187199 [bold (f'{ t :.5f} ' , f ) for t , f in zip (t3 , winner [2 ])]))
188- print ('\t ' .join ([bold ('SEG_CSR ' )] +
200+ print ('\t ' .join ([bold ('SCA2_C ' )] +
189201 [bold (f'{ t :.5f} ' , f ) for t , f in zip (t4 , winner [3 ])]))
190- print ('\t ' .join ([bold ('DENSE1 ' )] +
202+ print ('\t ' .join ([bold ('SEG_COO ' )] +
191203 [bold (f'{ t :.5f} ' , f ) for t , f in zip (t5 , winner [4 ])]))
192- print ('\t ' .join ([bold ('DENSE2 ' )] +
204+ print ('\t ' .join ([bold ('SEG_CSR ' )] +
193205 [bold (f'{ t :.5f} ' , f ) for t , f in zip (t6 , winner [5 ])]))
206+ print ('\t ' .join ([bold ('DENSE1 ' )] +
207+ [bold (f'{ t :.5f} ' , f ) for t , f in zip (t7 , winner [6 ])]))
208+ print ('\t ' .join ([bold ('DENSE2 ' )] +
209+ [bold (f'{ t :.5f} ' , f ) for t , f in zip (t8 , winner [7 ])]))
194210 print ()
195211
196212
197213if __name__ == '__main__' :
198214 parser = argparse .ArgumentParser ()
199215 parser .add_argument ('--reduce' , type = str , required = True ,
200- choices = ['sum' , 'add' , ' mean' , 'min' , 'max' ])
216+ choices = ['sum' , 'mean' , 'min' , 'max' ])
201217 parser .add_argument ('--with_backward' , action = 'store_true' )
202218 parser .add_argument ('--device' , type = str , default = 'cuda' )
203219 args = parser .parse_args ()
204- iters = 1 if args .device == 'cpu' else 20
220+ iters = 1 if args .device == 'cpu' else 50
205221 sizes = [1 , 16 , 32 , 64 , 128 , 256 , 512 ]
206222 sizes = sizes [:3 ] if args .device == 'cpu' else sizes
207223
0 commit comments