9  Recurrent Neural Networks

TipKey references
  • Simple RNN — the idea that a network can process sequences by maintaining a hidden state (Elman, 1990).
  • LSTM — Long Short-Term Memory, which solved the vanishing-gradient problem for sequences and enabled learning over hundreds of time steps (Hochreiter & Schmidhuber, 1997).
  • GRU — Gated Recurrent Unit, a simplified variant of LSTM with comparable performance (Cho et al., 2014).
  • ConvLSTM — combining convolutional and recurrent structures for spatiotemporal prediction (Shi et al., 2015).

Feedforward and convolutional networks process each input independently. But many geoscience datasets are sequential: seismograms, well-log curves, climate records, and satellite time series all have a natural ordering in time (or depth). A recurrent neural network (RNN) is designed for this: it processes a sequence one step at a time, maintaining a hidden state that carries information from earlier steps to later ones.

9.1 The simple RNN

At each time step \(t\), a simple RNN receives the current input \(\mathbf{x}_t\) and the previous hidden state \(\mathbf{h}_{t-1}\), and produces a new hidden state:

\[ \mathbf{h}_t = \tanh\!\bigl(W_h\, \mathbf{h}_{t-1} + W_x\, \mathbf{x}_t + \mathbf{b}\bigr) \]

Here the subscript \(t\) indexes time, while the subscripts on \(W_h\) and \(W_x\) identify the role of each matrix: hidden-to-hidden and input-to-hidden, respectively. This is the same convention introduced earlier in the part: superscripts are reserved for layer depth, and subscripts mark time or semantic roles.

The hidden state acts as the network’s memory. The output at each step can be read from \(\mathbf{h}_t\) directly or passed through an additional dense layer.

Problem: In practice, simple RNNs struggle to learn long-range dependencies because gradients either vanish (shrink to zero) or explode (grow unboundedly) when propagated backward through many time steps.

9.2 Long Short-Term Memory (LSTM)

The LSTM (Hochreiter & Schmidhuber, 1997) solves the vanishing-gradient problem by introducing a cell state \(\mathbf{c}_t\) alongside the hidden state \(\mathbf{h}_t\), controlled by three learned gates:

  • Forget gate \(\mathbf{f}_t\) — decides what information to discard from the cell state.
  • Input gate \(\mathbf{i}_t\) — decides what new information to store.
  • Output gate \(\mathbf{o}_t\) — decides what part of the cell state to expose.

As before, the subscript \(t\) denotes the time step. The different letter subscripts on the weight matrices indicate the gate they belong to.

