Unit 4: Learning Dynamics with Neural Differential Equations

Published

12/06/2026

Unit 3 §3.1 introduced ODEs as \dot{\mathbf{x}} = f(\mathbf{x}, t) and let OrdinaryDiffEq.jl integrate them. This unit takes the next step: make f itself learnable. Two routes — replace f entirely with a neural network (Neural ODEs), or splice a small network into known physics (UDEs). The first is the bridge from deep networks to continuous-depth models; the second sets up every inverse problem in the capstone.

We treat the solver as a black box throughout — adaptive step-size control, stiff vs non-stiff, implicit methods are deliberately out of scope. Pick the right solver, trust it, and focus on the dynamics.

4.1 ODEs as models of dynamics

Before we let neural networks loose on a vector field we need one more layer of ODE intuition: the qualitative behaviours that distinguish interesting dynamics from boring ones. The next two subsubsections lay this out (equilibria, limit cycles, chaos) and then put one example — Lotka–Volterra — under the microscope in Julia and Python so the cross-ecosystem reflex from Unit 3 §3.1 carries over.

The IVP setup from Unit 3 §3.1 carries over verbatim:

\dot{\mathbf{x}}(t) = f(\mathbf{x}(t), t),\qquad \mathbf{x}(0) = \mathbf{x}_0.

What’s new in this unit is the phenomenology — nonlinear ODEs admit qualitative behaviours linear systems can’t:

  • Equilibria — fixed points \mathbf{x}^* where f(\mathbf{x}^*) = 0. Stability is read off from the eigenvalues of the Jacobian J = \partial f / \partial \mathbf{x} at \mathbf{x}^*.
  • Limit cycles — closed orbits that attract nearby trajectories. Linear systems can’t have these; the simplest example is Lotka–Volterra.
  • Chaos — bounded, aperiodic, sensitive dependence on initial conditions. The Lorenz system we revisit in Unit 3 §3.4 (SINDy) lives here.

Worked example: Lotka–Volterra

A classical two-species predator–prey model:

\dot{x} = \alpha x - \beta x y, \qquad \dot{y} = \delta x y - \gamma y.

For positive (\alpha, \beta, \gamma, \delta) the system has a non-trivial fixed point (\gamma/\delta,\, \alpha/\beta) surrounded by closed orbits — a structurally robust limit cycle that’s a favourite testbed for SciML methods (we’ll meet it again in §4.3).

using OrdinaryDiffEq, Plots

function lv!(du, u, p, t)
    α, β, γ, δ = p
    du[1] = α*u[1] - β*u[1]*u[2]
    du[2] = δ*u[1]*u[2] - γ*u[2]
end
p     = (1.5, 1.0, 3.0, 1.0)
tspan = (0.0, 12.0)

plt = plot(xlabel="prey x", ylabel="predator y",
           title="Lotka–Volterra phase portrait",
           aspect_ratio = :equal, legend = :outerright)
for u0 in [[1.0, 1.0], [1.5, 1.0], [2.0, 1.0]]
    sol = solve(ODEProblem(lv!, copy(u0), tspan, p), Tsit5();
                saveat = 0.01)
    plot!(plt, sol[1, :], sol[2, :], label = "x(0)=$(u0[1])")
end
plt

Each initial condition lands on a different closed orbit — a first integral (a conserved quantity along trajectories; V(x, y) = \delta x - \gamma \log x + \beta y - \alpha \log y for this LV system) exists, but it is invisible to a method that only sees data. Recovering invariants like this from observations alone is exactly what Hamiltonian / Lagrangian networks (Unit 3 §3.3) and the SINDy machinery (Unit 3 §3.4) try to do.

The same model in Python with scipy.integrate.solve_ivp — same shape, same plot, different ecosystem:

units/unit_04/scripts/lotka_volterra_python.py
import numpy as np
from scipy.integrate import solve_ivp

alpha, beta, gamma, delta = 1.5, 1.0, 3.0, 1.0
def lv(t, u):
    x, y = u
    return [alpha * x - beta * x * y,
            delta * x * y - gamma * y]

t_eval = np.linspace(0, 12, 1201)
trajs = []
for u0 in ([1.0, 1.0], [1.5, 1.0], [2.0, 1.0]):
    sol = solve_ivp(lv, (0, 12), u0, t_eval=t_eval, rtol=1e-8)
    trajs.append(sol.y)
    print(f"x(0)={u0[0]:.1f}  x(12)={sol.y[0, -1]:.3f}  "
          f"y(12)={sol.y[1, -1]:.3f}")

