19  DeepONet

TipKey references
  • DeepONet — the foundational neural operator based on the universal approximation theorem for operators, using a branch-trunk architecture (Lu, Jin, et al., 2021).
  • Universal approximation of operators — the mathematical theorem guaranteeing that a two-sub-network architecture can approximate any continuous nonlinear operator (Chen & Chen, 1995).
  • DeepXDE — a library for PINNs and DeepONets that formalized many practical training strategies (Lu, Meng, et al., 2021).
  • Neural operator survey — a comprehensive mathematical treatment of neural operators and their approximation properties (Kovachki et al., 2023).
  • Neural operators for science — review of neural operators accelerating scientific simulations (Azizzadenesheli et al., 2024).

In the PINN chapter we trained a network to solve one instance of a PDE. Change the boundary conditions, initial conditions, or source term, and you must retrain from scratch. DeepONet (Deep Operator Network) solves a fundamentally more ambitious problem: it learns the solution operator — the mapping from input functions to output functions — so that after training, it can instantly predict solutions for any new input without retraining.

19.1 The operator learning problem

In many PDE problems, the solution \(u\) depends on an input function \(a\) (a forcing term, initial condition, boundary condition, or material property):

\[ \mathcal{N}[u; a] = 0 \quad \implies \quad u = \mathcal{G}(a) \]

where \(a \in \mathcal{A}\) is an input function, \(u \in \mathcal{U}\) is the corresponding output function, and \(\mathcal{G}: \mathcal{A} \to \mathcal{U}\) is the solution operator. Writing \(\mathcal{G}(a)(y)\) means: first apply the operator to the whole input function \(a\), then evaluate the resulting output function at the query location \(y\). A traditional solver computes \(\mathcal{G}(a)\) one \(a\) at a time. DeepONet learns \(\mathcal{G}_\theta \approx \mathcal{G}\) from data — pairs \(\{(a_i, u_i)\}_{i=1}^{N}\) — so that inference on new inputs is a single forward pass.

19.2 The universal approximation theorem for operators

Chen & Chen (1995) proved that a network with two sub-networks can approximate any continuous nonlinear operator to arbitrary accuracy. This theorem provides the mathematical foundation for DeepONet.

Theorem (informal): For any continuous operator \(\mathcal{G}: \mathcal{A} \to \mathcal{U}\) and any \(\epsilon > 0\), there exist neural networks (branch and trunk) such that:

\[ \bigl|\mathcal{G}(a)(y) - \mathcal{G}_\theta(a)(y)\bigr| < \epsilon \]

for all input functions \(a \in \mathcal{A}\) and all query locations \(y\).

19.3 Architecture: branch and trunk

DeepONet decomposes the operator approximation into two learnable components:

Input function a(x)                    Query location y
evaluated at sensors                   (where to evaluate output)
[a(x₁), a(x₂), ..., a(xₘ)]          (y₁, y₂, ..., yₐ)
         │                                     │
         ▼                                     ▼
   ┌───────────┐                        ┌───────────┐
   │  Branch    │                        │  Trunk    │
   │  Network   │                        │  Network  │
   └─────┬─────┘                        └─────┬─────┘
         │                                     │
    [b₁, b₂, ..., bₚ]                  [τ₁, τ₂, ..., τₚ]
         │                                     │
         └──────────────┬──────────────────────┘
                        │  dot product
                        ▼
                  G_θ(a)(y) = Σₖ bₖ · τₖ + bias

\[ \mathcal{G}_\theta(a)(y) = \sum_{k=1}^{p} b_k(a) \cdot \tau_k(y) + b_0 \]

  • Branch network — takes the input function \(a\) sampled at \(m\) fixed sensor locations \(\{x_1, \ldots, x_m\}\) and outputs \(p\) coefficients \(\{b_1, \ldots, b_p\}\). Think of these as the “expansion coefficients” of the output in a learned basis.

  • Trunk network — takes the query location \(y\) (where we want to evaluate the output) and produces \(p\) basis functions \(\{\tau_1, \ldots, \tau_p\}\).

The final output is the inner product of the branch and trunk outputs. This is elegant: the branch encodes what the input function looks like, while the trunk learns where to produce the output. In other words, the branch depends on the input function \(a\), while the trunk depends on the query location \(y\).

19.4 Variants

  • Unstacked DeepONet — single branch, single trunk (described above). The most common variant.
  • Stacked DeepONet — multiple independent branch-trunk pairs summed together.
  • POD-DeepONet — the trunk basis functions are replaced by Proper Orthogonal Decomposition (POD) modes computed from the training data, making the trunk data-driven rather than learned.

