@@ -285,17 +285,21 @@ def test_plot_kernel():
285285 plt .plot (kernel )
286286
287287
288- @pytest .mark .parametrize ('plot_meth_name' , ['scatter' , 'plot' ])
289- def test_unit_axis_label (plot_meth_name ):
290- # Check that the correct Axis labels are set on plots with units
288+ @check_figures_equal (extensions = ['png' ])
289+ def test_unit_axis_label (fig_test , fig_ref ):
291290 import matplotlib .testing .jpl_units as units
292291 units .register ()
293292
294- fig , ax = plt .subplots ()
295- ax .xaxis .set_units ('m' )
296- ax .yaxis .set_units ('sec' )
297- plot_method = getattr (ax , plot_meth_name )
298- plot_method (np .arange (3 ) * units .m , np .arange (3 ) * units .sec )
299- assert ax .get_xlabel () == 'm'
300- assert ax .get_ylabel () == 'sec'
301- plt .close ('all' )
293+ data = [0 * units .km , 1 * units .km , 2 * units .km ]
294+
295+ ax_test = fig_test .subplots ()
296+ ax_ref = fig_ref .subplots ()
297+ axs = [ax_test , ax_ref ]
298+
299+ for ax in axs :
300+ ax .yaxis .set_units ('km' )
301+ ax .set_xlim (10 , 20 )
302+ ax .set_ylim (10 , 20 )
303+
304+ ax_test .scatter ([1 , 2 , 3 ], data , edgecolors = 'none' )
305+ ax_ref .plot ([1 , 2 , 3 ], data , marker = 'o' , linewidth = 0 )
0 commit comments