Skip to content

Commit 2a3dca8

Browse files
committed
numerical stability
1 parent 3e409bf commit 2a3dca8

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

torch_scatter/std.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,16 @@
66

77
def scatter_std(src, index, dim=-1, out=None, dim_size=None, unbiased=True):
88
src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value=0)
9+
print('src', src.mean())
910

1011
tmp = scatter_add(src, index, dim, None, dim_size)
1112
count = scatter_add(torch.ones_like(src), index, dim, None, dim_size)
1213
mean = tmp / count.clamp(min=1)
1314

14-
var = (src - mean.gather(dim, index))**2
15+
var = (src - mean.gather(dim, index))
16+
var = var * var
1517
out = scatter_add(var, index, dim, out, dim_size)
1618
out = out / (count - 1 if unbiased else count).clamp(min=1)
17-
out = torch.sqrt(out)
19+
out = torch.sqrt(out.clamp(min=1e-12))
1820

1921
return out

0 commit comments

Comments
 (0)