19.5 Code example: learning the antiderivative operator

We train a DeepONet to learn the antiderivative operator:

\[ \mathcal{G}(a)(y) = \int_0^y a(s)\, ds \]

Given a function \(a(x)\), the operator returns its antiderivative evaluated at any query point \(y\).

using Lux, Random, Optimisers, Zygote, Statistics, Printf, CairoMakie

rng = Xoshiro(42)
# Data generation: random functions and their antiderivatives
n_sensors = 50        # fixed sensor locations for branch input
n_query = 50          # query points per sample
n_train = 2000
n_test = 200

x_sensors = Float32.(range(0, 1, length = n_sensors))
y_query = Float32.(range(0, 1, length = n_query))
dx_sensor = x_sensors[2] - x_sensors[1]

function random_function(rng)
    # Random sum of sinusoids
    n_modes = rand(rng, 2:5)
    coeffs = randn(rng, Float32, n_modes)
    freqs = Float32.(rand(rng, 1:6, n_modes))
    phases = 2π .* rand(rng, Float32, n_modes)
    function f(x)
        val = 0.0f0
        for j in 1:n_modes
            val += coeffs[j] * sin(2π * freqs[j] * x + phases[j])
        end
        return val
    end
    return f
end

function generate_data(rng, n_samples)
    A = zeros(Float32, n_sensors, n_samples)        # branch inputs
    Y = zeros(Float32, 1, n_query * n_samples)       # query locations
    G = zeros(Float32, 1, n_query * n_samples)       # true outputs

    for i in 1:n_samples
        f = random_function(rng)
        # Evaluate at sensors
        a_vals = f.(x_sensors)
        A[:, i] = a_vals

        # Compute antiderivative at query points via trapezoidal rule
        for (j, yj) in enumerate(y_query)
            # Integrate from 0 to yj
            n_int = max(2, round(Int, yj / dx_sensor) + 1)
            x_int = Float32.(range(0, yj, length = n_int))
            f_int = f.(x_int)
            integral = sum(0.5f0 * (f_int[1:end-1] .+ f_int[2:end]) .* diff(x_int))

            idx = (i - 1) * n_query + j
            Y[1, idx] = yj
            G[1, idx] = integral
        end
    end
    return A, Y, G
end

A_train, Y_train, G_train = generate_data(rng, n_train)
A_test, Y_test, G_test = generate_data(Xoshiro(123), n_test)
# Build DeepONet: Branch + Trunk
p = 64   # output dimension of both branch and trunk

branch_net = Chain(
    Dense(n_sensors => 128, relu),
    Dense(128 => 128, relu),
    Dense(128 => p)
)

trunk_net = Chain(
    Dense(1 => 128, relu),
    Dense(128 => 128, relu),
    Dense(128 => p)
)

ps_branch, st_branch = Lux.setup(rng, branch_net)
ps_trunk, st_trunk = Lux.setup(rng, trunk_net)
bias = zeros(Float32, 1)

ps_all = (branch = ps_branch, trunk = ps_trunk, bias = bias)
st_all = (branch = st_branch, trunk = st_trunk)
# Forward pass: DeepONet output
function deeponet_forward(A, Y, ps, st, n_samples, n_q)
    # Branch: process all input functions
    b_out, st_b = branch_net(A, ps.branch, st.branch)  # (p, n_samples)

    # Trunk: process all query locations
    t_out, st_t = trunk_net(Y, ps.trunk, st.trunk)      # (p, n_samples * n_q)

    # For each query point, dot-product the branch output of its sample
    # with the trunk output
    # Reshape branch output to match queries
    b_expanded = repeat(b_out, 1, n_q)  # (p, n_samples * n_q) — each sample's
    # coefficients repeated n_q times
    # This is a simplification: we need to properly tile
    b_tiled = similar(t_out)
    for i in 1:n_samples
        idx_range = (i-1)*n_q+1 : i*n_q
        b_tiled[:, idx_range] .= b_out[:, i]
    end

    # Dot product + bias
    output = sum(b_tiled .* t_out, dims = 1) .+ ps.bias  # (1, n_samples * n_q)

    st_new = (branch = st_b, trunk = st_t)
    return output, st_new
end

# Loss function
function deeponet_loss(ps, st)
    pred, st_new = deeponet_forward(A_train, Y_train, ps, st, n_train, n_query)
    loss = mean((pred .- G_train) .^ 2)
    return loss, st_new
