Unit 5: PINNs on Basic Models

Published

12/06/2026

Our first hands-on PINN, in the simplest possible settings. We do it twice:

  1. From first principles — a one-page implementation in Lux + Zygote that you can read end-to-end. No PINN library, no symbolic machinery; just a network, a residual, and an optimiser. This is the inline code in §5.3.
  2. With NeuralPDE.jl — the same problem written declaratively against a ModelingToolkit PDESystem and discretised in a single line. Concise, scalable, the right tool for the AIMS capstone — but harder to debug if you don’t know the first-principles version first. This is the sidecar in §5.4.

Before the implementations, §5.2 is a focused autodiff deep-dive — the forward/reverse/forward-over-forward composition that PINNs depend on, in both Julia and JAX. We then push NeuralPDE.jl from the ODE up to the simplest PDE (1D diffusion, §5.5) and call out the failure modes that Unit 7 finally fixes.

5.1 The PINN idea

Residual loss from a differential equation

For an ODE \dot{x} = f(x, t), define a network x_\theta(t) — input time, output state — and the residual

r_\theta(t) \;=\; \dot{x}_\theta(t) - f(x_\theta(t), t),

with \dot{x}_\theta computed by autodiff through the network. The residual loss is

\mathcal{L}_{\text{ODE}}(\theta) \;=\; \frac{1}{N_r}\sum_{i=1}^{N_r} \bigl|r_\theta(t_i)\bigr|^2

evaluated at scattered collocation points \{t_i\}. Minimising it pushes x_\theta toward a solution of the ODE.

Initial conditions are what pin it down

A residual alone has infinitely many solutions — the ODE is a one- parameter family indexed by x_0. The IC term

\mathcal{L}_{\text{IC}}(\theta) \;=\; \bigl|x_\theta(0) - x_0\bigr|^2

(with weight \lambda_{\text{IC}}) selects the trajectory we want. The same idea applies to boundary conditions on PDEs. The weights \lambda_{\text{IC}}, \lambda_{\text{BC}}, \lambda_{\text{PDE}} are the main hyperparameter — tuning them well is a major theme of Unit 7.

Collocation points

Where do you put the \{t_i\}? Options:

  • Uniform grid — simple, biased toward regular regions.
  • Random uniform — eliminates grid bias.
  • Latin hypercube — better space coverage in higher dimensions.
  • Adaptive — sample more densely where the residual is high.

For a first pass on a smooth ODE, a uniform grid of a few hundred points is plenty.

Note✏️ Section exercise — count the solutions

Two short pencil-and-paper checks of the §5.1 picture:

  1. For \dot x = -x without any IC term in the loss, write down three different functions that drive \mathcal{L}_{\text{ODE}} to exactly zero. What is the full family of zero-residual solutions, and which single member does adding \mathcal{L}_{\text{IC}} = |x_\theta(0) - 1|^2 select?
  2. Suppose you set \lambda_{\text{IC}} = 10^{-6} and train to a total loss of 10^{-8}. Roughly how far can x_\theta(0) sit from 1 while still “succeeding”? What does this tell you about reading a small total loss as evidence of a correct solution?

💡 Hint

For part 1, the zero-residual set is exactly the ODE’s solution family — write its general solution and you’re done. For part 2, bound the single term: \lambda_{IC}\,|x_\theta(0)-1|^2 \le \mathcal{L}_{total}, then solve for |x_\theta(0)-1|. No code required for either part.

5.2 Autodiff for PINNs: forward, reverse, and second derivatives

Unit 2 §2.5 introduced forward vs reverse mode at the level of “scalar loss gradient over parameters”. PINNs need more than that. The residual r_\theta(x, t) = \partial_t u_\theta - \alpha\,\partial_{xx} u_\theta contains derivatives of the network output with respect to the inputs (x, t), and the training gradient needs the parameter-gradient of that residual loss — a derivative of a derivative. Getting the AD composition right is what makes a PINN train in seconds rather than minutes (or not at all). This section walks the four AD operations PINNs actually need, side-by-side in Julia (ForwardDiff / Zygote / Enzyme) and JAX (jax.grad / jax.jacfwd / jax.hessian).

