@@ -355,7 +355,7 @@ def bootstrap(
355355 def create_contrasts (
356356 adata : AnnData ,
357357 groupby : str ,
358- selected_group : str ,
358+ selected_group : str | Sequence [ str ] ,
359359 * ,
360360 groups : Sequence [str ] | None = None ,
361361 split_by : str | Sequence [str ] | None = None ,
@@ -382,7 +382,10 @@ def create_contrasts(
382382 Column in ``adata.obs`` whose levels are compared against
383383 ``selected_group``
384384 selected_group
385- The reference (control) value in the ``groupby`` column
385+ The reference (control) value(s) in the ``groupby`` column.
386+ When a sequence is passed, each target is compared against
387+ every reference, producing one row per (target, reference)
388+ combination.
386389 groups
387390 Specific groups to include. If None, all non-reference groups
388391 are included.
@@ -405,6 +408,12 @@ def create_contrasts(
405408 ... adata, groupby="target_gene", selected_group="Non_target"
406409 ... )
407410
411+ >>> # Multiple references
412+ >>> contrasts = Distance.create_contrasts(
413+ ... adata, groupby="target_gene",
414+ ... selected_group=["Non_target", "Scramble"],
415+ ... )
416+
408417 >>> # Stratified by celltype
409418 >>> contrasts = Distance.create_contrasts(
410419 ... adata, groupby="target_gene", selected_group="Non_target",
@@ -425,10 +434,16 @@ def create_contrasts(
425434 """
426435 import pandas as pd
427436
428- if selected_group not in adata .obs [groupby ].values :
429- raise ValueError (
430- f"Reference '{ selected_group } ' not found in column '{ groupby } '"
431- )
437+ # Normalize to list
438+ if isinstance (selected_group , str ):
439+ selected_groups = [selected_group ]
440+ else :
441+ selected_groups = list (selected_group )
442+
443+ obs_values = set (adata .obs [groupby ].values )
444+ for sg in selected_groups :
445+ if sg not in obs_values :
446+ raise ValueError (f"Reference '{ sg } ' not found in column '{ groupby } '" )
432447
433448 if split_by is None :
434449 split_cols : list [str ] = []
@@ -438,41 +453,46 @@ def create_contrasts(
438453 split_cols = list (split_by )
439454
440455 allowed_groups = set (groups ) if groups is not None else None
456+ selected_set = set (selected_groups )
441457 all_cols = [groupby , * split_cols ]
442458
443- if split_cols :
444- # Get all existing (groupby, *split) combinations in one pass
445- existing = adata .obs [all_cols ].drop_duplicates ().reset_index (drop = True )
459+ parts : list [pd .DataFrame ] = []
460+ for sg in selected_groups :
461+ if split_cols :
462+ existing = adata .obs [all_cols ].drop_duplicates ().reset_index (drop = True )
446463
447- # Find which splits have the reference
448- ref_rows = existing [existing [groupby ] == selected_group ]
449- if len (ref_rows ) == 0 :
450- df = pd .DataFrame (columns = all_cols )
451- else :
452- # Inner join: keep only targets in splits that have reference
464+ ref_rows = existing [existing [groupby ] == sg ]
465+ if len (ref_rows ) == 0 :
466+ continue
453467 ref_splits = ref_rows [split_cols ]
454- targets = existing [existing [groupby ] != selected_group ]
468+ targets = existing [~ existing [groupby ]. isin ( selected_set ) ]
455469 if allowed_groups is not None :
456470 targets = targets [targets [groupby ].isin (allowed_groups )]
457- df = targets .merge (ref_splits , on = split_cols , how = "inner" )
458- df = (
459- df [all_cols ]
460- .sort_values ([* split_cols , groupby ])
461- .reset_index (drop = True )
462- )
463- else :
464- # No split — just all non-reference levels of groupby
465- targets = adata .obs [groupby ].unique ()
466- targets = [
467- t
468- for t in targets
469- if t != selected_group
470- and (allowed_groups is None or t in allowed_groups )
471- ]
472- df = pd .DataFrame ({groupby : targets })
473-
474- # Insert reference column right after groupby
475- df .insert (1 , "reference" , selected_group )
471+ matched = targets .merge (ref_splits , on = split_cols , how = "inner" )
472+ if len (matched ) == 0 :
473+ continue
474+ matched = matched [all_cols ].copy ()
475+ else :
476+ target_vals = [
477+ t
478+ for t in adata .obs [groupby ].unique ()
479+ if t not in selected_set
480+ and (allowed_groups is None or t in allowed_groups )
481+ ]
482+ if not target_vals :
483+ continue
484+ matched = pd .DataFrame ({groupby : target_vals })
485+
486+ matched .insert (1 , "reference" , sg )
487+ parts .append (matched )
488+
489+ if not parts :
490+ cols = [groupby , "reference" , * split_cols ]
491+ return pd .DataFrame (columns = cols )
492+
493+ df = pd .concat (parts , ignore_index = True )
494+ sort_cols = ["reference" , * split_cols , groupby ]
495+ df = df .sort_values (sort_cols ).reset_index (drop = True )
476496
477497 return df
478498
0 commit comments