@@ -711,14 +711,30 @@ def test_distance_default_obsm_key() -> None:
711711# ============================================================================
712712
713713
714- def test_float64_matches_float32_results (small_adata : AnnData ) -> None :
714+ def test_float64_matches_float32_results () -> None :
715715 """Test that float64 and float32 produce similar results (within float32 precision)."""
716- adata_f64 = small_adata .copy ()
717- adata_f64 .obsm ["X_pca" ] = small_adata .obsm ["X_pca" ].astype (cp .float64 )
716+ # Use small dataset to avoid GPU resource exhaustion with float64
717+ rng = np .random .default_rng (0 )
718+ n_groups = 3
719+ cells_per_group = 4
720+ n_features = 5
721+ total_cells = n_groups * cells_per_group
722+
723+ cpu_embedding = rng .normal (size = (total_cells , n_features )).astype (np .float32 )
724+ groups = [f"g{ idx } " for idx in range (n_groups ) for _ in range (cells_per_group )]
725+ obs = pd .DataFrame (
726+ {"group" : pd .Categorical (groups , categories = [f"g{ i } " for i in range (n_groups )])}
727+ )
728+
729+ adata_f32 = AnnData (cpu_embedding , obs = obs )
730+ adata_f32 .obsm ["X_pca" ] = cp .asarray (cpu_embedding , dtype = cp .float32 )
731+
732+ adata_f64 = AnnData (cpu_embedding , obs = obs .copy ())
733+ adata_f64 .obsm ["X_pca" ] = cp .asarray (cpu_embedding , dtype = cp .float64 )
718734
719735 distance = Distance (metric = "edistance" )
720736
721- result_f32 = distance .pairwise (small_adata , groupby = "group" )
737+ result_f32 = distance .pairwise (adata_f32 , groupby = "group" )
722738 result_f64 = distance .pairwise (adata_f64 , groupby = "group" )
723739
724740 np .testing .assert_allclose (
@@ -772,9 +788,10 @@ def test_bootstrap_different_feature_counts(n_features: int) -> None:
772788def test_pairwise_correctness_parametrized (dtype ) -> None :
773789 """Parametrized test for pairwise correctness across dtypes."""
774790 rng = np .random .default_rng (42 )
775- n_groups = 4
776- cells_per_group = 15
777- n_features = 20
791+ # Use smaller sizes for float64 to avoid GPU resource exhaustion
792+ n_groups = 3
793+ cells_per_group = 4 if dtype == np .float64 else 15
794+ n_features = 5 if dtype == np .float64 else 20
778795 total_cells = n_groups * cells_per_group
779796
780797 cpu_embedding = rng .normal (size = (total_cells , n_features )).astype (dtype )
@@ -793,7 +810,7 @@ def test_pairwise_correctness_parametrized(dtype) -> None:
793810 rtol = 1e-10 if dtype == np .float64 else 1e-5
794811 atol = 1e-12 if dtype == np .float64 else 1e-6
795812
796- for g1 , g2 in [("g0" , "g1" ), ("g1 " , "g3 " ), ("g2 " , "g3 " )]:
813+ for g1 , g2 in [("g0" , "g1" ), ("g0 " , "g2 " ), ("g1 " , "g2 " )]:
797814 X = cpu_embedding [np .array (groups ) == g1 ]
798815 Y = cpu_embedding [np .array (groups ) == g2 ]
799816 expected = _compute_energy_distance_cpu (X , Y )
@@ -807,30 +824,59 @@ def test_pairwise_correctness_parametrized(dtype) -> None:
807824 )
808825
809826
810- @pytest .mark .parametrize ("dtype" , [np .float32 , np .float64 ])
811827@pytest .mark .parametrize ("n_features" , [50 , 400 ])
812- def test_correctness_dtype_and_features ( dtype , n_features ) -> None :
813- """Test correctness across dtypes and feature counts."""
828+ def test_correctness_float32_features ( n_features ) -> None :
829+ """Test correctness across feature counts with float32 ."""
814830 rng = np .random .default_rng (42 )
815831 n_groups = 3
816832 cells_per_group = 10
817833 total_cells = n_groups * cells_per_group
818834
819- cpu_embedding = rng .normal (size = (total_cells , n_features )).astype (dtype )
835+ cpu_embedding = rng .normal (size = (total_cells , n_features )).astype (np . float32 )
820836 groups = [f"g{ idx } " for idx in range (n_groups ) for _ in range (cells_per_group )]
821837 obs = pd .DataFrame (
822838 {"group" : pd .Categorical (groups , categories = [f"g{ i } " for i in range (n_groups )])}
823839 )
824840
825- adata = AnnData (cpu_embedding .copy (). astype ( np . float32 ) , obs = obs )
826- adata .obsm ["X_pca" ] = cp .asarray (cpu_embedding , dtype = dtype )
841+ adata = AnnData (cpu_embedding .copy (), obs = obs )
842+ adata .obsm ["X_pca" ] = cp .asarray (cpu_embedding , dtype = np . float32 )
827843
828844 distance = Distance (metric = "edistance" )
829845 result_df = distance .pairwise (adata , groupby = "group" )
830846
831- # Check correctness
832- rtol = 1e-10 if dtype == np .float64 else 1e-4
833- atol = 1e-12 if dtype == np .float64 else 1e-5
847+ X = cpu_embedding [np .array (groups ) == "g0" ]
848+ Y = cpu_embedding [np .array (groups ) == "g1" ]
849+ expected = _compute_energy_distance_cpu (X , Y )
850+ actual = result_df .loc ["g0" , "g1" ]
851+
852+ np .testing .assert_allclose (
853+ actual ,
854+ expected ,
855+ rtol = 1e-4 ,
856+ atol = 1e-5 ,
857+ err_msg = f"n_features={ n_features } mismatch" ,
858+ )
859+
860+
861+ def test_correctness_float64_small () -> None :
862+ """Test correctness with float64 using small data to avoid GPU resource exhaustion."""
863+ rng = np .random .default_rng (42 )
864+ n_groups = 3
865+ cells_per_group = 4
866+ n_features = 5
867+ total_cells = n_groups * cells_per_group
868+
869+ cpu_embedding = rng .normal (size = (total_cells , n_features )).astype (np .float64 )
870+ groups = [f"g{ idx } " for idx in range (n_groups ) for _ in range (cells_per_group )]
871+ obs = pd .DataFrame (
872+ {"group" : pd .Categorical (groups , categories = [f"g{ i } " for i in range (n_groups )])}
873+ )
874+
875+ adata = AnnData (cpu_embedding .copy ().astype (np .float32 ), obs = obs )
876+ adata .obsm ["X_pca" ] = cp .asarray (cpu_embedding , dtype = np .float64 )
877+
878+ distance = Distance (metric = "edistance" )
879+ result_df = distance .pairwise (adata , groupby = "group" )
834880
835881 X = cpu_embedding [np .array (groups ) == "g0" ]
836882 Y = cpu_embedding [np .array (groups ) == "g1" ]
@@ -840,9 +886,9 @@ def test_correctness_dtype_and_features(dtype, n_features) -> None:
840886 np .testing .assert_allclose (
841887 actual ,
842888 expected ,
843- rtol = rtol ,
844- atol = atol ,
845- err_msg = f"dtype= { dtype . __name__ } , n_features= { n_features } mismatch" ,
889+ rtol = 1e-10 ,
890+ atol = 1e-12 ,
891+ err_msg = "float64 mismatch" ,
846892 )
847893
848894
@@ -928,7 +974,7 @@ def pertpy_adata() -> AnnData:
928974
929975 This uses the same data pertpy uses in their tests.
930976 """
931- import pertpy as pt
977+ pt = pytest . importorskip ( "pertpy" )
932978 import scanpy as sc
933979
934980 adata = pt .dt .distance_example ()
0 commit comments