using Lux, Random, Optimisers, Zygote, Statistics, Printf, CairoMakie
rng = Xoshiro(42)
n = 64
class_names = ["Quiet", "Positive arrival", "Negative arrival"]3-element Vector{String}:
"Quiet"
"Positive arrival"
"Negative arrival"
A convolutional neural network (CNN) exploits the spatial structure in data — the fact that nearby pixels or grid cells tend to be related. Instead of connecting every input to every neuron, a CNN slides small learned filters across the data, detecting local patterns such as edges, textures, and shapes. This makes CNNs far more parameter-efficient than feedforward networks for image-like data, and it gives them a property fully-connected networks lack: translation equivariance — the same pattern is recognised wherever it appears in the image, because the same filter is applied at every location.
In a CNN, a filter (or kernel) is a small set of weights: a short vector for 1D traces, or a small matrix such as \(3 \times 3\) or \(5 \times 5\) for 2D images. The filter slides across the input and at each position computes a dot product between the filter weights and the local patch of input values. This produces a feature map — a new grid where each cell represents how strongly that local pattern was detected at that position.
For a 2D input \(\mathbf{X}\) and a filter \(\mathbf{K}\) of size \(k \times k\), the convolution at position \((i, j)\) is:
\[ (\mathbf{X} * \mathbf{K})_{i,j} = \sum_{m=0}^{k-1} \sum_{n=0}^{k-1} \mathbf{K}_{m,n} \cdot \mathbf{X}_{i+m,\, j+n} \]
A convolutional layer applies many such filters in parallel, each learning to detect a different pattern.
For a 1D trace, the same operation uses one spatial index instead of two:
\[ (\mathbf{x} * \mathbf{k})_i = \sum_{m=0}^{k-1} \mathbf{k}_m \cdot \mathbf{x}_{i+m} \]
After convolution, pooling layers reduce the spatial size of the feature maps, keeping only the most important information. The most common type is max pooling, which takes the maximum value in each small window (e.g., length 2 for traces or \(2 \times 2\) for images). Pooling reduces computation, provides some translation invariance, and increases the receptive field of deeper layers.
CNNs come in two main shapes, depending on what the output should look like:
Geophysics uses both templates. Dense predictors are common for faults, salt bodies, and horizons; classifiers are useful when the question is whether a short window contains a particular waveform or structure. The worked example below uses the classifier template because it is small enough to run comfortably on a CPU.
We train a compact 1D CNN to classify short synthetic seismic traces into three classes: quiet background, a positive-polarity arrival, and a negative-polarity arrival. The arrival position is random, so the network must recognise the waveform regardless of where it appears in the trace.
This problem is a CNN’s natural habitat:
using Lux, Random, Optimisers, Zygote, Statistics, Printf, CairoMakie
rng = Xoshiro(42)
n = 64
class_names = ["Quiet", "Positive arrival", "Negative arrival"]3-element Vector{String}:
"Quiet"
"Positive arrival"
"Negative arrival"
# Generate one labelled synthetic seismic trace.
function make_trace_example(rng, n = n; class_id = rand(rng, 1:3))
samples = Float32.(1:n)
t = Float32.(range(0, 1, length = n))
phase = rand(rng, Float32)
trace = 0.03f0 .* randn(rng, Float32, n)
trace .+= 0.04f0 .* sin.(2f0 * Float32(pi) .* (2.5f0 .* t .+ phase))
if class_id != 1
center = rand(rng, 12:n-12)
width = 2.6f0 + 0.8f0 * rand(rng, Float32)
polarity = class_id == 2 ? 1f0 : -1f0
d = (samples .- center) ./ width
wavelet = exp.(-0.5f0 .* d .^ 2)
trace .+= polarity .* (0.90f0 + 0.10f0 * rand(rng, Float32)) .* wavelet
end
trace .-= mean(trace)
trace = clamp.(trace, -1f0, 1f0)
y = zeros(Float32, 3)
y[class_id] = 1f0
return reshape(trace, n, 1), y, class_id
end
n_samples = 180
X = zeros(Float32, n, 1, n_samples) # (samples, channels, batch)
Y = zeros(Float32, 3, n_samples)
for i in 1:n_samples
trace, y, _ = make_trace_example(rng; class_id = mod1(i, 3))
X[:, :, i] = trace
Y[:, i] = y
end
idx = randperm(rng, n_samples)
n_train = round(Int, 0.80 * n_samples)
tr = idx[1:n_train]
te = idx[n_train+1:end]
X_train, Y_train = X[:, :, tr], Y[:, tr]
X_test, Y_test = X[:, :, te], Y[:, te](Float32[-0.12995261; -0.15337303; … ; -0.088873886; -0.071319744;;; -0.09953149; -0.103878714; … ; -0.037969254; -0.10336336;;; -0.13360639; -0.1297299; … ; -0.17843503; -0.09252217;;; … ;;; 0.01662337; -0.022840936; … ; 0.016043989; 0.021402415;;; 0.17828238; 0.16310665; … ; 0.08849026; 0.03670042;;; 0.115309015; 0.12961383; … ; 0.15332137; 0.14940438], Float32[0.0 0.0 … 0.0 0.0; 1.0 1.0 … 0.0 0.0; 0.0 0.0 … 1.0 1.0])
The input tensor has shape (samples, channels, batch), and the label for each trace is a one-hot vector with three entries.
# Small 1D CNN classifier.
model = Chain(
Conv((5,), 1 => 8, relu; pad = SamePad()),
MaxPool((2,)), # 32 samples
WrappedFunction(x -> reshape(x, :, size(x, 3))),
Dense(8 * 32 => 16, relu),
Dense(16 => 3),
)
ps, st = Lux.setup(rng, model)
function softmax_cols(x)
x_shift = x .- maximum(x, dims = 1)
ex = exp.(x_shift)
ex ./ sum(ex, dims = 1)
end
function cross_entropy_loss(model, ps, st, data)
x, y = data
logits, st_new = model(x, ps, st)
prob = softmax_cols(logits)
loss = -mean(sum(y .* log.(prob .+ 1f-7), dims = 1))
return loss, st_new, ()
end
function predicted_classes(probabilities)
[findmax(probabilities[:, i])[2] for i in axes(probabilities, 2)]
end
function true_classes(labels)
[findmax(labels[:, i])[2] for i in axes(labels, 2)]
endtrue_classes (generic function with 1 method)
function train_classifier(model, ps, st, data; epochs = 60, lr = 5f-3)
tstate = Training.TrainState(model, ps, st, Adam(lr))
for epoch in 1:epochs
_, loss, _, tstate = Training.single_train_step!(
AutoZygote(), cross_entropy_loss, data, tstate
)
if epoch == 1 || epoch % 15 == 0
@printf "Epoch %2d cross-entropy = %.4f\n" epoch loss
end
end
return tstate
end
tstate = train_classifier(model, ps, st, (X_train, Y_train))Epoch 1 cross-entropy = 1.1431
Epoch 15 cross-entropy = 0.2116
Epoch 30 cross-entropy = 0.0077
Epoch 45 cross-entropy = 0.0013
Epoch 60 cross-entropy = 0.0007
TrainState(
Chain(
layer_1 = Conv((5,), 1 => 8, relu, pad=2), # 48 parameters
layer_2 = MaxPool((2,)),
layer_3 = WrappedFunction(#3),
layer_4 = Dense(256 => 16, relu), # 4_112 parameters
layer_5 = Dense(16 => 3), # 51 parameters
),
number of parameters: 4211
number of states: 0
optimizer: Adam(eta=0.005, beta=(0.9, 0.999), epsilon=1.0e-8)
step: 60
)
test_logits, _ = model(X_test, tstate.parameters, tstate.states)
test_prob = softmax_cols(test_logits)
test_loss = -mean(sum(Y_test .* log.(test_prob .+ 1f-7), dims = 1))
test_acc = mean(predicted_classes(test_prob) .== true_classes(Y_test))
@printf "Holdout cross-entropy: %.4f Accuracy: %.3f\n" test_loss test_acc
for i in 1:3
class_idx = true_classes(Y_test) .== i
class_acc = mean(predicted_classes(test_prob)[class_idx] .== i)
@printf "%s accuracy: %.3f\n" class_names[i] class_acc
endHoldout cross-entropy: 0.0008 Accuracy: 1.000
Quiet accuracy: 1.000
Positive arrival accuracy: 1.000
Negative arrival accuracy: 1.000
The holdout accuracy checks whether the CNN learned a location-independent waveform detector rather than memorising a fixed arrival sample.
function sample_with_label(seed, wanted_label)
rng_local = Xoshiro(seed)
while true
x, y, class_id = make_trace_example(rng_local)
class_id == wanted_label && return x, y, class_id
end
end
examples = [sample_with_label(100 + 17i, i) for i in 1:3]
fig = Figure(size = (720, 620))
for (row, (trace, _, class_id)) in enumerate(examples)
logits, _ = model(reshape(trace, n, 1, 1), tstate.parameters, tstate.states)
prob = softmax_cols(logits)
predicted = findmax(prob[:, 1])[2]
ax = Axis(fig[row, 1],
title = "$(class_names[class_id]) -> $(class_names[predicted])",
xlabel = row == 3 ? "sample" : "",
ylabel = "amplitude")
lines!(ax, 1:n, trace[:, 1]; color = :black, linewidth = 1.6)
ylims!(ax, -1.05, 1.05)
end
figThe final plot shows one trace from each class. The title reports the true class followed by the predicted class, giving a direct visual check of the trained 1D CNN.
Several landmark architectures expanded the capabilities of CNNs:
The key insight is: whenever your geoscience data lives on a regular grid and the relevant pattern can appear anywhere, a CNN is likely a good starting point. The spatial weight sharing built into convolutions matches the physics of spatially correlated earth properties and the way features such as faults, salt edges, or channel boundaries are defined locally rather than globally.