From 5f4deff81d1a9bb0cf8f5c120f3763e37fccfe52 Mon Sep 17 00:00:00 2001 From: Etienne Pot Date: Tue, 1 Feb 2022 06:26:27 -0800 Subject: [PATCH] Add re-usable interp util (better performances) PiperOrigin-RevId: 425609213 --- etils/enp/README.md | 2 +- etils/enp/interp_utils.py | 52 +++++++++++++++++++++++++++------- etils/enp/interp_utils_test.py | 29 +++++++++++++++++-- 3 files changed, 70 insertions(+), 13 deletions(-) diff --git a/etils/enp/README.md b/etils/enp/README.md index 7fee9b50..05f6c3b1 100644 --- a/etils/enp/README.md +++ b/etils/enp/README.md @@ -30,7 +30,7 @@ Examples: * Normalize `np.uint8` image to `np.float32`: ```python - img = enp.interp(img, (0, 255), (-1, 1)) + img = enp.interp(img, from_=(0, 255), to=(-1, 1)) ``` * Converting normalized 3d coordinates to world coordinates: diff --git a/etils/enp/interp_utils.py b/etils/enp/interp_utils.py index 3f1a52e4..4d74bb6c 100644 --- a/etils/enp/interp_utils.py +++ b/etils/enp/interp_utils.py @@ -14,7 +14,10 @@ """Interpolate utils.""" -from typing import Tuple, Union +from __future__ import annotations + +import functools +from typing import Callable, Optional, Union from etils.array_types import Array, ArrayLike, FloatArray # pylint: disable=g-multiple-import from etils.enp import numpy_utils @@ -26,11 +29,12 @@ def interp( - x: Array['*d'], - from_: Tuple[_MinMaxValue, _MinMaxValue], - to: Tuple[_MinMaxValue, _MinMaxValue], + x: Optional[Array['*d']] = None, + *, + from_: tuple[_MinMaxValue, _MinMaxValue], + to: tuple[_MinMaxValue, _MinMaxValue], axis: int = -1, -) -> FloatArray['*d']: +) -> Union[FloatArray['*d'], Callable[[Array['*d']], FloatArray['*d']]]: """Linearly scale the given value by the given range. Somehow similar to `np.interp` or `scipy.interpolate.inter1d` with some @@ -47,7 +51,7 @@ def interp( [0, 0], [127, 255], ]) - img = enp.interp(img, (0, 255), (0, 1)) + img = enp.interp(img, from_=(0, 255), to=(0, 1)) img == jnp.array([ [-1, -1], [0.498..., 1], @@ -68,15 +72,28 @@ def interp( * `coords[:, 1]` is interpolated from `(-1, 1)` to `(0, w)` * `coords[:, 2]` is interpolated from `(-1, 1)` to `(0, d)` + To apply the same interpolation on multiple arrays, you can use: + + ```python + interp_fn = enp.interp(from_=(-1, 1), to=(0, 255)) + img0 = interp_fn(img0) + img1 = interp_fn(img1) + ... + ``` + Args: - x: Array to scale + x: Array to scale. If not set, this function return a function which can + be applied multiple times (for faster performance). from_: Range of x. to: Range to which normalize x. axis: Axis on which normalizing. Only relevant if `from_` or `to` items contains range value. Returns: - Float tensor with same shape as x, but with normalized coordinates. + If `x` is set: Float tensor with same shape as x, but with normalized + coordinates. + If `x` missing: A callable with signature: `interp_fn(x) -> FloatArray` + to interpolate multiple arrays with the same from/to factors. """ # Could add an `axis` argument. # Could add an `fill_values` argument to indicates the behavior if input @@ -98,7 +115,12 @@ def interp( # `a` can be scalar or array of shape=(x.shape[-1],), same for `b` a, b = _linear_interp_factors(*from_, *to) # pytype: disable=wrong-arg-types - return a * x + b + + if x is None: + fn = functools.partial(_apply_interp, a=a, b=b) + return functools.wraps(_apply_interp)(fn) + else: + return _apply_interp(x, a=a, b=b) def _linear_interp_factors( @@ -106,8 +128,18 @@ def _linear_interp_factors( old_max: _MinMaxValue, new_min: _MinMaxValue, new_max: _MinMaxValue, -) -> Tuple[Union[float, FloatArray['d']], Union[float, FloatArray['d']]]: +) -> tuple[Union[float, FloatArray['d']], Union[float, FloatArray['d']]]: """Resolve the `y = a * x + b` equation and returns the factors.""" a = (new_min - new_max) / (old_min - old_max) b = (old_min * new_max - new_min * old_max) / (old_min - old_max) return a, b + + +def _apply_interp( + x: Array['*d'], + *, + a: _MinMaxValue, + b: _MinMaxValue, +) -> Array['*d']: + """Apply the interpolation with pre-computed factors.""" + return a * x + b diff --git a/etils/enp/interp_utils_test.py b/etils/enp/interp_utils_test.py index ded0edbe..105d8ad0 100644 --- a/etils/enp/interp_utils_test.py +++ b/etils/enp/interp_utils_test.py @@ -101,7 +101,7 @@ def test_interp_coords(xnp): [1, 1], ]) assert xnp.allclose( - enp.interp(coords, (-1, 1), (0, (1024, 256))), + enp.interp(coords, from_=(-1, 1), to=(0, (1024, 256))), xnp.array([ [0, 0], [0, 128], @@ -115,8 +115,33 @@ def test_interp_coords(xnp): [[256, 256], [0, 768]], ]) assert xnp.allclose( - enp.interp(coords, (0, (256, 1024)), (0, 1)), + enp.interp(coords, from_=(0, (256, 1024)), to=(0, 1)), xnp.array([ [[0, 0], [0, 1]], [[1, 0.25], [0, 0.75]], ])) + + +@pytest.mark.parametrize('xnp', [np, jnp, tnp]) +def test_interp_function(xnp): + + vals = xnp.array([ + [-1, -1], + [-1, 0], + [-1, 1], + [0.5, 1], + [1, 1], + ]) + expected = xnp.array([ + [0, 0], + [0, 128], + [0, 256], + [192, 256], + [256, 256], + ]) + + interp_fn = enp.interp(from_=(-1, 1), to=(0, 256)) + + for val, val_expected in zip(vals, expected): + val_out = interp_fn(val) + assert xnp.allclose(val_out, val_expected)