33import warnings
44
55import matplotlib as mpl
6- import matplotlib .colors as mplc
6+ import matplotlib .colors as cm
77import numpy as np
88import pythreejs as p3
99
10+ from .norm import Normalizer
1011from .utils import find_limits , fix_empty_range
1112
1213SHADER_LIBRARY = {
5152
5253
5354class Points :
54- def __init__ (self , x , y , c = "C0" , s = 3 , marker = "s" , zorder = 0 , cmap = "viridis" ) -> None :
55+ def __init__ (
56+ self ,
57+ x ,
58+ y ,
59+ c = "C0" ,
60+ s = 3 ,
61+ marker = "s" ,
62+ zorder = 0 ,
63+ cmap = "viridis" ,
64+ norm : str = "linear" ,
65+ ) -> None :
5566 self .axes = None
5667 self ._x = np .asarray (x )
5768 self ._y = np .asarray (y )
@@ -61,18 +72,21 @@ def __init__(self, x, y, c="C0", s=3, marker="s", zorder=0, cmap="viridis") -> N
6172
6273 if not isinstance (c , str ) or not np .isscalar (s ) or marker != "s" :
6374 if isinstance (c , str ):
64- rgba = mplc .LinearSegmentedColormap .from_list ("tmp" , [c , c ])(
65- np .ones_like (self ._x )
66- )
75+ self ._c = np .ones_like (self ._x )
76+ self ._norm = Normalizer (vmin = 1 , vmax = 1 )
77+ self ._cmap = cm .LinearSegmentedColormap .from_list ("tmp" , [c , c ])
78+ # (
79+ # np.ones_like(self._x)
80+ # )
6781 else :
6882 self ._c = np .asarray (c )
69- self .norm = mpl . colors . Normalize (
70- vmin = np .min (self ._c ), vmax = np .max (self ._c )
83+ self ._norm = Normalizer (
84+ vmin = np .min (self ._c ), vmax = np .max (self ._c ), norm = norm
7185 )
72- self .cmap = mpl .colormaps [cmap ].copy ()
73- rgba = self .cmap (self .norm (self ._c ))
86+ self ._cmap = mpl .colormaps [cmap ].copy ()
87+ # rgba = self.cmap(self.norm(self._c))
7488
75- colors = rgba [:, : 3 ]. astype ( np . float32 ) # Take only RGB, drop alpha
89+ colors = self . _make_colors ()
7690
7791 if np .isscalar (s ):
7892 sizes = np .full_like (self ._x , s , dtype = np .float32 )
@@ -123,10 +137,16 @@ def __init__(self, x, y, c="C0", s=3, marker="s", zorder=0, cmap="viridis") -> N
123137 }
124138 )
125139
126- self ._material = p3 .PointsMaterial (color = mplc .to_hex (c ), size = s )
140+ self ._material = p3 .PointsMaterial (color = cm .to_hex (c ), size = s )
127141
128142 self ._points = p3 .Points (geometry = self ._geometry , material = self ._material )
129143
144+ def _make_colors (self ) -> np .ndarray :
145+ return self ._cmap (self .norm (self ._c ))[..., :3 ].astype ("float32" )
146+
147+ def _update_colors (self ) -> None :
148+ self ._geometry .attributes ["customColor" ].array = self ._make_colors ()
149+
130150 def get_bbox (self ):
131151 pad = 0.03
132152 left , right = fix_empty_range (find_limits (self ._x , scale = self ._xscale , pad = pad ))
@@ -170,3 +190,33 @@ def _set_xscale(self, scale):
170190 def _set_yscale (self , scale ):
171191 self ._yscale = scale
172192 self ._update ()
193+
194+ def set_cmap (self , cmap : str ) -> None :
195+ self ._cmap = mpl .colormaps [cmap ].copy ()
196+ self ._update_colors ()
197+ if self ._colorbar is not None :
198+ self ._colorbar .update ()
199+
200+ @property
201+ def cmap (self ) -> cm .Colormap :
202+ return self ._cmap
203+
204+ @cmap .setter
205+ def cmap (self , cmap : str ) -> None :
206+ self .set_cmap (cmap )
207+
208+ @property
209+ def norm (self ) -> Normalizer :
210+ return self ._norm
211+
212+ @norm .setter
213+ def norm (self , norm : Normalizer | str ) -> None :
214+ if isinstance (norm , str ):
215+ self ._norm = Normalizer (
216+ vmin = np .min (self ._c ), vmax = np .max (self ._c ), norm = norm
217+ )
218+ else :
219+ self ._norm = norm
220+ self ._update_colors ()
221+ if self ._colorbar is not None :
222+ self ._colorbar .update ()
0 commit comments