Three derivatives, three modes

For a network u_\theta(\mathbf{x}, t): \mathbb{R}^{d+1} \to \mathbb{R} with P parameters, three derivative computations recur:

What Shape Right mode
\partial u_\theta / \partial x_i at one (\mathbf{x}, t) \mathbb{R}^{d+1} \to \mathbb{R} forward (cheap inputs)
\partial^2 u_\theta / \partial x_i^2 at one (\mathbf{x}, t) scalar forward-over-forward
\partial \mathcal{L} / \partial \theta over all \theta \mathbb{R}^P \to \mathbb{R} reverse (cheap output)

The composition that trains a PINN is therefore reverse-over-(forward-over-forward): forward for the \partial_x, forward again for the \partial_{xx}, reverse on the outside for \partial_\theta \mathcal{L}. Every PINN framework — Julia’s Lux + ForwardDiff + Zygote, JAX’s flax + jax.grad, PyTorch’s torch.autograd.grad — implements exactly this stack.

Forward mode: spatial derivatives of the network

In Julia, ForwardDiff.derivative is the workhorse for scalar inputs:

using Lux, Random, ForwardDiff

rng = Random.MersenneTwister(0)
u = Lux.Chain(Lux.Dense(2 => 16, tanh), Lux.Dense(16 => 1))
ps, st = Lux.setup(rng, u)

u_θ(x, t) = first(u([x, t], ps, st)[1])

# ∂u/∂x via forward mode
∂x(x, t) = ForwardDiff.derivative-> u_θ(ξ, t), x)
@info "∂u/∂x at (0.3, 0.5) = $(∂x(0.3, 0.5))"

The JAX equivalent — jax.jacfwd is the multi-input forward-mode Jacobian, but for a scalar input the single-derivative jax.grad composed in forward-mode style is equivalent:

import jax, jax.numpy as jnp
from flax import linen as nn

class MLP(nn.Module):
    @nn.compact
    def __call__(self, xt):
        h = nn.tanh(nn.Dense(16)(xt))
        return nn.Dense(1)(h).squeeze()

mlp = MLP()
key = jax.random.PRNGKey(0)
params = mlp.init(key, jnp.zeros(2))

u_theta = lambda x, t: mlp.apply(params, jnp.array([x, t]))
du_dx   = jax.jacfwd(u_theta, argnums=0)        # ∂u/∂x, forward mode
print("∂u/∂x at (0.3, 0.5) =", du_dx(0.3, 0.5))

Second derivatives: the PDE residual

For the heat equation we need \partial^2_x u_\theta. The cleanest composition is forward-over-forward: differentiate once with ForwardDiff, then differentiate the result again with ForwardDiff. The dual-number machinery handles the nesting automatically:

# ∂²u/∂x² via forward-over-forward
∂xx(x, t) = ForwardDiff.derivative(
                ξ -> ForwardDiff.derivative-> u_θ(η, t), ξ),
                x,
            )
@info "∂²u/∂x² at (0.3, 0.5) = $(∂xx(0.3, 0.5))"

# Residual of the heat equation ∂_t u = α ∂_xx u
α = 0.1
∂t(x, t) = ForwardDiff.derivative-> u_θ(x, τ), t)
r(x, t)  = ∂t(x, t) - α * ∂xx(x, t)
@info "heat-eq residual at (0.3, 0.5) = $(r(0.3, 0.5))"

In JAX the analogous pattern — jax.hessian is jacfwd ∘ jacrev under the hood, exactly the right composition:

import jax

# ∂²u/∂x² at fixed t
d2u_dx2 = jax.jacfwd(jax.jacfwd(u_theta, argnums=0), argnums=0)
du_dt   = jax.jacfwd(u_theta, argnums=1)

alpha = 0.1
def residual(x, t):
    return du_dt(x, t) - alpha * d2u_dx2(x, t)

print("heat-eq residual at (0.3, 0.5) =", residual(0.3, 0.5))

For higher-dimensional inputs (say a 3-D wave equation needing \partial^2_x + \partial^2_y + \partial^2_z), jax.hessian(u_theta) returns the full 3 \times 3 Hessian at a point — take its trace to get the Laplacian.

