Skip to content

scatter average (like np.average) (feature request) #494

@bryceForrest

Description

@bryceForrest

I just had an idea for a scatter_average that works like Numpy's average, where if weight is None it just works like mean, but if not, it's a weighted average. Also adding the approach I used to avoid using scatter twice, but it would probably be much faster if it was implemented in CPP.

def scatter_average(
    src: Tensor,
    index: Tensor,
    dim: int = 0,
    dim_size: int | None = None,
    weight: Tensor | None = None,
):
    if weight is None:
        result = scatter(src, index, dim, dim_size, "mean")
    else:
        tmp_weight = weight.unsqueeze(1)
        tmp_result = scatter(
            torch.concat((src * tmp_weight, tmp_weight), dim=-1),
            index,
            dim,
            dim_size,
            "sum",
        )

        result = tmp_result[:, :-1] / tmp_result[:, None, -1]

    return result

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions