diff --git a/src/grid/cubic.py b/src/grid/cubic.py index 628035a3..6e8dc71f 100644 --- a/src/grid/cubic.py +++ b/src/grid/cubic.py @@ -142,9 +142,10 @@ def interpolate(self, points, values, use_log=False, nu_x=0, nu_y=0, nu_z=0, met The interpolation of a function (or of it's derivatives) at a :math:`M` point. """ - if method not in ["cubic", "linear", "nearest"]: + supported_methods = ["linear", "nearest", "slinear", "cubic", "quintic", "pchip"] + if method not in supported_methods: raise ValueError( - f"Argument method should be either cubic, linear, or nearest , got {method}" + f"Argument method should be one of {supported_methods}, got {method}" ) if self.ndim != 3: raise NotImplementedError( @@ -159,12 +160,21 @@ def interpolate(self, points, values, use_log=False, nu_x=0, nu_y=0, nu_z=0, met if use_log: values = np.log(values) - # Use scipy if linear and nearest is requested and raise error if it's not cubic. - if method in ["linear", "nearest"]: + # Use scipy if no derivatives are requested + if method in supported_methods and (nu_x == 0 and nu_y == 0 and nu_z == 0): x, y, z = self.get_points_along_axes() values = values.reshape(self.shape) interpolate = RegularGridInterpolator((x, y, z), values, method=method) - return interpolate(points) + interpolated = interpolate(points) + if use_log: + return np.exp(interpolated) + return interpolated + + # At this point, derivatives are requested, which requires our custom cubic spline implementation + if method != "cubic": + raise NotImplementedError( + f"Computing analytical derivatives (nu_x={nu_x}, nu_y={nu_y}, nu_z={nu_z}) is only supported for the 'cubic' method." + ) # Interpolate the Z-Axis. def z_spline(z, x_index, y_index, nu_z=nu_z): diff --git a/src/grid/tests/test_cubic.py b/src/grid/tests/test_cubic.py index 026a981a..bff04dba 100644 --- a/src/grid/tests/test_cubic.py +++ b/src/grid/tests/test_cubic.py @@ -119,7 +119,7 @@ def test_raise_error_when_using_interpolation(self): points, values, method="not cubic" ) self.assertEqual( - "Argument method should be either cubic, linear, or nearest , got not cubic", + "Argument method should be one of ['linear', 'nearest', 'slinear', 'cubic', 'quintic', 'pchip'], got not cubic", str(err.exception), ) # Test raises error if dimension is two. @@ -201,10 +201,10 @@ def gaussian(points): num_pts = 500 random_pts = np.random.uniform(-0.9, 0.9, (num_pts, 3)) interpolated = cubic.interpolate(random_pts, gaussian_pts, use_log=False) - assert_allclose(interpolated, gaussian(random_pts), rtol=1e-5, atol=1e-6) + assert_allclose(interpolated, gaussian(random_pts), rtol=1e-5, atol=1e-5) - interpolated = cubic.interpolate(random_pts, gaussian_pts, use_log=True) - assert_allclose(interpolated, gaussian(random_pts), rtol=1e-5, atol=1e-6) + interpolated_log = cubic.interpolate(random_pts, gaussian_pts, use_log=True) + assert_allclose(interpolated_log, gaussian(random_pts), rtol=1e-5, atol=1e-5) def test_interpolation_of_linear_function_using_scipy_linear_method(self): r"""Test interpolation of a linear function using scipy with linear method."""