@@ -285,21 +285,17 @@ def test_plot_kernel():
285285 plt .plot (kernel )
286286
287287
288- @check_figures_equal (extensions = ['png' ])
289- def test_unit_axis_label (fig_test , fig_ref ):
288+ @pytest .mark .parametrize ('plot_meth_name' , ['scatter' , 'plot' ])
289+ def test_unit_axis_label (plot_meth_name ):
290+ # Check that the correct Axis labels are set on plots with units
290291 import matplotlib .testing .jpl_units as units
291292 units .register ()
292293
293- data = [0 * units .km , 1 * units .km , 2 * units .km ]
294-
295- ax_test = fig_test .subplots ()
296- ax_ref = fig_ref .subplots ()
297- axs = [ax_test , ax_ref ]
298-
299- for ax in axs :
300- ax .yaxis .set_units ('km' )
301- ax .set_xlim (10 , 20 )
302- ax .set_ylim (10 , 20 )
303-
304- ax_test .scatter ([1 , 2 , 3 ], data , edgecolors = 'none' )
305- ax_ref .plot ([1 , 2 , 3 ], data , marker = 'o' , linewidth = 0 )
294+ fig , ax = plt .subplots ()
295+ ax .xaxis .set_units ('m' )
296+ ax .yaxis .set_units ('sec' )
297+ plot_method = getattr (ax , plot_meth_name )
298+ plot_method (np .arange (3 ) * units .m , np .arange (3 ) * units .sec )
299+ assert ax .get_xlabel () == 'm'
300+ assert ax .get_ylabel () == 'sec'
301+ plt .close ('all' )
0 commit comments