11  Graph Neural Networks

TipKey references
  • The GNN model — the foundational framework for neural networks that operate on graph-structured data (Scarselli et al., 2009).
  • Graph Convolutional Networks (GCN) — spectral-domain convolution on graphs using a first-order Chebyshev approximation (Kipf & Welling, 2017).
  • Message Passing Neural Networks (MPNN) — a unifying framework that describes most GNN variants as message-passing operations between nodes (Gilmer et al., 2017).

All neural networks we have seen so far assume a specific data structure: vectors (feedforward), grids (CNNs), or sequences (RNNs/transformers). But many real-world datasets are naturally represented as graphs — collections of nodes connected by edges. In geoscience, examples include:

A graph neural network (GNN) operates directly on graph-structured data, learning representations that respect the connectivity structure.

11.1 Graphs: a brief reminder

A graph \(\mathcal{G} = (\mathcal{V}, \mathcal{E})\) consists of:

  • Nodes \(\mathcal{V} = \{v_1, \ldots, v_n\}\), each with a feature vector \(\mathbf{x}_i\).
  • Edges \(\mathcal{E} \subseteq \mathcal{V} \times \mathcal{V}\), representing connections between nodes. Edges can also carry features.

The connectivity is described by an adjacency matrix \(A \in \{0, 1\}^{n \times n}\), where \(A_{ij} = 1\) if there is an edge from node \(i\) to node \(j\).

11.2 Message passing

Most GNNs follow the message-passing paradigm (Gilmer et al., 2017). At each layer, every node:

  1. Gathers messages from its neighbors.
  2. Aggregates them (e.g., sum, mean, or max).
  3. Updates its own feature vector based on the aggregated message and its current state.

Formally, at layer \(k\):

\[ \mathbf{h}_i^{(k)} = \phi\!\left(\mathbf{h}_i^{(k-1)},\; \bigoplus_{j \in \mathcal{N}(i)} \psi\!\left(\mathbf{h}_i^{(k-1)}, \mathbf{h}_j^{(k-1)}, \mathbf{e}_{ij}\right)\right) \]

Here the notation is worth stating explicitly: the subscript \(i\) identifies the node, while the superscript \((k)\) identifies the message-passing layer. That keeps graph notation aligned with the rest of the book, where superscripts in parentheses track layer depth.

where:

  • \(\mathcal{N}(i)\) is the set of neighbors of node \(i\).
  • \(\psi\) is the message function (how to compute a message from a neighbor).
  • \(\bigoplus\) is the aggregation function (sum, mean, or max over all messages).
  • \(\phi\) is the update function (how to combine the aggregated message with the node’s current state).

After \(K\) layers of message passing, each node’s representation has been informed by nodes up to \(K\) hops away.

11.3 Graph Convolutional Network (GCN)

The GCN (Kipf & Welling, 2017) is a popular and simple GNN variant. The layer-wise update rule is:

\[ H^{(k+1)} = \sigma\!\left(\tilde{D}^{-1/2}\, \tilde{A}\, \tilde{D}^{-1/2}\, H^{(k)}\, W^{(k)}\right) \]

In matrix form, \(H^{(k)}\) collects all node features at layer \(k\) row by row, while \(W^{(k)}\) is the learnable weight matrix for that layer.

where \(\tilde{A} = A + I\) (adjacency with self-loops), \(\tilde{D}\) is the degree matrix of \(\tilde{A}\), \(H^{(k)}\) is the feature matrix at layer \(k\), \(W^{(k)}\) is the learnable weight matrix, and \(\sigma\) is an activation function.

In plain language: each node averages the features of its neighbors (including itself), then applies a linear transformation and a non-linearity.

11.4 Code example: node classification on a synthetic graph

We build a simple graph where node features are noisy versions of a class label, and the graph structure encodes which nodes are likely to share the same class. The GNN learns to denoise the labels using the graph structure.

using Random, LinearAlgebra, Statistics, Printf, CairoMakie, Zygote

rng = Xoshiro(42)

# Create a synthetic graph: two clusters
n_nodes = 40
n_class1 = 20

# Node features: 2D, with class-dependent mean
features = zeros(Float32, n_nodes, 2)
labels = zeros(Int, n_nodes)
for i in 1:n_nodes
    if i <= n_class1
        features[i, :] = [1.0f0, 0.0f0] .+ 0.5f0 .* randn(rng, Float32, 2)
        labels[i] = 1
    else
        features[i, :] = [0.0f0, 1.0f0] .+ 0.5f0 .* randn(rng, Float32, 2)
        labels[i] = 2
    end
end

# Adjacency: higher probability of edges within the same class
A = zeros(Float32, n_nodes, n_nodes)
for i in 1:n_nodes
    for j in i+1:n_nodes
        p = labels[i] == labels[j] ? 0.3 : 0.05
        if rand(rng) < p
            A[i, j] = 1.0f0
            A[j, i] = 1.0f0
        end
    end
end
# Implement a simple 2-layer GCN manually
# Normalized adjacency with self-loops
A_hat = A + I(n_nodes)
D_hat = Diagonal(vec(sum(A_hat, dims = 2)))
D_inv_sqrt = Diagonal(1.0f0 ./ sqrt.(diag(D_hat)))
A_norm = D_inv_sqrt * A_hat * D_inv_sqrt

# Parameters
d_in, d_hidden, d_out = 2, 8, 2
W1 = 0.5f0 .* randn(rng, Float32, d_in, d_hidden)
W2 = 0.5f0 .* randn(rng, Float32, d_hidden, d_out)

# Forward pass
function gcn_forward(X, A_norm, W1, W2)
    H1 = max.(A_norm * X * W1, 0)       # GCN layer 1 + ReLU
    H2 = A_norm * H1 * W2               # GCN layer 2 (logits)
    return H2
end

# Softmax
function softmax_rows(X)
    eX = exp.(X .- maximum(X, dims = 2))
    return eX ./ sum(eX, dims = 2)
end

# One-hot targets
targets = zeros(Float32, n_nodes, 2)
for i in 1:n_nodes
    targets[i, labels[i]] = 1.0f0
end

# Node split for evaluation
perm = randperm(rng, n_nodes)
n_train = Int(round(0.7 * n_nodes))
train_idx = perm[1:n_train]
test_idx  = perm[n_train+1:end]
12-element Vector{Int64}:
  7
 40
 11
  5
 13
  9
  8
 23
 33
 37
 38
 39
# Train the GCN
lr = 0.05f0

for epoch in 1:200
    function train_loss(W1_, W2_)
        logits = gcn_forward(features, A_norm, W1_, W2_)
        probs = softmax_rows(logits)
        p_tr = probs[train_idx, :]
        t_tr = targets[train_idx, :]
        return -mean(sum(t_tr .* log.(p_tr .+ 1.0f-7), dims = 2))
    end

    loss = train_loss(W1, W2)
    grad_W1, grad_W2 = Zygote.gradient(train_loss, W1, W2)

    W1 .-= lr .* grad_W1
    W2 .-= lr .* grad_W2
    
    if epoch == 1 || epoch % 50 == 0
        logits = gcn_forward(features, A_norm, W1, W2)
        probs = softmax_rows(logits)
        preds = argmax.(eachrow(probs))
        train_acc = mean(preds[train_idx] .== labels[train_idx])
        test_acc = mean(preds[test_idx] .== labels[test_idx])
        @printf "Epoch %3d  Loss: %.4f  Train acc: %.1f%%  Test acc: %.1f%%\n" epoch loss 100*train_acc 100*test_acc
    end
end
Epoch   1  Loss: 0.8375  Train acc: 50.0%  Test acc: 50.0%
Epoch  50  Loss: 0.6482  Train acc: 50.0%  Test acc: 50.0%
Epoch 100  Loss: 0.5825  Train acc: 50.0%  Test acc: 50.0%
Epoch 150  Loss: 0.5477  Train acc: 50.0%  Test acc: 50.0%
Epoch 200  Loss: 0.5272  Train acc: 50.0%  Test acc: 50.0%
# Visualize the graph with predicted labels
logits = gcn_forward(features, A_norm, W1, W2)
preds = argmax.(eachrow(softmax_rows(logits)))

# Layout: use feature coordinates
fig = Figure(size = (500, 400))
ax = Axis(fig[1, 1], title = "GCN node classification",
          xlabel = "Feature 1", ylabel = "Feature 2")

# Draw edges
for i in 1:n_nodes
    for j in i+1:n_nodes
        if A[i, j] > 0
            lines!(ax, [features[i, 1], features[j, 1]],
                      [features[i, 2], features[j, 2]],
                   color = (:gray, 0.2), linewidth = 0.5)
        end
    end
end

# Draw nodes colored by predicted class
colors = [p == 1 ? :steelblue : :coral for p in preds]
markers = [l == p ? :circle : :xcross for (l, p) in zip(labels, preds)]
scatter!(ax, features[:, 1], features[:, 2],
         color = colors, marker = markers, markersize = 12)

fig

In the plot, circles are correctly classified nodes and crosses are misclassified ones. The GNN uses graph connectivity (which nodes are connected) together with node features to make predictions — even when the features alone are noisy and overlapping.

11.5 When to use GNNs

GNNs are the right choice when:

  • Data lives on an irregular structure (not a regular grid or sequence).
  • Relationships between entities are important (sensor networks, molecular graphs).
  • You want to make predictions about nodes, edges, or the entire graph.
  • The graph structure itself carries information (e.g., which stations are nearby).

If your data is on a regular grid, a CNN is simpler and usually sufficient. If your data is sequential, use an RNN or transformer. GNNs fill the gap for irregular, relational data.

11.6 Geoscience applications

  • Weather forecastingLam et al. (2023) introduced GraphCast, a graph-neural-network-based model for medium-range global weather prediction. It represents the atmosphere as a graph on a multi-resolution mesh, enabling accurate 10-day forecasts at 0.25° resolution while running orders of magnitude faster than traditional numerical weather models.
  • Hydrological networksStanev et al. (2021) applied GNNs to hydrological coherence analysis using remotely sensed water-level data, leveraging the natural graph structure of river networks and gauge station connectivity.
  • Molecular and mineral property prediction — in geochemistry, GNNs can predict properties of minerals and fluids directly from their molecular graph structures, following the same approach used successfully for drug discovery and materials science (Gilmer et al., 2017).