Skip to content

Commit 755e122

Browse files
committed
Fix pairwise array slicing (#94)
1 parent 22b5378 commit 755e122

File tree

2 files changed

+104
-3
lines changed

2 files changed

+104
-3
lines changed

src/mudata/_core/mudata.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -320,10 +320,10 @@ def _init_as_view(self, mudata_ref: "MuData", index):
320320

321321
self._obs = DataFrameView(mudata_ref.obs.iloc[obsidx, :], view_args=(self, "obs"))
322322
self._obsm = mudata_ref.obsm._view(self, (obsidx,))
323-
self._obsp = mudata_ref.obsp._view(self, obsidx)
323+
self._obsp = mudata_ref.obsp._view(self, (obsidx, obsidx))
324324
self._var = DataFrameView(mudata_ref.var.iloc[varidx, :], view_args=(self, "var"))
325325
self._varm = mudata_ref.varm._view(self, (varidx,))
326-
self._varp = mudata_ref.varp._view(self, varidx)
326+
self._varp = mudata_ref.varp._view(self, (varidx, varidx))
327327

328328
for attr, idx in (("obs", obsidx), ("var", varidx)):
329329
posmap = {}
@@ -1297,7 +1297,7 @@ def _update_attr_legacy(
12971297

12981298
# Update .obsp/.varp (size might have changed)
12991299
for mx_key, mx in attrp.items():
1300-
attrp[mx_key] = attrp[mx_key][index_order, index_order]
1300+
attrp[mx_key] = attrp[mx_key][index_order, :][:,index_order]
13011301
attrp[mx_key][index_order == -1, :] = -1
13021302
attrp[mx_key][:, index_order == -1] = -1
13031303

tests/test_view_copy.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pandas as pd
55
import pytest
66
from anndata import AnnData
7+
from scipy import sparse
78

89
import mudata
910
from mudata import MuData
@@ -28,6 +29,52 @@ def mdata():
2829
return mdata
2930

3031

32+
@pytest.fixture()
33+
def mdata_with_obsp():
34+
"""Create a MuData object with populated obsp and varp fields."""
35+
rng = np.random.default_rng(42)
36+
mod1 = AnnData(
37+
np.arange(0, 100, 0.1).reshape(-1, 10),
38+
obs=pd.DataFrame(index=rng.choice(150, size=100, replace=False)),
39+
)
40+
mod2 = AnnData(
41+
np.arange(101, 2101, 1).reshape(-1, 20),
42+
obs=pd.DataFrame(index=rng.choice(150, size=100, replace=False)),
43+
)
44+
mods = {"mod1": mod1, "mod2": mod2}
45+
# Make var_names different in different modalities
46+
for m in ["mod1", "mod2"]:
47+
mods[m].var_names = [f"{m}_var{i}" for i in range(mods[m].n_vars)]
48+
mdata = MuData(mods)
49+
50+
# Create and add sparse matrices to obsp
51+
n_obs = mdata.n_obs
52+
n_var = mdata.n_var
53+
54+
# Create sparse distances matrix (symmetric)
55+
distances = sparse.random(n_obs, n_obs, density=0.2, random_state=42)
56+
distances = sparse.triu(distances)
57+
distances = distances + distances.T
58+
59+
# Create sparse connectivities matrix (symmetric)
60+
connectivities = sparse.random(n_obs, n_obs, density=0.1, random_state=42)
61+
connectivities = sparse.triu(connectivities)
62+
connectivities = connectivities + connectivities.T
63+
64+
# Add to obsp
65+
mdata.obsp["distances"] = distances
66+
mdata.obsp["connectivities"] = connectivities
67+
68+
# Create and add a sparse matrix to varp
69+
varp_matrix = sparse.random(n_var, n_var, density=0.05, random_state=42)
70+
varp_matrix = sparse.triu(varp_matrix)
71+
varp_matrix = varp_matrix + varp_matrix.T
72+
73+
mdata.varp["correlations"] = varp_matrix
74+
75+
return mdata
76+
77+
3178
@pytest.mark.usefixtures("filepath_h5mu", "filepath2_h5mu")
3279
class TestMuData:
3380
def test_copy(self, mdata):
@@ -106,3 +153,57 @@ def test_backed_copy(self, mdata, filepath_h5mu, filepath2_h5mu):
106153
assert mdata_b.n_obs == mdata.n_obs
107154
mdata_b_copy = mdata_b.copy(filepath2_h5mu)
108155
assert mdata_b_copy.file._filename.name == Path(filepath2_h5mu).name
156+
157+
def test_obsp_slicing(self, mdata_with_obsp):
158+
"""Test that obsp matrices are correctly sliced when subsetting a MuData object."""
159+
orig_n_obs = mdata_with_obsp.n_obs
160+
161+
# Check initial shapes
162+
assert mdata_with_obsp.obsp["distances"].shape == (orig_n_obs, orig_n_obs)
163+
assert mdata_with_obsp.obsp["connectivities"].shape == (orig_n_obs, orig_n_obs)
164+
165+
# Slice a subset of cells
166+
n_obs_subset = 50
167+
random_indices = np.random.choice(mdata_with_obsp.obs_names, size=n_obs_subset, replace=False)
168+
169+
# Create a slice view
170+
mdata_slice = mdata_with_obsp[random_indices]
171+
172+
# Check that the sliced obsp matrices have correct shape in the view
173+
assert mdata_slice.obsp["distances"].shape == (n_obs_subset, n_obs_subset), \
174+
f"Expected shape in view: {(n_obs_subset, orig_n_obs)}, got: {mdata_slice.obsp['distances'].shape}"
175+
assert mdata_slice.obsp["connectivities"].shape == (n_obs_subset, n_obs_subset), \
176+
f"Expected shape in view: {(n_obs_subset, orig_n_obs)}, got: {mdata_slice.obsp['connectivities'].shape}"
177+
178+
# Make a copy of the sliced MuData object
179+
mdata_copy = mdata_slice.copy()
180+
# Check shapes after copy - these should be (n_obs_subset, n_obs_subset) if correctly copied
181+
assert mdata_copy.obsp["distances"].shape == (n_obs_subset, n_obs_subset), \
182+
f"Expected shape after copy: {(n_obs_subset, n_obs_subset)}, got: {mdata_copy.obsp['distances'].shape}"
183+
assert mdata_copy.obsp["connectivities"].shape == (n_obs_subset, n_obs_subset), \
184+
f"Expected shape after copy: {(n_obs_subset, n_obs_subset)}, got: {mdata_copy.obsp['connectivities'].shape}"
185+
186+
def test_varp_slicing(self, mdata_with_obsp):
187+
"""Test that varp matrices are correctly sliced when subsetting a MuData object."""
188+
orig_n_var = mdata_with_obsp.n_var
189+
190+
# Check initial shape
191+
assert mdata_with_obsp.varp["correlations"].shape == (orig_n_var, orig_n_var)
192+
193+
# Slice a subset of variables
194+
n_var_subset = 15
195+
all_var_names = mdata_with_obsp.var_names
196+
random_var_indices = np.random.choice(all_var_names, size=n_var_subset, replace=False)
197+
198+
# Create a slice view
199+
mdata_slice = mdata_with_obsp[:, random_var_indices]
200+
201+
# Check that the sliced varp matrix has correct shape in the view
202+
assert mdata_slice.varp["correlations"].shape == (n_var_subset, n_var_subset), \
203+
f"Expected shape in view: {(n_var_subset, orig_n_var)}, got: {mdata_slice.varp['correlations'].shape}"
204+
205+
# Copy the sliced MuData object
206+
mdata_copy = mdata_slice.copy()
207+
# Check shapes after copy
208+
assert mdata_copy.varp["correlations"].shape == (n_var_subset, n_var_subset), \
209+
f"Expected shape after copy: {(n_var_subset, n_var_subset)}, got: {mdata_copy.varp['correlations'].shape}"

0 commit comments

Comments
 (0)