Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion etils/enp/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ Examples:
* Normalize `np.uint8` image to `np.float32`:

```python
img = enp.interp(img, (0, 255), (-1, 1))
img = enp.interp(img, from_=(0, 255), to=(-1, 1))
```

* Converting normalized 3d coordinates to world coordinates:
Expand Down
52 changes: 42 additions & 10 deletions etils/enp/interp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@

"""Interpolate utils."""

from typing import Tuple, Union
from __future__ import annotations

import functools
from typing import Callable, Optional, Union

from etils.array_types import Array, ArrayLike, FloatArray # pylint: disable=g-multiple-import
from etils.enp import numpy_utils
Expand All @@ -26,11 +29,12 @@


def interp(
x: Array['*d'],
from_: Tuple[_MinMaxValue, _MinMaxValue],
to: Tuple[_MinMaxValue, _MinMaxValue],
x: Optional[Array['*d']] = None,
*,
from_: tuple[_MinMaxValue, _MinMaxValue],
to: tuple[_MinMaxValue, _MinMaxValue],
axis: int = -1,
) -> FloatArray['*d']:
) -> Union[FloatArray['*d'], Callable[[Array['*d']], FloatArray['*d']]]:
"""Linearly scale the given value by the given range.

Somehow similar to `np.interp` or `scipy.interpolate.inter1d` with some
Expand All @@ -47,7 +51,7 @@ def interp(
[0, 0],
[127, 255],
])
img = enp.interp(img, (0, 255), (0, 1))
img = enp.interp(img, from_=(0, 255), to=(0, 1))
img == jnp.array([
[-1, -1],
[0.498..., 1],
Expand All @@ -68,15 +72,28 @@ def interp(
* `coords[:, 1]` is interpolated from `(-1, 1)` to `(0, w)`
* `coords[:, 2]` is interpolated from `(-1, 1)` to `(0, d)`

To apply the same interpolation on multiple arrays, you can use:

```python
interp_fn = enp.interp(from_=(-1, 1), to=(0, 255))
img0 = interp_fn(img0)
img1 = interp_fn(img1)
...
```

Args:
x: Array to scale
x: Array to scale. If not set, this function return a function which can
be applied multiple times (for faster performance).
from_: Range of x.
to: Range to which normalize x.
axis: Axis on which normalizing. Only relevant if `from_` or `to` items
contains range value.

Returns:
Float tensor with same shape as x, but with normalized coordinates.
If `x` is set: Float tensor with same shape as x, but with normalized
coordinates.
If `x` missing: A callable with signature: `interp_fn(x) -> FloatArray`
to interpolate multiple arrays with the same from/to factors.
"""
# Could add an `axis` argument.
# Could add an `fill_values` argument to indicates the behavior if input
Expand All @@ -98,16 +115,31 @@ def interp(

# `a` can be scalar or array of shape=(x.shape[-1],), same for `b`
a, b = _linear_interp_factors(*from_, *to) # pytype: disable=wrong-arg-types
return a * x + b

if x is None:
fn = functools.partial(_apply_interp, a=a, b=b)
return functools.wraps(_apply_interp)(fn)
else:
return _apply_interp(x, a=a, b=b)


def _linear_interp_factors(
old_min: _MinMaxValue,
old_max: _MinMaxValue,
new_min: _MinMaxValue,
new_max: _MinMaxValue,
) -> Tuple[Union[float, FloatArray['d']], Union[float, FloatArray['d']]]:
) -> tuple[Union[float, FloatArray['d']], Union[float, FloatArray['d']]]:
"""Resolve the `y = a * x + b` equation and returns the factors."""
a = (new_min - new_max) / (old_min - old_max)
b = (old_min * new_max - new_min * old_max) / (old_min - old_max)
return a, b


def _apply_interp(
x: Array['*d'],
*,
a: _MinMaxValue,
b: _MinMaxValue,
) -> Array['*d']:
"""Apply the interpolation with pre-computed factors."""
return a * x + b
29 changes: 27 additions & 2 deletions etils/enp/interp_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def test_interp_coords(xnp):
[1, 1],
])
assert xnp.allclose(
enp.interp(coords, (-1, 1), (0, (1024, 256))),
enp.interp(coords, from_=(-1, 1), to=(0, (1024, 256))),
xnp.array([
[0, 0],
[0, 128],
Expand All @@ -115,8 +115,33 @@ def test_interp_coords(xnp):
[[256, 256], [0, 768]],
])
assert xnp.allclose(
enp.interp(coords, (0, (256, 1024)), (0, 1)),
enp.interp(coords, from_=(0, (256, 1024)), to=(0, 1)),
xnp.array([
[[0, 0], [0, 1]],
[[1, 0.25], [0, 0.75]],
]))


@pytest.mark.parametrize('xnp', [np, jnp, tnp])
def test_interp_function(xnp):

vals = xnp.array([
[-1, -1],
[-1, 0],
[-1, 1],
[0.5, 1],
[1, 1],
])
expected = xnp.array([
[0, 0],
[0, 128],
[0, 256],
[192, 256],
[256, 256],
])

interp_fn = enp.interp(from_=(-1, 1), to=(0, 256))

for val, val_expected in zip(vals, expected):
val_out = interp_fn(val)
assert xnp.allclose(val_out, val_expected)