Skip to content

Commit b29c475

Browse files
committed
more comprehensive tests and fixes for _update_attr()
1 parent 5ecd982 commit b29c475

File tree

2 files changed

+91
-54
lines changed

2 files changed

+91
-54
lines changed

src/mudata/_core/mudata.py

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -614,19 +614,11 @@ def _update_attr(
614614
return
615615

616616
data_global = getattr(self, attr)
617+
prev_index = data_global.index
617618

618619
attr_duplicated = not data_global.index.is_unique or self._check_duplicated_attr_names(attr)
619620
attr_intersecting = self._check_intersecting_attr_names(attr)
620621

621-
if attr_duplicated:
622-
warnings.warn(
623-
f"{attr}_names are not unique. To make them unique, call `.{attr}_names_make_unique`."
624-
)
625-
if self._axis == -1:
626-
warnings.warn(
627-
f"Behaviour is not defined with axis=-1, {attr}_names need to be made unique first."
628-
)
629-
630622
# Generate unique colnames
631623
(rowcol,) = self._find_unique_colnames(attr, 1)
632624

@@ -708,16 +700,19 @@ def _update_attr(
708700
col.replace(np.nan, 0, inplace=True)
709701
col = col.astype(np.uint32)
710702
data_mod[colname] = col
711-
if mod in attrmap and (
712-
col.shape[0] != data_global.shape[0]
713-
and np.sum(attrmap[mod] > 0)
714-
== getattr(amod, attr).shape[0] # added/removed observations
715-
or col.shape[0] == data_global.shape[0]
716-
and np.array_equal(attrmap[mod], col) # reordered
717-
):
718-
data_mod.set_index(colname, append=True, inplace=True)
719-
data_global.set_index(attrmap[mod].reshape(-1), append=True, inplace=True)
720-
data_global.index.set_names(colname, level=-1, inplace=True)
703+
if mod in attrmap:
704+
modmap = attrmap[mod].reshape(-1)
705+
modmask = modmap > 0
706+
# only use unchanged modalities for ordering
707+
if (
708+
modmask.sum() == getattr(amod, attr).shape[0]
709+
and (
710+
getattr(amod, attr).index[modmap[modmask] - 1] == prev_index[modmask]
711+
).all()
712+
):
713+
data_mod.set_index(colname, append=True, inplace=True)
714+
data_global.set_index(attrmap[mod].reshape(-1), append=True, inplace=True)
715+
data_global.index.set_names(colname, level=-1, inplace=True)
721716

722717
if data_global.shape[0] > 0:
723718
if not data_global.index.is_unique:
@@ -749,9 +744,11 @@ def _update_attr(
749744
== data_global.shape[
750745
0
751746
] # renamed (since new_idx.shape[0] > 0 and kept_idx.shape[0] < data_global.shape[0])
752-
or axis == self._axis
753-
and data_mod.shape[0]
754-
> data_global.shape[0] # new modality added and concacenated
747+
or (
748+
axis == self._axis
749+
and axis != -1
750+
and data_mod.shape[0] > data_global.shape[0]
751+
) # new modality added and concacenated
755752
)
756753

757754
if need_unique:
@@ -770,6 +767,15 @@ def _update_attr(
770767
mdict[m] = data_mod[colname].to_numpy()
771768
data_mod.drop(colname, axis=1, inplace=True)
772769

770+
if not data_mod.index.is_unique:
771+
warnings.warn(
772+
f"{attr}_names are not unique. To make them unique, call `.{attr}_names_make_unique`."
773+
)
774+
if self._axis == -1:
775+
warnings.warn(
776+
f"Behaviour is not defined with axis=-1, {attr}_names need to be made unique first."
777+
)
778+
773779
setattr(
774780
self,
775781
"_" + attr,

tests/test_update.py

Lines changed: 63 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,16 @@ def modalities(request, obs_n, obs_across, obs_mod):
3434

3535
if obs_mod:
3636
if obs_mod == "duplicated":
37-
for m in ["mod1", "mod2"]:
38-
# Index does not support mutable operations
39-
obs_names = mods[m].obs_names.values.copy()
40-
obs_names[1] = obs_names[0]
41-
mods[m].obs_names = obs_names
42-
43-
var_names = mods[m].var_names.values.copy()
44-
var_names[1] = var_names[0]
45-
mods[m].var_names = var_names
37+
obsnames2 = mods["mod2"].obs_names.to_numpy()
38+
obsnames3 = mods["mod3"].obs_names.to_numpy()
39+
varnames2 = mods["mod2"].var_names.to_numpy()
40+
varnames3 = mods["mod3"].var_names.to_numpy()
41+
obsnames2[0] = obsnames2[1] = obsnames3[1] = "testobs"
42+
varnames2[0] = varnames2[1] = varnames3[1] = "testvar"
43+
mods["mod2"].obs_names = obsnames2
44+
mods["mod3"].obs_names = obsnames3
45+
mods["mod2"].var_names = varnames2
46+
mods["mod3"].var_names = varnames3
4647

4748
return mods
4849

@@ -64,8 +65,8 @@ def mdata(modalities, axis):
6465
md.obs["batch"] = np.random.choice(["a", "b", "c"], size=md.shape[0], replace=True)
6566
md.var["batch"] = np.random.choice(["d", "e", "f"], size=md.shape[1], replace=True)
6667

67-
md.obsm["test_obsm"] = np.random.normal(size=(md.n_obs, 2))
68-
md.varm["test_varm"] = np.random.normal(size=(md.n_var, 2))
68+
md.obsm["test"] = np.random.normal(size=(md.n_obs, 2))
69+
md.varm["test"] = np.random.normal(size=(md.n_var, 2))
6970

7071
return md
7172

@@ -82,6 +83,14 @@ def new_update(self):
8283
yield
8384
set_options(pull_on_update=None)
8485

86+
@staticmethod
87+
def get_attrm_values(mdata, attr, key, names):
88+
attrm = getattr(mdata, f"{attr}m")
89+
index = getattr(mdata, f"{attr}_names")
90+
return np.concatenate(
91+
[np.atleast_1d(attrm[key][np.nonzero(index == name)[0]]) for name in names]
92+
)
93+
8594
def test_update_simple(self, mdata, axis):
8695
"""
8796
Update should work when
@@ -134,9 +143,25 @@ def test_update_add_modality(self, modalities, axis):
134143
old_attrnames = getattr(mdata, f"{attr}_names")
135144
old_oattrnames = getattr(mdata, f"{oattr}_names")
136145

146+
some_obs_names = mdata.obs_names[:2]
147+
mdata.obsm["test"] = np.random.normal(size=(mdata.n_obs, 1))
148+
true_obsm_values = self.get_attrm_values(mdata, "obs", "test", some_obs_names)
149+
137150
mdata.mod[modnames[i]] = modalities[modnames[i]]
138151
mdata.update()
139152

153+
test_obsm_values = self.get_attrm_values(mdata, "obs", "test", some_obs_names)
154+
if axis == 1:
155+
assert np.isnan(mdata.obsm["test"]).sum() == modalities[modnames[i]].n_obs
156+
assert np.all(np.isnan(mdata.obsm["test"][-modalities[modnames[i]].n_obs :]))
157+
assert np.all(~np.isnan(mdata.obsm["test"][: -modalities[modnames[i]].n_obs]))
158+
assert (
159+
test_obsm_values[~np.isnan(test_obsm_values)].reshape(-1)
160+
== true_obsm_values.reshape(-1)
161+
).all()
162+
else:
163+
assert (test_obsm_values == true_obsm_values).all()
164+
140165
attrnames = getattr(mdata, f"{attr}_names")
141166
oattrnames = getattr(mdata, f"{oattr}_names")
142167
assert (attrnames[: old_attrnames.size] == old_attrnames).all()
@@ -157,9 +182,13 @@ def test_update_delete_modality(self, mdata, axis):
157182
modnames = list(mdata.mod.keys())
158183
attr = "obs" if axis == 0 else "var"
159184
oattr = "var" if axis == 0 else "obs"
185+
attrm = f"{attr}m"
186+
oattrm = f"{oattr}m"
160187

161188
fullbatch = getattr(mdata, attr)["batch"]
162189
fullobatch = getattr(mdata, oattr)["batch"]
190+
fulltestm = getattr(mdata, attrm)["test"]
191+
fullotestm = getattr(mdata, oattrm)["test"]
163192
keptmask = (getattr(mdata, f"{attr}map")[modnames[1]].reshape(-1) > 0) | (
164193
getattr(mdata, f"{attr}map")[modnames[2]].reshape(-1) > 0
165194
)
@@ -171,19 +200,26 @@ def test_update_delete_modality(self, mdata, axis):
171200
mdata.update()
172201

173202
assert mdata.shape[1 - axis] == sum(mod.shape[1 - axis] for mod in mdata.mod.values())
174-
assert (getattr(mdata, oattr)["batch"] == fullobatch[keptomask]).all()
175203
assert (getattr(mdata, attr)["batch"] == fullbatch[keptmask]).all()
204+
assert (getattr(mdata, oattr)["batch"] == fullobatch[keptomask]).all()
205+
assert (getattr(mdata, attrm)["test"] == fulltestm[keptmask, :]).all()
206+
assert (getattr(mdata, oattrm)["test"] == fullotestm[keptomask, :]).all()
176207

177208
fullbatch = getattr(mdata, attr)["batch"]
178209
fullobatch = getattr(mdata, oattr)["batch"]
210+
fulltestm = getattr(mdata, attrm)["test"]
211+
fullotestm = getattr(mdata, oattrm)["test"]
179212
keptmask = getattr(mdata, f"{attr}map")[modnames[1]].reshape(-1) > 0
180213
keptomask = getattr(mdata, f"{oattr}map")[modnames[1]].reshape(-1) > 0
214+
181215
del mdata.mod[modnames[2]]
182216
mdata.update()
183217

184218
assert mdata.shape[1 - axis] == sum(mod.shape[1 - axis] for mod in mdata.mod.values())
185219
assert (getattr(mdata, oattr)["batch"] == fullobatch[keptomask]).all()
186220
assert (getattr(mdata, attr)["batch"] == fullbatch[keptmask]).all()
221+
assert (getattr(mdata, attrm)["test"] == fulltestm[keptmask, :]).all()
222+
assert (getattr(mdata, oattrm)["test"] == fullotestm[keptomask, :]).all()
187223

188224
def test_update_intersecting(self, modalities, axis):
189225
"""
@@ -216,22 +252,14 @@ def test_update_intersecting(self, modalities, axis):
216252
)
217253
).all()
218254

219-
# names along axis are intersected
255+
# names along axis are unioned
220256
axisnames = reduce(
221257
lambda x, y: x.union(y, sort=False),
222258
(getattr(mod, f"{attr}_names") for mod in modalities.values()),
223259
)
224260
assert mdata.shape[axis] == axisnames.shape[0]
225261
assert (getattr(mdata, f"{attr}_names") == axisnames).all()
226262

227-
# Variables are different across modalities
228-
for m, mod in modalities.items():
229-
# Columns are intact in individual modalities
230-
assert "mod" in mod.obs.columns
231-
assert all(mod.obs["mod"] == m)
232-
assert "mod" in mod.var.columns
233-
assert all(mod.var["mod"] == m)
234-
235263
def test_update_after_filter_obs_adata(self, mdata, axis):
236264
"""
237265
Check for muon issue #44.
@@ -242,6 +270,14 @@ def test_update_after_filter_obs_adata(self, mdata, axis):
242270
old_obsnames = mdata.obs_names
243271
old_varnames = mdata.var_names
244272

273+
filtermask = mdata["mod3"].obs["min_count"] < -2
274+
fullfiltermask = mdata.obsmap["mod3"].copy() > 0
275+
fullfiltermask[fullfiltermask] = filtermask
276+
keptmask = (mdata.obsmap["mod1"] > 0) | (mdata.obsmap["mod2"] > 0) | fullfiltermask
277+
278+
some_obs_names = mdata[keptmask, :].obs_names.values[:2]
279+
true_obsm_values = self.get_attrm_values(mdata[keptmask], "obs", "test", some_obs_names)
280+
245281
mdata.mod["mod3"] = mdata["mod3"][mdata["mod3"].obs["min_count"] < -2].copy()
246282
mdata.update()
247283
assert mdata.obs["batch"].isna().sum() == 0
@@ -251,28 +287,23 @@ def test_update_after_filter_obs_adata(self, mdata, axis):
251287
# check if the order is preserved
252288
assert (mdata.obs_names == old_obsnames[old_obsnames.isin(mdata.obs_names)]).all()
253289

290+
test_obsm_values = self.get_attrm_values(mdata, "obs", "test", some_obs_names)
291+
assert (true_obsm_values == test_obsm_values).all()
292+
254293
def test_update_after_obs_reordered(self, mdata):
255294
"""
256295
Update should work if obs are reordered.
257296
"""
258297
some_obs_names = mdata.obs_names.values[:2]
259298

260-
true_obsm_values = [
261-
mdata.obsm["test_obsm"][np.where(mdata.obs_names.values == name)[0][0]]
262-
for name in some_obs_names
263-
]
299+
true_obsm_values = self.get_attrm_values(mdata, "obs", "test", some_obs_names)
264300

265301
mdata.mod["mod1"] = mdata["mod1"][::-1].copy()
266302
mdata.update()
267303

268-
test_obsm_values = [
269-
mdata.obsm["test_obsm"][np.where(mdata.obs_names == name)[0][0]]
270-
for name in some_obs_names
271-
]
304+
test_obsm_values = self.get_attrm_values(mdata, "obs", "test", some_obs_names)
272305

273-
assert all(
274-
[all(true_obsm_values[i] == test_obsm_values[i]) for i in range(len(true_obsm_values))]
275-
)
306+
assert (true_obsm_values == test_obsm_values).all()
276307

277308

278309
@pytest.mark.usefixtures("filepath_h5mu")

0 commit comments

Comments
 (0)