Available at scripts/lotka_volterra_python.py.

Note✏️ Section exercise — test the invariant

The text claims V(x, y) = \delta x - \gamma \log x + \beta y - \alpha \log y is conserved along Lotka–Volterra trajectories. Verify it numerically: solve from (x_0, y_0) = (1, 1) with the parameters above and plot V(x(t), y(t)) - V(x_0, y_0) over t \in [0, 50] at solver tolerances reltol = 1e-3 and reltol = 1e-10. Then answer two questions: (a) is the drift you see physics or numerics, and how do the two tolerance runs prove it? (b) If you fit a neural network to this trajectory and integrated it for 500 time units, what would you expect V to do, and why is that the motivation for §3.3’s Hamiltonian architectures?

💡 Hint

Solve the same ODEProblem twice, changing only reltol/abstol, and plot V.(sol.u) .- V(u0...) for both. The logic: structural non-conservation wouldn’t care about solver tolerance. For part (b), recall what Solution 3.3’s perturbed field did to the energy — same mechanism, different invariant.

4.2 From ResNets to Neural ODEs

This section walks the bridge between the discrete-layer MLPs of Unit 2 and the continuous-time vector fields we’ll use in the rest of the workshop. The subsubsections cover (a) the layer-as-Euler-step intuition, (b) what a Neural ODE looks like in Lux.jl, (c) the adjoint trick that makes training through an ODE solver feasible, and (d) an illustrative training skeleton. All four together form one of the two patterns (Unit 9) the capstone inverse problem uses.

The Euler-step analogy

A residual network layer is

h_{\ell+1} = h_\ell + \Delta t\, f_\theta(h_\ell, \ell).

Compare to one explicit Euler step on the ODE \dot{h} = f_\theta(h, t):

h(t + \Delta t) \approx h(t) + \Delta t\, f_\theta(h(t), t).

Identical. Taking \Delta t \to 0 and depth L \to \infty gives a Neural ODE (Chen et al., 2018): the network is a vector field, and the “forward pass” is an ODE solve from input h(0) = \mathbf{x} to output h(T) = \text{prediction}.

Continuous-depth models in Lux.jl style

The vector field is just an MLP. Wrap it in an ODEProblem and hand it to the solver:

using Lux, OrdinaryDiffEq, Random

rng    = Random.MersenneTwister(42)
net    = Lux.Chain(Lux.Dense(2 => 16, tanh), Lux.Dense(16 => 2))
ps, st = Lux.setup(rng, net)

# Lift the network into a vector field f_θ(u, p, t)
function nn_rhs!(du, u, p, t)
    y, _ = net(u, p, st)
    du .= y
end

prob = ODEProblem(nn_rhs!, [1.0f0, 0.0f0], (0.0f0, 1.0f0), ps)
sol  = solve(prob, Tsit5(); saveat = 0.05)
size(sol)   # (state, time) — the integrated trajectory
(2, 21)

That’s a forward pass of an untrained Neural ODE. Training fits \theta so the integrated trajectory matches data — and that training is what the workshop’s later units use.

Adjoint sensitivities (sketch)

Naively Zygote.gradient through an ODE solver stores every intermediate state — what’s called discretise-then-differentiate or “direct backprop through the solver”. Memory cost scales with the number of solver steps, which is fine for short trajectories and prohibitive for long ones.

The adjoint method (originally due to Pontryagin in optimal control; rediscovered for Neural ODEs by Chen et al. 2018) replaces this by differentiate-then-discretise: derive an auxiliary ODE for the gradient — the adjoint \boldsymbol{\lambda}(t) = \partial \mathcal{L} / \partial \mathbf{x}(t) — and solve it backward in time alongside the original equation. Memory cost becomes constant in trajectory length; the price is some numerical-accuracy bookkeeping (the backward solve has to be at least as accurate as the forward one).

A third middle-ground option, checkpointing, stores only a sparse subset of states forward, then re-integrates between them on the backward pass. Logarithmic-in-time memory, logarithmic-in-time extra compute.

In Julia that machinery lives in SciMLSensitivity.jl. In Python the equivalents are torchdiffeq (Chen’s reference implementation) and diffrax (the modern JAX-based version with exceptional adjoint support). SciMLSensitivity.jl is left out of the workshop default Project.toml to keep precompile light, but a one-line addition when you reach for it.

Training a Neural ODE (illustrative)

