Skip to content

Commit d8610cb

Browse files
committed
add tests for dtype
1 parent 826c8ce commit d8610cb

File tree

1 file changed

+28
-1
lines changed

1 file changed

+28
-1
lines changed

tests/test_update.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ def modalities(request, obs_n, obs_across, obs_mod):
4444
mods["mod3"].obs_names = obsnames3
4545
mods["mod2"].var_names = varnames2
4646
mods["mod3"].var_names = varnames3
47-
elif obs_mod == "extreme_duplicated": # integer overflow: https://github.com/scverse/mudata/issues/107
47+
elif (
48+
obs_mod == "extreme_duplicated"
49+
): # integer overflow: https://github.com/scverse/mudata/issues/107
4850
obsnames2 = mods["mod2"].obs_names.to_numpy()
4951
varnames2 = mods["mod2"].var_names.to_numpy()
5052
obsnames2[:-1] = obsnames2[0] = "testobs"
@@ -107,6 +109,10 @@ def test_update_simple(self, mdata, axis):
107109
attr = "obs" if axis == 0 else "var"
108110
oattr = "var" if axis == 0 else "obs"
109111

112+
for mod in mdata.mod.keys():
113+
assert mdata.obsmap[mod].dtype.kind == "u"
114+
assert mdata.varmap[mod].dtype.kind == "u"
115+
110116
# names along non-axis are concatenated
111117
assert mdata.shape[1 - axis] == sum(mod.shape[1 - axis] for mod in mdata.mod.values())
112118
assert (
@@ -157,6 +163,10 @@ def test_update_add_modality(self, modalities, axis):
157163
mdata.mod[modnames[i]] = modalities[modnames[i]]
158164
mdata.update()
159165

166+
for mod in mdata.mod.keys():
167+
assert mdata.obsmap[mod].dtype.kind == "u"
168+
assert mdata.varmap[mod].dtype.kind == "u"
169+
160170
test_obsm_values = self.get_attrm_values(mdata, "obs", "test", some_obs_names)
161171
if axis == 1:
162172
assert np.isnan(mdata.obsm["test"]).sum() == modalities[modnames[i]].n_obs
@@ -206,6 +216,10 @@ def test_update_delete_modality(self, mdata, axis):
206216
del mdata.mod[modnames[0]]
207217
mdata.update()
208218

219+
for mod in mdata.mod.keys():
220+
assert mdata.obsmap[mod].dtype.kind == "u"
221+
assert mdata.varmap[mod].dtype.kind == "u"
222+
209223
assert mdata.shape[1 - axis] == sum(mod.shape[1 - axis] for mod in mdata.mod.values())
210224
assert (getattr(mdata, attr)["batch"] == fullbatch[keptmask]).all()
211225
assert (getattr(mdata, oattr)["batch"] == fullobatch[keptomask]).all()
@@ -249,6 +263,10 @@ def test_update_intersecting(self, modalities, axis):
249263

250264
mdata = MuData(modalities, axis=axis)
251265

266+
for mod in mdata.mod.keys():
267+
assert mdata.obsmap[mod].dtype.kind == "u"
268+
assert mdata.varmap[mod].dtype.kind == "u"
269+
252270
# names along non-axis are concatenated
253271
assert mdata.shape[1 - axis] == sum(mod.shape[1 - axis] for mod in modalities.values())
254272
assert (
@@ -287,6 +305,11 @@ def test_update_after_filter_obs_adata(self, mdata, axis):
287305

288306
mdata.mod["mod3"] = mdata["mod3"][mdata["mod3"].obs["min_count"] < -2].copy()
289307
mdata.update()
308+
309+
for mod in mdata.mod.keys():
310+
assert mdata.obsmap[mod].dtype.kind == "u"
311+
assert mdata.varmap[mod].dtype.kind == "u"
312+
290313
assert mdata.obs["batch"].isna().sum() == 0
291314

292315
assert (mdata.var_names == old_varnames).all()
@@ -308,6 +331,10 @@ def test_update_after_obs_reordered(self, mdata):
308331
mdata.mod["mod1"] = mdata["mod1"][::-1].copy()
309332
mdata.update()
310333

334+
for mod in mdata.mod.keys():
335+
assert mdata.obsmap[mod].dtype.kind == "u"
336+
assert mdata.varmap[mod].dtype.kind == "u"
337+
311338
test_obsm_values = self.get_attrm_values(mdata, "obs", "test", some_obs_names)
312339

313340
assert (true_obsm_values == test_obsm_values).all()

0 commit comments

Comments
 (0)