diff --git a/.github/workflows/ci_tests.yml b/.github/workflows/ci_tests.yml index e7cc4b3..1a5c833 100644 --- a/.github/workflows/ci_tests.yml +++ b/.github/workflows/ci_tests.yml @@ -16,7 +16,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.12"] + python-version: ["3.10"] steps: diff --git a/test/interpolation_test.py b/test/interpolation_test.py index e69de29..96ef9be 100644 --- a/test/interpolation_test.py +++ b/test/interpolation_test.py @@ -0,0 +1,34 @@ +__author__ = 'aymgal' + +import pytest +import os +import numpy as np +import numpy.testing as npt + +import jax +jax.config.update("jax_enable_x64", True) + +import utax +from utax.interpolation import * + + +class TestBilinearInterpolator(object): + + def setup_method(self): + utax_path = os.path.dirname(utax.__path__[0]) + data_path = os.path.join(utax_path, 'test', 'data') + + # load some test images obtained using from pysparse (sparse2d) from PySAP + self.image = np.load(os.path.join(data_path, 'galaxy_image.npy')) + # reduce size for faster computations + nx, ny = 20, 20 + image = image[30:30+ny, 30:30+nx] + + # create a coordinate grid + self.pix_scl = 0.08 + self.x_coord = np.arange(-3., +3., nx) * self.pix_scl + self.x_coord = np.arange(-3., +3., ny) * self.pix_scl + self.x_grid, self.y_grid = np.meshgrid(self.x_coord, self.y_coord) + + + # TODO: add tests diff --git a/test/wavelet_test.py b/test/wavelet_test.py index 7dd97f3..2549940 100644 --- a/test/wavelet_test.py +++ b/test/wavelet_test.py @@ -14,7 +14,7 @@ from utax.wavelet import * -class TestWaveletTransform(object): +class TestStarletTransform(object): def setup_method(self): utax_path = os.path.dirname(utax.__path__[0]) @@ -62,3 +62,20 @@ def test_scale_norms(self): assert norms.size == self.n_scales + 1 # check that values are decreasing (a Dirac impulse has more power in high frequencies) assert np.all(norms[:-1] > norms[1:]) + + +# class TestBLWTransform(object): + +# def setup_method(self): +# utax_path = os.path.dirname(utax.__path__[0]) +# data_path = os.path.join(utax_path, 'test', 'data') + +# # load some test images obtained using from pysparse (sparse2d) from PySAP +# self.image = np.load(os.path.join(data_path, 'galaxy_image.npy')) +# self.coeffs_bl1 = np.load(os.path.join(data_path, 'galaxy_battle-lemarie-1_coeffs_gen1_pysparse.npy')) +# self.n_scales = self.coeffs_bl1.shape[0]-1 + +# def test_decomposition(self): +# starlet = WaveletTransform(self.n_scales, wavelet_type='battle-lemarie-1', second_gen=False) +# coeffs = starlet.decompose(self.image) +# npt.assert_almost_equal(coeffs, self.coeffs_bl1, decimal=6) diff --git a/utax/info.py b/utax/info.py index 1b84eb3..bf3ade6 100644 --- a/utax/info.py +++ b/utax/info.py @@ -5,7 +5,7 @@ """ # Set the package release version -version_info = (0, 0, 1) +version_info = (0, 0, 2) __version__ = '.'.join(str(c) for c in version_info) # Set the package details @@ -14,10 +14,10 @@ __year__ = '2022' __url__ = 'https://github.com/aymgal/utax' __description__ = 'Utility functions for signal processing, compatible with the differentable programming library JAX.' -__python__ = '>=3.7' +__python__ = '>=3.10' __requires__ = [ - 'jax>=0.3.14', - 'jaxlib>=0.3.14', + 'jax>=0.5.0', + 'jaxlib>=0.5.0', ] # Package dependencies # Default package properties diff --git a/utax/interpolation.py b/utax/interpolation.py index 35ad64f..4e147d5 100644 --- a/utax/interpolation.py +++ b/utax/interpolation.py @@ -16,7 +16,7 @@ class BilinearInterpolator(object): """ def __init__(self, x, y, z, allow_extrapolation=True): - self.z = jnp.array(z) + self.z = jnp.array(z) # Sort x if not increasing x = jnp.array(x) diff --git a/utax/wavelet.py b/utax/wavelet.py index 2d8d307..52637b0 100644 --- a/utax/wavelet.py +++ b/utax/wavelet.py @@ -148,7 +148,7 @@ def _decompose_1st_gen(self, image): return image # Preparations - image = jnp.copy(image) + # image = jnp.copy(image) kernel = self._h.copy() # Compute the first scale: @@ -181,7 +181,7 @@ def _decompose_2nd_gen(self, image): return image # Preparations - image = jnp.copy(image) + # image = jnp.copy(image) kernel = self._h.copy() # Compute the first scale: @@ -216,7 +216,7 @@ def _reconstruct_1st_gen(self, coeffs): @partial(jit, static_argnums=(0,)) def _reconstruct_2nd_gen(self, coeffs): # Validate input - assert coeffs.shape[0] == self._n_scales+1, \ + assert coeffs.shape[-3] == self._n_scales+1, \ "Wavelet coefficients are not consistent with number of scales" if self._n_scales == 0: return coeffs[0, :, :]