From 35163777b3be942127a8f425da364dcd78cbb091 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 2 Apr 2026 12:49:04 -0700 Subject: [PATCH] Finalize deprecation of several `jax.core` APIs for JAX v0.10.0 As part of this, add three new APIs to `jax.extend.core`: - `jex.core.AbstractToken` - `jex.core.call_impl` - `jex.core.subjaxprs` PiperOrigin-RevId: 893650230 --- drjax/_src/primitives_test.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/drjax/_src/primitives_test.py b/drjax/_src/primitives_test.py index 87f4246..569a4c1 100644 --- a/drjax/_src/primitives_test.py +++ b/drjax/_src/primitives_test.py @@ -21,7 +21,8 @@ from drjax._src import impls from drjax._src import primitives import jax -from jax import numpy as jnp +from jax.extend import core as jex_core +import jax.numpy as jnp from jax.sharding import AxisType # pylint: disable=g-importing-member import numpy as np @@ -31,7 +32,13 @@ def _jaxpr_has_primitive(jaxpr, prim_name: str): for eqn in jaxpr.eqns: if prim_name in eqn.primitive.name: return True - for subjaxpr in jax.core.subjaxprs(jaxpr): + try: + # JAX v0.10.0 and newer + subjaxprs = jex_core.subjaxprs + except AttributeError: + # JAX v0.9.2 and older + subjaxprs = jax.core.subjaxprs + for subjaxpr in subjaxprs(jaxpr): if _jaxpr_has_primitive(subjaxpr, prim_name): return True return False