Nov 2025: Added "ePC" (Goemaere, et al., 2025)
Oct 2025: Added bidirectional PC (bPC, Oliviers, et al., 2025)
May 2025: Added "μPC" (Innocenti et al., 2025)
JPC is a JAX library for training neural networks with Predictive Coding (PC).
JPC provides a simple, fast and flexible API for training of a variety of PCNs including discriminative, generative and hybrid models.
- Like JAX, JPC is completely functional in design, and the core library code is <1000 lines of code.
- Unlike existing implementations, JPC provides a wide range of optimisers, both discrete and continuous, to solve the inference dynamics of PC, including ordinary differential equation (ODE) solvers.
- JPC also provides some analytical tools that can be used to study and potentially diagnose issues with PCNs.
If you're new to JPC, we recommend starting from the example notebooks and checking the documentation.
Clone the repo and in the project's directory run
pip install .
Requires Python 3.10+ and JAX 0.4.38–0.5.2 (inclusive). For GPU usage, upgrade jax to the appropriate cuda version (12 as an example here).
pip install --upgrade "jax[cuda12]"
Available at https://thebuckleylab.github.io/jpc/.
Use jpc.make_pc_step() to update the parameters of any neural network
compatible with PC updates (see the notebook examples
)
import jax.random as jr
import jax.numpy as jnp
import equinox as eqx
import optax
import jpc
# toy data
x = jnp.array([1., 1., 1.])
y = -x
# define model and optimiser
key = jr.PRNGKey(0)
model = jpc.make_mlp(
key,
input_dim=3,
width=50,
depth=5,
output_dim=3
act_fn="relu"
)
optim = optax.adam(1e-3)
opt_state = optim.init(
(eqx.filter(model, eqx.is_array), None)
)
# perform one training step with PC
result = jpc.make_pc_step(
model=model,
optim=optim,
opt_state=opt_state,
output=y,
input=x
)
# updated model and optimiser
model, opt_state = result["model"], result["opt_state"]Under the hood, jpc.make_pc_step()
- integrates the inference (activity) dynamics using a diffrax ODE solver, and
- updates model parameters at the numerical solution of the activities with a given optax optimiser.
See the documentation for more details.
NOTE: All convenience training and test functions such as
make_pc_step()are already "jitted" (for optimised performance) for the user's convenience.
Advanced users can access all the underlying functions of jpc.make_pc_step()
as well as additional features. A custom PC training step looks like the
following:
import jpc
# 1. initialise activities with a feedforward pass
activities = jpc.init_activities_with_ffwd(model=model, input=x)
# 2. perform inference (state optimisation)
activity_opt_state = activity_optim.init(activities)
for _ in range(len(model)):
activity_update_result = jpc.update_pc_activities(
params=(model, None),
activities=activities,
optim=activity_optim,
opt_state=activity_opt_state,
output=y,
input=x
)
activities = activity_update_result["activities"]
activity_opt_state = activity_update_result["opt_state"]
# 3. update parameters at the activities' solution with PC
result = jpc.update_params(
params=(model, None),
activities=converged_activities,
optim=optim,
opt_state=opt_state,
output=y,
input=x
)which can be embedded in a jitted function with any other additional computations. Again, see the docs for details.
Contributions are welcome! Fork the repo, install in editable mode (pip install -e .), then:
- Run
ruff check .before committing (auto-fix withruff check --fix .) - Ensure all tests pass:
pytest tests/ - Add docstrings to public functions and update
docs/for user-facing changes - Open a PR with a clear description
For major features, open an issue first to discuss.
If you found this library useful in your work, please cite (paper link):
@article{innocenti2024jpc,
title={JPC: Flexible Inference for Predictive Coding Networks in JAX},
author={Innocenti, Francesco and Kinghorn, Paul and Yun-Farmbrough, Will and Varona, Miguel De Llanza and Singh, Ryan and Buckley, Christopher L},
journal={arXiv preprint arXiv:2412.03676},
year={2024}
}Also consider starring the repo! ⭐️
We are grateful to Patrick Kidger for early advice on how to use Diffrax.