end
# Training loop
opt = Adam(0.001f0)
opt_state = Optimisers.setup(opt, ps_all)

for epoch in 1:1000
    (loss_val, st_new), grads = Zygote.withgradient(ps_all) do ps
        deeponet_loss(ps, st_all)
    end
    opt_state, ps_all = Optimisers.update(opt_state, ps_all, grads[1])
    st_all = st_new

    if epoch == 1 || epoch % 200 == 0
        # Test error
        pred_test, _ = deeponet_forward(A_test, Y_test, ps_all, st_all, n_test, n_query)
        test_mse = mean((pred_test .- G_test) .^ 2)
        @printf "Epoch %4d  Train MSE = %.6f  Test MSE = %.6f\n" epoch loss_val test_mse
    end
end
# Visualize predictions on test samples
fig = Figure(size = (700, 500))
n_show = 4

for i in 1:n_show
    f_test = A_test[:, i]
    idx_range = (i-1)*n_query+1 : i*n_query

    true_vals = vec(G_test[:, idx_range])
    pred_vals_all, _ = deeponet_forward(
        reshape(f_test, :, 1),
        reshape(Y_test[:, idx_range], 1, :),
        ps_all, st_all, 1, n_query
    )
    pred_vals = vec(pred_vals_all)

    row = (i - 1) ÷ 2 + 1
    col = (i - 1) % 2 + 1
    ax = Axis(fig[row, col], xlabel = "y", ylabel = "∫₀ʸ a(s) ds",
              title = "Test sample $i")
    lines!(ax, y_query, true_vals, color = :black, linewidth = 2, label = "Exact")
    lines!(ax, y_query, pred_vals, color = :steelblue, linewidth = 2,
           linestyle = :dash, label = "DeepONet")
    if i == 1
        axislegend(ax, position = :lt, labelsize = 10)
    end
end

Label(fig[0, :], "DeepONet: learning the antiderivative operator", fontsize = 14)
fig

19.6 Key properties of DeepONet

19.6.1 Strengths

  • Mathematical foundation — grounded in the universal approximation theorem for operators, providing theoretical guarantees on expressiveness.
  • Flexible geometry — unlike FNO (which requires regular grids for the FFT), DeepONet works with arbitrary input sensor locations and query points. This is crucial for geoscience data, which is rarely on regular grids.
  • Modular design — the branch and trunk are independent networks that can be customized separately. For example, the branch could be a CNN if the input is an image, or an RNN if the input is a time series.
  • Point-wise evaluation — the trunk network produces output at individual query points, making it efficient for problems where you only need the solution at specific locations.

19.6.2 Limitations

  • Fixed sensor locations — the branch network requires the input function to be evaluated at the same fixed sensor locations for all samples. This can be restrictive if data comes from different measurement configurations.
  • Data hungry — DeepONet typically requires more training data than FNO for the same accuracy on problems with regular grids, because it doesn’t exploit spatial structure the way spectral methods do.
  • Scaling — for high-dimensional output functions, the number of trunk basis functions \(p\) may need to be large, increasing the computational cost.

19.7 DeepONet vs FNO

Feature DeepONet FNO
Architecture Branch + Trunk Spectral convolution layers
Grid requirements Arbitrary (irregular OK) Regular grid (for FFT)
Theoretical basis Universal approximation theorem Kernel integral operators
Resolution invariance Through trunk network Through Fourier representation
Best for Irregular data, point queries Regular-grid data, global patterns
Data efficiency Moderate Higher (exploits spatial structure)

19.8 Geoscience applications

DeepONet is particularly well-suited for geoscience problems because field data is rarely on regular grids:

  • Subsurface flow surrogates — DeepONet can learn the mapping from permeability fields to pressure/saturation solutions, enabling real-time uncertainty quantification for reservoir management.
  • Seismic wave modeling — given a velocity model as input, DeepONet predicts the wavefield at arbitrary receiver locations, replacing expensive wave-equation solves during inversion.
  • Well log prediction — the branch network ingests spatially irregular borehole measurements while the trunk evaluates predictions at arbitrary subsurface locations.
  • Climate downscaling — DeepONet maps coarse-resolution climate model output to fine-resolution fields, handling the irregular observation network naturally.
NoteNext: Physics-Informed DeepONet

A powerful extension of DeepONet incorporates PDE constraints into the training loss, dramatically reducing the amount of training data needed. We explore this in the next chapter.