Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions src/grid/cubic.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,10 @@ def interpolate(self, points, values, use_log=False, nu_x=0, nu_y=0, nu_z=0, met
The interpolation of a function (or of it's derivatives) at a :math:`M` point.

"""
if method not in ["cubic", "linear", "nearest"]:
supported_methods = ["linear", "nearest", "slinear", "cubic", "quintic", "pchip"]
if method not in supported_methods:
raise ValueError(
f"Argument method should be either cubic, linear, or nearest , got {method}"
f"Argument method should be one of {supported_methods}, got {method}"
)
if self.ndim != 3:
raise NotImplementedError(
Expand All @@ -159,12 +160,21 @@ def interpolate(self, points, values, use_log=False, nu_x=0, nu_y=0, nu_z=0, met
if use_log:
values = np.log(values)

# Use scipy if linear and nearest is requested and raise error if it's not cubic.
if method in ["linear", "nearest"]:
# Use scipy if no derivatives are requested
if method in supported_methods and (nu_x == 0 and nu_y == 0 and nu_z == 0):
x, y, z = self.get_points_along_axes()
values = values.reshape(self.shape)
interpolate = RegularGridInterpolator((x, y, z), values, method=method)
return interpolate(points)
interpolated = interpolate(points)
if use_log:
return np.exp(interpolated)
return interpolated

# At this point, derivatives are requested, which requires our custom cubic spline implementation
if method != "cubic":
raise NotImplementedError(
f"Computing analytical derivatives (nu_x={nu_x}, nu_y={nu_y}, nu_z={nu_z}) is only supported for the 'cubic' method."
)

# Interpolate the Z-Axis.
def z_spline(z, x_index, y_index, nu_z=nu_z):
Expand Down
8 changes: 4 additions & 4 deletions src/grid/tests/test_cubic.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def test_raise_error_when_using_interpolation(self):
points, values, method="not cubic"
)
self.assertEqual(
"Argument method should be either cubic, linear, or nearest , got not cubic",
"Argument method should be one of ['linear', 'nearest', 'slinear', 'cubic', 'quintic', 'pchip'], got not cubic",
str(err.exception),
)
# Test raises error if dimension is two.
Expand Down Expand Up @@ -201,10 +201,10 @@ def gaussian(points):
num_pts = 500
random_pts = np.random.uniform(-0.9, 0.9, (num_pts, 3))
interpolated = cubic.interpolate(random_pts, gaussian_pts, use_log=False)
assert_allclose(interpolated, gaussian(random_pts), rtol=1e-5, atol=1e-6)
assert_allclose(interpolated, gaussian(random_pts), rtol=1e-5, atol=1e-5)

interpolated = cubic.interpolate(random_pts, gaussian_pts, use_log=True)
assert_allclose(interpolated, gaussian(random_pts), rtol=1e-5, atol=1e-6)
interpolated_log = cubic.interpolate(random_pts, gaussian_pts, use_log=True)
assert_allclose(interpolated_log, gaussian(random_pts), rtol=1e-5, atol=1e-5)

def test_interpolation_of_linear_function_using_scipy_linear_method(self):
r"""Test interpolation of a linear function using scipy with linear method."""
Expand Down