Skip to content

Performance issue with SDE solver #517

@pierreguilmin

Description

@pierreguilmin

Hello,

When solving the (trivial) SDE $d y_t = -y_t\ dt + 0.2\ dW_t$, the Diffrax Euler solver is ~200x slower than a naive for loop. Am I doing something wrong? The speed difference is consistent across various SDEs, solvers, time steps dt, and number of trajectories, and it appears to be specific to SDE solvers.

import diffrax as dx
import jax
import jax.numpy as jnp
from matplotlib import pyplot as plt

# === simulation parameters
key = jax.random.PRNGKey(42)
t0 = 0
t1 = 1
y0 = 1.0
ndt = 101
dt = (t1 - t0) / (ndt - 1)
drift = lambda t, y, args: -y
diffusion = lambda t, y, args: 0.2

# === diffrax euler
brownian_motion = dx.VirtualBrownianTree(t0, t1, tol=1e-3, shape=(), key=key)
solver = dx.Euler()
terms = dx.MultiTerm(dx.ODETerm(drift), dx.ControlTerm(diffusion, brownian_motion))
saveat = dx.SaveAt(ts=jnp.linspace(t0, t1, ndt))

@jax.jit
def diffrax_simu():
    return dx.diffeqsolve(terms, solver, t0, t1, dt0=dt, y0=y0, saveat=saveat).ys

# === homemade euler
@jax.jit
def homemade_simu():
    dWs = jnp.sqrt(dt) * jax.random.normal(key, (ndt,))

    def step(y, dW):
        dy = drift(None, y, None) * dt + diffusion(None, y, None) * dW
        return y + dy, y

    return jax.lax.scan(step, 1.0, dWs)[-1]

# === plot a single trajectory
y = diffrax_simu()
plt.plot(y)
y = homemade_simu()
plt.plot(y)

# === benchmark
%timeit diffrax_simu().block_until_ready()
%timeit homemade_simu().block_until_ready()
5.39 ms ± 261 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
19.7 μs ± 899 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions