1515from .grid_finder import GridFinder
1616
1717
18+ def _value_and_jacobian (func , xs , ys , xlims , ylims ):
19+ """
20+ Compute *func* and its derivatives along x and y at positions *xs*, *ys*,
21+ while ensuring that finite difference calculations don't try to evaluate
22+ values outside of *xlims*, *ylims*.
23+ """
24+ eps = np .finfo (float ).eps ** (1 / 2 ) # see e.g. scipy.optimize.approx_fprime
25+ val = func (xs , ys )
26+ # Take the finite difference step in the direction where the bound is the
27+ # furthest; the step size is min of epsilon and distance to that bound.
28+ xlo , xhi = sorted (xlims )
29+ dxlo = xs - xlo
30+ dxhi = xhi - xs
31+ xeps = (np .take ([- 1 , 1 ], dxhi >= dxlo )
32+ * np .minimum (eps , np .maximum (dxlo , dxhi )))
33+ val_dx = func (xs + xeps , ys )
34+ ylo , yhi = sorted (ylims )
35+ dylo = ys - ylo
36+ dyhi = yhi - ys
37+ yeps = (np .take ([- 1 , 1 ], dyhi >= dylo )
38+ * np .minimum (eps , np .maximum (dylo , dyhi )))
39+ val_dy = func (xs , ys + yeps )
40+ return (val , (val_dx - val ) / xeps , (val_dy - val ) / yeps )
41+
42+
1843class FixedAxisArtistHelper (AxisArtistHelper .Fixed ):
1944 """
2045 Helper class for a fixed axis.
@@ -121,31 +146,23 @@ def get_axislabel_transform(self, axes):
121146 return Affine2D () # axes.transData
122147
123148 def get_axislabel_pos_angle (self , axes ):
149+ def trf_xy (x , y ):
150+ trf = self .grid_helper .grid_finder .get_transform () + axes .transData
151+ return trf .transform ([x , y ]).T
124152
125- extremes = self ._grid_info ["extremes" ]
126-
153+ xmin , xmax , ymin , ymax = self ._grid_info ["extremes" ]
127154 if self .nth_coord == 0 :
128155 xx0 = self .value
129- yy0 = (extremes [2 ] + extremes [3 ]) / 2
130- dxx = 0
131- dyy = abs (extremes [2 ] - extremes [3 ]) / 1000
156+ yy0 = (ymin + ymax ) / 2
132157 elif self .nth_coord == 1 :
133- xx0 = (extremes [ 0 ] + extremes [ 1 ] ) / 2
158+ xx0 = (xmin + xmax ) / 2
134159 yy0 = self .value
135- dxx = abs (extremes [0 ] - extremes [1 ]) / 1000
136- dyy = 0
137-
138- grid_finder = self .grid_helper .grid_finder
139- (xx1 ,), (yy1 ,) = grid_finder .transform_xy ([xx0 ], [yy0 ])
140-
141- data_to_axes = axes .transData - axes .transAxes
142- p = data_to_axes .transform ([xx1 , yy1 ])
143-
160+ xy1 , dxy1_dx , dxy1_dy = _value_and_jacobian (
161+ trf_xy , xx0 , yy0 , (xmin , xmax ), (ymin , ymax ))
162+ p = axes .transAxes .inverted ().transform (xy1 )
144163 if 0 <= p [0 ] <= 1 and 0 <= p [1 ] <= 1 :
145- xx1c , yy1c = axes .transData .transform ([xx1 , yy1 ])
146- (xx2 ,), (yy2 ,) = grid_finder .transform_xy ([xx0 + dxx ], [yy0 + dyy ])
147- xx2c , yy2c = axes .transData .transform ([xx2 , yy2 ])
148- return (xx1c , yy1c ), np .rad2deg (np .arctan2 (yy2c - yy1c , xx2c - xx1c ))
164+ d = [dxy1_dy , dxy1_dx ][self .nth_coord ]
165+ return xy1 , np .rad2deg (np .arctan2 (* d [::- 1 ]))
149166 else :
150167 return None , None
151168
@@ -155,78 +172,48 @@ def get_tick_transform(self, axes):
155172 def get_tick_iterators (self , axes ):
156173 """tick_loc, tick_angle, tick_label, (optionally) tick_label"""
157174
158- grid_finder = self .grid_helper .grid_finder
159-
160175 lat_levs , lat_n , lat_factor = self ._grid_info ["lat_info" ]
161176 yy0 = lat_levs / lat_factor
162- dy = 0.01 / lat_factor
163177
164178 lon_levs , lon_n , lon_factor = self ._grid_info ["lon_info" ]
165179 xx0 = lon_levs / lon_factor
166- dx = 0.01 / lon_factor
167180
168181 e0 , e1 = self ._extremes
169182
170- if self .nth_coord == 0 :
171- mask = (e0 <= yy0 ) & (yy0 <= e1 )
172- # xx0, yy0 = xx0[mask], yy0[mask]
173- yy0 = yy0 [mask ]
174- elif self .nth_coord == 1 :
175- mask = (e0 <= xx0 ) & (xx0 <= e1 )
176- # xx0, yy0 = xx0[mask], yy0[mask]
177- xx0 = xx0 [mask ]
178-
179- def transform_xy (x , y ):
180- trf = grid_finder .get_transform () + axes .transData
181- return trf .transform (np .column_stack ([x , y ])).T
183+ def trf_xy (x , y ):
184+ trf = self .grid_helper .grid_finder .get_transform () + axes .transData
185+ return trf .transform (np .column_stack (np .broadcast_arrays (x , y ))).T
182186
183187 # find angles
184188 if self .nth_coord == 0 :
185- xx0 = np .full_like (yy0 , self .value )
186-
187- xx1 , yy1 = transform_xy (xx0 , yy0 )
188-
189- xx00 = xx0 .copy ()
190- xx00 [xx0 + dx > e1 ] -= dx
191- xx1a , yy1a = transform_xy (xx00 , yy0 )
192- xx1b , yy1b = transform_xy (xx00 + dx , yy0 )
193-
194- xx2a , yy2a = transform_xy (xx0 , yy0 )
195- xx2b , yy2b = transform_xy (xx0 , yy0 + dy )
196-
189+ mask = (e0 <= yy0 ) & (yy0 <= e1 )
190+ (xx1 , yy1 ), (dxx1 , dyy1 ), (dxx2 , dyy2 ) = _value_and_jacobian (
191+ trf_xy , self .value , yy0 [mask ], (- np .inf , np .inf ), (e0 , e1 ))
197192 labels = self ._grid_info ["lat_labels" ]
198- labels = [l for l , m in zip (labels , mask ) if m ]
199193
200194 elif self .nth_coord == 1 :
201- yy0 = np .full_like (xx0 , self .value )
202-
203- xx1 , yy1 = transform_xy (xx0 , yy0 )
195+ mask = (e0 <= xx0 ) & (xx0 <= e1 )
196+ (xx1 , yy1 ), (dxx2 , dyy2 ), (dxx1 , dyy1 ) = _value_and_jacobian (
197+ trf_xy , xx0 [mask ], self .value , (- np .inf , np .inf ), (e0 , e1 ))
198+ labels = self ._grid_info ["lon_labels" ]
204199
205- xx1a , yy1a = transform_xy (xx0 , yy0 )
206- xx1b , yy1b = transform_xy (xx0 , yy0 + dy )
200+ labels = [l for l , m in zip (labels , mask ) if m ]
207201
208- xx00 = xx0 . copy ( )
209- xx00 [ xx0 + dx > e1 ] -= dx
210- xx2a , yy2a = transform_xy ( xx00 , yy0 )
211- xx2b , yy2b = transform_xy ( xx00 + dx , yy0 )
202+ angle_normal = np . arctan2 ( dyy1 , dxx1 )
203+ angle_tangent = np . arctan2 ( dyy2 , dxx2 )
204+ mm = ( dyy1 == 0 ) & ( dxx1 == 0 ) # points with degenerate normal
205+ angle_normal [ mm ] = angle_tangent [ mm ] + np . pi / 2
212206
213- labels = self ._grid_info ["lon_labels" ]
214- labels = [l for l , m in zip (labels , mask ) if m ]
207+ tick_to_axes = self .get_tick_transform (axes ) - axes .transAxes
208+ in_01 = functools .partial (
209+ mpl .transforms ._interval_contains_close , (0 , 1 ))
215210
216211 def f1 ():
217- dd = np .arctan2 (yy1b - yy1a , xx1b - xx1a ) # angle normal
218- dd2 = np .arctan2 (yy2b - yy2a , xx2b - xx2a ) # angle tangent
219- mm = (yy1b == yy1a ) & (xx1b == xx1a ) # mask where dd not defined
220- dd [mm ] = dd2 [mm ] + np .pi / 2
221-
222- tick_to_axes = self .get_tick_transform (axes ) - axes .transAxes
223- in_01 = functools .partial (
224- mpl .transforms ._interval_contains_close , (0 , 1 ))
225- for x , y , d , d2 , lab in zip (xx1 , yy1 , dd , dd2 , labels ):
212+ for x , y , normal , tangent , lab \
213+ in zip (xx1 , yy1 , angle_normal , angle_tangent , labels ):
226214 c2 = tick_to_axes .transform ((x , y ))
227215 if in_01 (c2 [0 ]) and in_01 (c2 [1 ]):
228- d1 , d2 = np .rad2deg ([d , d2 ])
229- yield [x , y ], d1 , d2 , lab
216+ yield [x , y ], * np .rad2deg ([normal , tangent ]), lab
230217
231218 return f1 (), iter ([])
232219
0 commit comments