8  Convolutional Neural Networks

TipKey references
  • LeNet — the first successful convolutional neural network for image recognition (LeCun et al., 1989).
  • AlexNet — deep CNN that won the ImageNet competition and launched the deep-learning era (Krizhevsky et al., 2012).
  • U-Net — encoder-decoder architecture with skip connections for dense prediction (Ronneberger et al., 2015).
  • ResNet — residual connections enabling very deep networks (100+ layers) (He et al., 2016).

A convolutional neural network (CNN) exploits the spatial structure in data — the fact that nearby pixels or grid cells tend to be related. Instead of connecting every input to every neuron, a CNN slides small learned filters across the data, detecting local patterns such as edges, textures, and shapes. This makes CNNs far more parameter-efficient than feedforward networks for image-like data.

8.1 The convolution operation

In a CNN, a filter (or kernel) is a small weight matrix, typically \(3 \times 3\) or \(5 \times 5\). The filter slides across the input and at each position computes a dot product between the filter weights and the local patch of input values. This produces a feature map — a new grid where each cell represents how strongly that local pattern was detected at that position.

For a 2D input \(\mathbf{X}\) and a filter \(\mathbf{K}\) of size \(k \times k\), the convolution at position \((i, j)\) is:

\[ (\mathbf{X} * \mathbf{K})_{i,j} = \sum_{m=0}^{k-1} \sum_{n=0}^{k-1} \mathbf{K}_{m,n} \cdot \mathbf{X}_{i+m,\, j+n} \]

A convolutional layer applies many such filters in parallel, each learning to detect a different pattern.

8.2 Pooling

After convolution, pooling layers reduce the spatial size of the feature maps, keeping only the most important information. The most common type is max pooling, which takes the maximum value in each small window (e.g., \(2 \times 2\)). Pooling reduces computation, provides some translation invariance, and increases the receptive field of deeper layers.

8.3 A typical CNN architecture

A CNN usually alternates convolutional and pooling layers, progressively reducing spatial resolution while increasing the number of feature channels:

  1. Input — e.g., a \(28 \times 28\) single-channel image.
  2. Conv → ReLU → Pool — repeated 2–3 times.
  3. Flatten — reshape the 2D feature maps into a 1D vector.
  4. Dense layers — one or two fully connected layers for the final prediction.

8.4 Code example: classifying simple seismic image patches

We use a standard 2D CNN to classify tiny synthetic seismic-style images into three classes: layered horizons, a faulted horizon pattern, and a dome-like structure. This is easier to read than the previous texture example because the three patterns are visually distinct and the problem matches the usual CNN story: take an image as input, return one class label as output.

The problem formulation is simple: each input is one \(32 \times 32\) grayscale image patch, and the target is one of three structural classes. The network output is a vector of three class scores. After a softmax, those scores become class probabilities, and the largest probability gives the predicted class.

This is still a toy problem. Real seismic interpretation is not this clean. But for a first CNN example, it is useful because we can clearly see what the network is trying to separate and we can directly check whether the predictions match the visible pattern in the image.

using Lux, Random, Optimisers, Zygote, Statistics, Printf, CairoMakie

rng = Xoshiro(42)

class_names = ["Layered", "Faulted", "Dome"]

function gaussian2d(x, y, μx, μy, σx, σy)
    exp.(-0.5f0 .* (((x .- μx) ./ σx) .^ 2 .+ ((y .- μy) ./ σy) .^ 2))
end

# Generate a small synthetic seismic-style image patch.
function make_seismic_patch(rng, n = 32)
    class_id = rand(rng, 1:3)
    x = Float32.(range(-1, 1, length = n))
    y = Float32.(range(-1, 1, length = n))
    xx = repeat(reshape(x, n, 1), 1, n)
    yy = repeat(reshape(y, 1, n), n, 1)

    image = zeros(Float32, n, n)

    if class_id == 1
        # Layered reflectors.
        image .= 0.50f0 .+ 0.22f0 .* sin.(Float32(7.0) .* Float32(pi) .* (yy .+ 0.04f0 .* sin.(Float32(2.0) .* Float32(pi) .* xx)))
        image .+= 0.015f0 .* randn(rng, Float32, n, n)
    elseif class_id == 2
        # Faulted reflectors with a visible offset.
        shifted_yy = yy .+ 0.22f0 .* (xx .> 0.08f0)
        image .= 0.50f0 .+ 0.22f0 .* sin.(Float32(7.0) .* Float32(pi) .* shifted_yy)
        image .-= 0.12f0 .* gaussian2d(xx, yy, 0.08f0, 0.0f0, 0.03f0, 0.85f0)
        image .+= 0.015f0 .* randn(rng, Float32, n, n)
    else
        # Dome-like reflector geometry.
        dome = yy .+ 0.55f0 .* exp.(-((xx ./ 0.42f0) .^ 2))
        image .= 0.50f0 .+ 0.22f0 .* sin.(Float32(7.0) .* Float32(pi) .* dome)
        image .+= 0.015f0 .* randn(rng, Float32, n, n)
    end

    image = Float32.(clamp.(image, 0.03f0, 0.98f0))
    patch = reshape(image, n, n, 1)
    y = zeros(Float32, 3)
    y[class_id] = 1.0f0
    return patch, y, class_id
end

# Create a labelled image dataset
n_samples = 900
patches = zeros(Float32, 32, 32, 1, n_samples)    # (height, width, channels, batch)
labels = zeros(Float32, 3, n_samples)

for i in 1:n_samples
    x, y, _ = make_seismic_patch(rng)
    patches[:, :, :, i] .= x
    labels[:, i] .= y
end

# Train/test split
idx = randperm(rng, n_samples)
n_train = Int(round(0.8 * n_samples))
tr = idx[1:n_train]
te = idx[n_train+1:end]

X_train, Y_train = patches[:, :, :, tr], labels[:, tr]
X_test,  Y_test  = patches[:, :, :, te], labels[:, te]
(Float32[0.51531047 0.28371927 … 0.72622246 0.5048716; 0.51920027 0.2755116 … 0.6919893 0.4924293; … ; 0.7353064 0.4865233 … 0.59237975 0.7328739; 0.7295664 0.4930115 … 0.55437815 0.69425195;;;; 0.5081255 0.3025327 … 0.7114682 0.5237674; 0.48398626 0.2728822 … 0.7040553 0.52349836; … ; 0.48296136 0.31030953 … 0.6851819 0.4887031; 0.47210735 0.28636742 … 0.71370864 0.4752179;;;; 0.50086296 0.2740744 … 0.7044406 0.48916095; 0.48162553 0.3077873 … 0.73861116 0.47220704; … ; 0.4731664 0.28028107 … 0.73368406 0.46567997; 0.50586027 0.26091206 … 0.723414 0.47897398;;;; … ;;;; 0.47184294 0.27145138 … 0.7256305 0.48414078; 0.49589533 0.28599662 … 0.71444005 0.48386353; … ; 0.697475 0.505437 … 0.550541 0.71377194; 0.7494174 0.5272526 … 0.5671363 0.6969744;;;; 0.5121751 0.26954803 … 0.687372 0.47444332; 0.46662736 0.26066983 … 0.7079319 0.48588318; … ; 0.45485437 0.28552005 … 0.7177441 0.46606797; 0.49538812 0.2798146 … 0.7198205 0.48189262;;;; 0.5130103 0.28475654 … 0.715824 0.4742058; 0.4770684 0.2957339 … 0.71576345 0.49057734; … ; 0.4625973 0.27397218 … 0.6922824 0.4764485; 0.48471308 0.3116813 … 0.71817493 0.5003804], Float32[0.0 0.0 … 0.0 0.0; 1.0 0.0 … 0.0 0.0; 0.0 1.0 … 1.0 1.0])

The input tensor has shape (height, width, channels, batch), and the label for each patch is a one-hot vector with three entries. So this is just standard three-class image classification with geoscience-flavored patterns.

# Build a small 2D CNN classifier
model = Chain(
    Conv((5, 5), 1 => 8, relu; pad = SamePad()),
    MaxPool((2, 2)),
    Conv((3, 3), 8 => 16, relu; pad = SamePad()),
    MaxPool((2, 2)),
    WrappedFunction(x -> reshape(x, :, size(x, 4))),
    Dense(16 * 8 * 8 => 24, relu),
    Dense(24 => 3)
)

ps, st = Lux.setup(rng, model)

function softmax_cols(x)
    x_shift = x .- maximum(x, dims = 1)
    ex = exp.(x_shift)
    ex ./ sum(ex, dims = 1)
end

function cross_entropy_loss(model, ps, st, data)
    x, y = data
    logits, st_new = model(x, ps, st)
    ŷ = softmax_cols(logits)
    ε = 1.0f-7
    loss = -mean(sum(y .* log.(ŷ .+ ε), dims = 1))
    return loss, st_new, ()
end

function predicted_classes(probabilities)
    [findmax(probabilities[:, i])[2] for i in axes(probabilities, 2)]
end

function true_classes(labels)
    [findmax(labels[:, i])[2] for i in axes(labels, 2)]
end
true_classes (generic function with 1 method)
function train_model(model, ps, st, data; epochs = 240, lr = 0.0015f0)
    tstate = Training.TrainState(model, ps, st, Adam(lr))
    for epoch in 1:epochs
        _, loss, _, tstate = Training.single_train_step!(
            AutoZygote(), cross_entropy_loss, data, tstate
        )
        if epoch == 1 || epoch % 60 == 0
            @printf "Epoch %3d  cross-entropy = %.4f\n" epoch loss
        end
    end
    return tstate
end

tstate = train_model(model, ps, st, (X_train, Y_train))

