1515from .widgets import ClickableHTML
1616
1717
18+ def min_with_none (a , b ):
19+ return a if b is None else min (a , b )
20+
21+
22+ def max_with_none (a , b ):
23+ return a if b is None else max (a , b )
24+
25+
1826class Axes (ipw .GridBox ):
1927 def __init__ (self , * , ax : MplAxes , figure = None ) -> None :
2028 self .background_color = "#ffffff"
@@ -290,35 +298,43 @@ def height(self, h):
290298 # self._margins["rightspine"].height = h
291299
292300 def autoscale (self ):
293- xmin = np . inf
294- xmax = - np . inf
295- ymin = np . inf
296- ymax = - np . inf
301+ xmin = None
302+ xmax = None
303+ ymin = None
304+ ymax = None
297305 for artist in self ._artists :
298306 lims = artist .get_bbox ()
299- xmin = min (lims ["left" ], xmin )
300- xmax = max (lims ["right" ], xmax )
301- ymin = min (lims ["bottom" ], ymin )
302- ymax = max (lims ["top" ], ymax )
303- self ._xmin = xmin
304- self ._xmax = xmax
305- self ._ymin = ymin
306- self ._ymax = ymax
307-
308- # self._background_mesh.geometry = p3.BoxGeometry(
309- # width=2 * (self._xmax - self._xmin),
310- # height=2 * (self._ymax - self._ymin),
311- # widthSegments=1,
312- # heightSegments=1,
313- # )
307+ xmin = min_with_none (lims ["left" ], xmin )
308+ xmax = max_with_none (lims ["right" ], xmax )
309+ ymin = min_with_none (lims ["bottom" ], ymin )
310+ ymax = max_with_none (lims ["top" ], ymax )
311+ self ._xmin = (
312+ xmin
313+ if xmin is not None
314+ else (0.0 if self .get_xscale () == "linear" else 1.0 )
315+ )
316+ self ._xmax = (
317+ xmax
318+ if xmax is not None
319+ else (1.0 if self .get_xscale () == "linear" else 10.0 )
320+ )
321+ self ._ymin = (
322+ ymin
323+ if ymin is not None
324+ else (0.0 if self .get_yscale () == "linear" else 1.0 )
325+ )
326+ self ._ymax = (
327+ ymax
328+ if ymax is not None
329+ else (1.0 if self .get_yscale () == "linear" else 10.0 )
330+ )
331+
314332 self ._background_mesh .geometry = p3 .PlaneGeometry (
315333 width = 2 * (self ._xmax - self ._xmin ),
316334 height = 2 * (self ._ymax - self ._ymin ),
317335 widthSegments = 1 ,
318336 heightSegments = 1 ,
319337 )
320- # self._background_mesh.geometry.width = 2 * (self._xmax - self._xmin)
321- # self._background_mesh.geometry.height = 2 * (self._ymax - self._ymin)
322338
323339 self ._background_mesh .position = [
324340 0.5 * (self ._xmin + self ._xmax ),
@@ -523,6 +539,10 @@ def get_xscale(self):
523539 return self ._ax .get_xscale ()
524540
525541 def set_xscale (self , scale ):
542+ if scale not in ("linear" , "log" ):
543+ raise ValueError ("Scale must be 'linear' or 'log'" )
544+ if scale == self .get_xscale ():
545+ return
526546 self ._ax .set_xscale (scale )
527547 for artist in self ._artists :
528548 artist ._set_xscale (scale )
@@ -533,6 +553,10 @@ def get_yscale(self):
533553 return self ._ax .get_yscale ()
534554
535555 def set_yscale (self , scale ):
556+ if scale not in ("linear" , "log" ):
557+ raise ValueError ("Scale must be 'linear' or 'log'" )
558+ if scale == self .get_yscale ():
559+ return
536560 self ._ax .set_yscale (scale )
537561 for artist in self ._artists :
538562 artist ._set_yscale (scale )
@@ -661,13 +685,35 @@ def get_title(self):
661685 def plot (self , * args , color = None , ** kwargs ):
662686 if color is None :
663687 color = f"C{ len (self .lines )} "
664- line = Line (* args , color = color , ** kwargs )
688+ line = Line (
689+ * args ,
690+ color = color ,
691+ xscale = self .get_xscale (),
692+ yscale = self .get_yscale (),
693+ ** kwargs ,
694+ )
665695 line .axes = self
666696 self .lines .append (line )
667697 self .add_artist (line )
668698 self .autoscale ()
669699 return line
670700
701+ def semilogx (self , * args , ** kwargs ):
702+ out = self .plot (* args , ** kwargs )
703+ self .set_xscale ("log" )
704+ return out
705+
706+ def semilogy (self , * args , ** kwargs ):
707+ out = self .plot (* args , ** kwargs )
708+ self .set_yscale ("log" )
709+ return out
710+
711+ def loglog (self , * args , ** kwargs ):
712+ out = self .plot (* args , ** kwargs )
713+ self .set_xscale ("log" )
714+ self .set_yscale ("log" )
715+ return out
716+
671717 def scatter (self , * args , c = None , ** kwargs ):
672718 if c is None :
673719 c = f"C{ len (self .collections )} "
0 commit comments