Skip to content

Commit 8040ea1

Browse files
committed
fix deleting modalities
1 parent 75f884f commit 8040ea1

File tree

2 files changed

+51
-11
lines changed

2 files changed

+51
-11
lines changed

src/mudata/_core/mudata.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,8 @@ def _check_changed_attr_names(self, attr: str, columns: bool = False):
452452
attr_names_changed, attr_columns_changed = False, False
453453
if not hasattr(self, attrhash):
454454
attr_names_changed, attr_columns_changed = True, True
455+
elif len(self.mod) < len(getattr(self, attrhash)):
456+
attr_names_changed, attr_columns_changed = True, None
455457
else:
456458
for m, mod in self.mod.items():
457459
if m in getattr(self, attrhash):
@@ -607,7 +609,13 @@ def _update_attr(
607609
_attrhash = f"_{attr}hash"
608610
attr_changed = self._check_changed_attr_names(attr)
609611

610-
attr_duplicated = self._check_duplicated_attr_names(attr)
612+
if not any(attr_changed):
613+
# Nothing to update
614+
return
615+
616+
data_global = getattr(self, attr)
617+
618+
attr_duplicated = not data_global.index.is_unique or self._check_duplicated_attr_names(attr)
611619
attr_intersecting = self._check_intersecting_attr_names(attr)
612620

613621
if attr_duplicated:
@@ -619,12 +627,6 @@ def _update_attr(
619627
f"Behaviour is not defined with axis=-1, {attr}_names need to be made unique first."
620628
)
621629

622-
if not any(attr_changed):
623-
# Nothing to update
624-
return
625-
626-
data_global = getattr(self, attr)
627-
628630
# Generate unique colnames
629631
(rowcol,) = self._find_unique_colnames(attr, 1)
630632

@@ -646,7 +648,6 @@ def _update_attr(
646648
# Join modality .obs/.var tables
647649
#
648650
# Main case: no duplicates and no intersection if the axis is not shared
649-
#
650651
if not attr_duplicated:
651652
# Shared axis
652653
if axis == (1 - self._axis) or self._axis == -1:
@@ -665,7 +666,7 @@ def _update_attr(
665666
data_mod = _make_index_unique(data_mod, force=attr_intersecting)
666667
data_global = _make_index_unique(data_global, force=attr_intersecting)
667668
if data_global.shape[1] > 0:
668-
data_mod = data_global.join(data_mod, how="left", sort=False)
669+
data_mod = data_mod.join(data_global, how="left", sort=False)
669670

670671
if data_global.shape[0] > 0:
671672
# reorder new index to conform to the old index as much as possible
@@ -728,8 +729,11 @@ def _update_attr(
728729
)
729730
# after inserting a new modality with duplicates, but no duplicates before:
730731
# data_mod.index is not unique
731-
data_global = _make_index_unique(data_global, force=not data_mod.index.is_unique)
732-
data_mod = _make_index_unique(data_mod)
732+
# after deleting a modality with duplicates: data_global.index is not unique, but
733+
# data_mod.index is unique
734+
need_unique = data_mod.index.is_unique | data_global.index.is_unique
735+
data_global = _make_index_unique(data_global, force=need_unique)
736+
data_mod = _make_index_unique(data_mod, force=need_unique)
733737
data_mod = data_mod.join(data_global, how="left", sort=False)
734738

735739
# reorder new index to conform to the old index as much as possible
@@ -750,6 +754,10 @@ def _update_attr(
750754
> data_global.shape[0] # new modality added and concacenated
751755
)
752756

757+
if need_unique:
758+
data_mod = _restore_index(data_mod)
759+
data_global = _restore_index(data_global)
760+
753761
data_mod.reset_index(level=list(range(1, data_mod.index.nlevels)), inplace=True)
754762
data_global.reset_index(level=list(range(1, data_global.index.nlevels)), inplace=True)
755763
data_mod.index.set_names(None, inplace=True)

tests/test_update.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,38 @@ def test_update_add_modality(self, modalities, axis):
161161
== old_oattrnames.append(getattr(modalities[modnames[i]], f"{oattr}_names"))
162162
).all()
163163

164+
def test_update_delete_modality(self, mdata, axis):
165+
modnames = list(mdata.mod.keys())
166+
attr = "obs" if axis == 0 else "var"
167+
oattr = "var" if axis == 0 else "obs"
168+
169+
fullbatch = getattr(mdata, attr)["batch"]
170+
fullobatch = getattr(mdata, oattr)["batch"]
171+
keptmask = (getattr(mdata, f"{attr}map")[modnames[1]].reshape(-1) > 0) | (
172+
getattr(mdata, f"{attr}map")[modnames[2]].reshape(-1) > 0
173+
)
174+
keptomask = (getattr(mdata, f"{oattr}map")[modnames[1]].reshape(-1) > 0) | (
175+
getattr(mdata, f"{oattr}map")[modnames[2]].reshape(-1) > 0
176+
)
177+
178+
del mdata.mod[modnames[0]]
179+
mdata.update()
180+
181+
assert mdata.shape[1 - axis] == sum(mod.shape[1 - axis] for mod in mdata.mod.values())
182+
assert (getattr(mdata, oattr)["batch"] == fullobatch[keptomask]).all()
183+
assert (getattr(mdata, attr)["batch"] == fullbatch[keptmask]).all()
184+
185+
fullbatch = getattr(mdata, attr)["batch"]
186+
fullobatch = getattr(mdata, oattr)["batch"]
187+
keptmask = getattr(mdata, f"{attr}map")[modnames[1]].reshape(-1) > 0
188+
keptomask = getattr(mdata, f"{oattr}map")[modnames[1]].reshape(-1) > 0
189+
del mdata.mod[modnames[2]]
190+
mdata.update()
191+
192+
assert mdata.shape[1 - axis] == sum(mod.shape[1 - axis] for mod in mdata.mod.values())
193+
assert (getattr(mdata, oattr)["batch"] == fullobatch[keptomask]).all()
194+
assert (getattr(mdata, attr)["batch"] == fullbatch[keptmask]).all()
195+
164196
def test_update_intersecting(self, modalities, axis):
165197
"""
166198
Update should work when

0 commit comments

Comments
 (0)