@@ -445,26 +445,6 @@ def draw(self, renderer):
445445 self .offsetText .set_ha (align )
446446 self .offsetText .draw (renderer )
447447
448- if self .axes ._draw_grid and len (ticks ):
449- # Grid points where the planes meet
450- xyz0 = np .tile (minmax , (len (ticks ), 1 ))
451- xyz0 [:, index ] = [tick .get_loc () for tick in ticks ]
452-
453- # Grid lines go from the end of one plane through the plane
454- # intersection (at xyz0) to the end of the other plane. The first
455- # point (0) differs along dimension index-2 and the last (2) along
456- # dimension index-1.
457- lines = np .stack ([xyz0 , xyz0 , xyz0 ], axis = 1 )
458- lines [:, 0 , index - 2 ] = maxmin [index - 2 ]
459- lines [:, 2 , index - 1 ] = maxmin [index - 1 ]
460- self .gridlines .set_segments (lines )
461- gridinfo = info ['grid' ]
462- self .gridlines .set_color (gridinfo ['color' ])
463- self .gridlines .set_linewidth (gridinfo ['linewidth' ])
464- self .gridlines .set_linestyle (gridinfo ['linestyle' ])
465- self .gridlines .do_3d_projection ()
466- self .gridlines .draw (renderer )
467-
468448 # Draw ticks:
469449 tickdir = self ._get_tickdir ()
470450 tickdelta = deltas [tickdir ] if highs [tickdir ] else - deltas [tickdir ]
@@ -502,6 +482,45 @@ def draw(self, renderer):
502482 renderer .close_group ('axis3d' )
503483 self .stale = False
504484
485+ @artist .allow_rasterization
486+ def draw_grid (self , renderer ):
487+ if not self .axes ._draw_grid :
488+ return
489+
490+ renderer .open_group ("grid3d" , gid = self .get_gid ())
491+
492+ ticks = self ._update_ticks ()
493+ if len (ticks ):
494+ # Get general axis information:
495+ info = self ._axinfo
496+ index = info ["i" ]
497+
498+ mins , maxs , _ , _ , _ , highs = self ._get_coord_info (renderer )
499+
500+ minmax = np .where (highs , maxs , mins )
501+ maxmin = np .where (~ highs , maxs , mins )
502+
503+ # Grid points where the planes meet
504+ xyz0 = np .tile (minmax , (len (ticks ), 1 ))
505+ xyz0 [:, index ] = [tick .get_loc () for tick in ticks ]
506+
507+ # Grid lines go from the end of one plane through the plane
508+ # intersection (at xyz0) to the end of the other plane. The first
509+ # point (0) differs along dimension index-2 and the last (2) along
510+ # dimension index-1.
511+ lines = np .stack ([xyz0 , xyz0 , xyz0 ], axis = 1 )
512+ lines [:, 0 , index - 2 ] = maxmin [index - 2 ]
513+ lines [:, 2 , index - 1 ] = maxmin [index - 1 ]
514+ self .gridlines .set_segments (lines )
515+ gridinfo = info ['grid' ]
516+ self .gridlines .set_color (gridinfo ['color' ])
517+ self .gridlines .set_linewidth (gridinfo ['linewidth' ])
518+ self .gridlines .set_linestyle (gridinfo ['linestyle' ])
519+ self .gridlines .do_3d_projection ()
520+ self .gridlines .draw (renderer )
521+
522+ renderer .close_group ('grid3d' )
523+
505524 # TODO: Get this to work (more) properly when mplot3d supports the
506525 # transforms framework.
507526 def get_tightbbox (self , renderer = None , * , for_layout_only = False ):
0 commit comments