13  Generative Adversarial Networks

TipKey references
  • GANs — the original framework introducing the adversarial training paradigm: a generator network learns to produce realistic data by competing against a discriminator network (Goodfellow et al., 2014).
  • DCGAN — deep convolutional GAN, which established stable architectures and training practices for image generation (Radford et al., 2016).

A generative adversarial network (GAN) sits inside a short block of chapters on generative models. The previous chapters mostly focused on architectures used for prediction, classification, or representation learning on different data structures. In this part of the book, autoencoders, GANs, diffusion models, and flow matching are grouped together because they all learn data distributions and produce new samples, even though their training mechanisms differ substantially.

A generative adversarial network (GAN) consists of two neural networks that are trained simultaneously in competition:

Training alternates between improving the discriminator (so it better distinguishes real from fake) and improving the generator (so its fakes become more convincing). The result is a generator that can produce realistic, novel samples from the data distribution.

13.1 The adversarial objective

The GAN training objective is a minimax game:

\[ \min_G \max_D \; \mathbb{E}_{\mathbf{x} \sim p_{\text{data}}}\!\bigl[\log D(\mathbf{x})\bigr] + \mathbb{E}_{\mathbf{z} \sim p_{\mathbf{z}}}\!\bigl[\log\bigl(1 - D(G(\mathbf{z}))\bigr)\bigr] \]

At convergence, the generator produces samples that the discriminator cannot distinguish from real data. In practice, GANs can be difficult to train: the generator and discriminator must be balanced, and training can be unstable.

13.2 Architecture

The original GAN used feedforward networks for both \(G\) and \(D\). The DCGAN (Radford et al., 2016) established best practices for using convolutional architectures:

  • Generator: uses transposed convolutions to upsample noise into an image.
  • Discriminator: uses standard convolutions to classify images as real or fake.
  • Batch normalization, LeakyReLU in the discriminator, and ReLU in the generator improve stability.

13.3 Code example: a more stable GAN for a rock-physics crossplot

The earlier GAN version had the right idea but the training was too unstable. Here we keep the same geoscience target, synthetic porosity-velocity points, but switch to a more standard Wasserstein-style setup with a critic and weight clipping. That is still simple enough for a chapter example, but it is much less likely to collapse onto only one cluster.

The problem formulation is different from the CNN chapter. We are not classifying anything here. The generator takes random noise as input and produces a synthetic \((\phi, v_p)\) point. The critic looks at real and generated points and learns to score which ones are more realistic. After training, the output we care about is the cloud of generated samples and whether it matches the real crossplot shape.

using Lux, Random, Optimisers, Zygote, Statistics, Printf, CairoMakie

rng = Xoshiro(42)

# Generate synthetic porosity-velocity pairs for two rock populations.
function make_crossplot_sample(rng)
    if rand(rng) < 0.5f0
        # High-porosity, lower-velocity population
        ϕ = clamp(0.28f0 + 0.025f0 * randn(rng, Float32), 0.18f0, 0.36f0)
        vₚ = 3.15f0 - 1.35f0 *- 0.28f0) + 0.05f0 * randn(rng, Float32)
    else
        # Tight, lower-porosity, higher-velocity population
        ϕ = clamp(0.085f0 + 0.015f0 * randn(rng, Float32), 0.03f0, 0.14f0)
        vₚ = 4.55f0 - 0.85f0 *- 0.085f0) + 0.04f0 * randn(rng, Float32)
    end
    return Float32[ϕ, vₚ]
end

n_real = 512
real_data = zeros(Float32, 2, n_real)
for i in 1:n_real
    real_data[:, i] = make_crossplot_sample(rng)
end

μ_cross = mean(real_data, dims = 2)
σ_cross = std(real_data, dims = 2)
real_data_scaled = (real_data .- μ_cross) ./ σ_cross
2×512 Matrix{Float32}:
 -1.08862   -1.06215  -0.863432   1.17695  …  -1.0941    0.593133  -0.912383
  0.942664   0.98521   1.04727   -1.1457       1.00004  -1.05781    1.00444
# Generator: noise → synthetic porosity-velocity pair
latent_dim = 4
generator = Chain(
    Dense(latent_dim => 48, relu),
    Dense(48 => 48, relu),
    Dense(48 => 2)
)

# Critic: porosity-velocity pair → realism score
critic = Chain(
    Dense(2 => 48, leakyrelu),
    Dense(48 => 48, leakyrelu),
    Dense(48 => 1)
)

ps_g, st_g = Lux.setup(rng, generator)
ps_c, st_c = Lux.setup(rng, critic)
((layer_1 = (weight = Float32[0.095236704 -0.26362634; -0.48108375 -0.13117667; … ; 0.44177726 0.6421332; 0.6499372 -0.12413696], bias = Float32[-0.553358, -0.19415233, 0.111818284, -0.13330719, -0.14170487, 0.0057655205, -0.22624613, 0.6107927, -0.5914383, 0.14539988  …  -0.31601107, 0.46229714, -0.6209592, -0.1444649, -0.2987946, -0.1347729, -0.5750111, -0.5215596, 0.5276568, 0.6253284]), layer_2 = (weight = Float32[0.06816555 -0.05611117 … 0.12109106 -0.113850355; -0.028847044 -0.11505275 … 0.034280352 -0.13108106; … ; 0.019708203 0.09202547 … 0.12045074 -0.06890222; 0.027860826 -0.10359403 … -0.12468299 -0.12424605], bias = Float32[0.110834286, -0.13745223, -0.013768459, -0.13376565, 0.0061205155, -0.09815026, 0.07609777, -0.07705173, 0.08112589, -0.09739101  …  -0.015835546, 0.029057994, 0.08935784, -0.0706621, -0.061476413, 0.015290982, 0.13887009, -0.05602218, -0.028790316, -0.10544139]), layer_3 = (weight = Float32[0.09919259 0.1722236 … 0.14712217 -0.07780698], bias = Float32[0.08056794])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()))
function sample_batch(rng, data, batch_size)
    idx = rand(rng, 1:size(data, 2), batch_size)
    data[:, idx]
end

function clip_params(x, clip_value)
    if x isa NamedTuple
        names = fieldnames(typeof(x))
        values = map(name -> clip_params(getfield(x, name), clip_value), names)
        return NamedTuple{names}(Tuple(values))
    elseif x isa AbstractArray
        return clamp.(x, -clip_value, clip_value)
    else
        return x
    end
end

function train_wgan(generator, critic, ps_g, st_g, ps_c, st_c;
                    epochs = 700, batch_size = 128, critic_steps = 4,
                    lr_g = 0.00025f0, lr_c = 0.00025f0, clip_value = 0.03f0)
    opt_state_g = Optimisers.setup(RMSProp(lr_g), ps_g)
    opt_state_c = Optimisers.setup(RMSProp(lr_c), ps_c)

    for epoch in 1:epochs
        critic_loss = 0.0f0
        for _ in 1:critic_steps
            x_real = sample_batch(rng, real_data_scaled, batch_size)
            z = randn(rng, Float32, latent_dim, batch_size)

            (c_loss, st_c_new), c_grads = Zygote.withgradient(ps_c) do pc
                fake, _ = generator(z, ps_g, st_g)
                real_score, st_c1 = critic(x_real, pc, st_c)
                fake_score, st_c2 = critic(fake, pc, st_c1)
                loss = mean(fake_score) - mean(real_score)
                return loss, st_c2
            end

            opt_state_c, ps_c = Optimisers.update(opt_state_c, ps_c, c_grads[1])
            ps_c = clip_params(ps_c, clip_value)
            st_c = st_c_new
            critic_loss = c_loss
        end

        z = randn(rng, Float32, latent_dim, batch_size)
        (g_loss, st_g_new), g_grads = Zygote.withgradient(ps_g) do pg
            fake, st_g_ = generator(z, pg, st_g)
            fake_score, _ = critic(fake, ps_c, st_c)
            loss = -mean(fake_score)
            return loss, st_g_
        end

        opt_state_g, ps_g = Optimisers.update(opt_state_g, ps_g, g_grads[1])
        st_g = st_g_new

        if epoch == 1 || epoch % 140 == 0
            @printf "Epoch %3d  critic loss: %.4f  generator loss: %.4f\n" epoch critic_loss g_loss
        end
    end

    return ps_g, st_g, ps_c, st_c
end

ps_g, st_g, ps_c, st_c = train_wgan(generator, critic, ps_g, st_g, ps_c, st_c)
Epoch   1  critic loss: -0.0010  generator loss: -0.0293
Epoch 140  critic loss: -0.0033  generator loss: -0.0304
Epoch 280  critic loss: -0.0037  generator loss: -0.0350
Epoch 420  critic loss: -0.0027  generator loss: -0.0358
Epoch 560  critic loss: -0.0030  generator loss: -0.0353
Epoch 700  critic loss: -0.0027  generator loss: -0.0356
((layer_1 = (weight = Float32[-1.6170778 1.438317 -1.4887156 -0.799762; 0.20049846 -0.25750396 -0.63647604 -1.4251372; … ; -0.822932 -1.4251498 1.1302763 -1.6529914; -1.0376585 0.26558796 -0.42028984 1.1279557], bias = Float32[0.11374632, -0.20868967, 0.26450476, -0.36825222, 0.1826905, -0.36717394, 0.16281371, 0.36403623, 0.4234871, 0.2853451  …  -0.49563542, -0.0097132465, -0.17194575, 0.12402804, 0.23683684, -0.36284924, 0.24442586, -0.1958218, 0.42159754, -0.42089146]), layer_2 = (weight = Float32[-0.22104508 0.2708321 … 0.062846504 -0.41944122; 0.22122528 0.18381226 … -0.39359152 -0.16378614; … ; -0.32657 0.2787427 … -0.008122357 -0.069475085; 0.028553978 0.04858848 … -0.20754774 0.2562649], bias = Float32[0.07640082, -0.023688277, 0.20423998, 0.067435145, -0.025763195, 0.025153305, 0.030307252, -0.17817481, 0.25480497, -0.122916736  …  0.039165087, 0.09899269, 0.029852873, -0.09638943, -0.09809221, 0.111562625, 0.2355382, 0.09990677, -0.04035821, -0.020403877]), layer_3 = (weight = Float32[-0.24831717 -0.095159106 … -0.1811671 0.066351555; 0.0873096 0.18436377 … 0.12655254 -0.095087424], bias = Float32[0.03987667, 0.018674318])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()), (layer_1 = (weight = Float32[0.03 -0.03; -0.03 0.03; … ; -0.008723305 0.024386626; 0.01966317 -0.009388254], bias = Float32[0.008675049, -0.017944403, -0.03, -0.017168913, -0.017156893, 0.026852814, -0.017958269, 0.027898777, 0.008676276, 0.03  …  0.008671155, 0.02752393, -0.01795752, -0.029777953, -0.01857011, 0.008671484, 0.0086770365, -0.029778061, -0.03, 0.027442524]), layer_2 = (weight = Float32[0.029959878 0.00929539 … -0.03 -0.028011717; 0.029960053 0.008583459 … -0.03 -0.028023977; … ; 0.029960053 0.008583459 … -0.03 -0.028023977; -0.029960053 -0.008583461 … 0.03 0.028023977], bias = Float32[0.03, -0.03, 0.019332852, -0.03, 0.03, -0.03, 0.021491326, -0.028429843, 0.019329412, -0.03  …  -0.026209906, 0.01933217, 0.03, -0.03, -0.03, 0.03, 0.03, -0.03, -0.03, -0.028429843]), layer_3 = (weight = Float32[0.03 0.03 … 0.03 -0.03], bias = Float32[0.03])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()))
# Generate synthetic crossplot samples
z_test = randn(rng, Float32, latent_dim, n_real)
fake_points, _ = generator(z_test, ps_g, st_g)
fake_points = fake_points .* σ_cross .+ μ_cross

# Simple distribution checks (not sufficient alone, but useful sanity checks)
@printf "Real porosity mean/std: %.3f / %.3f\n" mean(real_data[1, :]) std(real_data[1, :])
@printf "Fake porosity mean/std: %.3f / %.3f\n" mean(fake_points[1, :]) std(fake_points[1, :])
@printf "Real velocity mean/std: %.3f / %.3f\n" mean(real_data[2, :]) std(real_data[2, :])
@printf "Fake velocity mean/std: %.3f / %.3f\n" mean(fake_points[2, :]) std(fake_points[2, :])

fig = Figure(size = (620, 330))
ax1 = Axis(fig[1, 1], title = "Real crossplot",
           xlabel = "Porosity", ylabel = "P-wave velocity (km/s)")
scatter!(ax1, real_data[1, :], real_data[2, :], color = (:black, 0.45), markersize = 7)

ax2 = Axis(fig[1, 2], title = "GAN-generated crossplot",
           xlabel = "Porosity", ylabel = "P-wave velocity (km/s)")
scatter!(ax2, fake_points[1, :], fake_points[2, :], color = (:coral, 0.45), markersize = 7)

Label(fig[0, :], "GAN: Real vs Generated Rock-Physics Crossplots",
      fontsize = 16)
fig
Real porosity mean/std: 0.181 / 0.101
Fake porosity mean/std: 0.180 / 0.111
Real velocity mean/std: 3.866 / 0.700
Fake velocity mean/std: 3.867 / 0.738

The printed means and standard deviations are only a quick sanity check. The real test is the scatter plot: if the generated cloud matches the two-cluster structure, then the GAN is at least learning the gross geometry of the data distribution. If it collapses to one cluster or one narrow streak, the training has failed even if one or two scalar summary numbers look acceptable.

13.4 When to use GANs

GANs excel at generating realistic samples from a learned distribution. They are particularly useful when:

  • You need to augment limited training data with realistic synthetic examples.
  • You want to sample from a complex distribution (e.g., geological models consistent with observations).
  • You need to transfer styles or transform between domains (e.g., converting sketches to realistic geological cross-sections).

GANs can be harder to train than other generative models (VAEs, diffusion models). Mode collapse (the generator produces only a few types of output) and training instability are common challenges.

13.5 Geoscience applications

  • Porous media reconstructionMosser et al. (2017) used DCGANs to generate 3D micro-CT-scale porous media samples that match the statistical properties of real rock samples, enabling rapid generation of representative geological volumes for flow simulation.
  • Geostatistical inversionLaloy et al. (2018) used a spatial GAN as a geological prior for inverse problems, generating training-image-consistent geological models during inversion. This allows the inversion to stay within geologically realistic model space.
  • Stochastic seismic inversionMosser et al. (2020) combined GANs with seismic inversion, using the generator as a geological prior to produce multiple subsurface models consistent with both seismic data and geological knowledge.
  • Geological facies modeling — GANs have been used to generate realistic 2D and 3D geological facies models that honor well-data and geological constraints, as an alternative to traditional geostatistical simulation.