diff --git a/drjax/_src/primitives_test.py b/drjax/_src/primitives_test.py index 87f4246..569a4c1 100644 --- a/drjax/_src/primitives_test.py +++ b/drjax/_src/primitives_test.py @@ -21,7 +21,8 @@ from drjax._src import impls from drjax._src import primitives import jax -from jax import numpy as jnp +from jax.extend import core as jex_core +import jax.numpy as jnp from jax.sharding import AxisType # pylint: disable=g-importing-member import numpy as np @@ -31,7 +32,13 @@ def _jaxpr_has_primitive(jaxpr, prim_name: str): for eqn in jaxpr.eqns: if prim_name in eqn.primitive.name: return True - for subjaxpr in jax.core.subjaxprs(jaxpr): + try: + # JAX v0.10.0 and newer + subjaxprs = jex_core.subjaxprs + except AttributeError: + # JAX v0.9.2 and older + subjaxprs = jax.core.subjaxprs + for subjaxpr in subjaxprs(jaxpr): if _jaxpr_has_primitive(subjaxpr, prim_name): return True return False