Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/release-notes/3824.feat.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add column `"n_obs_aggregated"` to {func}`scanpy.get.aggregate` output to show the total number of observations aggregated per group.
8 changes: 6 additions & 2 deletions src/scanpy/get/_aggregated.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def aggregate( # noqa: PLR0912
... )
>>> aggregated
AnnData object with n_obs × n_vars = 8 × 13714
obs: 'louvain'
obs: 'louvain', 'n_obs_aggregated'
var: 'n_cells'
layers: 'mean', 'count_nonzero'

Expand All @@ -253,7 +253,7 @@ def aggregate( # noqa: PLR0912
... pbmc, by=["louvain", "percent_mito_binned"], func=["mean", "count_nonzero"]
... )
AnnData object with n_obs × n_vars = 40 × 13714
obs: 'louvain', 'percent_mito_binned'
obs: 'louvain', 'percent_mito_binned', 'n_obs_aggregated'
var: 'n_cells'
layers: 'mean', 'count_nonzero'

Expand Down Expand Up @@ -295,6 +295,10 @@ def aggregate( # noqa: PLR0912

dim_df = getattr(adata, axis_name)
categorical, new_label_df = _combine_categories(dim_df, by)

# Add number of obs aggregated into each group
group_sizes = pd.Series(categorical).value_counts().reindex(new_label_df.index)
new_label_df["n_obs_aggregated"] = group_sizes.values
Comment on lines +300 to +301
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
group_sizes = pd.Series(categorical).value_counts().reindex(new_label_df.index)
new_label_df["n_obs_aggregated"] = group_sizes.values
new_label_df["n_obs_aggregated"] = pd.Series(categorical).value_counts().reindex(new_label_df.index)

no?

# Actual computation
layers = _aggregate(
data,
Expand Down
95 changes: 88 additions & 7 deletions tests/test_aggregated.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,13 @@ def test_aggregate_axis_specification(axis_name):
["count_nonzero"], # , "sum", "mean"],
ad.AnnData(
obs=pd.DataFrame(
{"a": ["a", "a", "b"], "b": ["c", "d", "d"]},
{
"a": pd.Categorical(["a", "a", "b"]),
"b": pd.Categorical(["c", "d", "d"]),
"n_obs_aggregated": [1, 1, 2],
},
index=["a_c", "a_d", "b_d"],
).astype("category"),
),
var=pd.DataFrame(index=[f"gene_{i}" for i in range(4)]),
layers={
"count_nonzero": np.array([
Expand Down Expand Up @@ -323,9 +327,13 @@ def test_aggregate_axis_specification(axis_name):
["sum", "mean", "count_nonzero"],
ad.AnnData(
obs=pd.DataFrame(
{"a": ["a", "a", "b"], "b": ["c", "d", "d"]},
{
"a": pd.Categorical(["a", "a", "b"]),
"b": pd.Categorical(["c", "d", "d"]),
"n_obs_aggregated": [1, 1, 2],
},
index=["a_c", "a_d", "b_d"],
).astype("category"),
),
var=pd.DataFrame(index=[f"gene_{i}" for i in range(4)]),
layers={
"sum": np.array([[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 2, 2]]),
Expand Down Expand Up @@ -355,9 +363,13 @@ def test_aggregate_axis_specification(axis_name):
["mean"],
ad.AnnData(
obs=pd.DataFrame(
{"a": ["a", "a", "b"], "b": ["c", "d", "d"]},
{
"a": pd.Categorical(["a", "a", "b"]),
"b": pd.Categorical(["c", "d", "d"]),
"n_obs_aggregated": [1, 1, 2],
},
index=["a_c", "a_d", "b_d"],
).astype("category"),
),
var=pd.DataFrame(index=[f"gene_{i}" for i in range(4)]),
layers={
"mean": np.array([[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1]]),
Expand Down Expand Up @@ -516,7 +528,10 @@ def test_aggregate_obsm_labels():
)

expected = ad.AnnData(
obs=pd.DataFrame({"labels": pd.Categorical(list("abc"))}, index=list("abc")),
obs=pd.DataFrame(
{"labels": pd.Categorical(list("abc")), "n_obs_aggregated": [5, 3, 4]},
index=list("abc"),
),
Comment on lines +531 to +534
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Construct from label_counts variable please!

Suggested change
obs=pd.DataFrame(
{"labels": pd.Categorical(list("abc")), "n_obs_aggregated": [5, 3, 4]},
index=list("abc"),
),
obs=pd.DataFrame(
{"labels": pd.Categorical([lc[0] for lc in label_counts]), "n_obs_aggregated": [lc[1] for lc in label_counts]},
index=[lc[0] for lc in label_counts],
),

var=pd.DataFrame(index=[f"dim_{i}" for i in range(3)]),
layers={
"sum": np.diag([n for _, n in label_counts]),
Expand Down Expand Up @@ -544,3 +559,69 @@ def test_factors():

res = sc.get.aggregate(adata, by=["a", "b", "c", "d"], func="sum")
np.testing.assert_equal(res.layers["sum"], adata.X)


def test_aggregate_n_obs_aggregated_single_key():
"""Test n_obs_aggregated with single grouping key using known ground truth."""
# Create data where we KNOW the exact counts
adata = ad.AnnData(
X=np.random.rand(10, 5),
obs=pd.DataFrame(
{"cluster": ["A", "A", "A", "B", "B", "B", "B", "C", "C", "C"]},
index=[f"cell_{i}" for i in range(10)],
),
)

result = sc.get.aggregate(adata, by="cluster", func="mean")

# Verify column exists
assert "n_obs_aggregated" in result.obs

# Check known ground truth counts
assert result.obs.loc["A", "n_obs_aggregated"] == 3
assert result.obs.loc["B", "n_obs_aggregated"] == 4
assert result.obs.loc["C", "n_obs_aggregated"] == 3

# Total should equal original n_obs
assert result.obs["n_obs_aggregated"].sum() == 10


def test_aggregate_n_obs_aggregated_multiple_keys():
"""Test n_obs_aggregated with multiple grouping keys using known ground truth."""
# Create data with known combinations
adata = ad.AnnData(
X=np.random.rand(12, 5),
obs=pd.DataFrame(
{
"cluster": ["A", "A", "A", "A", "B", "B", "B", "B", "C", "C", "C", "C"],
"batch": ["1", "1", "2", "2", "1", "1", "2", "2", "1", "1", "2", "2"],
},
index=[f"cell_{i}" for i in range(12)],
),
)

result = sc.get.aggregate(adata, by=["cluster", "batch"], func="mean")

# Verify column exists
assert "n_obs_aggregated" in result.obs

# Check each combination has exactly 2 observations
# This is our ground truth - we designed the data this way
expected_combinations = [
("A", "1", 2),
("A", "2", 2),
("B", "1", 2),
("B", "2", 2),
("C", "1", 2),
("C", "2", 2),
]

for cluster, batch, expected_count in expected_combinations:
mask = (result.obs["cluster"] == cluster) & (result.obs["batch"] == batch)
actual_count = result.obs.loc[mask, "n_obs_aggregated"].values[0]
assert actual_count == expected_count, (
f"Expected {cluster}+{batch} to have {expected_count} obs, got {actual_count}"
)

# Total should equal original n_obs
assert result.obs["n_obs_aggregated"].sum() == 12
Comment on lines +562 to +627
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these tests are redundant with the other ones

Suggested change
def test_aggregate_n_obs_aggregated_single_key():
"""Test n_obs_aggregated with single grouping key using known ground truth."""
# Create data where we KNOW the exact counts
adata = ad.AnnData(
X=np.random.rand(10, 5),
obs=pd.DataFrame(
{"cluster": ["A", "A", "A", "B", "B", "B", "B", "C", "C", "C"]},
index=[f"cell_{i}" for i in range(10)],
),
)
result = sc.get.aggregate(adata, by="cluster", func="mean")
# Verify column exists
assert "n_obs_aggregated" in result.obs
# Check known ground truth counts
assert result.obs.loc["A", "n_obs_aggregated"] == 3
assert result.obs.loc["B", "n_obs_aggregated"] == 4
assert result.obs.loc["C", "n_obs_aggregated"] == 3
# Total should equal original n_obs
assert result.obs["n_obs_aggregated"].sum() == 10
def test_aggregate_n_obs_aggregated_multiple_keys():
"""Test n_obs_aggregated with multiple grouping keys using known ground truth."""
# Create data with known combinations
adata = ad.AnnData(
X=np.random.rand(12, 5),
obs=pd.DataFrame(
{
"cluster": ["A", "A", "A", "A", "B", "B", "B", "B", "C", "C", "C", "C"],
"batch": ["1", "1", "2", "2", "1", "1", "2", "2", "1", "1", "2", "2"],
},
index=[f"cell_{i}" for i in range(12)],
),
)
result = sc.get.aggregate(adata, by=["cluster", "batch"], func="mean")
# Verify column exists
assert "n_obs_aggregated" in result.obs
# Check each combination has exactly 2 observations
# This is our ground truth - we designed the data this way
expected_combinations = [
("A", "1", 2),
("A", "2", 2),
("B", "1", 2),
("B", "2", 2),
("C", "1", 2),
("C", "2", 2),
]
for cluster, batch, expected_count in expected_combinations:
mask = (result.obs["cluster"] == cluster) & (result.obs["batch"] == batch)
actual_count = result.obs.loc[mask, "n_obs_aggregated"].values[0]
assert actual_count == expected_count, (
f"Expected {cluster}+{batch} to have {expected_count} obs, got {actual_count}"
)
# Total should equal original n_obs
assert result.obs["n_obs_aggregated"].sum() == 12

Loading