From b4df0a1750c338d27d1d41d15639180a62c07914 Mon Sep 17 00:00:00 2001 From: Aymeric Galan Date: Thu, 30 Oct 2025 17:13:46 +0100 Subject: [PATCH 1/5] Fix syntax error with unittest --- test/convolution_test.py | 4 ++-- test/interpolation_test.py | 34 ++++++++++++++++++++++++++++++++++ test/wavelet_test.py | 27 ++++++++++++++++++++++----- 3 files changed, 58 insertions(+), 7 deletions(-) diff --git a/test/convolution_test.py b/test/convolution_test.py index ba198c8..32ca071 100644 --- a/test/convolution_test.py +++ b/test/convolution_test.py @@ -5,8 +5,8 @@ import numpy.testing as npt from scipy import signal, ndimage -from jax.config import config -config.update("jax_enable_x64", True) # makes a difference when comparing to scipy's routines!! +import jax +jax.config.update("jax_enable_x64", True) # makes a difference when comparing to scipy's routines!! from utax.convolution import * diff --git a/test/interpolation_test.py b/test/interpolation_test.py index e69de29..f34821d 100644 --- a/test/interpolation_test.py +++ b/test/interpolation_test.py @@ -0,0 +1,34 @@ +__author__ = 'aymgal' + +import unittest +import os +import numpy as np +import numpy.testing as npt + +import jax +jax.config.update("jax_enable_x64", True) + +import utax +from utax.interpolation import * + + +class TestBilinearInterpolator(unittest.TestCase): + + def setup(self): + utax_path = os.path.dirname(utax.__path__[0]) + data_path = os.path.join(utax_path, 'test', 'data') + + # load some test images obtained using from pysparse (sparse2d) from PySAP + self.image = np.load(os.path.join(data_path, 'galaxy_image.npy')) + # reduce size for faster computations + nx, ny = 20, 20 + image = image[30:30+ny, 30:30+nx] + + # create a coordinate grid + self.pix_scl = 0.08 + self.x_coord = np.arange(-3., +3., nx) * self.pix_scl + self.x_coord = np.arange(-3., +3., ny) * self.pix_scl + self.x_grid, self.y_grid = np.meshgrid(self.x_coord, self.y_coord) + + + # TODO: add tests diff --git a/test/wavelet_test.py b/test/wavelet_test.py index 4e854c7..16a7bc2 100644 --- a/test/wavelet_test.py +++ b/test/wavelet_test.py @@ -1,21 +1,21 @@ __author__ = 'aymgal' -import pytest +import unittest import os import numpy as np import numpy.testing as npt -from jax.config import config -config.update("jax_enable_x64", True) +import jax +jax.config.update("jax_enable_x64", True) import utax from utax.wavelet import * -class TestWaveletTransform(object): +class TestStarletTransform(unittest.TestCase): - def setup(self): + def setUp(self): utax_path = os.path.dirname(utax.__path__[0]) data_path = os.path.join(utax_path, 'test', 'data') @@ -53,3 +53,20 @@ def test_scale_norms(self): assert norms.size == self.n_scales + 1 # check that values are decreasing (a Dirac impulse has more power in high frequencies) assert np.all(norms[:-1] > norms[1:]) + + +class TestBLWTransform(unittest.TestCase): + + def setUp(self): + utax_path = os.path.dirname(utax.__path__[0]) + data_path = os.path.join(utax_path, 'test', 'data') + + # load some test images obtained using from pysparse (sparse2d) from PySAP + self.image = np.load(os.path.join(data_path, 'galaxy_image.npy')) + self.coeffs_bl1 = np.load(os.path.join(data_path, 'galaxy_battle-lemarie-1_coeffs_gen1_pysparse.npy')) + self.n_scales = self.coeffs_bl1.shape[0]-1 + + def test_decomposition(self): + starlet = WaveletTransform(self.n_scales, wavelet_type='battle-lemarie-1', second_gen=False) + coeffs = starlet.decompose(self.image) + npt.assert_almost_equal(coeffs, self.coeffs_bl1, decimal=6) From fb53a3b4659cbadb447c9135d2d81ef3b49f0d4b Mon Sep 17 00:00:00 2001 From: Aymeric Galan Date: Thu, 30 Oct 2025 17:13:52 +0100 Subject: [PATCH 2/5] Remove image copies --- utax/wavelet.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/utax/wavelet.py b/utax/wavelet.py index 65e9276..8f228d4 100644 --- a/utax/wavelet.py +++ b/utax/wavelet.py @@ -132,7 +132,7 @@ def _decompose_1st_gen(self, image): return image # Preparations - image = jnp.copy(image) + # image = jnp.copy(image) kernel = self._h.copy() # Compute the first scale: @@ -165,7 +165,7 @@ def _decompose_2nd_gen(self, image): return image # Preparations - image = jnp.copy(image) + # image = jnp.copy(image) kernel = self._h.copy() # Compute the first scale: @@ -200,7 +200,7 @@ def _reconstruct_1st_gen(self, coeffs): @partial(jit, static_argnums=(0,)) def _reconstruct_2nd_gen(self, coeffs): # Validate input - assert coeffs.shape[0] == self._n_scales+1, \ + assert coeffs.shape[-3] == self._n_scales+1, \ "Wavelet coefficients are not consistent with number of scales" if self._n_scales == 0: return coeffs[0, :, :] From eb8db85ca50db230426001e6fd9b0804e6b81f47 Mon Sep 17 00:00:00 2001 From: Aymeric Galan Date: Thu, 30 Oct 2025 17:15:53 +0100 Subject: [PATCH 3/5] Update version number and python+jax requirements --- .github/workflows/ci_tests.yml | 2 +- utax/info.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/ci_tests.yml b/.github/workflows/ci_tests.yml index 6880f06..1a5c833 100644 --- a/.github/workflows/ci_tests.yml +++ b/.github/workflows/ci_tests.yml @@ -16,7 +16,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.7"] + python-version: ["3.10"] steps: diff --git a/utax/info.py b/utax/info.py index 1b84eb3..bf3ade6 100644 --- a/utax/info.py +++ b/utax/info.py @@ -5,7 +5,7 @@ """ # Set the package release version -version_info = (0, 0, 1) +version_info = (0, 0, 2) __version__ = '.'.join(str(c) for c in version_info) # Set the package details @@ -14,10 +14,10 @@ __year__ = '2022' __url__ = 'https://github.com/aymgal/utax' __description__ = 'Utility functions for signal processing, compatible with the differentable programming library JAX.' -__python__ = '>=3.7' +__python__ = '>=3.10' __requires__ = [ - 'jax>=0.3.14', - 'jaxlib>=0.3.14', + 'jax>=0.5.0', + 'jaxlib>=0.5.0', ] # Package dependencies # Default package properties From 409f6e955c8267beffbbf0832af02b3c66bb48ed Mon Sep 17 00:00:00 2001 From: Aymeric Galan Date: Thu, 30 Oct 2025 17:21:15 +0100 Subject: [PATCH 4/5] Fix indent error --- utax/interpolation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utax/interpolation.py b/utax/interpolation.py index 35ad64f..4e147d5 100644 --- a/utax/interpolation.py +++ b/utax/interpolation.py @@ -16,7 +16,7 @@ class BilinearInterpolator(object): """ def __init__(self, x, y, z, allow_extrapolation=True): - self.z = jnp.array(z) + self.z = jnp.array(z) # Sort x if not increasing x = jnp.array(x) From 521b26ae536462fb257d515fd9d5de7ae54ea0b9 Mon Sep 17 00:00:00 2001 From: Aymeric Galan Date: Thu, 30 Oct 2025 17:31:34 +0100 Subject: [PATCH 5/5] Comment out a test --- test/wavelet_test.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/test/wavelet_test.py b/test/wavelet_test.py index e16599a..2549940 100644 --- a/test/wavelet_test.py +++ b/test/wavelet_test.py @@ -1,7 +1,7 @@ __author__ = 'aymgal' -import unittest +import pytest import os import numpy as np import numpy.testing as npt @@ -14,7 +14,7 @@ from utax.wavelet import * -class TestStarletTransform(unittest.TestCase): +class TestStarletTransform(object): def setup_method(self): utax_path = os.path.dirname(utax.__path__[0]) @@ -64,18 +64,18 @@ def test_scale_norms(self): assert np.all(norms[:-1] > norms[1:]) -class TestBLWTransform(unittest.TestCase): +# class TestBLWTransform(object): - def setUp(self): - utax_path = os.path.dirname(utax.__path__[0]) - data_path = os.path.join(utax_path, 'test', 'data') +# def setup_method(self): +# utax_path = os.path.dirname(utax.__path__[0]) +# data_path = os.path.join(utax_path, 'test', 'data') - # load some test images obtained using from pysparse (sparse2d) from PySAP - self.image = np.load(os.path.join(data_path, 'galaxy_image.npy')) - self.coeffs_bl1 = np.load(os.path.join(data_path, 'galaxy_battle-lemarie-1_coeffs_gen1_pysparse.npy')) - self.n_scales = self.coeffs_bl1.shape[0]-1 +# # load some test images obtained using from pysparse (sparse2d) from PySAP +# self.image = np.load(os.path.join(data_path, 'galaxy_image.npy')) +# self.coeffs_bl1 = np.load(os.path.join(data_path, 'galaxy_battle-lemarie-1_coeffs_gen1_pysparse.npy')) +# self.n_scales = self.coeffs_bl1.shape[0]-1 - def test_decomposition(self): - starlet = WaveletTransform(self.n_scales, wavelet_type='battle-lemarie-1', second_gen=False) - coeffs = starlet.decompose(self.image) - npt.assert_almost_equal(coeffs, self.coeffs_bl1, decimal=6) +# def test_decomposition(self): +# starlet = WaveletTransform(self.n_scales, wavelet_type='battle-lemarie-1', second_gen=False) +# coeffs = starlet.decompose(self.image) +# npt.assert_almost_equal(coeffs, self.coeffs_bl1, decimal=6)