Skip to content

Commit 1006514

Browse files
committed
test with pytorch scatter
1 parent f056396 commit 1006514

File tree

1 file changed

+36
-20
lines changed

1 file changed

+36
-20
lines changed

benchmark/scatter_segment.py

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,20 @@ def timing(dataset):
115115
dim_size = rowptr.size(0) - 1
116116
avg_row_len = row.size(0) / dim_size
117117

118-
def sca_row(x):
118+
def sca1_row(x):
119+
out = x.new_zeros(dim_size, *x.size()[1:])
120+
row_tmp = row.view(-1, 1).expand_as(x) if x.dim() > 1 else row
121+
return out.scatter_add_(0, row_tmp, x)
122+
123+
def sca1_col(x):
124+
out = x.new_zeros(dim_size, *x.size()[1:])
125+
row2_tmp = row2.view(-1, 1).expand_as(x) if x.dim() > 1 else row2
126+
return out.scatter_add_(0, row2_tmp, x)
127+
128+
def sca2_row(x):
119129
return scatter(x, row, dim=0, dim_size=dim_size, reduce=args.reduce)
120130

121-
def sca_col(x):
131+
def sca2_col(x):
122132
return scatter(x, row2, dim=0, dim_size=dim_size, reduce=args.reduce)
123133

124134
def seg_coo(x):
@@ -133,75 +143,81 @@ def dense1(x):
133143
def dense2(x):
134144
return getattr(torch, args.reduce)(x, dim=-1)
135145

136-
t1, t2, t3, t4, t5, t6 = [], [], [], [], [], []
146+
t1, t2, t3, t4, t5, t6, t7, t8 = [], [], [], [], [], [], [], []
137147

138148
for size in sizes:
139149
try:
140150
x = torch.randn((row.size(0), size), device=args.device)
141151
x = x.squeeze(-1) if size == 1 else x
142152

143-
t1 += [time_func(sca_row, x)]
144-
t2 += [time_func(sca_col, x)]
145-
t3 += [time_func(seg_coo, x)]
146-
t4 += [time_func(seg_csr, x)]
153+
t1 += [time_func(sca1_row, x)]
154+
t2 += [time_func(sca1_col, x)]
155+
t3 += [time_func(sca2_row, x)]
156+
t4 += [time_func(sca2_col, x)]
157+
t5 += [time_func(seg_coo, x)]
158+
t6 += [time_func(seg_csr, x)]
147159

148160
del x
149161

150162
except RuntimeError as e:
151163
if 'out of memory' not in str(e):
152164
raise RuntimeError(e)
153165
torch.cuda.empty_cache()
154-
for t in (t1, t2, t3, t4):
166+
for t in (t1, t2, t3, t4, t5, t6):
155167
t.append(float('inf'))
156168

157169
try:
158170
x = torch.randn((dim_size, int(avg_row_len + 1), size),
159171
device=args.device)
160172

161-
t5 += [time_func(dense1, x)]
173+
t7 += [time_func(dense1, x)]
162174
x = x.view(dim_size, size, int(avg_row_len + 1))
163-
t6 += [time_func(dense2, x)]
175+
t8 += [time_func(dense2, x)]
164176

165177
del x
166178

167179
except RuntimeError as e:
168180
if 'out of memory' not in str(e):
169181
raise RuntimeError(e)
170182
torch.cuda.empty_cache()
171-
for t in (t5, t6):
183+
for t in (t7, t8):
172184
t.append(float('inf'))
173185

174-
ts = torch.tensor([t1, t2, t3, t4, t5, t6])
186+
ts = torch.tensor([t1, t2, t3, t4, t5, t6, t7, t8])
175187
winner = torch.zeros_like(ts, dtype=torch.bool)
176188
winner[ts.argmin(dim=0), torch.arange(len(sizes))] = 1
177189
winner = winner.tolist()
178190

179191
name = f'{group}/{name}'
180192
print(f'{bold(name)} (avg row length: {avg_row_len:.2f}):')
181193
print('\t'.join([' '] + [f'{size:>5}' for size in sizes]))
182-
print('\t'.join([bold('SCA_ROW')] +
194+
print('\t'.join([bold('SCA1_R ')] +
183195
[bold(f'{t:.5f}', f) for t, f in zip(t1, winner[0])]))
184-
print('\t'.join([bold('SCA_COL')] +
196+
print('\t'.join([bold('SCA1_C ')] +
185197
[bold(f'{t:.5f}', f) for t, f in zip(t2, winner[1])]))
186-
print('\t'.join([bold('SEG_COO')] +
198+
print('\t'.join([bold('SCA2_R ')] +
187199
[bold(f'{t:.5f}', f) for t, f in zip(t3, winner[2])]))
188-
print('\t'.join([bold('SEG_CSR')] +
200+
print('\t'.join([bold('SCA2_C ')] +
189201
[bold(f'{t:.5f}', f) for t, f in zip(t4, winner[3])]))
190-
print('\t'.join([bold('DENSE1 ')] +
202+
print('\t'.join([bold('SEG_COO')] +
191203
[bold(f'{t:.5f}', f) for t, f in zip(t5, winner[4])]))
192-
print('\t'.join([bold('DENSE2 ')] +
204+
print('\t'.join([bold('SEG_CSR')] +
193205
[bold(f'{t:.5f}', f) for t, f in zip(t6, winner[5])]))
206+
print('\t'.join([bold('DENSE1 ')] +
207+
[bold(f'{t:.5f}', f) for t, f in zip(t7, winner[6])]))
208+
print('\t'.join([bold('DENSE2 ')] +
209+
[bold(f'{t:.5f}', f) for t, f in zip(t8, winner[7])]))
194210
print()
195211

196212

197213
if __name__ == '__main__':
198214
parser = argparse.ArgumentParser()
199215
parser.add_argument('--reduce', type=str, required=True,
200-
choices=['sum', 'add', 'mean', 'min', 'max'])
216+
choices=['sum', 'mean', 'min', 'max'])
201217
parser.add_argument('--with_backward', action='store_true')
202218
parser.add_argument('--device', type=str, default='cuda')
203219
args = parser.parse_args()
204-
iters = 1 if args.device == 'cpu' else 20
220+
iters = 1 if args.device == 'cpu' else 50
205221
sizes = [1, 16, 32, 64, 128, 256, 512]
206222
sizes = sizes[:3] if args.device == 'cpu' else sizes
207223

0 commit comments

Comments
 (0)