Stochastic PDE solvers built on top of Exponax.
Installation • Quickstart • Available Steppers • Utilities for SPDEs • Method • Validation • Tests • Extensions • Limitations • References • Citation
exponax-spde extends the Exponax library
with first-class support for stochastic partial differential equations (SPDEs).
It follows exponax's design philosophy. Every stochastic stepper is a
differentiable Equinox Module,
JIT-compilable, vmappable, and GPU/TPU-ready, while adding the noise
infrastructure, ensemble utilities, and hybrid coupling tools that deterministic exponax does not need.
Note: this package requires
exponax >= 0.2.0as a dependency. It was developed following Felix Köhler's suggestion to build custom SPDE solvers on top of exponax rather than inside it. See the related exponax PR #101 for context.
pip install exponax-spdeFor development (includes test and notebook dependencies):
git clone https://github.com/GVourvachakis/exponax-spde.git
cd exponax-spde
pip install -e ".[dev,test]"Requires Python 3.10+, JAX 0.4.13+.
For 3-D volume rendering in the validation notebook (optional):
pip install exponax-spde[vape4d]Simulate the stochastic Allen-Cahn equation in 1-D with additive Q-Wiener noise — one stepper object, one line to roll out 500 steps across 64 independent ensemble members:
import jax
import exponax as ex
import exponax_spde as spde
# Build the stepper
stepper = spde.StochasticAllenCahn(
num_spatial_dims=1,
domain_extent=1.0,
num_points=128,
dt=5e-4,
diffusivity=0.01,
lambda_=1.0,
sigma=0.1,
noise_alpha=1.5, # colour index: α=0 white noise, α>d/2 smooth
noise_type="additive",
use_taming=True,
)
# Initial condition
u0 = ex.ic.GaussianRandomField(1, powerlaw_exponent=3.0)(
128, key=jax.random.PRNGKey(0)
)
# Single trajectory
trajectory = jax.jit(
spde.stochastic_rollout(stepper, T=500, include_init=True)
)(u0, jax.random.PRNGKey(1)) # shape (501, 1, 128)
# Ensemble of M=64 independent trajectories
ensemble = jax.jit(
spde.stochastic_ensemble_rollout(stepper, T=500, M=64, include_init=False)
)(u0, jax.random.PRNGKey(2)) # shape (64, 500, 1, 128)
# Ensemble-averaged power spectrum S(k)
S_k = spde.structure_factor(ensemble, burn_in_fraction=0.5)Because every stepper is a differentiable JAX function you can freely compose
it with jax.grad, jax.vmap, and jax.jit:
# Gradient of final state w.r.t. initial condition
grad_u0 = jax.grad(
lambda u: jax.jit(stepper)(u, key=jax.random.PRNGKey(0)).sum()
)(u0)1-D stochastic Allen-Cahn: space-time diagram showing phase separation and coarsening under additive Q-Wiener noise (σ=0.1, α=1.5, λ=1).
Solves the
where
| Parameter | Description | Default |
|---|---|---|
diffusivity |
Interface width |
0.01 |
lambda_ |
Reaction rate |
1.0 |
sigma |
Noise amplitude |
0.1 |
noise_alpha |
Spectral colour index |
1.0 |
noise_type |
"additive" or "multiplicative"
|
"additive" |
use_taming |
Hutzenthaler-Jentzen taming of |
True |
use_milstein |
First-order Milstein correction (multiplicative only) | False |
Special cases:
lambda_=0→ stochastic heat equation (analytically tractable invariant measure)sigma=0→ recoversexponax.stepper.reaction.AllenCahnto machine precision
Calling convention (PRNGKey required):
u_next = stepper(u, key=jax.random.PRNGKey(0))
# Batched over ensemble members:
keys = jax.random.split(jax.random.PRNGKey(0), M)
u_batch = jax.vmap(lambda k: stepper(u0, key=k))(keys)| Function | Description |
|---|---|
stochastic_rollout(stepper, T, *, include_init) |
Single trajectory via jax.lax.scan; JIT-compatible |
stochastic_ensemble_rollout(stepper, T, M, *, include_init) |
jax.vmap
|
structure_factor(ensemble, *, burn_in_fraction) |
Ensemble power spectrum |
richardson_weak_extrapolation(stepper_coarse, stepper_fine, u0, num_steps_coarse, key, *, num_samples) |
Richardson extrapolation: |
strang_split_step(spectral_stepper, ssa_step_fn, u, ssa_state, dt, key, *, ...) |
Second-order Strang splitting for hybrid PDE/SSA coupling |
StochasticAllenCahn uses the Exponential Euler-Maruyama (EEM) scheme
(Lord, Powell & Shardlow, 2014, Chapters 7–10).
The ETD operator splitting follows exponax's reaction.AllenCahn convention
(
This avoids the exponax's existing ETD infrastructure.
The Jupyter notebook validation/validate_stochastic_allen_cahn.ipynb
provides end-to-end validation against theoretical predictions:
2-D phase-field evolution: domains coarsen toward the ±1 attractors of the double-well potential.
| Section | Validated quantity |
|---|---|
| 1 | Deterministic limit: reaction.AllenCahn in |
| 2 | 1-D invariant measure: |
| 3 | 2-D invariant measure: structure-factor heat maps |
| 4 | Noise colour sweep |
| 5 | Additive vs multiplicative noise variance growth |
| 6 | IC sensitivity: GaussianRandomField, WhiteNoise, flat IC |
| 7 | 1-D spatio-temporal visualisation and animation |
| 8 | 2-D snapshots, animation, radial PSD and coarsening ( |
| 9 | 3-D volume rendering (vape4d optional) |
| 10 | Strong convergence: measured slope vs |
| 11 | Milstein vs EEM: weak-error slopes and per-step timing |
| 12 | Richardson weak extrapolation: bias reduction |
| 13 | Hybrid SSA scaffold (strang_split_step with OU sub-step) |
| 14 | Summary table: all measured vs expected quantities with ✓/✗ |
Run the notebook with double precision enabled:
export JAX_ENABLE_X64=1
jupyter notebook validation/validate_stochastic_allen_cahn.ipynb# Fast tests only (~30 s on CPU)
JAX_ENABLE_X64=1 pytest tests/test_stochastic/ -m "not slow" -v
# Full suite including slow ensemble tests (~2 min on CPU)
JAX_ENABLE_X64=1 pytest tests/test_stochastic/ -vThe test suite covers: deterministic limit, 1-D/2-D invariant measures, strong convergence order, structure-factor grid convergence, mean/variance time series, Milstein sanity, and JAX-compatibility (JIT, vmap, PyTree).
The stochastic/ subpackage is designed to accept future additions.
To add a new SPDE stepper (e.g. stochastic Kuramoto-Sivashinsky):
- Create
exponax_spde/stepper/stochastic/_stochastic_ks.py - Subclass
BaseStepperfromexponax_spde - Implement
_build_linear_operatorand_build_nonlinear_fun - Use
TamedPolynomialNonlinearFunif the nonlinearity grows super-linearly - Disable
step(), require a PRNGKey in__call__ - Add a
TestDeterministicLimitclass as a first sanity check - Export from
exponax_spde/stepper/stochastic/__init__.py
See docs/extending.md for a step-by-step template.
-
DC/Nyquist mode variance:
.realafterifftsilently halves the per-step variance of the DC and Nyquist modes. -
Non-standard Milstein prefactor: the Milstein correction carries a
$\varphi_1(L_k\Delta t)\Delta t$ ETD factor; measured weak-error slope is therefore closer to 1 than to 2. -
Noise above dealiasing cutoff:
_noise_stdis not masked by the 2/3 filter, consistent with Q-Wiener conventions (Lord et al., 2014, Ch. 10). - Strong convergence test: uses a same-key proxy rather than exact path coupling; measured slope is a lower bound on the true strong order.
-
strang_split_stepJIT limitation: not JIT-compatible across steps when the SSA sub-step uses a Python-level RNG.
- Allen, S. M., & Cahn, J. W. (1979). Acta Metallurgica, 27(6), 1085–1095.
- Lord, G. J., Powell, C. E., & Shardlow, T. (2014). An Introduction to Computational Stochastic PDEs. Cambridge University Press.
- Jentzen, A., & Kloeden, P. E. (2009a). Proceedings of the Royal Society A, 465(2102), 649–667.
- Hutzenthaler, M., & Jentzen, A. (2015). Memoirs of the American Mathematical Society, 236(1112).
- Cox, S. M., & Matthews, P. C. (2002). Journal of Computational Physics, 176(2), 430–455.
- Kassam, A.-K., & Trefethen, L. N. (2005). SIAM Journal on Scientific Computing, 26(4), 1214–1233.
- Strang, G. (1968). SIAM Journal on Numerical Analysis, 5(3), 506–517.
- Gillespie, D. T. (1977). The Journal of Physical Chemistry, 81(25), 2340–2361.
If you use this package in your research, please cite the upstream exponax paper
and this repository:
@article{koehler2024apebench,
title={{APEBench}: A Benchmark for Autoregressive Neural Emulators of {PDE}s},
author={Felix Koehler and Simon Niedermayr and R{\"u}diger Westermann and Nils Thuerey},
journal={Advances in Neural Information Processing Systems (NeurIPS)},
volume={38},
year={2024}
}
@software{vakis2025exponaxspde,
author = {Vakis, Georgios},
title = {exponax-spde: Stochastic PDE solvers built on top of exponax},
year = {2025},
url = {https://github.com/GVourvachakis/exponax-spde},
}MIT — see LICENSE.txt.