# Holdout evaluation
test_logits, _ = model(X_test, tstate.parameters, tstate.states)
test_prob = softmax_cols(test_logits)
test_loss = -mean(sum(Y_test .* log.(test_prob .+ 1.0f-7), dims = 1))
test_acc = mean(predicted_classes(test_prob) .== true_classes(Y_test))
@printf "Holdout cross-entropy: %.4f  Accuracy: %.3f\n" test_loss test_acc
Epoch   1  cross-entropy = 2.7561
Epoch  60  cross-entropy = 0.0423
Epoch 120  cross-entropy = 0.0014
Epoch 180  cross-entropy = 0.0003
Epoch 240  cross-entropy = 0.0001
Holdout cross-entropy: 0.0001  Accuracy: 1.000

The output above reports two things. The cross-entropy tells us how confident the network is on the correct class, while the accuracy tells us how often it predicts the right class on unseen patches. For a teaching example like this one, we want both numbers to show that the CNN has learned the visible image pattern instead of just memorizing the training set.

confusion = zeros(Int, 3, 3)
for (true_class, pred_class) in zip(true_classes(Y_test), predicted_classes(test_prob))
    confusion[true_class, pred_class] += 1
end

for i in 1:3
    class_acc = confusion[i, i] / sum(confusion[i, :])
    @printf "%s accuracy: %.3f\n" class_names[i] class_acc
end
Layered accuracy: 1.000
Faulted accuracy: 1.000
Dome accuracy: 1.000

Those per-class accuracies are useful because a single overall accuracy can hide one weak class. If one seismic pattern is consistently confused with another, it will show up here even when the mean score still looks good.

# Visualize one image patch from each class
function sample_with_label(seed, wanted_label)
    rng_local = Xoshiro(seed)
    while true
        x, y, class_id = make_seismic_patch(rng_local)
        class_id == wanted_label && return x, y, class_id
    end
end

patch_layered, _, _ = sample_with_label(99, 1)
patch_faulted, _, _ = sample_with_label(199, 2)
patch_dome, _, _ = sample_with_label(299, 3)

prob_layered, _ = model(reshape(patch_layered, 32, 32, 1, 1), tstate.parameters, tstate.states)
prob_faulted, _ = model(reshape(patch_faulted, 32, 32, 1, 1), tstate.parameters, tstate.states)
prob_dome, _ = model(reshape(patch_dome, 32, 32, 1, 1), tstate.parameters, tstate.states)

prob_layered = softmax_cols(prob_layered)
prob_faulted = softmax_cols(prob_faulted)
prob_dome = softmax_cols(prob_dome)

fig = Figure(size = (640, 760))

ax1 = Axis(fig[1, 1], title = "Layered patch  predicted: $(class_names[findmax(prob_layered[:, 1])[2]])")
image!(ax1, permutedims(dropdims(patch_layered, dims = 3), (2, 1)))
hidedecorations!(ax1)

ax2 = Axis(fig[2, 1], title = "Faulted patch  predicted: $(class_names[findmax(prob_faulted[:, 1])[2]])")
image!(ax2, permutedims(dropdims(patch_faulted, dims = 3), (2, 1)))
hidedecorations!(ax2)

ax3 = Axis(fig[3, 1], title = "Dome-like patch  predicted: $(class_names[findmax(prob_dome[:, 1])[2]])")
image!(ax3, permutedims(dropdims(patch_dome, dims = 3), (2, 1)))
hidedecorations!(ax3)

Colorbar(fig[1:3, 2], limits = (0, 1), colormap = :grays)
fig

In the plot, each panel is one input image and the title shows the predicted class. That gives a direct visual check that the network output matches the pattern a human reader would also identify.

8.5 Key CNN architectures

Several landmark architectures expanded the capabilities of CNNs:

  • ResNet (He et al., 2016) — introduces skip connections that add the input of a block to its output, enabling training of very deep networks (100+ layers) without vanishing gradients.
  • U-Net (Ronneberger et al., 2015) — an encoder-decoder architecture with skip connections at each resolution level, originally designed for biomedical image segmentation. Widely adopted in geoscience for dense prediction tasks.

8.6 Geoscience applications

CNNs are the dominant architecture for geoscience tasks involving gridded spatial data:

  • Seismic fault detectionWu et al. (2019) trained a 3D CNN (FaultSeg3D) on synthetic seismic volumes to segment faults in 3D, demonstrating that CNNs can detect complex fault geometries directly from seismic data.
  • Earthquake detectionPerol et al. (2018) developed ConvQuake, a CNN that detects and locates earthquakes directly from raw seismic waveforms, outperforming traditional detection methods in noisy environments.
  • Seismic waveform classification and first-break pickingYuan et al. (2020) used a CNN for waveform classification and first-break picking, showing how convolutional models can detect local seismic patterns directly from traces.
  • Remote sensing — land-use classification, mineral mapping, and change detection from satellite and airborne imagery are natural CNN applications, as the data is inherently image-like.
  • Seismic image interpretation — 2D CNNs can classify local image patterns such as layered structure, faults, or dome-like geometry from small patches, as demonstrated in the code example above.

The key insight is: whenever your geoscience data lives on a regular grid, a CNN is likely a good starting point. The spatial weight sharing built into convolutions matches the physics of spatially correlated earth properties.