Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.12"]
python-version: ["3.10"]

steps:

Expand Down
34 changes: 34 additions & 0 deletions test/interpolation_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
__author__ = 'aymgal'

import pytest
import os
import numpy as np
import numpy.testing as npt

import jax
jax.config.update("jax_enable_x64", True)

import utax
from utax.interpolation import *


class TestBilinearInterpolator(object):

def setup_method(self):
utax_path = os.path.dirname(utax.__path__[0])
data_path = os.path.join(utax_path, 'test', 'data')

# load some test images obtained using from pysparse (sparse2d) from PySAP
self.image = np.load(os.path.join(data_path, 'galaxy_image.npy'))
# reduce size for faster computations
nx, ny = 20, 20
image = image[30:30+ny, 30:30+nx]

# create a coordinate grid
self.pix_scl = 0.08
self.x_coord = np.arange(-3., +3., nx) * self.pix_scl
self.x_coord = np.arange(-3., +3., ny) * self.pix_scl
self.x_grid, self.y_grid = np.meshgrid(self.x_coord, self.y_coord)


# TODO: add tests
19 changes: 18 additions & 1 deletion test/wavelet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from utax.wavelet import *


class TestWaveletTransform(object):
class TestStarletTransform(object):

def setup_method(self):
utax_path = os.path.dirname(utax.__path__[0])
Expand Down Expand Up @@ -62,3 +62,20 @@ def test_scale_norms(self):
assert norms.size == self.n_scales + 1
# check that values are decreasing (a Dirac impulse has more power in high frequencies)
assert np.all(norms[:-1] > norms[1:])


# class TestBLWTransform(object):

# def setup_method(self):
# utax_path = os.path.dirname(utax.__path__[0])
# data_path = os.path.join(utax_path, 'test', 'data')

# # load some test images obtained using from pysparse (sparse2d) from PySAP
# self.image = np.load(os.path.join(data_path, 'galaxy_image.npy'))
# self.coeffs_bl1 = np.load(os.path.join(data_path, 'galaxy_battle-lemarie-1_coeffs_gen1_pysparse.npy'))
# self.n_scales = self.coeffs_bl1.shape[0]-1

# def test_decomposition(self):
# starlet = WaveletTransform(self.n_scales, wavelet_type='battle-lemarie-1', second_gen=False)
# coeffs = starlet.decompose(self.image)
# npt.assert_almost_equal(coeffs, self.coeffs_bl1, decimal=6)
8 changes: 4 additions & 4 deletions utax/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""

# Set the package release version
version_info = (0, 0, 1)
version_info = (0, 0, 2)
__version__ = '.'.join(str(c) for c in version_info)

# Set the package details
Expand All @@ -14,10 +14,10 @@
__year__ = '2022'
__url__ = 'https://github.com/aymgal/utax'
__description__ = 'Utility functions for signal processing, compatible with the differentable programming library JAX.'
__python__ = '>=3.7'
__python__ = '>=3.10'
__requires__ = [
'jax>=0.3.14',
'jaxlib>=0.3.14',
'jax>=0.5.0',
'jaxlib>=0.5.0',
] # Package dependencies

# Default package properties
Expand Down
2 changes: 1 addition & 1 deletion utax/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class BilinearInterpolator(object):

"""
def __init__(self, x, y, z, allow_extrapolation=True):
self.z = jnp.array(z)
self.z = jnp.array(z)

# Sort x if not increasing
x = jnp.array(x)
Expand Down
6 changes: 3 additions & 3 deletions utax/wavelet.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def _decompose_1st_gen(self, image):
return image

# Preparations
image = jnp.copy(image)
# image = jnp.copy(image)
kernel = self._h.copy()

# Compute the first scale:
Expand Down Expand Up @@ -181,7 +181,7 @@ def _decompose_2nd_gen(self, image):
return image

# Preparations
image = jnp.copy(image)
# image = jnp.copy(image)
kernel = self._h.copy()

# Compute the first scale:
Expand Down Expand Up @@ -216,7 +216,7 @@ def _reconstruct_1st_gen(self, coeffs):
@partial(jit, static_argnums=(0,))
def _reconstruct_2nd_gen(self, coeffs):
# Validate input
assert coeffs.shape[0] == self._n_scales+1, \
assert coeffs.shape[-3] == self._n_scales+1, \
"Wavelet coefficients are not consistent with number of scales"
if self._n_scales == 0:
return coeffs[0, :, :]
Expand Down