diff --git a/.github/workflows/build-ci.yml b/.github/workflows/build-ci.yml index 0c67763..d064339 100644 --- a/.github/workflows/build-ci.yml +++ b/.github/workflows/build-ci.yml @@ -19,7 +19,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v6 with: - python-version: "3.10" + python-version: "3.13" cache: "pip" cache-dependency-path: pyproject.toml @@ -39,7 +39,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v6 with: - python-version: "3.10" + python-version: "3.13" cache: "pip" cache-dependency-path: pyproject.toml @@ -56,55 +56,12 @@ jobs: run: | darglint src --strictness=short --ignore-raise=ValueError - tests-jax-latest: + tests: runs-on: ubuntu-latest - timeout-minutes: 2 - steps: - - name: Checkout repository - uses: actions/checkout@v5 + strategy: + matrix: + python-version: ["3.10", "3.13"] - - name: Set up Python - uses: actions/setup-python@v6 - with: - python-version: "3.10" - cache: "pip" - cache-dependency-path: pyproject.toml - - - name: Setup environment - run: | - python -m pip install --upgrade pip - pip install ".[tests,dev]" - - - name: Run Python tests - run: | - pytest --cov=agjax tests - - tests-jax-0_4_35: - runs-on: ubuntu-latest - timeout-minutes: 2 - steps: - - name: Checkout repository - uses: actions/checkout@v5 - - - name: Set up Python - uses: actions/setup-python@v6 - with: - python-version: "3.10" - cache: "pip" - cache-dependency-path: pyproject.toml - - - name: Setup environment - run: | - python -m pip install --upgrade pip - pip install --upgrade "jax[cpu]==0.4.35" - pip install ".[tests,dev]" - - - name: Run Python tests - run: | - pytest --cov=agjax tests - - tests-jax-0_4_27: - runs-on: ubuntu-latest timeout-minutes: 2 steps: - name: Checkout repository @@ -113,21 +70,20 @@ jobs: - name: Set up Python uses: actions/setup-python@v6 with: - python-version: "3.10" + python-version: ${{ matrix.python-version }} cache: "pip" cache-dependency-path: pyproject.toml - name: Setup environment run: | python -m pip install --upgrade pip - pip install --upgrade "jax[cpu]==0.4.27" pip install ".[tests,dev]" - name: Run Python tests run: | pytest --cov=agjax tests - test_docs: + test-docs: runs-on: ubuntu-latest timeout-minutes: 3 needs: [pre-commit] @@ -138,7 +94,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v6 with: - python-version: '3.10' + python-version: '3.13' cache: "pip" cache-dependency-path: pyproject.toml diff --git a/tests/experimental/test_wrapper.py b/tests/experimental/test_wrapper.py index 00bce8e..e123272 100644 --- a/tests/experimental/test_wrapper.py +++ b/tests/experimental/test_wrapper.py @@ -7,17 +7,12 @@ import autograd.numpy as npa import jax import jax.numpy as jnp -import jaxlib import numpy as onp from parameterized import parameterized from agjax import utils from agjax.experimental import wrapper -if hasattr(jaxlib, "xla_extension"): - JaxError = jaxlib.xla_extension.XlaRuntimeError -else: - JaxError = jaxlib._jax.XlaRuntimeError TEST_FNS_AND_ARGS = ( ( # Basic scalar-valued function, real outputs. @@ -79,7 +74,7 @@ def fn(x, y): ) with self.assertRaisesRegex(ValueError, "Found out of bounds"): wrapped(1.0, 2.0) - with self.assertRaisesRegex(JaxError, "Found out of bounds"): + with self.assertRaisesRegex(jax.errors.JaxRuntimeError, "Found out of bounds"): jax.grad(wrapped)(1.0, 2.0) @parameterized.expand(([2], [-3])) @@ -94,7 +89,7 @@ def fn(x, y): ) with self.assertRaisesRegex(ValueError, "Found out of bounds"): wrapped(1.0, 2.0) - with self.assertRaisesRegex(JaxError, "Found out of bounds"): + with self.assertRaisesRegex(jax.errors.JaxRuntimeError, "Found out of bounds"): jax.grad(wrapped)(1.0, 2.0) @parameterized.expand(([(1, 1)], [(1, -1)])) @@ -121,9 +116,13 @@ def fn(x, y): ) with self.assertRaisesRegex(ValueError, "At least one differentiable output"): wrapped(1.0, 2.0) - with self.assertRaisesRegex(JaxError, "At least one differentiable output"): + with self.assertRaisesRegex( + jax.errors.JaxRuntimeError, "At least one differentiable output" + ): jax.grad(wrapped)(1.0, 2.0) - with self.assertRaisesRegex(JaxError, "At least one differentiable output"): + with self.assertRaisesRegex( + jax.errors.JaxRuntimeError, "At least one differentiable output" + ): jax.value_and_grad(wrapped)(1.0, 2.0) @parameterized.expand(TEST_FNS_AND_ARGS)