The skeleton — define the vector field, integrate, compute the trajectory-misfit loss, gradient-descend on \theta:

using Lux, OrdinaryDiffEq, Zygote, Optimisers, SciMLSensitivity
# `SciMLSensitivity` extends `solve` with adjoint backends used
# by Zygote; add it to Project.toml before running.

rng     = Random.MersenneTwister(0)
f_θ     = Lux.Chain(Lux.Dense(2 => 16, tanh), Lux.Dense(16 => 2))
ps, st  = Lux.setup(rng, f_θ)

function predict(ps, u0, ts)
    rhs!(du, u, p, t) = (y, _ = f_θ(u, p, st); du .= y)
    sol = solve(ODEProblem(rhs!, u0, (ts[1], ts[end]), ps),
                Tsit5(); saveat = ts,
                sensealg = InterpolatingAdjoint())
    Array(sol)                          # (state, time)
end

loss(ps, u0, ts, ŷ) = sum(abs2, predict(ps, u0, ts) .- ŷ)

opt_state = Optimisers.setup(Optimisers.Adam(1.0f-2), ps)
for epoch in 1:1000
    g = first(Zygote.gradient(p -> loss(p, u0, ts, data), ps))
    opt_state, ps = Optimisers.update(opt_state, ps, g)
end

Three structural differences from the standard supervised-learning loop in Unit 2 §2.6:

  1. The forward pass is an ODE solve, not a feed-forward evaluation. Cost is dominated by the integrator.
  2. The loss is a trajectory misfit — a sum over time, not a single per-sample term.
  3. The gradient flows through the solver via sensealg. Pick the wrong one and training time blows up.

We’ll see this pattern again in Unit 9 — inverse problems with the column model use exactly the same loop, just with PDE residuals instead of trajectory data.

Note✏️ Section exercise — train a tiny Neural ODE on a spiral

Make the §4.2 skeleton run end-to-end on the classic toy: data from the linear spiral \dot{\mathbf{u}} = A\mathbf{u}, A = \begin{pmatrix} -0.1 & 2 \\ -2 & -0.1\end{pmatrix}, sampled at 30 points over t \in [0, 6] from \mathbf{u}_0 = (2, 0). Fit the 2 → 16 → 2 tanh network with the adjoint-backed training loop (SciMLSensitivity + Adam, ~1 000 iterations), then extrapolate the trained model to t = 12 and plot it against the true spiral. Where does the learned dynamics stay faithful, and where does it drift? Bonus: repeat with only the first half of the spiral as training data and watch the extrapolation get worse — the data-efficiency argument for UDEs in one picture.

💡 Hint

You’ll need two packages not in the workshop Project.toml: SciMLSensitivity (adjoints) and ComponentArrays (flat parameter vector Zygote can differentiate). Pass sensealg = InterpolatingAdjoint() to solve and train on sum(abs2, pred .- data). For extrapolation, re-solve with the trained ps over the longer tspan — don’t retrain.

4.3 Universal Differential Equations

A Neural ODE replaces all of f with a network. A Universal Differential Equation (UDE) only replaces the bit you don’t know, keeping the trusted physics intact. That’s almost always the right move in scientific applications — most domains have some equation that’s known to be correct (mass conservation, a governing PDE in part of the domain, a well-established empirical law) and some term that’s empirical or unknown. UDEs let you write the known bits with a pencil and learn the rest with a small network — the position taken throughout Units 8–10 of the capstone.

Known physics + learned closure

A Universal Differential Equation (UDE) (Rackauckas et al. 2020) is an ODE whose right-hand side splits into a known term and a learned term:

\dot{\mathbf{x}} = \underbrace{f_{\text{phys}}(\mathbf{x}, t)}_{\text{known}} \;+\; \underbrace{N_\theta(\mathbf{x}, t)}_{\text{learned}}.

You write down whatever physics you trust (f_{\text{phys}}) and let a small neural network absorb the rest. Train on observed \mathbf{x}(t). The result is a model you can simulate, analyse, and interpret — the NN is small and you can probe it. This is the position taken throughout the AIMS capstone in Unit 9: the column PDE structure is hard-coded; the unknown wind-stress envelope is the learned piece.

Worked example: pendulum with unknown friction

The damped pendulum

\ddot{q} + b\,\dot{q} + \tfrac{g}{L}\sin q = 0

is a sandbox UDE. Hold the gravitational restoring term \tfrac{g}{L}\sin q as known physics and replace the friction term with an unknown function g_\theta(\dot{q}):

