@@ -34,15 +34,16 @@ def modalities(request, obs_n, obs_across, obs_mod):
3434
3535 if obs_mod :
3636 if obs_mod == "duplicated" :
37- for m in ["mod1" , "mod2" ]:
38- # Index does not support mutable operations
39- obs_names = mods [m ].obs_names .values .copy ()
40- obs_names [1 ] = obs_names [0 ]
41- mods [m ].obs_names = obs_names
42-
43- var_names = mods [m ].var_names .values .copy ()
44- var_names [1 ] = var_names [0 ]
45- mods [m ].var_names = var_names
37+ obsnames2 = mods ["mod2" ].obs_names .to_numpy ()
38+ obsnames3 = mods ["mod3" ].obs_names .to_numpy ()
39+ varnames2 = mods ["mod2" ].var_names .to_numpy ()
40+ varnames3 = mods ["mod3" ].var_names .to_numpy ()
41+ obsnames2 [0 ] = obsnames2 [1 ] = obsnames3 [1 ] = "testobs"
42+ varnames2 [0 ] = varnames2 [1 ] = varnames3 [1 ] = "testvar"
43+ mods ["mod2" ].obs_names = obsnames2
44+ mods ["mod3" ].obs_names = obsnames3
45+ mods ["mod2" ].var_names = varnames2
46+ mods ["mod3" ].var_names = varnames3
4647
4748 return mods
4849
@@ -64,8 +65,8 @@ def mdata(modalities, axis):
6465 md .obs ["batch" ] = np .random .choice (["a" , "b" , "c" ], size = md .shape [0 ], replace = True )
6566 md .var ["batch" ] = np .random .choice (["d" , "e" , "f" ], size = md .shape [1 ], replace = True )
6667
67- md .obsm ["test_obsm " ] = np .random .normal (size = (md .n_obs , 2 ))
68- md .varm ["test_varm " ] = np .random .normal (size = (md .n_var , 2 ))
68+ md .obsm ["test " ] = np .random .normal (size = (md .n_obs , 2 ))
69+ md .varm ["test " ] = np .random .normal (size = (md .n_var , 2 ))
6970
7071 return md
7172
@@ -82,6 +83,14 @@ def new_update(self):
8283 yield
8384 set_options (pull_on_update = None )
8485
86+ @staticmethod
87+ def get_attrm_values (mdata , attr , key , names ):
88+ attrm = getattr (mdata , f"{ attr } m" )
89+ index = getattr (mdata , f"{ attr } _names" )
90+ return np .concatenate (
91+ [np .atleast_1d (attrm [key ][np .nonzero (index == name )[0 ]]) for name in names ]
92+ )
93+
8594 def test_update_simple (self , mdata , axis ):
8695 """
8796 Update should work when
@@ -134,9 +143,25 @@ def test_update_add_modality(self, modalities, axis):
134143 old_attrnames = getattr (mdata , f"{ attr } _names" )
135144 old_oattrnames = getattr (mdata , f"{ oattr } _names" )
136145
146+ some_obs_names = mdata .obs_names [:2 ]
147+ mdata .obsm ["test" ] = np .random .normal (size = (mdata .n_obs , 1 ))
148+ true_obsm_values = self .get_attrm_values (mdata , "obs" , "test" , some_obs_names )
149+
137150 mdata .mod [modnames [i ]] = modalities [modnames [i ]]
138151 mdata .update ()
139152
153+ test_obsm_values = self .get_attrm_values (mdata , "obs" , "test" , some_obs_names )
154+ if axis == 1 :
155+ assert np .isnan (mdata .obsm ["test" ]).sum () == modalities [modnames [i ]].n_obs
156+ assert np .all (np .isnan (mdata .obsm ["test" ][- modalities [modnames [i ]].n_obs :]))
157+ assert np .all (~ np .isnan (mdata .obsm ["test" ][: - modalities [modnames [i ]].n_obs ]))
158+ assert (
159+ test_obsm_values [~ np .isnan (test_obsm_values )].reshape (- 1 )
160+ == true_obsm_values .reshape (- 1 )
161+ ).all ()
162+ else :
163+ assert (test_obsm_values == true_obsm_values ).all ()
164+
140165 attrnames = getattr (mdata , f"{ attr } _names" )
141166 oattrnames = getattr (mdata , f"{ oattr } _names" )
142167 assert (attrnames [: old_attrnames .size ] == old_attrnames ).all ()
@@ -157,9 +182,13 @@ def test_update_delete_modality(self, mdata, axis):
157182 modnames = list (mdata .mod .keys ())
158183 attr = "obs" if axis == 0 else "var"
159184 oattr = "var" if axis == 0 else "obs"
185+ attrm = f"{ attr } m"
186+ oattrm = f"{ oattr } m"
160187
161188 fullbatch = getattr (mdata , attr )["batch" ]
162189 fullobatch = getattr (mdata , oattr )["batch" ]
190+ fulltestm = getattr (mdata , attrm )["test" ]
191+ fullotestm = getattr (mdata , oattrm )["test" ]
163192 keptmask = (getattr (mdata , f"{ attr } map" )[modnames [1 ]].reshape (- 1 ) > 0 ) | (
164193 getattr (mdata , f"{ attr } map" )[modnames [2 ]].reshape (- 1 ) > 0
165194 )
@@ -171,19 +200,26 @@ def test_update_delete_modality(self, mdata, axis):
171200 mdata .update ()
172201
173202 assert mdata .shape [1 - axis ] == sum (mod .shape [1 - axis ] for mod in mdata .mod .values ())
174- assert (getattr (mdata , oattr )["batch" ] == fullobatch [keptomask ]).all ()
175203 assert (getattr (mdata , attr )["batch" ] == fullbatch [keptmask ]).all ()
204+ assert (getattr (mdata , oattr )["batch" ] == fullobatch [keptomask ]).all ()
205+ assert (getattr (mdata , attrm )["test" ] == fulltestm [keptmask , :]).all ()
206+ assert (getattr (mdata , oattrm )["test" ] == fullotestm [keptomask , :]).all ()
176207
177208 fullbatch = getattr (mdata , attr )["batch" ]
178209 fullobatch = getattr (mdata , oattr )["batch" ]
210+ fulltestm = getattr (mdata , attrm )["test" ]
211+ fullotestm = getattr (mdata , oattrm )["test" ]
179212 keptmask = getattr (mdata , f"{ attr } map" )[modnames [1 ]].reshape (- 1 ) > 0
180213 keptomask = getattr (mdata , f"{ oattr } map" )[modnames [1 ]].reshape (- 1 ) > 0
214+
181215 del mdata .mod [modnames [2 ]]
182216 mdata .update ()
183217
184218 assert mdata .shape [1 - axis ] == sum (mod .shape [1 - axis ] for mod in mdata .mod .values ())
185219 assert (getattr (mdata , oattr )["batch" ] == fullobatch [keptomask ]).all ()
186220 assert (getattr (mdata , attr )["batch" ] == fullbatch [keptmask ]).all ()
221+ assert (getattr (mdata , attrm )["test" ] == fulltestm [keptmask , :]).all ()
222+ assert (getattr (mdata , oattrm )["test" ] == fullotestm [keptomask , :]).all ()
187223
188224 def test_update_intersecting (self , modalities , axis ):
189225 """
@@ -216,22 +252,14 @@ def test_update_intersecting(self, modalities, axis):
216252 )
217253 ).all ()
218254
219- # names along axis are intersected
255+ # names along axis are unioned
220256 axisnames = reduce (
221257 lambda x , y : x .union (y , sort = False ),
222258 (getattr (mod , f"{ attr } _names" ) for mod in modalities .values ()),
223259 )
224260 assert mdata .shape [axis ] == axisnames .shape [0 ]
225261 assert (getattr (mdata , f"{ attr } _names" ) == axisnames ).all ()
226262
227- # Variables are different across modalities
228- for m , mod in modalities .items ():
229- # Columns are intact in individual modalities
230- assert "mod" in mod .obs .columns
231- assert all (mod .obs ["mod" ] == m )
232- assert "mod" in mod .var .columns
233- assert all (mod .var ["mod" ] == m )
234-
235263 def test_update_after_filter_obs_adata (self , mdata , axis ):
236264 """
237265 Check for muon issue #44.
@@ -242,6 +270,14 @@ def test_update_after_filter_obs_adata(self, mdata, axis):
242270 old_obsnames = mdata .obs_names
243271 old_varnames = mdata .var_names
244272
273+ filtermask = mdata ["mod3" ].obs ["min_count" ] < - 2
274+ fullfiltermask = mdata .obsmap ["mod3" ].copy () > 0
275+ fullfiltermask [fullfiltermask ] = filtermask
276+ keptmask = (mdata .obsmap ["mod1" ] > 0 ) | (mdata .obsmap ["mod2" ] > 0 ) | fullfiltermask
277+
278+ some_obs_names = mdata [keptmask , :].obs_names .values [:2 ]
279+ true_obsm_values = self .get_attrm_values (mdata [keptmask ], "obs" , "test" , some_obs_names )
280+
245281 mdata .mod ["mod3" ] = mdata ["mod3" ][mdata ["mod3" ].obs ["min_count" ] < - 2 ].copy ()
246282 mdata .update ()
247283 assert mdata .obs ["batch" ].isna ().sum () == 0
@@ -251,28 +287,23 @@ def test_update_after_filter_obs_adata(self, mdata, axis):
251287 # check if the order is preserved
252288 assert (mdata .obs_names == old_obsnames [old_obsnames .isin (mdata .obs_names )]).all ()
253289
290+ test_obsm_values = self .get_attrm_values (mdata , "obs" , "test" , some_obs_names )
291+ assert (true_obsm_values == test_obsm_values ).all ()
292+
254293 def test_update_after_obs_reordered (self , mdata ):
255294 """
256295 Update should work if obs are reordered.
257296 """
258297 some_obs_names = mdata .obs_names .values [:2 ]
259298
260- true_obsm_values = [
261- mdata .obsm ["test_obsm" ][np .where (mdata .obs_names .values == name )[0 ][0 ]]
262- for name in some_obs_names
263- ]
299+ true_obsm_values = self .get_attrm_values (mdata , "obs" , "test" , some_obs_names )
264300
265301 mdata .mod ["mod1" ] = mdata ["mod1" ][::- 1 ].copy ()
266302 mdata .update ()
267303
268- test_obsm_values = [
269- mdata .obsm ["test_obsm" ][np .where (mdata .obs_names == name )[0 ][0 ]]
270- for name in some_obs_names
271- ]
304+ test_obsm_values = self .get_attrm_values (mdata , "obs" , "test" , some_obs_names )
272305
273- assert all (
274- [all (true_obsm_values [i ] == test_obsm_values [i ]) for i in range (len (true_obsm_values ))]
275- )
306+ assert (true_obsm_values == test_obsm_values ).all ()
276307
277308
278309@pytest .mark .usefixtures ("filepath_h5mu" )
0 commit comments