using Lux, Random, Optimisers, Zygote, Statistics, Printf, CairoMakie, FFTW
rng = Xoshiro(42)21 Fourier Neural Operator
- Fourier Neural Operator (FNO) — learning solution operators in Fourier space, enabling resolution-invariant surrogate models for PDEs (Li et al., 2021).
- Neural operator survey — a comprehensive mathematical treatment of neural operators and their approximation properties (Kovachki et al., 2023).
- FourCastNet — global weather prediction using adaptive Fourier neural operators (Pathak et al., 2022).
- U-FNO — enhanced FNO with U-Net-style skip connections for multiphase flow in subsurface reservoirs (Wen et al., 2022).
- FNO on general geometries — extending FNO to handle irregular domains via learned deformations (Li et al., 2023).
- Neural operators for science — review of neural operators accelerating scientific simulations (Azizzadenesheli et al., 2024).
The DeepONet learned operators using a branch-trunk decomposition. The Fourier Neural Operator (FNO) (Li et al., 2021) takes a completely different approach: it performs learning directly in Fourier space, where global spatial patterns are captured efficiently and the learned operator is inherently resolution-invariant.
FNO has become arguably the most successful neural operator architecture, powering applications from weather forecasting to fluid dynamics to seismic inversion.
21.1 From convolution to spectral convolution
A standard convolution applies a local kernel:
\[ (K * v)(x) = \int K(x - y)\, v(y)\, dy \]
This has a local receptive field — each output point depends only on nearby inputs. For PDEs, where information propagates globally (e.g., waves, diffusion), stacking many convolutional layers is needed to capture long-range interactions.
By the convolution theorem, convolution in physical space is multiplication in Fourier space:
\[ \mathcal{F}[K * v] = \mathcal{F}[K] \cdot \mathcal{F}[v] \]
The FNO leverages this: instead of learning a spatial kernel, it learns a spectral filter — a weight tensor that multiplies the Fourier coefficients directly. This gives each layer a global receptive field while keeping the computation efficient through the FFT.
21.2 FNO architecture
An FNO consists of three stages:
21.2.1 1. Lifting layer
A point-wise linear layer that maps the input from its native dimension to a higher-dimensional representation:
\[ v^{(0)}(x) = P\, a(x) + \mathbf{q} \]
where \(P \in \mathbb{R}^{d_v \times d_a}\) and \(a(x)\) is the input function at location \(x\).
21.2.2 2. Fourier layers (repeated \(L\) times)
Each Fourier layer applies:
\[ v^{(l+1)}(x) = \sigma\!\Bigl(\underbrace{W^{(l)}\, v^{(l)}(x)}_{\text{local linear}} + \underbrace{\mathcal{F}^{-1}\!\bigl[R^{(l)} \cdot \mathcal{F}[v^{(l)}]\bigr](x)}_{\text{spectral convolution}}\Bigr) \]
The spectral convolution keeps only the lowest \(k_{\max}\) Fourier modes and sets the rest to zero. The learned weight tensor \(R^{(l)} \in \mathbb{C}^{d_v \times d_v \times k_{\max}}\) acts as a frequency-dependent linear transformation. Here the superscript \((l)\) plays the same role as elsewhere in the book: it indexes layer depth.
Input v^(l)(x)
│
├──────────────────────────────────┐
│ │
▼ ▼
FFT Local linear W^(l)
│
▼
Truncate to k_max modes
│
▼
Multiply by R^(l)
│
▼
Inverse FFT
│ │
└──────────┬───────────────────────┘
│ add
▼
Activation σ
│
▼
Output v^(l+1)(x)
21.2.3 3. Projection layer
A point-wise linear layer that maps back to the output dimension:
\[ u(x) = Q\, v^{(L)}(x) + r \]
21.2.4 Why this works
- Global receptive field — each Fourier layer sees the entire spatial domain, not just a local neighborhood.
- Resolution invariance — the spectral representation is independent of the grid resolution. A model trained on a \(64 \times 64\) grid can be evaluated on a \(256 \times 256\) grid by zero-padding the Fourier modes.
- Efficiency — the FFT costs \(O(n \log n)\), making the spectral convolution as fast as spatial convolution for typical grid sizes.
- Mode truncation as regularization — keeping only \(k_{\max}\) modes implicitly smooths the learned operator, preventing overfitting to high-frequency noise.
21.3 Code example: FNO for the 1D advection equation
We train a simple FNO to learn the solution operator for the 1D advection equation:
\[ \frac{\partial u}{\partial t} + c \frac{\partial u}{\partial x} = 0, \quad u(x, 0) = u_0(x) \]
The exact solution is a rightward shift: \(u(x, t) = u_0(x - ct)\). The operator maps \(u_0 \mapsto u(\cdot, T)\).
# Generate training data: advection with periodic BCs
nx = 64
c = 1.0f0
T_final = 0.5f0
dx = 1.0f0 / nx
x_grid = Float32.(range(0, 1 - dx, length = nx))
function generate_pair(rng)
# Random initial condition: sum of sinusoids
n_modes = rand(rng, 2:5)
u0 = zeros(Float32, nx)
for _ in 1:n_modes
k = rand(rng, 1:6)
a = 0.5f0 * randn(rng, Float32)
ϕ = 2π * rand(rng, Float32)
u0 .+= a .* sin.(2π .* k .* x_grid .+ ϕ)
end
# Exact solution at t = T: shift by c*T (periodic)
shift = round(Int, c * T_final / dx)
uT = circshift(u0, -shift)
return u0, uT
end
n_train = 500
n_test = 100
U0_train = zeros(Float32, nx, n_train)
UT_train = zeros(Float32, nx, n_train)
for i in 1:n_train
u0, uT = generate_pair(rng)
U0_train[:, i] = u0
UT_train[:, i] = uT
end
U0_test = zeros(Float32, nx, n_test)
UT_test = zeros(Float32, nx, n_test)
for i in 1:n_test
u0, uT = generate_pair(Xoshiro(9000 + i))
U0_test[:, i] = u0
UT_test[:, i] = uT
end# Spectral convolution layer (custom Lux layer)
struct SpectralConv <: Lux.AbstractLuxLayer
in_channels::Int
out_channels::Int
modes::Int
end
function Lux.initialparameters(rng::AbstractRNG, l::SpectralConv)
scale = 1.0f0 / (l.in_channels * l.out_channels)
R_real = scale .* randn(rng, Float32, l.out_channels, l.in_channels, l.modes)
R_imag = scale .* randn(rng, Float32, l.out_channels, l.in_channels, l.modes)
return (R_real = R_real, R_imag = R_imag)
end
Lux.initialstates(::AbstractRNG, ::SpectralConv) = NamedTuple()
function (l::SpectralConv)(x, ps, st)
# x: (nx, channels, batch)
nx_in = size(x, 1)
batch = size(x, 3)
# FFT along spatial dimension
x_ft = rfft(x, 1) # (nx÷2+1, channels, batch)
# Multiply by learned weights for first `modes` frequencies
R = complex.(ps.R_real, ps.R_imag) # (out_ch, in_ch, modes)
out_ft = zeros(ComplexF32, size(x_ft, 1), l.out_channels, batch)
for b in 1:batch
for k in 1:l.modes
for o in 1:l.out_channels
for ic in 1:l.in_channels
out_ft[k, o, b] += R[o, ic, k] * x_ft[k, ic, b]
end
end
end
end
# Inverse FFT
out = irfft(out_ft, nx_in, 1)
return out, st
end# Build FNO: lift → spectral layers → project
modes = 12
width = 16
lift = Dense(1 => width)
spectral1 = SpectralConv(width, width, modes)
skip1 = Dense(width => width)
spectral2 = SpectralConv(width, width, modes)
skip2 = Dense(width => width)
spectral3 = SpectralConv(width, width, modes)
skip3 = Dense(width => width)
project = Dense(width => 1)
ps_lift, st_lift = Lux.setup(rng, lift)
ps_s1, st_s1 = Lux.setup(rng, spectral1)
ps_sk1, st_sk1 = Lux.setup(rng, skip1)
ps_s2, st_s2 = Lux.setup(rng, spectral2)
ps_sk2, st_sk2 = Lux.setup(rng, skip2)
ps_s3, st_s3 = Lux.setup(rng, spectral3)
ps_sk3, st_sk3 = Lux.setup(rng, skip3)
ps_proj, st_proj = Lux.setup(rng, project)
ps_all = (lift = ps_lift,
s1 = ps_s1, sk1 = ps_sk1,
s2 = ps_s2, sk2 = ps_sk2,
s3 = ps_s3, sk3 = ps_sk3,
proj = ps_proj)
st_all = (lift = st_lift,
s1 = st_s1, sk1 = st_sk1,
s2 = st_s2, sk2 = st_sk2,
s3 = st_s3, sk3 = st_sk3,
proj = st_proj)
function fno_forward(x, ps, st)
# x: (nx, 1, batch)
nx_in, _, batch = size(x)
# Lift to higher dimension
x_flat = reshape(x, 1, nx_in * batch)
v, st_l = lift(x_flat, ps.lift, st.lift)
v = reshape(v, nx_in, width, batch)
# Fourier layer 1
v_spec, st_s1 = spectral1(v, ps.s1, st.s1)
v_skip = reshape(v, width, nx_in * batch)
v_local, st_sk1 = skip1(v_skip, ps.sk1, st.sk1)
v_local = reshape(v_local, nx_in, width, batch)
v = relu.(v_spec .+ v_local)
# Fourier layer 2
v_spec2, st_s2 = spectral2(v, ps.s2, st.s2)
v_skip2 = reshape(v, width, nx_in * batch)
v_local2, st_sk2 = skip2(v_skip2, ps.sk2, st.sk2)
v_local2 = reshape(v_local2, nx_in, width, batch)
v = relu.(v_spec2 .+ v_local2)
# Fourier layer 3
v_spec3, st_s3 = spectral3(v, ps.s3, st.s3)
v_skip3 = reshape(v, width, nx_in * batch)
v_local3, st_sk3 = skip3(v_skip3, ps.sk3, st.sk3)
v_local3 = reshape(v_local3, nx_in, width, batch)
v = relu.(v_spec3 .+ v_local3)
# Project back
v_flat = reshape(v, width, nx_in * batch)
out, st_p = project(v_flat, ps.proj, st.proj)
out = reshape(out, nx_in, 1, batch)
st_new = (lift = st_l,
s1 = st_s1, sk1 = st_sk1,
s2 = st_s2, sk2 = st_sk2,
s3 = st_s3, sk3 = st_sk3,
proj = st_p)
return out, st_new
end# Training loop
opt = Adam(0.001f0)
opt_state = Optimisers.setup(opt, ps_all)
X_train = reshape(U0_train, nx, 1, n_train)
Y_train = reshape(UT_train, nx, 1, n_train)
X_test = reshape(U0_test, nx, 1, n_test)
Y_test = reshape(UT_test, nx, 1, n_test)
for epoch in 1:500
(loss, st_new), grads = Zygote.withgradient(ps_all) do ps
pred, st_ = fno_forward(X_train, ps, st_all)
l = mean((pred .- Y_train) .^ 2)
(l, st_)
end
opt_state, ps_all = Optimisers.update(opt_state, ps_all, grads[1])
st_all = st_new
if epoch == 1 || epoch % 100 == 0
pred_test, _ = fno_forward(X_test, ps_all, st_all)
test_mse = mean((pred_test .- Y_test) .^ 2)
@printf "Epoch %3d Train MSE = %.6f Test MSE = %.6f\n" epoch loss test_mse
end
end# Test on new initial conditions
n_show = 4
fig = Figure(size = (700, 500))
for i in 1:n_show
u0_test, uT_test = generate_pair(Xoshiro(5000 + i))
x_in = reshape(u0_test, nx, 1, 1)
pred, _ = fno_forward(x_in, ps_all, st_all)
row = (i - 1) ÷ 2 + 1
col = (i - 1) % 2 + 1
ax = Axis(fig[row, col], title = "Test $i",
xlabel = "x", ylabel = "u")
lines!(ax, x_grid, u0_test, color = (:gray, 0.5), label = "u₀")
lines!(ax, x_grid, uT_test, color = :black, linewidth = 2, label = "Exact u(T)")
lines!(ax, x_grid, vec(pred), color = :steelblue, linewidth = 2,
linestyle = :dash, label = "FNO prediction")
if i == 1
axislegend(ax, position = :rt, labelsize = 10)
end
end
Label(fig[0, :], "FNO: advection equation — generalizing to new initial conditions",
fontsize = 14)
fig21.4 Resolution invariance
One of FNO’s most remarkable properties is resolution invariance (also called zero-shot super-resolution). Because the learned operator acts on Fourier modes rather than grid points, a model trained at one resolution can be evaluated at a different resolution:
- Train on a \(64\)-point grid.
- Evaluate on a \(256\)-point grid by zero-padding the spectral weights to the new resolution.
This works because the Fourier coefficients have the same physical meaning regardless of discretization — they represent the same spatial frequencies. In practice, FNO maintains good accuracy when evaluating at \(2\)–\(4\times\) the training resolution, with degradation beyond that.
21.5 FNO variants
The original FNO has inspired many extensions:
| Variant | Key idea | Use case |
|---|---|---|
| FNO-2D/3D | Multi-dimensional spectral convolution | 2D/3D PDEs |
| U-FNO (Wen et al., 2022) | U-Net-style skip connections between Fourier layers | Multiphase subsurface flow |
| Geo-FNO (Li et al., 2023) | Learned input deformation for irregular domains | Geophysics, complex geometries |
| AFNO (Adaptive FNO) (Pathak et al., 2022) | Token mixing via adaptive Fourier layers | Global weather forecasting |
| Factorized FNO | Factorize multi-dimensional spectral weights | Reduced memory for 3D problems |
21.6 Comparison with other neural operators
| Feature | FNO | DeepONet | PI-DeepONet |
|---|---|---|---|
| Core mechanism | Spectral convolution | Branch-trunk decomposition | Branch-trunk + PDE loss |
| Grid requirement | Regular grid | Arbitrary | Arbitrary |
| Resolution invariance | Native (via Fourier) | Via trunk re-evaluation | Via trunk re-evaluation |
| Training data | Input-output pairs | Input-output pairs | Physics (+ optional data) |
| Global receptive field | Per layer | Per forward pass | Per forward pass |
| Best suited for | Regular-grid PDEs | Irregular sensors | Limited/no solver data |
21.7 Geoscience applications
FNO is transforming geoscience workflows where repeated forward modeling is the computational bottleneck:
- Weather forecasting — FourCastNet (Pathak et al., 2022) used adaptive Fourier neural operators for global weather prediction, producing 10-day forecasts in seconds. Pangu-Weather (Bi et al., 2023) achieved competitive medium-range forecasting with 3D neural architectures.
- Seismic inversion surrogates — FNO-based surrogates replace expensive wave-equation solves during seismic full-waveform inversion, achieving orders-of-magnitude speedup (Yin et al., 2023).
- Subsurface multiphase flow — U-FNO (Wen et al., 2022) was applied to CO₂ storage simulations, learning the mapping from injection scenarios to pressure and saturation fields for real-time reservoir management.
- Seismic wave simulation — FNO learns the mapping from velocity models to wavefields, enabling rapid scenario testing for seismic hazard assessment (Song & Alkhalifah, 2023).
- General geometries — Geo-FNO (Li et al., 2023) handles the irregular domains common in Earth science (topography, coastlines, geological boundaries) through learned coordinate deformations.
FNO excels when your data lives on regular grids and you need fast, resolution-invariant operator evaluation. If your data is on irregular sensors or you lack training data, consider DeepONet or PI-DeepONet instead.