-
-
Notifications
You must be signed in to change notification settings - Fork 175
Open
Description
Hello,
When solving the (trivial) SDE dt, and number of trajectories, and it appears to be specific to SDE solvers.
import diffrax as dx
import jax
import jax.numpy as jnp
from matplotlib import pyplot as plt
# === simulation parameters
key = jax.random.PRNGKey(42)
t0 = 0
t1 = 1
y0 = 1.0
ndt = 101
dt = (t1 - t0) / (ndt - 1)
drift = lambda t, y, args: -y
diffusion = lambda t, y, args: 0.2
# === diffrax euler
brownian_motion = dx.VirtualBrownianTree(t0, t1, tol=1e-3, shape=(), key=key)
solver = dx.Euler()
terms = dx.MultiTerm(dx.ODETerm(drift), dx.ControlTerm(diffusion, brownian_motion))
saveat = dx.SaveAt(ts=jnp.linspace(t0, t1, ndt))
@jax.jit
def diffrax_simu():
return dx.diffeqsolve(terms, solver, t0, t1, dt0=dt, y0=y0, saveat=saveat).ys
# === homemade euler
@jax.jit
def homemade_simu():
dWs = jnp.sqrt(dt) * jax.random.normal(key, (ndt,))
def step(y, dW):
dy = drift(None, y, None) * dt + diffusion(None, y, None) * dW
return y + dy, y
return jax.lax.scan(step, 1.0, dWs)[-1]
# === plot a single trajectory
y = diffrax_simu()
plt.plot(y)
y = homemade_simu()
plt.plot(y)
# === benchmark
%timeit diffrax_simu().block_until_ready()
%timeit homemade_simu().block_until_ready()5.39 ms ± 261 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
19.7 μs ± 899 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels