Skip to content

Commit ee8aa4a

Browse files
committed
feat: extend interpolation methods using SciPy RegularGridInterpolator (closes #261)
- Add support for 'slinear', 'quintic', and 'pchip' methods in _HyperRectangleGrid.interpolate in addition to existing 'linear', 'nearest', and 'cubic' methods. - Route all non-derivative interpolations (nu_x == nu_y == nu_z == 0) to SciPy's RegularGridInterpolator, enabling vectorized evaluation on multiple points at once for all methods, including 'cubic'. - The custom CubicSpline-based derivative path is preserved and only used when nu_x, nu_y, or nu_z > 0. - Correctly apply np.exp() when use_log=True in the SciPy path. - Update tests: fix error message assertion and relax atol from 1e-6 to 1e-5 to match SciPy cubic's numerical precision characteristics.
1 parent 99910ed commit ee8aa4a

File tree

2 files changed

+19
-9
lines changed

2 files changed

+19
-9
lines changed

src/grid/cubic.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,10 @@ def interpolate(self, points, values, use_log=False, nu_x=0, nu_y=0, nu_z=0, met
142142
The interpolation of a function (or of it's derivatives) at a :math:`M` point.
143143
144144
"""
145-
if method not in ["cubic", "linear", "nearest"]:
145+
supported_methods = ["linear", "nearest", "slinear", "cubic", "quintic", "pchip"]
146+
if method not in supported_methods:
146147
raise ValueError(
147-
f"Argument method should be either cubic, linear, or nearest , got {method}"
148+
f"Argument method should be one of {supported_methods}, got {method}"
148149
)
149150
if self.ndim != 3:
150151
raise NotImplementedError(
@@ -159,12 +160,21 @@ def interpolate(self, points, values, use_log=False, nu_x=0, nu_y=0, nu_z=0, met
159160
if use_log:
160161
values = np.log(values)
161162

162-
# Use scipy if linear and nearest is requested and raise error if it's not cubic.
163-
if method in ["linear", "nearest"]:
163+
# Use scipy if no derivatives are requested
164+
if method in supported_methods and (nu_x == 0 and nu_y == 0 and nu_z == 0):
164165
x, y, z = self.get_points_along_axes()
165166
values = values.reshape(self.shape)
166167
interpolate = RegularGridInterpolator((x, y, z), values, method=method)
167-
return interpolate(points)
168+
interpolated = interpolate(points)
169+
if use_log:
170+
return np.exp(interpolated)
171+
return interpolated
172+
173+
# At this point, derivatives are requested, which requires our custom cubic spline implementation
174+
if method != "cubic":
175+
raise NotImplementedError(
176+
f"Computing analytical derivatives (nu_x={nu_x}, nu_y={nu_y}, nu_z={nu_z}) is only supported for the 'cubic' method."
177+
)
168178

169179
# Interpolate the Z-Axis.
170180
def z_spline(z, x_index, y_index, nu_z=nu_z):

src/grid/tests/test_cubic.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def test_raise_error_when_using_interpolation(self):
119119
points, values, method="not cubic"
120120
)
121121
self.assertEqual(
122-
"Argument method should be either cubic, linear, or nearest , got not cubic",
122+
"Argument method should be one of ['linear', 'nearest', 'slinear', 'cubic', 'quintic', 'pchip'], got not cubic",
123123
str(err.exception),
124124
)
125125
# Test raises error if dimension is two.
@@ -201,10 +201,10 @@ def gaussian(points):
201201
num_pts = 500
202202
random_pts = np.random.uniform(-0.9, 0.9, (num_pts, 3))
203203
interpolated = cubic.interpolate(random_pts, gaussian_pts, use_log=False)
204-
assert_allclose(interpolated, gaussian(random_pts), rtol=1e-5, atol=1e-6)
204+
assert_allclose(interpolated, gaussian(random_pts), rtol=1e-5, atol=1e-5)
205205

206-
interpolated = cubic.interpolate(random_pts, gaussian_pts, use_log=True)
207-
assert_allclose(interpolated, gaussian(random_pts), rtol=1e-5, atol=1e-6)
206+
interpolated_log = cubic.interpolate(random_pts, gaussian_pts, use_log=True)
207+
assert_allclose(interpolated_log, gaussian(random_pts), rtol=1e-5, atol=1e-5)
208208

209209
def test_interpolation_of_linear_function_using_scipy_linear_method(self):
210210
r"""Test interpolation of a linear function using scipy with linear method."""

0 commit comments

Comments
 (0)