33import warnings
44
55import matplotlib as mpl
6+ import matplotlib .colors as cm
67import numpy as np
78import pythreejs as p3
89
10+ from .norm import Normalizer
911from .utils import find_limits , fix_empty_range
1012
1113
1214class Mesh :
13- def __init__ (self , * args , cmap = "viridis" ):
15+ def __init__ (self , * args , cmap : str = "viridis" , norm : str = "linear " ):
1416 if len (args ) not in (1 , 3 ):
1517 raise ValueError (
1618 f"Invalid number of arguments: expected 1 or 3. Got { len (args )} "
@@ -32,8 +34,8 @@ def __init__(self, *args, cmap="viridis"):
3234 self ._y = np .asarray (y )
3335 self ._c = np .asarray (c )
3436
35- self .norm = mpl . colors . Normalize (vmin = np .min (self ._c ), vmax = np .max (self ._c ))
36- self .cmap = mpl .colormaps [cmap ].copy ()
37+ self ._norm = Normalizer (vmin = np .min (self ._c ), vmax = np .max (self ._c ), norm = norm )
38+ self ._cmap = mpl .colormaps [cmap ].copy ()
3739
3840 self ._faces = self ._make_faces ()
3941
@@ -55,7 +57,7 @@ def __init__(self, *args, cmap="viridis"):
5557 self ._mesh = p3 .Mesh (geometry = self ._geometry , material = self ._material )
5658
5759 def _make_colors (self ) -> np .ndarray :
58- colors_rgba = self .cmap (self .norm (self ._c .flatten ()))
60+ colors_rgba = self ._cmap (self ._norm (self ._c .flatten ()))
5961 colors = colors_rgba [:, :3 ].astype ("float32" )
6062 # Assign colors to vertices (each vertex in a cell gets the same color)
6163 return np .repeat (colors , 4 , axis = 0 ) # 4 vertices per cell
@@ -114,6 +116,9 @@ def _make_vertices(self) -> np.ndarray:
114116 def _update_positions (self ) -> None :
115117 self ._geometry .attributes ["position" ].array = self ._make_vertices ()
116118
119+ def _update_colors (self ) -> None :
120+ self ._geometry .attributes ["color" ].array = self ._make_colors ()
121+
117122 def get_bbox (self ) -> dict [str , float ]:
118123 pad = False
119124 left , right = fix_empty_range (find_limits (self ._x , scale = self ._xscale , pad = pad ))
@@ -139,7 +144,7 @@ def set_ydata(self, y: np.ndarray):
139144
140145 def set_array (self , c : np .ndarray ):
141146 self ._c = np .asarray (c )
142- self ._geometry . attributes [ "color" ]. array = self . _make_colors ()
147+ self ._update_colors ()
143148
144149 def _set_xscale (self , scale : str ) -> None :
145150 self ._xscale = scale
@@ -148,3 +153,33 @@ def _set_xscale(self, scale: str) -> None:
148153 def _set_yscale (self , scale : str ) -> None :
149154 self ._yscale = scale
150155 self ._update_positions ()
156+
157+ def set_cmap (self , cmap : str ) -> None :
158+ self ._cmap = mpl .colormaps [cmap ].copy ()
159+ self ._update_colors ()
160+ if self ._colorbar is not None :
161+ self ._colorbar .update ()
162+
163+ @property
164+ def cmap (self ) -> cm .Colormap :
165+ return self ._cmap
166+
167+ @cmap .setter
168+ def cmap (self , cmap : str ) -> None :
169+ self .set_cmap (cmap )
170+
171+ @property
172+ def norm (self ) -> Normalizer :
173+ return self ._norm
174+
175+ @norm .setter
176+ def norm (self , norm : Normalizer | str ) -> None :
177+ if isinstance (norm , str ):
178+ self ._norm = Normalizer (
179+ vmin = np .min (self ._c ), vmax = np .max (self ._c ), norm = norm
180+ )
181+ else :
182+ self ._norm = norm
183+ self ._update_colors ()
184+ if self ._colorbar is not None :
185+ self ._colorbar .update ()
0 commit comments