@@ -368,84 +368,208 @@ def test_contrast_distances_matches_compute_distance(
368368
369369 d = EDistanceMetric (obsm_key = "X_pca" )
370370
371- contrasts = {
372- "drugA_vs_ctrl_T" : (
373- {"treatment" : "drugA" , "celltype" : "T" },
374- {"treatment" : "ctrl" , "celltype" : "T" },
375- ),
376- "drugA_vs_ctrl_B" : (
377- {"treatment" : "drugA" , "celltype" : "B" },
378- {"treatment" : "ctrl" , "celltype" : "B" },
379- ),
380- }
371+ contrasts = Distance .create_contrasts (
372+ contrast_adata ,
373+ groupby = "treatment" ,
374+ selected_group = "ctrl" ,
375+ split_by = "celltype" ,
376+ )
381377
382378 result = d .contrast_distances (contrast_adata , contrasts = contrasts )
383379
384- assert isinstance (result , pd .Series )
385- assert result . name == "edistance"
386- assert len (result ) == 2
380+ assert isinstance (result , pd .DataFrame )
381+ assert "edistance" in result . columns
382+ assert len (result ) == len ( contrasts )
387383
388384 # Verify each contrast against compute_distance
389- for name , ( cond_a , cond_b ) in contrasts . items ():
390- mask_a = np . ones ( len ( contrast_adata ), dtype = bool )
391- mask_b = np . ones ( len ( contrast_adata ), dtype = bool )
392- for col , val in cond_a . items ():
393- mask_a &= contrast_adata .obs [col ].values == val
394- for col , val in cond_b . items ():
395- mask_b &= contrast_adata . obs [ col ]. values == val
396-
397- X = contrast_adata .obsm ["X_pca" ][mask_a ]
398- Y = contrast_adata .obsm ["X_pca" ][mask_b ]
385+ for _ , row in result . iterrows ():
386+ mask_target = ( contrast_adata . obs [ "treatment" ]. values == row [ "treatment" ]) & (
387+ contrast_adata . obs [ "celltype" ]. values == row [ "celltype" ]
388+ )
389+ mask_ref = ( contrast_adata .obs ["treatment" ].values == row [ "reference" ]) & (
390+ contrast_adata . obs [ "celltype" ]. values == row [ "celltype" ]
391+ )
392+
393+ X = contrast_adata .obsm ["X_pca" ][mask_target ]
394+ Y = contrast_adata .obsm ["X_pca" ][mask_ref ]
399395 expected = d .compute_distance (X , Y )
400396
401397 np .testing .assert_allclose (
402- result [ name ],
398+ row [ "edistance" ],
403399 expected ,
404400 atol = 1e-6 ,
405- err_msg = f"Contrast { name } mismatch" ,
401+ err_msg = f"Contrast { row ['treatment' ]} vs { row ['reference' ]} "
402+ f"in { row ['celltype' ]} mismatch" ,
406403 )
407404
408405
409406def test_contrast_distances_shared_condition (contrast_adata : AnnData ) -> None :
410407 """Test that contrasts sharing a condition (e.g. same control) work."""
408+ distance = Distance (metric = "edistance" )
409+
410+ contrasts = Distance .create_contrasts (
411+ contrast_adata ,
412+ groupby = "treatment" ,
413+ selected_group = "ctrl" ,
414+ split_by = "celltype" ,
415+ )
416+
417+ result = distance .contrast_distances (contrast_adata , contrasts = contrasts )
418+
419+ assert isinstance (result , pd .DataFrame )
420+ assert "edistance" in result .columns
421+ # All distances should be finite
422+ assert np .all (np .isfinite (result ["edistance" ].values ))
423+
424+
425+ def test_contrast_distances_self_distance_zero (contrast_adata : AnnData ) -> None :
426+ """Test that self-distance (same group vs itself) is zero."""
427+ distance = Distance (metric = "edistance" )
428+
429+ # Manually create a contrast where target == reference
430+ contrasts = pd .DataFrame (
431+ {
432+ "treatment" : ["ctrl" ],
433+ "reference" : ["ctrl" ],
434+ "celltype" : ["T" ],
435+ }
436+ )
437+
438+ result = distance .contrast_distances (contrast_adata , contrasts = contrasts )
439+ assert result ["edistance" ].iloc [0 ] == pytest .approx (0.0 , abs = 1e-7 )
440+
441+
442+ def test_contrast_distances_no_split (contrast_adata : AnnData ) -> None :
443+ """Test contrast_distances without split_by columns."""
444+ distance = Distance (metric = "edistance" )
445+
446+ contrasts = Distance .create_contrasts (
447+ contrast_adata ,
448+ groupby = "treatment" ,
449+ selected_group = "ctrl" ,
450+ )
451+
452+ result = distance .contrast_distances (contrast_adata , contrasts = contrasts )
453+
454+ assert isinstance (result , pd .DataFrame )
455+ assert "edistance" in result .columns
456+ assert len (result ) == 1 # only drugA vs ctrl
457+ assert np .all (np .isfinite (result ["edistance" ].values ))
458+
459+
460+ def test_contrast_distances_filtered (contrast_adata : AnnData ) -> None :
461+ """Test that filtering a contrasts DataFrame before computing works."""
411462 from rapids_singlecell .pertpy_gpu ._metrics ._edistance import EDistanceMetric
412463
413464 d = EDistanceMetric (obsm_key = "X_pca" )
465+ distance = Distance (metric = "edistance" )
414466
415- # Both contrasts share the ctrl_T condition
416- contrasts = {
417- "drugA_vs_ctrl_T" : (
418- {"treatment" : "drugA" , "celltype" : "T" },
419- {"treatment" : "ctrl" , "celltype" : "T" },
420- ),
421- "ctrl_T_self" : (
422- {"treatment" : "ctrl" , "celltype" : "T" },
423- {"treatment" : "ctrl" , "celltype" : "T" },
424- ),
425- }
467+ # Create full contrasts, then drop one celltype
468+ contrasts = Distance .create_contrasts (
469+ contrast_adata ,
470+ groupby = "treatment" ,
471+ selected_group = "ctrl" ,
472+ split_by = "celltype" ,
473+ )
474+ assert len (contrasts ) == 2 # drugA-T, drugA-B
426475
427- result = d .contrast_distances (contrast_adata , contrasts = contrasts )
476+ # Keep only celltype == "T"
477+ filtered = contrasts [contrasts ["celltype" ] == "T" ].reset_index (drop = True )
478+ assert len (filtered ) == 1
479+
480+ result = distance .contrast_distances (contrast_adata , contrasts = filtered )
481+
482+ assert isinstance (result , pd .DataFrame )
483+ assert len (result ) == 1
484+ assert result ["celltype" ].iloc [0 ] == "T"
485+
486+ # Verify the distance matches compute_distance
487+ mask_target = (contrast_adata .obs ["treatment" ].values == "drugA" ) & (
488+ contrast_adata .obs ["celltype" ].values == "T"
489+ )
490+ mask_ref = (contrast_adata .obs ["treatment" ].values == "ctrl" ) & (
491+ contrast_adata .obs ["celltype" ].values == "T"
492+ )
493+ expected = d .compute_distance (
494+ contrast_adata .obsm ["X_pca" ][mask_target ],
495+ contrast_adata .obsm ["X_pca" ][mask_ref ],
496+ )
497+ np .testing .assert_allclose (result ["edistance" ].iloc [0 ], expected , atol = 1e-6 )
428498
429- # Self-distance should be 0
430- assert result ["ctrl_T_self" ] == pytest .approx (0.0 , abs = 1e-7 )
499+ # Also verify it differs from the full (unfiltered) result
500+ full_result = distance .contrast_distances (contrast_adata , contrasts = contrasts )
501+ assert len (full_result ) == 2
431502
503+ # The T-cell row should match between filtered and full
504+ full_t = full_result [full_result ["celltype" ] == "T" ]["edistance" ].iloc [0 ]
505+ np .testing .assert_allclose (result ["edistance" ].iloc [0 ], full_t , atol = 1e-10 )
506+
507+
508+ def test_contrast_distances_two_split_by () -> None :
509+ """Test contrast_distances with two split_by columns."""
510+ rng = np .random .default_rng (42 )
511+ n = 10
512+ cpu_emb = rng .normal (size = (n * 6 , 5 )).astype (np .float32 )
513+ obs = pd .DataFrame (
514+ {
515+ "treatment" : pd .Categorical (
516+ ["ctrl" ] * n
517+ + ["drugA" ] * n
518+ + ["ctrl" ] * n
519+ + ["drugA" ] * n
520+ + ["ctrl" ] * n
521+ + ["drugA" ] * n
522+ ),
523+ "celltype" : pd .Categorical (["T" ] * n * 2 + ["B" ] * n * 2 + ["T" ] * n * 2 ),
524+ "batch" : pd .Categorical (["b1" ] * n * 4 + ["b2" ] * n * 2 ),
525+ }
526+ )
527+ adata = AnnData (cpu_emb .copy (), obs = obs )
528+ adata .obsm ["X_pca" ] = cp .asarray (cpu_emb , dtype = cp .float32 )
432529
433- def test_contrast_distances_empty_condition (contrast_adata : AnnData ) -> None :
434- """Test that a condition matching no cells is handled."""
435530 from rapids_singlecell .pertpy_gpu ._metrics ._edistance import EDistanceMetric
436531
437532 d = EDistanceMetric (obsm_key = "X_pca" )
533+ distance = Distance (metric = "edistance" )
438534
439- contrasts = {
440- "nonexistent" : (
441- { "treatment" : "drugX" , "celltype" : "T" } ,
442- { "treatment" : " ctrl", "celltype" : "T" } ,
443- ) ,
444- }
535+ contrasts = Distance . create_contrasts (
536+ adata ,
537+ groupby = "treatment" ,
538+ selected_group = " ctrl" ,
539+ split_by = [ "celltype" , "batch" ] ,
540+ )
445541
446- # Should not crash — group will have 0 cells
447- result = d .contrast_distances (contrast_adata , contrasts = contrasts )
448- assert isinstance (result , pd .Series )
542+ assert "celltype" in contrasts .columns
543+ assert "batch" in contrasts .columns
544+
545+ result = distance .contrast_distances (adata , contrasts = contrasts )
546+
547+ assert isinstance (result , pd .DataFrame )
548+ assert "edistance" in result .columns
549+
550+ # Verify each contrast against compute_distance
551+ for _ , row in result .iterrows ():
552+ mask_target = (
553+ (adata .obs ["treatment" ].values == row ["treatment" ])
554+ & (adata .obs ["celltype" ].values == row ["celltype" ])
555+ & (adata .obs ["batch" ].values == row ["batch" ])
556+ )
557+ mask_ref = (
558+ (adata .obs ["treatment" ].values == row ["reference" ])
559+ & (adata .obs ["celltype" ].values == row ["celltype" ])
560+ & (adata .obs ["batch" ].values == row ["batch" ])
561+ )
562+
563+ X = adata .obsm ["X_pca" ][mask_target ]
564+ Y = adata .obsm ["X_pca" ][mask_ref ]
565+ expected = d .compute_distance (X , Y )
566+
567+ np .testing .assert_allclose (
568+ row ["edistance" ],
569+ expected ,
570+ rtol = 1e-5 ,
571+ atol = 1e-5 ,
572+ )
449573
450574
451575# ============================================================================
0 commit comments