You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository was archived by the owner on Apr 19, 2026. It is now read-only.
Traceback (most recent call last):
File "test.py", line 15, in <module>
f2(x)
File "test.py", line 9, in f1
y = soft_rank(x)[0]
File "/home/patrick/.pyenv/versions/3.8.5/lib/python3.8/site-packages/fast_soft_sort/jax_ops.py", line 80, in soft_rank
return jnp.vstack([func(val) for val in values])
File "/home/patrick/.pyenv/versions/3.8.5/lib/python3.8/site-packages/fast_soft_sort/jax_ops.py", line 80, in <listcomp>
return jnp.vstack([func(val) for val in values])
File "/home/patrick/.pyenv/versions/3.8.5/lib/python3.8/site-packages/fast_soft_sort/jax_ops.py", line 35, in _func_fwd
values = np.array(values)
jax._src.traceback_util.FilteredStackTrace: Exception: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[3])>wit$
<DynamicJaxprTrace(level=0/1)>.
This error can occur when a JAX Tracer object is passed to a raw numpy function, or a method on a numpy.ndarray object. You might want to check that you are using `jnp` toge$
her with `import jax.numpy as jnp` rather than using `np` via `import numpy as np`. If this error arises on a line that involves array indexing, like `x[idx]`, it may be tha$
the array being indexed `x` is a raw numpy.ndarray while the indices `idx` are a JAX Tracer instance; in that case, you can instead write `jax.device_put(x)[idx]`.
Thanks for the work on this! Here's an example: