Skip to content

Commit ca603a3

Browse files
committed
fix ordering of views of views
1 parent 2fd59ca commit ca603a3

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

src/mudata/_core/mudata.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,7 @@ def _init_common(self):
268268

269269
def _init_as_view(self, mudata_ref: "MuData", index):
270270
from anndata._core.index import _normalize_indices
271+
from anndata._core.views import _resolve_idxs
271272

272273
obsidx, varidx = _normalize_indices(index, mudata_ref.obs.index, mudata_ref.var.index)
273274

@@ -307,9 +308,13 @@ def _init_as_view(self, mudata_ref: "MuData", index):
307308
cvaridx = slice(None)
308309
if a.is_view:
309310
if isinstance(a, MuData):
310-
self.mod[m] = a._mudata_ref[cobsidx, cvaridx]
311+
self.mod[m] = a._mudata_ref[
312+
_resolve_idxs((a._oidx, a._vidx), (cobsidx, cvaridx), a._mudata_ref)
313+
]
311314
else:
312-
self.mod[m] = a._adata_ref[cobsidx, cvaridx]
315+
self.mod[m] = a._adata_ref[
316+
_resolve_idxs((a._oidx, a._vidx), (cobsidx, cvaridx), a._adata_ref)
317+
]
313318
else:
314319
self.mod[m] = a[cobsidx, cvaridx]
315320

@@ -334,6 +339,8 @@ def _init_as_view(self, mudata_ref: "MuData", index):
334339
self.file = mudata_ref.file
335340
self._axis = mudata_ref._axis
336341
self._uns = mudata_ref._uns
342+
self._oidx = obsidx
343+
self._vidx = varidx
337344

338345
if mudata_ref.is_view:
339346
self._mudata_ref = mudata_ref._mudata_ref

tests/test_view_copy.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,18 @@ def test_view_view(self, mdata):
8888
assert mdata_view_view.is_view
8989
assert mdata_view_view.n_obs == view_view_n_obs
9090

91+
for modname, mod in mdata_view_view.mod.items():
92+
ref_obsmap = mdata_view.obsmap[modname][:view_view_n_obs]
93+
ref_obsmap = ref_obsmap[ref_obsmap > 0] - 1
94+
assert (mod.obs_names == mdata_view[modname].obs_names[ref_obsmap]).all()
95+
assert (mod.var_names == mdata_view[modname].var_names).all()
96+
97+
# test reordering
98+
mdata_view_view = mdata_view[:, :]
99+
for modname, mod in mdata_view_view.mod.items():
100+
assert (mod.obs_names == mdata_view[modname].obs_names).all()
101+
assert (mod.var_names == mdata_view[modname].var_names).all()
102+
91103
def test_backed_copy(self, mdata, filepath_h5mu, filepath2_h5mu):
92104
mdata.write(filepath_h5mu)
93105
mdata_b = mudata.read_h5mu(filepath_h5mu, backed="r")

0 commit comments

Comments
 (0)