-
Notifications
You must be signed in to change notification settings - Fork 203
Open
Description
I just had an idea for a scatter_average that works like Numpy's average, where if weight is None it just works like mean, but if not, it's a weighted average. Also adding the approach I used to avoid using scatter twice, but it would probably be much faster if it was implemented in CPP.
def scatter_average(
src: Tensor,
index: Tensor,
dim: int = 0,
dim_size: int | None = None,
weight: Tensor | None = None,
):
if weight is None:
result = scatter(src, index, dim, dim_size, "mean")
else:
tmp_weight = weight.unsqueeze(1)
tmp_result = scatter(
torch.concat((src * tmp_weight, tmp_weight), dim=-1),
index,
dim,
dim_size,
"sum",
)
result = tmp_result[:, :-1] / tmp_result[:, None, -1]
return result
Metadata
Metadata
Assignees
Labels
No labels