2828
2929if TYPE_CHECKING :
3030 from collections .abc import Callable
31+ from typing import Any
3132
3233 from matplotlib .axes import Axes
3334
@@ -58,71 +59,53 @@ def test_highest_expr_genes(image_comparer, col, layer):
5859
5960
6061@needs .leidenalg
61- def test_heatmap (image_comparer ):
62+ @pytest .mark .parametrize (
63+ ("params" , "key" ),
64+ [
65+ pytest .param ({}, "heatmap" , id = "default" ),
66+ pytest .param (
67+ dict (swap_axes = True , figsize = (10 , 3 ), cmap = "YlGnBu" ),
68+ "heatmap_swap_axes" ,
69+ id = "swap" ,
70+ ),
71+ pytest .param (
72+ dict (
73+ groupby = "numeric_value" ,
74+ num_categories = 4 ,
75+ figsize = (4.5 , 5 ),
76+ dendrogram = False ,
77+ ),
78+ "heatmap2" ,
79+ id = "numeric" ,
80+ ),
81+ pytest .param (
82+ dict (standard_scale = "var" , layer = "test" ),
83+ "heatmap_std_scale_var" ,
84+ id = "std_scale=var" ,
85+ ),
86+ pytest .param (
87+ dict (standard_scale = "obs" ),
88+ "heatmap_std_scale_obs" ,
89+ id = "std_scale=obs" ,
90+ ),
91+ ],
92+ )
93+ def test_heatmap (image_comparer , params : dict [str , Any ], key : str ) -> None :
6294 save_and_compare_images = partial (image_comparer , ROOT , tol = 15 )
6395
6496 adata = krumsiek11 ()
65- sc .pl .heatmap (
66- adata , adata .var_names , "cell_type" , use_raw = False , show = False , dendrogram = True
67- )
68- save_and_compare_images ("heatmap" )
69-
70- # test swap axes
71- sc .pl .heatmap (
72- adata ,
73- adata .var_names ,
74- "cell_type" ,
75- use_raw = False ,
76- show = False ,
77- dendrogram = True ,
78- swap_axes = True ,
79- figsize = (10 , 3 ),
80- cmap = "YlGnBu" ,
81- )
82- save_and_compare_images ("heatmap_swap_axes" )
83-
84- # test heatmap numeric column():
85-
86- # set as numeric column the vales for the first gene on the matrix
8797 adata .obs ["numeric_value" ] = adata .X [:, 0 ]
88- sc .pl .heatmap (
89- adata ,
90- adata .var_names ,
91- "numeric_value" ,
92- use_raw = False ,
93- num_categories = 4 ,
94- figsize = (4.5 , 5 ),
95- show = False ,
96- )
97- save_and_compare_images ("heatmap2" )
98-
99- # test var/obs standardization and layer
10098 adata .layers ["test" ] = - 1 * adata .X .copy ()
101- sc .pl .heatmap (
102- adata ,
103- adata .var_names ,
104- "cell_type" ,
105- use_raw = False ,
106- dendrogram = True ,
107- show = False ,
108- standard_scale = "var" ,
109- layer = "test" ,
110- )
111- save_and_compare_images ("heatmap_std_scale_var" )
11299
113- # test standard_scale_obs
114- sc .pl .heatmap (
115- adata ,
116- adata .var_names ,
117- "cell_type" ,
118- use_raw = False ,
119- dendrogram = True ,
120- show = False ,
121- standard_scale = "obs" ,
122- )
123- save_and_compare_images ("heatmap_std_scale_obs" )
100+ params = dict (groupby = "cell_type" , dendrogram = True ) | params
101+ sc .pl .heatmap (adata , adata .var_names , ** params , use_raw = False , show = False )
102+ save_and_compare_images (key )
103+
104+
105+ @needs .leidenalg
106+ def test_heatmap_var_as_dict (image_comparer ) -> None :
107+ save_and_compare_images = partial (image_comparer , ROOT , tol = 15 )
124108
125- # test var_names as dict
126109 pbmc = pbmc68k_reduced ()
127110 sc .tl .leiden (
128111 pbmc ,
@@ -154,8 +137,13 @@ def test_heatmap(image_comparer):
154137 )
155138 save_and_compare_images ("heatmap_var_as_dict" )
156139
157- # test that plot elements are well aligned
158- # small
140+
141+ @needs .leidenalg
142+ @pytest .mark .parametrize ("swap_axes" , [True , False ])
143+ def test_heatmap_alignment (* , image_comparer , swap_axes : bool ) -> None :
144+ """Test that plot elements are well aligned."""
145+ save_and_compare_images = partial (image_comparer , ROOT , tol = 15 )
146+
159147 a = AnnData (
160148 np .array ([[0 , 0.3 , 0.5 ], [1 , 1.3 , 1.5 ], [2 , 2.3 , 2.5 ]]),
161149 obs = {"foo" : ["a" , "b" , "c" ]},
@@ -166,21 +154,11 @@ def test_heatmap(image_comparer):
166154 a ,
167155 var_names = a .var_names ,
168156 groupby = "foo" ,
169- swap_axes = True ,
170- figsize = (4 , 4 ),
171- show = False ,
172- )
173- save_and_compare_images ("heatmap_small_swap_alignment" )
174-
175- sc .pl .heatmap (
176- a ,
177- var_names = a .var_names ,
178- groupby = "foo" ,
179- swap_axes = False ,
157+ swap_axes = swap_axes ,
180158 figsize = (4 , 4 ),
181159 show = False ,
182160 )
183- save_and_compare_images ("heatmap_small_alignment " )
161+ save_and_compare_images (f"heatmap_small { '_swap' if swap_axes else '' } _alignment " )
184162
185163
186164@pytest .mark .skipif (
0 commit comments