\ddot{q} + g_\theta(\dot{q}) + \tfrac{g}{L}\sin q = 0.

Train g_\theta against trajectory data; recover its functional shape (often nearly linear for small \dot{q}, with corrections elsewhere). The output is interpretable: g_\theta is a 1D function you can plot.

# Hybrid pendulum: gravity known, friction learned.
using Lux, OrdinaryDiffEq, Zygote, Optimisers, SciMLSensitivity

friction = Lux.Chain(Lux.Dense(1 => 8, tanh), Lux.Dense(8 => 1))
ps, st   = Lux.setup(Random.default_rng(), friction)

const g_over_L = 9.81

function ude_pendulum!(du, u, p, t)
    q, q̇   = u
    fric, _ = friction([q̇], p, st)
    du[1]   =
    du[2]   = -fric[1] - g_over_L * sin(q)
end

# … usual training loop: solve, misfit, adjoint, Adam.

A clean sandbox for inverse problems before we hit the capstone column in Unit 9.

A domain example: Crown-of-Thorns starfish on the Great Barrier Reef

The Great Barrier Reef has lost roughly half its hard-coral cover since 1985. The peer-reviewed AIMS attribution of that loss (De’ath, Fabricius, Sweatman & Puotinen 2012, PNAS) puts the breakdown at 48% from cyclones, 42% from Crown-of-Thorns starfish (COTS) predation, and 10% from coral bleaching — and notes that without COTS, coral cover would have increased by ~0.9% per year despite the cyclones and bleaching. COTS is, by some distance, the biggest controllable biological driver of reef decline that AIMS studies.

A reef carries ~1–10 adult COTS per hectare in its “background” state; active outbreaks are declared at >30 adults/ha (GBRMPA dashboard). A single female releases 10^610^8 eggs per spawning event, so outbreaks can flip on quickly when conditions favour larval survival. The operational response — vessel-based diver injection with bile-salt or vinegar — culled 73 881 COTS over 11 710 ha across 234 target reefs in 2024–25, delivering up to a 6× COTS reduction and +44% coral cover where the control was timely. AIMS provides the population-state input that prioritises which reefs to target.

The modelling question we’ll borrow: given monitored time-series of coral cover C(t) and COTS density S(t) on a reef, can a UDE separate the (mostly) known population dynamics from the genuinely uncertain density-dependent COTS mortality?

Known physics: logistic coral + Holling-II grazing

A standard predator–prey model for COTS / coral (Morello et al. 2014, Marine Ecology Progress Series; historically Antonelli et al. 1990, Ecological Modelling) takes the form

\begin{aligned} \dot C \;&=\; r_C\, C\,\Bigl(1 - \frac{C}{K}\Bigr) \;-\; \frac{a\, S\, C}{1 + a\,h\,C} \\ \dot S \;&=\; e\,\frac{a\, S\, C}{1 + a\,h\,C} \;-\; m(S, \ldots)\, S \end{aligned}

with r_C the coral intrinsic growth rate, K the carrying capacity (% cover), a the COTS attack rate, h the handling time, e the conversion efficiency, and m(\cdot) the COTS mortality rate. The two structural terms — logistic coral growth and Holling type-II grazing — are well established. The mortality m(S, \ldots) is exactly where the modelling literature is honest about not knowing the right functional form (Babcock et al. 2016, PLOS ONE flags it as the dominant source of structural uncertainty), and it’s the natural place to put a learnable closure.

Learned closure: COTS mortality as a small MLP

Replace the unknown mortality with a small network of the state and one or two environmental drivers — a sea-surface-temperature anomaly T_a(t) from satellite data is a reasonable starting choice (warmer water suppresses larval metamorphic success; Lang et al. 2022):

\dot S \;=\; e\,\frac{a\, S\, C}{1 + a\,h\,C} \;-\; \underbrace{N_\theta\!\bigl(S, C, T_a(t)\bigr)}_{\text{learned mortality}} \cdot S.

Train against synthetic COTS / coral trajectories (or, eventually, the AIMS Long-Term Monitoring Program’s manta-tow records). The learned N_\theta has biological meaning — plot it against S at fixed C to see whether it recovers a density-dependent Allee-like effect; plot against T_a at fixed (S, C) to see whether warming suppresses mortality (or doesn’t).

