using Lux, Random, Optimisers, Zygote, Statistics, Printf, CairoMakie
rng = Xoshiro(42)19 DeepONet
- DeepONet — the foundational neural operator based on the universal approximation theorem for operators, using a branch-trunk architecture (Lu, Jin, et al., 2021).
- Universal approximation of operators — the mathematical theorem guaranteeing that a two-sub-network architecture can approximate any continuous nonlinear operator (Chen & Chen, 1995).
- DeepXDE — a library for PINNs and DeepONets that formalized many practical training strategies (Lu, Meng, et al., 2021).
- Neural operator survey — a comprehensive mathematical treatment of neural operators and their approximation properties (Kovachki et al., 2023).
- Neural operators for science — review of neural operators accelerating scientific simulations (Azizzadenesheli et al., 2024).
In the PINN chapter we trained a network to solve one instance of a PDE. Change the boundary conditions, initial conditions, or source term, and you must retrain from scratch. DeepONet (Deep Operator Network) solves a fundamentally more ambitious problem: it learns the solution operator — the mapping from input functions to output functions — so that after training, it can instantly predict solutions for any new input without retraining.
19.1 The operator learning problem
In many PDE problems, the solution \(u\) depends on an input function \(a\) (a forcing term, initial condition, boundary condition, or material property):
\[ \mathcal{N}[u; a] = 0 \quad \implies \quad u = \mathcal{G}(a) \]
where \(a \in \mathcal{A}\) is an input function, \(u \in \mathcal{U}\) is the corresponding output function, and \(\mathcal{G}: \mathcal{A} \to \mathcal{U}\) is the solution operator. Writing \(\mathcal{G}(a)(y)\) means: first apply the operator to the whole input function \(a\), then evaluate the resulting output function at the query location \(y\). A traditional solver computes \(\mathcal{G}(a)\) one \(a\) at a time. DeepONet learns \(\mathcal{G}_\theta \approx \mathcal{G}\) from data — pairs \(\{(a_i, u_i)\}_{i=1}^{N}\) — so that inference on new inputs is a single forward pass.
19.2 The universal approximation theorem for operators
Chen & Chen (1995) proved that a network with two sub-networks can approximate any continuous nonlinear operator to arbitrary accuracy. This theorem provides the mathematical foundation for DeepONet.
Theorem (informal): For any continuous operator \(\mathcal{G}: \mathcal{A} \to \mathcal{U}\) and any \(\epsilon > 0\), there exist neural networks (branch and trunk) such that:
\[ \bigl|\mathcal{G}(a)(y) - \mathcal{G}_\theta(a)(y)\bigr| < \epsilon \]
for all input functions \(a \in \mathcal{A}\) and all query locations \(y\).
19.3 Architecture: branch and trunk
DeepONet decomposes the operator approximation into two learnable components:
Input function a(x) Query location y
evaluated at sensors (where to evaluate output)
[a(x₁), a(x₂), ..., a(xₘ)] (y₁, y₂, ..., yₐ)
│ │
▼ ▼
┌───────────┐ ┌───────────┐
│ Branch │ │ Trunk │
│ Network │ │ Network │
└─────┬─────┘ └─────┬─────┘
│ │
[b₁, b₂, ..., bₚ] [τ₁, τ₂, ..., τₚ]
│ │
└──────────────┬──────────────────────┘
│ dot product
▼
G_θ(a)(y) = Σₖ bₖ · τₖ + bias
\[ \mathcal{G}_\theta(a)(y) = \sum_{k=1}^{p} b_k(a) \cdot \tau_k(y) + b_0 \]
Branch network — takes the input function \(a\) sampled at \(m\) fixed sensor locations \(\{x_1, \ldots, x_m\}\) and outputs \(p\) coefficients \(\{b_1, \ldots, b_p\}\). Think of these as the “expansion coefficients” of the output in a learned basis.
Trunk network — takes the query location \(y\) (where we want to evaluate the output) and produces \(p\) basis functions \(\{\tau_1, \ldots, \tau_p\}\).
The final output is the inner product of the branch and trunk outputs. This is elegant: the branch encodes what the input function looks like, while the trunk learns where to produce the output. In other words, the branch depends on the input function \(a\), while the trunk depends on the query location \(y\).
19.4 Variants
- Unstacked DeepONet — single branch, single trunk (described above). The most common variant.
- Stacked DeepONet — multiple independent branch-trunk pairs summed together.
- POD-DeepONet — the trunk basis functions are replaced by Proper Orthogonal Decomposition (POD) modes computed from the training data, making the trunk data-driven rather than learned.
19.5 Code example: learning the antiderivative operator
We train a DeepONet to learn the antiderivative operator:
\[ \mathcal{G}(a)(y) = \int_0^y a(s)\, ds \]
Given a function \(a(x)\), the operator returns its antiderivative evaluated at any query point \(y\).
# Data generation: random functions and their antiderivatives
n_sensors = 50 # fixed sensor locations for branch input
n_query = 50 # query points per sample
n_train = 2000
n_test = 200
x_sensors = Float32.(range(0, 1, length = n_sensors))
y_query = Float32.(range(0, 1, length = n_query))
dx_sensor = x_sensors[2] - x_sensors[1]
function random_function(rng)
# Random sum of sinusoids
n_modes = rand(rng, 2:5)
coeffs = randn(rng, Float32, n_modes)
freqs = Float32.(rand(rng, 1:6, n_modes))
phases = 2π .* rand(rng, Float32, n_modes)
function f(x)
val = 0.0f0
for j in 1:n_modes
val += coeffs[j] * sin(2π * freqs[j] * x + phases[j])
end
return val
end
return f
end
function generate_data(rng, n_samples)
A = zeros(Float32, n_sensors, n_samples) # branch inputs
Y = zeros(Float32, 1, n_query * n_samples) # query locations
G = zeros(Float32, 1, n_query * n_samples) # true outputs
for i in 1:n_samples
f = random_function(rng)
# Evaluate at sensors
a_vals = f.(x_sensors)
A[:, i] = a_vals
# Compute antiderivative at query points via trapezoidal rule
for (j, yj) in enumerate(y_query)
# Integrate from 0 to yj
n_int = max(2, round(Int, yj / dx_sensor) + 1)
x_int = Float32.(range(0, yj, length = n_int))
f_int = f.(x_int)
integral = sum(0.5f0 * (f_int[1:end-1] .+ f_int[2:end]) .* diff(x_int))
idx = (i - 1) * n_query + j
Y[1, idx] = yj
G[1, idx] = integral
end
end
return A, Y, G
end
A_train, Y_train, G_train = generate_data(rng, n_train)
A_test, Y_test, G_test = generate_data(Xoshiro(123), n_test)# Build DeepONet: Branch + Trunk
p = 64 # output dimension of both branch and trunk
branch_net = Chain(
Dense(n_sensors => 128, relu),
Dense(128 => 128, relu),
Dense(128 => p)
)
trunk_net = Chain(
Dense(1 => 128, relu),
Dense(128 => 128, relu),
Dense(128 => p)
)
ps_branch, st_branch = Lux.setup(rng, branch_net)
ps_trunk, st_trunk = Lux.setup(rng, trunk_net)
bias = zeros(Float32, 1)
ps_all = (branch = ps_branch, trunk = ps_trunk, bias = bias)
st_all = (branch = st_branch, trunk = st_trunk)# Forward pass: DeepONet output
function deeponet_forward(A, Y, ps, st, n_samples, n_q)
# Branch: process all input functions
b_out, st_b = branch_net(A, ps.branch, st.branch) # (p, n_samples)
# Trunk: process all query locations
t_out, st_t = trunk_net(Y, ps.trunk, st.trunk) # (p, n_samples * n_q)
# For each query point, dot-product the branch output of its sample
# with the trunk output
# Reshape branch output to match queries
b_expanded = repeat(b_out, 1, n_q) # (p, n_samples * n_q) — each sample's
# coefficients repeated n_q times
# This is a simplification: we need to properly tile
b_tiled = similar(t_out)
for i in 1:n_samples
idx_range = (i-1)*n_q+1 : i*n_q
b_tiled[:, idx_range] .= b_out[:, i]
end
# Dot product + bias
output = sum(b_tiled .* t_out, dims = 1) .+ ps.bias # (1, n_samples * n_q)
st_new = (branch = st_b, trunk = st_t)
return output, st_new
end
# Loss function
function deeponet_loss(ps, st)
pred, st_new = deeponet_forward(A_train, Y_train, ps, st, n_train, n_query)
loss = mean((pred .- G_train) .^ 2)
return loss, st_new
end# Training loop
opt = Adam(0.001f0)
opt_state = Optimisers.setup(opt, ps_all)
for epoch in 1:1000
(loss_val, st_new), grads = Zygote.withgradient(ps_all) do ps
deeponet_loss(ps, st_all)
end
opt_state, ps_all = Optimisers.update(opt_state, ps_all, grads[1])
st_all = st_new
if epoch == 1 || epoch % 200 == 0
# Test error
pred_test, _ = deeponet_forward(A_test, Y_test, ps_all, st_all, n_test, n_query)
test_mse = mean((pred_test .- G_test) .^ 2)
@printf "Epoch %4d Train MSE = %.6f Test MSE = %.6f\n" epoch loss_val test_mse
end
end# Visualize predictions on test samples
fig = Figure(size = (700, 500))
n_show = 4
for i in 1:n_show
f_test = A_test[:, i]
idx_range = (i-1)*n_query+1 : i*n_query
true_vals = vec(G_test[:, idx_range])
pred_vals_all, _ = deeponet_forward(
reshape(f_test, :, 1),
reshape(Y_test[:, idx_range], 1, :),
ps_all, st_all, 1, n_query
)
pred_vals = vec(pred_vals_all)
row = (i - 1) ÷ 2 + 1
col = (i - 1) % 2 + 1
ax = Axis(fig[row, col], xlabel = "y", ylabel = "∫₀ʸ a(s) ds",
title = "Test sample $i")
lines!(ax, y_query, true_vals, color = :black, linewidth = 2, label = "Exact")
lines!(ax, y_query, pred_vals, color = :steelblue, linewidth = 2,
linestyle = :dash, label = "DeepONet")
if i == 1
axislegend(ax, position = :lt, labelsize = 10)
end
end
Label(fig[0, :], "DeepONet: learning the antiderivative operator", fontsize = 14)
fig19.6 Key properties of DeepONet
19.6.1 Strengths
- Mathematical foundation — grounded in the universal approximation theorem for operators, providing theoretical guarantees on expressiveness.
- Flexible geometry — unlike FNO (which requires regular grids for the FFT), DeepONet works with arbitrary input sensor locations and query points. This is crucial for geoscience data, which is rarely on regular grids.
- Modular design — the branch and trunk are independent networks that can be customized separately. For example, the branch could be a CNN if the input is an image, or an RNN if the input is a time series.
- Point-wise evaluation — the trunk network produces output at individual query points, making it efficient for problems where you only need the solution at specific locations.
19.6.2 Limitations
- Fixed sensor locations — the branch network requires the input function to be evaluated at the same fixed sensor locations for all samples. This can be restrictive if data comes from different measurement configurations.
- Data hungry — DeepONet typically requires more training data than FNO for the same accuracy on problems with regular grids, because it doesn’t exploit spatial structure the way spectral methods do.
- Scaling — for high-dimensional output functions, the number of trunk basis functions \(p\) may need to be large, increasing the computational cost.
19.7 DeepONet vs FNO
| Feature | DeepONet | FNO |
|---|---|---|
| Architecture | Branch + Trunk | Spectral convolution layers |
| Grid requirements | Arbitrary (irregular OK) | Regular grid (for FFT) |
| Theoretical basis | Universal approximation theorem | Kernel integral operators |
| Resolution invariance | Through trunk network | Through Fourier representation |
| Best for | Irregular data, point queries | Regular-grid data, global patterns |
| Data efficiency | Moderate | Higher (exploits spatial structure) |
19.8 Geoscience applications
DeepONet is particularly well-suited for geoscience problems because field data is rarely on regular grids:
- Subsurface flow surrogates — DeepONet can learn the mapping from permeability fields to pressure/saturation solutions, enabling real-time uncertainty quantification for reservoir management.
- Seismic wave modeling — given a velocity model as input, DeepONet predicts the wavefield at arbitrary receiver locations, replacing expensive wave-equation solves during inversion.
- Well log prediction — the branch network ingests spatially irregular borehole measurements while the trunk evaluates predictions at arbitrary subsurface locations.
- Climate downscaling — DeepONet maps coarse-resolution climate model output to fine-resolution fields, handling the irregular observation network naturally.
A powerful extension of DeepONet incorporates PDE constraints into the training loss, dramatically reducing the amount of training data needed. We explore this in the next chapter.