@@ -202,18 +202,14 @@ def combat( # noqa: PLR0915
202202 sanitize_anndata (adata )
203203
204204 # construct a pandas series of the batch annotation
205- model = adata .obs [[key , * (covariates if covariates else [])]]
206- batch_info = model .groupby (key , observed = True ).indices . values ()
205+ model : pd . DataFrame = adata .obs [[key , * (covariates if covariates else [])]]
206+ batch_info = model .groupby (key , observed = True ).indices
207207 n_batch = len (batch_info )
208- n_batches = np .array ([len (v ) for v in batch_info ])
208+ n_batches = np .array ([len (v ) for v in batch_info . values () ])
209209
210210 # check for batches with fewer than 2 cells
211211 small_batches = [
212- batch
213- for batch , size in zip (
214- model .groupby (key , observed = True ).indices , n_batches , strict = True
215- )
216- if size < 2
212+ batch for batch , size in zip (batch_info , n_batches , strict = True ) if size < 2
217213 ]
218214 if small_batches :
219215 msg = (
@@ -236,7 +232,9 @@ def combat( # noqa: PLR0915
236232 la .inv (batch_design .T @ batch_design ) @ batch_design .T @ s_data .T
237233 ).values
238234 # first estimate for the multiplicative batch effect
239- delta_hat = [s_data .iloc [:, batch_idxs ].var (axis = 1 ) for batch_idxs in batch_info ]
235+ delta_hat = [
236+ s_data .iloc [:, batch_idxs ].var (axis = 1 ) for batch_idxs in batch_info .values ()
237+ ]
240238
241239 # empirically fix the prior hyperparameters
242240 gamma_bar = gamma_hat .mean (axis = 1 )
@@ -249,7 +247,7 @@ def combat( # noqa: PLR0915
249247 # gamma star and delta star will be our empirical bayes (EB) estimators
250248 # for the additive and multiplicative batch effect per batch and cell
251249 gamma_star , delta_star = [], []
252- for i , batch_idxs in enumerate (batch_info ):
250+ for i , batch_idxs in enumerate (batch_info . values () ):
253251 # temp stores our estimates for the batch effect parameters.
254252 # temp[0] is the additive batch effect
255253 # temp[1] is the multiplicative batch effect
@@ -273,7 +271,7 @@ def combat( # noqa: PLR0915
273271
274272 # we now apply the parametric adjustment to the standardized data from above
275273 # loop over all batches in the data
276- for j , batch_idxs in enumerate (batch_info ):
274+ for j , batch_idxs in enumerate (batch_info . values () ):
277275 # we basically subtract the additive batch effect, rescale by the ratio
278276 # of multiplicative batch effect to pooled variance and add the overall gene
279277 # wise mean
0 commit comments