# Hybrid COTS-coral model:
#   - logistic coral growth and Holling-II grazing as known physics,
#   - density- and temperature-dependent COTS mortality as a learned NN.
using Lux, OrdinaryDiffEq, Zygote, Optimisers, SciMLSensitivity, Random

mortality_net = Lux.Chain(
    Lux.Dense(3 => 16, tanh),
    Lux.Dense(16 => 16, tanh),
    Lux.Dense(16 => 1),
)
ps, st = Lux.setup(Random.MersenneTwister(0), mortality_net)

# Known parameters (Morello et al. 2014, rough orders of magnitude)
const r_C, K = 0.30, 80.0       # coral growth rate (1/yr), carrying capacity (% cover)
const a, h   = 0.005, 0.10       # attack rate, handling time
const e      = 0.40              # conversion efficiency

# Temperature anomaly driver (read from observations or model out)
T_anom(t) = 0.3 * sin(2π * t / 1.0)   # placeholder annual cycle

function cots_coral!(du, u, p, t)
    C, S = u                                 # coral cover (%), COTS density (per ha)
    graze = a * S * C / (1 + a * h * C)
    m_S, _ = mortality_net([S, C, T_anom(t)], p, st)
    du[1] = r_C * C * (1 - C / K) - graze
    du[2] = e * graze - m_S[1] * S
end

# … usual training loop: solve, trajectory-misfit against AIMS LTMP data,
#   adjoint sensitivities, Adam → L-BFGS.

The pedagogical pay-off: the UDE separates what we trust (the logistic-grazing skeleton, with parameters fit to decades of monitoring) from what we don’t (the closure that says why COTS adults die at the rates they do). A successful fit gives an ecologist a plottable, biologically-interpretable mortality surface — exactly the kind of artefact that motivates UDEs over black-box neural ODEs in real applications.

Lotka–Volterra inverse variant

Same setup, dynamics from §4.1. Two flavours:

  • Parameter inverse — assume the LV form and recover the scalars (\alpha, \beta, \gamma, \delta). Cheap.
  • Functional inverse — drop one term (say the predator–prey interaction -\beta x y) and learn it as a small network. Now you’re recovering a function, not a number — exactly the capstone problem in miniature.

The functional inverse is where neural-network capacity pays off: classical parameter estimation can’t recover an unknown functional shape, but a UDE with a 2-layer MLP can. Code placeholder — units/unit_04/scripts/ude_lv.jl (TODO).

Note✏️ Section exercise — the functional inverse, in miniature

Fill in the ude_lv.jl placeholder yourself. Generate Lotka–Volterra data (\alpha, \beta, \gamma, \delta as in §4.1, t \in [0, 8], 60 samples, 2% noise), then build the UDE that keeps \dot x = \alpha x - \beta x y and \dot y = -\gamma y + N_\theta(x, y) — i.e. drop the +\delta x y predator-growth term and learn it with a 2 → 8 → 1 network. Train (Adam → a few hundred iterations is enough), then plot the learned N_\theta(x, y) against the true surface \delta x y over the visited region of the phase plane. Two checks: does N_\theta match \delta x y inside the training orbit’s range, and what does it do outside it? That contrast — good interpolation, unconstrained extrapolation — is the honest summary of every UDE result you’ll ever publish.

💡 Hint

Write the UDE rhs with the known LV terms inline and first(nn([x, y], p, st))[1] in place of the dropped term. Training is identical to ex-4-2’s loop. For the comparison, evaluate N_θ and δ·x·y on a grid with two comprehensions and heatmap both — then overlay the training orbit with plot! to see exactly where the agreement ends.

4.4 What comes next

This unit established the trajectory-loss family of physics-aware training: integrate the ODE forward, compare to data, backprop through the solver via the adjoint. The remaining units add the rest of the toolkit:

  • Unit 5 introduces the other family: residual-loss training. Instead of integrating, you evaluate the equation residual at random collocation points and train the network to satisfy it. No solver, no adjoint — just autodiff.
  • Unit 6 provides the PDE background the residual-loss approach assumes (function spaces, weak solutions, well-posedness).
  • Unit 7 covers what goes wrong in practice with the residual approach (the so-called “PINN failure modes”) and the modern fixes.
  • Units 8–10 combine both approaches on the capstone column problem: a UDE forward model, then a hybrid residual + trajectory inverse for the unknown driver.

The trajectory-loss skeleton in §4.2 and the residual-loss skeleton in Unit 5 are the two recipes every Sci-ML / PINN paper builds on; understanding when to reach for each is the practical takeaway of these four units.