The outer gradient: reverse mode over parameters

The residual at a collocation point is a scalar; the loss sums many of them, then we want \partial_\theta \mathcal{L} — a single output, P inputs. Classic reverse-mode territory: Zygote.gradient in Julia, jax.grad in JAX. The two AD modes happily compose.

function pde_loss(ps, st, xs, ts)
    ŝ = 0.0
    for (x, t) in zip(xs, ts)
        u_θ(x, t) = first(u([x, t], ps, st)[1])
        ∂t = ForwardDiff.derivative-> u_θ(x, τ), t)
        ∂xx = ForwardDiff.derivative(
                  ξ -> ForwardDiff.derivative-> u_θ(η, t), ξ), x)
        ŝ += (∂t - α * ∂xx)^2
    end
    ŝ / length(xs)
end

# Reverse-mode autodiff over θ, with forward-over-forward inside
grad = first(Zygote.gradient(p -> pde_loss(p, st, xs, ts), ps))

The same shape in JAX is breathtakingly compact:

def loss(params, xs, ts):
    def r_single(x, t):
        u   = lambda x, t: mlp.apply(params, jnp.array([x, t]))
        d2  = jax.jacfwd(jax.jacfwd(u, argnums=0), argnums=0)
        dt  = jax.jacfwd(u, argnums=1)
        return (dt(x, t) - alpha * d2(x, t)) ** 2
    return jnp.mean(jax.vmap(r_single)(xs, ts))

grad_fn = jax.grad(loss)             # reverse-mode over params

Why nesting order matters

A practical pitfall: don’t use reverse mode for the inner derivatives. Reverse-mode AD records a tape proportional to the size of the network’s intermediate activations; nesting two reverse passes through the same network multiplies the tape size. On a modest 5-layer MLP this can turn a 100 ms step into a 20-second step. Forward mode’s cost scales with the input dimension, so for PDEs in 1D / 2D / 3D — exactly the regime PINNs operate in — forward-over-forward is the right choice for the spatial pieces.

The general rule (true in both Julia and JAX):

Inner derivatives over the few-dimensional input → forward mode. Outer derivative over the many-dimensional parameter vector → reverse mode.

That single rule of thumb is what makes the Lux + ForwardDiff + Zygote stack work, and it’s also what NeuralPDE.jl and JAX’s Lineax / optax glue together for you. Get the composition right and PINN training scales; get it wrong and it doesn’t.

Connecting back to Unit 1

The forward solver for the shallow-water equations in Unit 1 computed \partial^2 \eta / \partial t^2 by finite differences, on a 100×190 grid, in \sim 12\,\text{s}. The PINN above computes the same kind of second derivative — but at arbitrary (x, y, t) points, with exact (rounding error only) accuracy, on demand from the network. That is the headline trade-off PINNs offer: you give up the structured-grid efficiency of FD/FV/FEM and you get a meshless, autodiff-exact, easily-differentiable surrogate in return. The rest of Units 5–7 is mostly about how to realise that trade in practice.

Note✏️ Section exercise — trust, then verify, the autodiff stack

Take a function whose derivatives you know exactly: u(x, t) = \sin(3x)\,e^{-2t}, so \partial_x u = 3\cos(3x)e^{-2t}, \partial_{xx} u = -9\sin(3x)e^{-2t}, and \partial_t u = -2\sin(3x)e^{-2t}. Implement u as plain Julia code and compute all three derivatives at (x, t) = (0.7, 0.4) with nested ForwardDiff.derivative, exactly as the section does for the network. Verify each against the closed form to machine precision. Then confirm the residual of the heat equation \partial_t u - \alpha\,\partial_{xx} u vanishes when \alpha = 2/9 — i.e. your u is a heat-equation solution for that diffusivity. If you have a Python environment handy, repeat the whole check in JAX with jax.jacfwd and confirm the same numbers come out.

💡 Hint

Define u(x, t) = sin(3x) * exp(-2t) as an ordinary function — the point is to validate the derivative pipeline before any network is involved. Nest ForwardDiff.derivative for \partial_{xx}, and find α from the eigenmode decay rate \alpha k^2 with k = 3. In JAX, jax.jacfwd composes the same way.