\[ \begin{aligned} \mathbf{f}_t &= \sigma\bigl(W_f [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_f\bigr) \\ \mathbf{i}_t &= \sigma\bigl(W_i [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_i\bigr) \\ \tilde{\mathbf{c}}_t &= \tanh\bigl(W_c [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_c\bigr) \\ \mathbf{c}_t &= \mathbf{f}_t \odot \mathbf{c}_{t-1} + \mathbf{i}_t \odot \tilde{\mathbf{c}}_t \\ \mathbf{o}_t &= \sigma\bigl(W_o [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_o\bigr) \\ \mathbf{h}_t &= \mathbf{o}_t \odot \tanh(\mathbf{c}_t) \end{aligned} \]

The cell state can carry information unchanged through many time steps, and the gates learn to open and close during training, allowing the network to decide what to remember and what to forget.

9.3 Gated Recurrent Unit (GRU)

The GRU (Cho et al., 2014) is a simplification of the LSTM that merges the cell state and hidden state into a single state vector, using two gates instead of three:

\[ \begin{aligned} \mathbf{r}_t &= \sigma\bigl(W_r [\mathbf{h}_{t-1}, \mathbf{x}_t]\bigr) \\ \mathbf{z}_t &= \sigma\bigl(W_z [\mathbf{h}_{t-1}, \mathbf{x}_t]\bigr) \\ \tilde{\mathbf{h}}_t &= \tanh\bigl(W_h [\mathbf{r}_t \odot \mathbf{h}_{t-1}, \mathbf{x}_t]\bigr) \\ \mathbf{h}_t &= (1 - \mathbf{z}_t) \odot \mathbf{h}_{t-1} + \mathbf{z}_t \odot \tilde{\mathbf{h}}_t \end{aligned} \]

GRUs have fewer parameters than LSTMs and train faster, while producing similar results on many tasks.

9.4 Code example: predicting a synthetic geophysical time series

We generate a synthetic oscillating signal (simulating a geophysical measurement with periodic and trend components) and train an LSTM to predict the next value given the recent past.

using Lux, Random, Optimisers, Zygote, Statistics, Printf, CairoMakie

rng = Xoshiro(42)

# Generate a synthetic time series: trend + oscillation + noise
t = Float32.(0:0.05:10)
signal = 0.3f0 .* t .+ sin.(2π .* 0.5f0 .* t) .+ 0.3f0 .* sin.(2π .* 1.3f0 .* t) .+
         0.15f0 .* randn(rng, Float32, length(t))

# Standardize before training so the recurrent model does not spend capacity on scale alone
μ_signal = mean(signal)
σ_signal = std(signal)
signal_scaled = (signal .- μ_signal) ./ σ_signal

# Create input/output pairs using a sliding window
window = 20
n_pairs = length(signal_scaled) - window
X_seq = zeros(Float32, 1, window, n_pairs)  # (features, time_steps, batch)
Y_seq = zeros(Float32, 1, n_pairs)          # (features, batch)

for i in 1:n_pairs
    X_seq[1, :, i] = signal_scaled[i:i+window-1]
    Y_seq[1, i]    = signal_scaled[i+window]
end

# Train/test split (chronological split to respect time ordering)
n_train = Int(round(0.8 * n_pairs))
X_train, Y_train = X_seq[:, :, 1:n_train], Y_seq[:, 1:n_train]
X_test,  Y_test  = X_seq[:, :, n_train+1:end], Y_seq[:, n_train+1:end]
(Float32[0.24319372 -0.057659 … 0.91191345 1.0251187;;; -0.057659 -0.13387857 … 1.0251187 1.3189939;;; -0.13387857 -0.17540371 … 1.3189939 1.5677433;;; … ;;; 0.85227734 0.96867204 … 0.4017068 0.8487579;;; 0.96867204 1.0553415 … 0.8487579 1.0835305;;; 1.0553415 0.69462425 … 1.0835305 1.0754862], Float32[1.3189939 1.5677433 … 1.0754862 1.2363987])
# Build an LSTM model that reads the sequence, then maps the final hidden state to a prediction
model = Chain(
    Recurrence(LSTMCell(1 => 32)),
    Dense(32 => 1)
)

ps, st = Lux.setup(rng, model)

function mse_loss(model, ps, st, data)
    x, y = data
    ŷ, st_new = model(x, ps, st)
    loss = mean((ŷ .- y) .^ 2)
    return loss, st_new, ()
end
mse_loss (generic function with 1 method)
function train_model(model, ps, st, data; epochs = 600, lr = 0.003f0)
    tstate = Training.TrainState(model, ps, st, Adam(lr))
    for epoch in 1:epochs
        _, loss, _, tstate = Training.single_train_step!(
            AutoZygote(), mse_loss, data, tstate
        )
        if epoch == 1 || epoch % 120 == 0
            @printf "Epoch %3d  MSE = %.6f\n" epoch loss
        end
    end
    return tstate
end

tstate = train_model(model, ps, st, (X_train, Y_train))

# Holdout evaluation
Y_test_pred, _ = model(X_test, tstate.parameters, tstate.states)
test_mse = mean(((σ_signal .* Y_test_pred .+ μ_signal) .- (σ_signal .* Y_test .+ μ_signal)) .^ 2)
@printf "Holdout test MSE = %.6f\n" test_mse
Epoch   1  MSE = 0.895788
Epoch 120  MSE = 0.024068
Epoch 240  MSE = 0.011021
Epoch 360  MSE = 0.006385
Epoch 480  MSE = 0.004051
Epoch 600  MSE = 0.003718
Holdout test MSE = 0.233183
# Predict and plot
Y_pred, _ = model(X_seq, tstate.parameters, tstate.states)
Y_pred = σ_signal .* Y_pred .+ μ_signal

fig = Figure(size = (700, 350))
ax = Axis(fig[1, 1], xlabel = "Time step", ylabel = "Value",
          title = "LSTM time-series prediction")
lines!(ax, window+1:length(signal), signal[window+1:end],
       color = :black, label = "True", linewidth = 2)
lines!(ax, window+1:length(signal), vec(Y_pred),
       color = :coral, label = "LSTM prediction", linestyle = :dash)
axislegend(ax, position = :lt)
fig

9.5 When to use RNNs

RNNs are the natural choice when:

  • Data has a sequential or temporal structure (time series, depth-indexed logs).
  • Order matters — shuffling the data would destroy information.
  • You need to capture dependencies between earlier and later parts of a sequence.

For very long sequences (thousands of steps), transformers (next chapter) often outperform RNNs because they can attend to any part of the sequence without passing information step by step.

9.6 Geoscience applications

Recurrent networks have been applied to a wide range of sequential geoscience problems:

  • Sequence modeling in seismology — recurrent and hybrid sequence models are widely used for waveform analysis and event interpretation. Zhu & Beroza (2019) is a closely related deep-learning benchmark for seismic arrival picking, though PhaseNet itself is primarily convolutional rather than recurrent.
  • Climate and weather forecastingHam et al. (2019) used a CNN-LSTM hybrid to forecast the El Niño–Southern Oscillation (ENSO) up to 18 months ahead, significantly outperforming physics-based dynamical models.
  • Precipitation nowcastingShi et al. (2015) introduced the ConvLSTM, combining convolutional and LSTM operations to predict radar echo sequences, a spatiotemporal forecasting task.
  • Machine learning in geoscience overviewDramsch (2020) provides a comprehensive review of 70 years of machine learning in the geosciences, covering many recurrent-network applications in seismology, well-log analysis, and geophysical signal processing.