Skip to content

Commit 3e409bf

Browse files
committed
scatter_std
1 parent d305ecc commit 3e409bf

File tree

2 files changed

+29
-2
lines changed

2 files changed

+29
-2
lines changed

torch_scatter/__init__.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,20 @@
33
from .mul import scatter_mul
44
from .div import scatter_div
55
from .mean import scatter_mean
6+
from .std import scatter_std
67
from .max import scatter_max
78
from .min import scatter_min
89

910
__version__ = '1.0.4'
1011

1112
__all__ = [
12-
'scatter_add', 'scatter_sub', 'scatter_mul', 'scatter_div', 'scatter_mean',
13-
'scatter_max', 'scatter_min', '__version__'
13+
'scatter_add',
14+
'scatter_sub',
15+
'scatter_mul',
16+
'scatter_div',
17+
'scatter_mean',
18+
'scatter_std',
19+
'scatter_max',
20+
'scatter_min',
21+
'__version__',
1422
]

torch_scatter/std.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import torch
2+
3+
from torch_scatter import scatter_add
4+
from torch_scatter.utils.gen import gen
5+
6+
7+
def scatter_std(src, index, dim=-1, out=None, dim_size=None, unbiased=True):
8+
src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value=0)
9+
10+
tmp = scatter_add(src, index, dim, None, dim_size)
11+
count = scatter_add(torch.ones_like(src), index, dim, None, dim_size)
12+
mean = tmp / count.clamp(min=1)
13+
14+
var = (src - mean.gather(dim, index))**2
15+
out = scatter_add(var, index, dim, out, dim_size)
16+
out = out / (count - 1 if unbiased else count).clamp(min=1)
17+
out = torch.sqrt(out)
18+
19+
return out

0 commit comments

Comments
 (0)