5.3 A PINN from first principles

We solve \dot{x}(t) = -x(t) with x(0) = 1 — the simplest non-trivial ODE — whose exact solution is x(t) = e^{-t}. Our PINN will be a 1 → 16 → 1 MLP with \tanh activations. The three pieces of machinery are exactly what §5.2 motivated:

  • Lux for the model and its parameter container;
  • a symmetric central difference to compute \dot{x}_\theta(t) at collocation points — see the comment in the listing for why the demo prefers it over AD-inside-AD with current package versions;
  • Zygote.gradient to differentiate the loss with respect to \theta (many-input scalar-output → reverse mode).
using Lux, Random, Zygote, Optimisers, Plots

rng = Random.MersenneTwister(0)
model = Lux.Chain(
    Lux.Dense(1 => 16, tanh),
    Lux.Dense(16 => 1),
)
ps, st = Lux.setup(rng, model)

# Float32 throughout — Lux initialises Float32 weights; Float64 input
# triggers a "mixed-precision matmul fallback" warning from LuxLib and
# kicks Octavian.jl out of the matmul path. Keep ps, st, t all Float32.
const F = Float32

# scalar t → scalar x_θ(t)
function x_θ(t::F, ps, st)
    y, _ = model([t], ps, st)
    y[1]
end

# Time derivative of the network at a collocation point, by a symmetric
# central difference. Why not AD-inside-AD? Both pure compositions are
# version-sensitive in the current Lux/Zygote/ForwardDiff stack:
# nesting Zygote inside Zygote (reverse-over-reverse) hits Zygote's
# array-mutation limitation in Lux's pullbacks, and ForwardDiff inside
# Zygote silently DROPS the parameter gradient of the derivative term
# (Zygote warns it "cannot track gradients with respect to f" for a
# parameter-closure). The central difference sidesteps both: Zygote
# differentiates straight through two plain network evaluations, so the
# training gradient is exact for the discretised residual — at the cost
# of an O(h²) bias in ẋ_θ that is far below the training error here.
# §5.2's clean AD compositions remain the right mental model; this is
# the robust portable implementation as of mid-2026.
const h = F(1e-2)
dxdt(t::F, ps, st) = (x_θ(t + h, ps, st) - x_θ(t - h, ps, st)) / (2h)

ts = collect(range(F(0), F(4); length = 64))   # collocation points

function loss(ps)
    L_res = sum(t -> (dxdt(t, ps, st) + x_θ(t, ps, st))^2, ts) / length(ts)
    L_ic  = (x_θ(F(0), ps, st) - one(F))^2
    L_res + F(100) * L_ic
end

opt_state = Optimisers.setup(Optimisers.Adam(F(2e-2)), ps)
for epoch in 1:1500
    g = first(Zygote.gradient(loss, ps))
    opt_state, ps = Optimisers.update(opt_state, ps, g)
end

# compare to the exact solution
tgrid  = range(F(0), F(4); length = 200)
exact  = exp.(-collect(tgrid))
learnt = [x_θ(t, ps, st) for t in tgrid]

plt = plot(tgrid, exact;  label = "exact e^(-t)",
           xlabel = "t", ylabel = "x", lw = 2)
plot!(plt, tgrid, learnt; label = "PINN", lw = 2, ls = :dash)

Notice what this is, and what it isn’t:

  • There is no time-stepping. The PINN doesn’t march forward in time; it fits a function on the whole interval at once.
  • There is no training data — only the equation and the IC. The collocation points are unlabelled.
  • The loss is differentiable end-to-end: \dot{x}_\theta here is a central difference of two network evaluations, so Zygote.gradient flows through the whole residual exactly. If you can write the residual, you can train.

That last point is the whole game. Whatever differential equation you can encode as a residual r(t, ps), the same loop trains a PINN for it. The difficulty in §§5.3-5.4 is scaling the bookkeeping: multiple BCs, multiple spatial dimensions, multiple loss terms. That’s what NeuralPDE.jl automates.

Note✏️ Section exercise — your second PINN: the oscillator

