Skip to content

Commit ad14c64

Browse files
committed
simplify
1 parent baef21d commit ad14c64

File tree

1 file changed

+9
-11
lines changed

1 file changed

+9
-11
lines changed

src/scanpy/preprocessing/_combat.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -202,18 +202,14 @@ def combat( # noqa: PLR0915
202202
sanitize_anndata(adata)
203203

204204
# construct a pandas series of the batch annotation
205-
model = adata.obs[[key, *(covariates if covariates else [])]]
206-
batch_info = model.groupby(key, observed=True).indices.values()
205+
model: pd.DataFrame = adata.obs[[key, *(covariates if covariates else [])]]
206+
batch_info = model.groupby(key, observed=True).indices
207207
n_batch = len(batch_info)
208-
n_batches = np.array([len(v) for v in batch_info])
208+
n_batches = np.array([len(v) for v in batch_info.values()])
209209

210210
# check for batches with fewer than 2 cells
211211
small_batches = [
212-
batch
213-
for batch, size in zip(
214-
model.groupby(key, observed=True).indices, n_batches, strict=True
215-
)
216-
if size < 2
212+
batch for batch, size in zip(batch_info, n_batches, strict=True) if size < 2
217213
]
218214
if small_batches:
219215
msg = (
@@ -236,7 +232,9 @@ def combat( # noqa: PLR0915
236232
la.inv(batch_design.T @ batch_design) @ batch_design.T @ s_data.T
237233
).values
238234
# first estimate for the multiplicative batch effect
239-
delta_hat = [s_data.iloc[:, batch_idxs].var(axis=1) for batch_idxs in batch_info]
235+
delta_hat = [
236+
s_data.iloc[:, batch_idxs].var(axis=1) for batch_idxs in batch_info.values()
237+
]
240238

241239
# empirically fix the prior hyperparameters
242240
gamma_bar = gamma_hat.mean(axis=1)
@@ -249,7 +247,7 @@ def combat( # noqa: PLR0915
249247
# gamma star and delta star will be our empirical bayes (EB) estimators
250248
# for the additive and multiplicative batch effect per batch and cell
251249
gamma_star, delta_star = [], []
252-
for i, batch_idxs in enumerate(batch_info):
250+
for i, batch_idxs in enumerate(batch_info.values()):
253251
# temp stores our estimates for the batch effect parameters.
254252
# temp[0] is the additive batch effect
255253
# temp[1] is the multiplicative batch effect
@@ -273,7 +271,7 @@ def combat( # noqa: PLR0915
273271

274272
# we now apply the parametric adjustment to the standardized data from above
275273
# loop over all batches in the data
276-
for j, batch_idxs in enumerate(batch_info):
274+
for j, batch_idxs in enumerate(batch_info.values()):
277275
# we basically subtract the additive batch effect, rescale by the ratio
278276
# of multiplicative batch effect to pooled variance and add the overall gene
279277
# wise mean

0 commit comments

Comments
 (0)