Skip to content

Commit 22dcc45

Browse files
author
Mohsen Naghipourfar
committed
Setting obs_names for corrected adata #4
1 parent b92bd4f commit 22dcc45

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

scgen/models/util.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,7 @@ def batch_removal(network, adata, batch_key="batch", cell_label_key="cell_type")
302302
corrected = anndata.AnnData(network.reconstruct(all_shared_ann.X, use_data=True))
303303
corrected.obs = all_shared_ann.obs.copy(deep=True)
304304
corrected.var_names = adata.var_names.tolist()
305+
corrected.obs_names = adata.obs_names.tolist()
305306
return corrected
306307
else:
307308
all_not_shared_ann = anndata.AnnData.concatenate(*not_shared_ct, batch_key="concat_batch")
@@ -311,6 +312,7 @@ def batch_removal(network, adata, batch_key="batch", cell_label_key="cell_type")
311312
corrected = anndata.AnnData(network.reconstruct(all_corrected_data.X, use_data=True), )
312313
corrected.obs = pd.concat([all_shared_ann.obs, all_not_shared_ann.obs])
313314
corrected.var_names = adata.var_names.tolist()
315+
corrected.obs_names = adata.obs_names.tolist()
314316
return corrected
315317

316318

tests/test_util.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33

44

55
def test_batch_removal():
6-
train = sc.read("./tests/data/pancreas.h5ad", backup_url="https://goo.gl/V29FNk")
6+
train = sc.read("./data/pancreas.h5ad", backup_url="https://goo.gl/V29FNk")
77
train.obs["cell_type"] = train.obs["celltype"].tolist()
88
network = scgen.VAEArith(x_dimension=train.shape[1], model_path="./models/batch")
9-
network.train(train_data=train, n_epochs=0)
9+
network.train(train_data=train, n_epochs=1, verbose=1)
1010
corrected_adata = scgen.batch_removal(network, train)
11+
print(corrected_adata.obs)
1112
network.sess.close()
1213

1314

15+
test_batch_removal()

0 commit comments

Comments
 (0)