Upgrade the §5.3 script from first-order to second-order physics. Solve \ddot x + \omega^2 x = 0 with \omega = 2, x(0) = 1, \dot x(0) = 0 on t \in [0, 4] (exact solution \cos(2t) — about 1¼ periods). You’ll need three changes: a second time derivative of the network in the residual, a second IC term for \dot x_\theta(0), and probably a wider network (1 → 32 → 32 → 1) with more Adam iterations. Plot the PINN against \cos(2t). Then stretch the domain to t \in [0, 12], retrain, and describe what happens to the late-time fit — your first encounter with the failure mode that §5.6 names.

💡 Hint

Three local changes to the §5.3 script: a second derivative via the central second difference d2(t) = (x_θ(t+h) - 2x_θ(t) + x_θ(t-h)) / h^2 (same trick as §5.3’s dxdt, one order up), an extra IC penalty on dxdt(0), and a 1 → 32 → 32 → 1 network. Why two ICs? Count the parameters of the general solution A\cos\omega t + B\sin\omega t. Expect to need ~4 000 Adam iterations.

5.4 The same ODE with NeuralPDE.jl

NeuralPDE.jl turns a ModelingToolkit.PDESystem (equation + domains + BCs) and a PhysicsInformedNN(chain, training_strategy) discretisation into an OptimizationProblem that any Optimization.jl optimiser can solve. The same \dot{x} = -x problem becomes:

# A PINN for ẋ = -x, x(0) = 1 using NeuralPDE.jl.
#
# Same problem as the hand-rolled version in unit_05.qmd §5.2, but
# written declaratively against a ModelingToolkit PDESystem.
#
# Run via ./build.sh execute 5 (writes output to ../output/pinn_neuralpde_ode.md).

using NeuralPDE, Lux, ModelingToolkit
using Optimization, OptimizationOptimJL
using DomainSets: ClosedInterval
using Random

@parameters t
@variables x(..)
Dt = Differential(t)

eq      = Dt(x(t)) ~ -x(t)
bcs     = [x(0.0) ~ 1.0]
domains = [t  ClosedInterval(0.0, 4.0)]

@named pde_system = PDESystem(eq, bcs, domains, [t], [x(t)])

chain = Lux.Chain(Lux.Dense(1 => 16, tanh), Lux.Dense(16 => 1))
disc  = PhysicsInformedNN(chain, GridTraining(0.05))

prob = discretize(pde_system, disc)
@info "solving with LBFGS..."
sol  = solve(prob, LBFGS(); maxiters = 200)

# Pull the trained function out of the solution and score it against
# the exact e^{-t} at a held-out grid.
phi    = disc.phi
params = sol.u
tgrid  = collect(range(0.0, 4.0; length = 200))
learnt = [phi([t], params)[1] for t in tgrid]
exact  = exp.(-tgrid)
err    = maximum(abs, learnt .- exact)

println("retcode         : $(sol.retcode)")
println("final loss      : $(round(sol.objective; sigdigits = 4))")
println("max |PINN-exact|: $(round(err; sigdigits = 3))")

Captured output from ./build.sh execute 5:

[ Info: solving with LBFGS...
retcode         : MaxIters
final loss      : 1.391e-7
max |PINN-exact|: 0.000114

What the package buys you: a one-line residual derivation (no hand-rolled dxdt), built-in collocation strategies, and a clean path from Optimization.solve to Adam, LBFGS, or any other adapter in the SciML stack. What you pay: heavier precompile and a lot of symbolic machinery between your problem and the gradient that ends up training the network.

Note✏️ Section exercise — change the equation, not the plumbing

The declarative payoff of NeuralPDE.jl is that a different ODE is a one-line edit. Swap the equation in the sidecar script for the logistic equation \dot x = x(1 - x) with x(0) = 0.1 on t \in [0, 8] (exact solution x(t) = 1 / (1 + 9e^{-t})), re-discretise, and re-solve. Compare the PINN to the closed form. Then look at what you did not have to touch — no hand-rolled dxdt, no loss assembly — and note the one thing you did have to reconsider (hint: the solution now saturates near 1; does the default network/strategy still fit it well at the plateau?).

💡 Hint

Only three lines of the sidecar script change: eq, bcs, and domains — the symbolic layer regenerates everything downstream. The exact solution for grading is the logistic curve 1/(1+9e^{-t}). For the ‘reconsider’ question, think about where on [0, 8] the solution actually does anything.

5.5 The simplest PDE: 1D diffusion

The 1D heat equation \partial_t u = \alpha\,\partial_x^2 u on (x, t) \in [0, 1]\times[0, 1] with Dirichlet boundaries u(0, t) = u(1, t) = 0 and a Gaussian initial bump u(x, 0) = \exp\!\bigl(-200(x - 0.5)^2\bigr). Parabolic, smoothing, well-posed — and the first proper PDE PINN benchmark.

The implementation is structurally identical to §5.4, with one extra spatial coordinate and one extra BC pair:

# 1D heat equation PINN with NeuralPDE.jl.
#
#   ∂t u = α ∂xx u       on (x,t) ∈ [0,1] × [0,1]
#   u(x, 0) = exp(-200 (x - 0.5)^2)
#   u(0, t) = u(1, t) = 0
#
# Compares the trained network to a quick second-order finite-
# difference reference solve.
#
# Run via ./build.sh execute 5; writes output/pinn_neuralpde_heat.md
# plus output/pinn_neuralpde_heat.png.

using NeuralPDE, Lux, ModelingToolkit
using Optimization, OptimizationOptimJL
using OrdinaryDiffEq, OrdinaryDiffEqBDF
using DomainSets: ClosedInterval
using LinearAlgebra, Plots

@parameters x t
@variables u(..)
Dt  = Differential(t)
Dxx = Differential(x)^2

α   = 0.01
eq  = Dt(u(x, t)) ~ α * Dxx(u(x, t))

ic_bump(x_) = exp(-200 * (x_ - 0.5)^2)
bcs = [
    u(x, 0.0) ~ ic_bump(x),
    u(0.0, t) ~ 0.0,
    u(1.0, t) ~ 0.0,
]
domains = [x  ClosedInterval(0.0, 1.0), t  ClosedInterval(0.0, 1.0)]

@named pde_system = PDESystem(eq, bcs, domains, [x, t], [u(x, t)])

chain = Lux.Chain(
    Lux.Dense(2 => 32, tanh),
    Lux.Dense(32 => 32, tanh),
    Lux.Dense(32 => 1),
)
disc = PhysicsInformedNN(chain, GridTraining([0.02, 0.02]))
prob = discretize(pde_system, disc)

@info "solving heat equation PINN..."
sol = solve(prob, LBFGS(); maxiters = 1500)
println("retcode    : $(sol.retcode)")
println("final loss : $(round(sol.objective; sigdigits = 4))")

# ── finite-difference reference (centred diffs, BDF in time) ───────────
Nx = 101
xg = range(0.0, 1.0; length = Nx)
Δx = step(xg)
u0 = ic_bump.(xg)
u0[1] = 0.0; u0[end] = 0.0          # enforce BCs at t=0

function rhs!(du, u, p, t)
    du[1] = 0.0
    du[end] = 0.0
    @inbounds for i in 2:length(u)-1
        du[i] = α * (u[i+1] - 2u[i] + u[i-1]) / Δx^2
    end
end
ref_prob = ODEProblem(rhs!, u0, (0.0, 1.0))
ref_sol  = solve(ref_prob, FBDF(); saveat = [0.0, 0.3, 0.7, 1.0])

# ── evaluate PINN at the same grid+times ───────────────────────────────
phi    = disc.phi
params = sol.u
pinn_at(t) = [phi([xi, t], params)[1] for xi in xg]

max_err = let m = 0.0
    for (i, t) in enumerate(ref_sol.t)
        ref = ref_sol.u[i]
        pin = pinn_at(t)
        m = max(m, maximum(abs, pin .- ref))
    end
    m
end
println("max |PINN-FD|: $(round(max_err; sigdigits = 3))")

# ── plot ───────────────────────────────────────────────────────────────
plt = plot(xlabel = "x", ylabel = "u(x, t)",
           title  = "1D diffusion: PINN vs. FD reference", legend = :topright)
colors = [:black, :red, :blue, :green]
for (i, t) in enumerate(ref_sol.t)
    plot!(plt, xg, ref_sol.u[i]; lw = 2,  ls = :solid, color = colors[i],
          label = "FD t=$(round(t; digits=2))")
    plot!(plt, xg, pinn_at(t);    lw = 2,  ls = :dash,  color = colors[i],
          label = "PINN t=$(round(t; digits=2))")
end
outpath = joinpath(@__DIR__, "..", "output", "pinn_neuralpde_heat.png")
savefig(plt, outpath)
println("saved $outpath")

Captured output from ./build.sh execute 5:

[ Info: solving heat equation PINN...
retcode    : MaxIters
final loss : 0.0003209
max |PINN-FD|: 0.0276
saved /Users/uqjnazar/git/PIML/Julia_PINN_training_2026/units/unit_05/scripts/../output/pinn_neuralpde_heat.png

PINN solution at three times vs. a finite-difference reference. The PINN matches the smoothing trend; the residual at the Gaussian peak is the visible discrepancy.
Note✏️ Section exercise — sharpen the bump until it breaks

The heat-equation PINN above fits a Gaussian initial bump of width parameter 200. Make the problem harder one notch at a time: rerun the script with the IC sharpness at 200, 800, 3200 (the bump’s width halves each step), keeping everything else fixed. For each run record the final loss and plot the PINN against the FD reference at t = 0 — the IC is where the damage shows. At what sharpness does the PINN visibly fail to represent the initial condition, and does throwing 4× more collocation points at it fix the problem? Keep your three loss numbers; §5.6 explains the trend and Unit 7 sells the cure.

💡 Hint

The sharpness is the 200 inside the IC line of the bcs vector — sweep it, record res.objective (or the final printed loss), and compare the t=0 slice against exp.(-S .* (x .- 0.5).^2). For the collocation test, shrink the GridTraining spacings by 2× in each dimension (4× the points).

5.6 Failure modes already visible

Even on benign problems three pathologies show up that motivate everything in Unit 7:

  • Spectral bias. MLPs with smooth activations preferentially fit low frequencies first (Rahaman et al., 2019). The Gaussian initial bump in §5.5 — locally high curvature — is the slowest-converging feature of the loss.
  • Loss imbalance. With residual, IC, and BC contributing terms of very different magnitude, the optimiser zeroes the easiest one first. The PINN can land on a solution that satisfies the equation inside the domain but ignores the boundary, or fits the IC and the PDE while drifting at the BC.
  • Causal violation. A time-dependent PINN can fit the residual at every t simultaneously, including futures whose dependence on the IC hasn’t propagated forward. The result is locally consistent at each (x, t) point but globally inconsistent as a trajectory.

These three are the meat of Unit 7: adaptive loss weighting, Fourier feature embeddings, hard BC enforcement, and causal training. With those tools we can train PINNs on the capstone column in Unit 9 / Unit 10; without them we can’t.

Note✏️ Section exercise — diagnose before you medicate

Match symptom to disease. For each observation below — all from real PINN training logs — name which of the three §5.6 pathologies (spectral bias, loss imbalance, causal violation) is the prime suspect, and which Unit 7 fix you’d reach for first:

  1. The total loss is 10^{-6} but the solution at the domain boundary is visibly wrong; the BC term contributes 10^{-9} of the total with \lambda_b = 1.
  2. A wave-equation PINN reproduces the first quarter-period beautifully, then the predicted field “freezes” — late times look like a smeared copy of early times.
  3. A PINN fits \sin(2\pi x) initial data in 500 iterations but needs 50 000 for \sin(16\pi x), with the same network.
  4. Doubling \lambda_{\text{IC}} fixes the initial condition but the interior residual gets 100× worse.

💡 Hint

Three diagnostics separate the three diseases: per-term loss magnitudes (imbalance), residual-binned-by-t (causality), and whether difficulty scales with the target’s frequency (spectral bias). Each numbered symptom answers exactly one of those probes — start by deciding which probe the symptom is reporting.