using Lux, Random, Optimisers, Zygote, Statistics, Printf, CairoMakie
rng = Xoshiro(42)
n_depth = 128128
A generative adversarial network (GAN) sits inside a short block of chapters on generative models. Earlier chapters focused on architectures used for prediction, classification, or representation learning. In this part of the book, autoencoders, GANs, diffusion models, and flow matching are grouped together because they all learn data distributions and produce new samples, even though their training mechanisms differ substantially.
A GAN consists of two neural networks trained simultaneously in competition:
Training alternates between improving the discriminator (so it better distinguishes real from fake) and improving the generator (so its fakes become more convincing). The result is a generator that can produce realistic, novel samples from the data distribution.
The original GAN training objective is a minimax game:
\[ \min_G \max_D \; \mathbb{E}_{\mathbf{x} \sim p_{\text{data}}}\!\bigl[\log D(\mathbf{x})\bigr] + \mathbb{E}_{\mathbf{z} \sim p_{\mathbf{z}}}\!\bigl[\log\bigl(1 - D(G(\mathbf{z}))\bigr)\bigr] \]
At convergence, the generator produces samples that the discriminator cannot distinguish from real data. In practice, GANs can be difficult to train: the generator and discriminator must be balanced, the log-loss saturates when the discriminator gets ahead, and training can be unstable.
The least-squares GAN (Mao et al., 2017) replaces the log loss with mean-squared error:
\[ \begin{aligned} \mathcal{L}_D &= \tfrac{1}{2}\,\mathbb{E}\bigl[(D(\mathbf{x}) - 1)^2\bigr] + \tfrac{1}{2}\,\mathbb{E}\bigl[D(G(\mathbf{z}))^2\bigr], \\ \mathcal{L}_G &= \tfrac{1}{2}\,\mathbb{E}\bigl[(D(G(\mathbf{z})) - 1)^2\bigr]. \end{aligned} \]
Conceptually the discriminator is now a regressor that scores realism rather than a classifier. The gradient does not vanish when fakes are easy to spot, so the generator keeps getting useful signal. For small chapter-sized problems LSGAN is almost always the right starting point — it is one of the simplest reliable GAN variants.
The original GAN used feedforward networks for both \(G\) and \(D\). The DCGAN (Radford et al., 2016) established best practices for convolutional GANs on images:
For 1D signals such as well logs or seismic traces, the same recipe works with 1D convolutions or — for short sequences — with plain dense layers, which is what we use below.
We train a GAN to generate synthetic 128-sample well logs that look like noisy gamma-ray traces. Each real log is a sequence of layered facies with sharp transitions — a low (clean-sand-like) value of about \(0.25\) in one facies, a higher (shale-like) value of about \(0.75\) in the other, with random segment lengths and within-segment noise. This is a useful GAN target because the real distribution has clear structure that is not captured by simple per-sample statistics: the mean of any well log will sit somewhere near \(0.5\), but a random number near \(0.5\) is not a plausible log. A successful generator must reproduce the blocky, sharp-transition character of the data.
using Lux, Random, Optimisers, Zygote, Statistics, Printf, CairoMakie
rng = Xoshiro(42)
n_depth = 128128
# Synthetic gamma-ray-like well log: blocky facies + small within-segment noise.
function make_well_log(rng, n = n_depth)
log_curve = zeros(Float32, n)
pos = 1
while pos <= n
facies_value = rand(rng) < 0.5f0 ? 0.25f0 : 0.75f0
seg_len = rand(rng, 8:25)
seg_end = min(pos + seg_len - 1, n)
for i in pos:seg_end
log_curve[i] = facies_value + 0.04f0 * randn(rng, Float32)
end
pos = seg_end + 1
end
# Tiny moving-average smooth so transitions are not perfectly vertical.
smoothed = copy(log_curve)
for i in 2:n-1
smoothed[i] = (log_curve[i-1] + log_curve[i] + log_curve[i+1]) / 3f0
end
return clamp.(smoothed, 0f0, 1f0)
end
n_real = 1024
real_data = zeros(Float32, n_depth, n_real)
for i in 1:n_real
real_data[:, i] = make_well_log(rng)
end# Generator: latent noise → 128-sample log in [0, 1].
latent_dim = 8
generator = Chain(
Dense(latent_dim => 64, leakyrelu),
Dense(64 => 128, leakyrelu),
Dense(128 => n_depth, sigmoid),
)
# Discriminator: 128-sample log → realism score (no output activation for LSGAN).
discriminator = Chain(
Dense(n_depth => 64, leakyrelu),
Dense(64 => 32, leakyrelu),
Dense(32 => 1),
)
ps_g, st_g = Lux.setup(rng, generator)
ps_d, st_d = Lux.setup(rng, discriminator)((layer_1 = (weight = Float32[-0.024676053 -0.002098639 … -0.044877026 -0.046134003; -0.038247116 -0.05247629 … 0.04925603 0.07634108; … ; 0.06401964 0.008886526 … 0.051730853 -0.0017384943; -0.0001173579 0.0123850405 … 0.050821196 0.023714745], bias = Float32[-0.07466913, 0.037287876, 0.032130998, -0.012096926, -0.033099458, 0.05156158, 0.036047123, 0.011894915, -0.025597615, 0.065775566 … 0.017628562, -0.04630745, -0.069761276, 0.08442882, -0.06769788, -0.013040962, 0.021374015, -0.07681843, 0.021371813, 0.0745257]), layer_2 = (weight = Float32[0.11667827 0.08054881 … 0.06734568 -0.0012068748; -0.029160216 -0.08315967 … 0.10985878 -0.009191543; … ; 0.051003426 0.03164649 … -0.018128157 -0.062255844; 0.12426603 0.1080246 … 0.033976927 -0.103591874], bias = Float32[-0.11402111, 0.08494413, -0.00043082237, -0.013875544, 0.0013636202, -0.07597697, 0.07454008, 0.09409356, -0.08412726, -0.012271404 … 0.06314571, -0.073111385, -0.088329196, -0.011134356, -0.06848462, -0.060876846, 0.04936582, 0.00090539455, -0.021638662, -0.031791866]), layer_3 = (weight = Float32[-0.2950033 0.016170083 … -0.122403786 0.06679993], bias = Float32[0.13838758])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()))
function sample_batch(rng, data, batch_size)
data[:, rand(rng, 1:size(data, 2), batch_size)]
end
function train_lsgan(ps_g, st_g, ps_d, st_d;
epochs = 1500, batch_size = 128,
lr = 2f-4)
opt_state_g = Optimisers.setup(Adam(lr), ps_g)
opt_state_d = Optimisers.setup(Adam(lr), ps_d)
for epoch in 1:epochs
#----- Discriminator step -----
x_real = sample_batch(rng, real_data, batch_size)
z = randn(rng, Float32, latent_dim, batch_size)
(d_loss, st_d_new), d_grads = Zygote.withgradient(ps_d) do pd
fake, _ = generator(z, ps_g, st_g)
real_score, st_d_a = discriminator(x_real, pd, st_d)
fake_score, st_d_b = discriminator(fake, pd, st_d_a)
loss = 0.5f0 * (mean((real_score .- 1f0) .^ 2) +
mean(fake_score .^ 2))
return loss, st_d_b
end
opt_state_d, ps_d = Optimisers.update(opt_state_d, ps_d, d_grads[1])
st_d = st_d_new
#----- Generator step -----
z = randn(rng, Float32, latent_dim, batch_size)
(g_loss, st_g_new), g_grads = Zygote.withgradient(ps_g) do pg
fake, st_g_a = generator(z, pg, st_g)
fake_score, _ = discriminator(fake, ps_d, st_d)
loss = 0.5f0 * mean((fake_score .- 1f0) .^ 2)
return loss, st_g_a
end
opt_state_g, ps_g = Optimisers.update(opt_state_g, ps_g, g_grads[1])
st_g = st_g_new
if epoch == 1 || epoch % 300 == 0
@printf "Epoch %4d D loss = %.4f G loss = %.4f\n" epoch d_loss g_loss
end
end
return ps_g, st_g, ps_d, st_d
end
ps_g, st_g, ps_d, st_d = train_lsgan(ps_g, st_g, ps_d, st_d)Epoch 1 D loss = 0.4374 G loss = 0.4185
Epoch 300 D loss = 0.1149 G loss = 0.2558
Epoch 600 D loss = 0.1120 G loss = 0.2898
Epoch 900 D loss = 0.0740 G loss = 0.4282
Epoch 1200 D loss = 0.0390 G loss = 0.4430
Epoch 1500 D loss = 0.0475 G loss = 0.3697
((layer_1 = (weight = Float32[0.13592318 -0.076565616 … 0.1348433 0.36031944; -0.21848495 -0.0031714328 … 0.3025118 -0.27537122; … ; 0.04081553 0.019570962 … -0.0005316982 0.026437517; 0.09162651 -0.11666592 … 0.12273269 -0.3226697], bias = Float32[-0.34246424, -0.17902787, 0.17575416, 0.053011127, 0.22236848, -0.2616507, -0.36606553, -0.01867482, -0.007971103, -0.12888148 … 0.12127662, 0.17464656, -0.10501003, 0.0002564523, -0.38667282, -0.071860984, 0.12817699, -0.27274445, -0.27498734, -0.35667068]), layer_2 = (weight = Float32[0.033957705 0.05816984 … 0.056892995 0.043274716; 0.165852 -0.014511374 … 0.04489993 -0.048214335; … ; 0.076058164 0.030431312 … -0.054794986 -0.052779797; -0.14537737 -0.045123704 … -0.017197074 0.0011269503], bias = Float32[-0.040153947, -0.07734148, -0.106720656, -0.047928084, -0.05647712, -0.10917442, -0.014080572, 0.027075075, -0.08525028, 0.08386699 … 0.10973545, 0.044654943, 0.07358999, -0.009793916, -0.082856, -0.024594648, -0.043301657, -0.003275767, -0.057197068, 0.050276227]), layer_3 = (weight = Float32[0.11726926 -0.18951134 … -0.104722686 -0.046019424; 0.020796055 -0.23075588 … 0.013772973 0.03099328; … ; 0.0214975 -0.08863947 … 0.095999904 -0.010347472; 0.087976605 -0.081624985 … 0.02709281 -0.10973082], bias = Float32[0.0029252488, 0.09101337, -0.019956939, 0.03735648, -0.04853205, -0.004070994, -0.037140645, -0.0030676327, 0.0011496956, 0.006816017 … 0.061031815, -0.07333144, 0.03256791, -0.07088801, -0.01524331, -0.03295853, 0.011322372, 0.090870164, -0.081154644, -0.012144994])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()), (layer_1 = (weight = Float32[-0.04679935 -0.06547731 … -0.03908667 -0.13163656; -0.054104928 -0.10373312 … 0.07324969 0.10942239; … ; 0.13273855 -0.057538223 … 0.15795566 -0.08413483; 0.042755608 -0.06697351 … 0.090270974 -0.024304135], bias = Float32[-0.06409039, 0.111862436, 0.03999823, -0.01570875, -0.019751314, 0.046139922, 0.09160841, 0.037718706, -0.08775115, 0.09284249 … 0.0009894005, 0.014889264, 0.017511116, 0.12222577, -0.05918265, 0.03124667, 0.04722711, -0.06994751, 0.033702444, 0.06184157]), layer_2 = (weight = Float32[0.05218651 -0.07138059 … 0.2268465 -0.039154086; 0.06782716 0.0084841205 … 0.060580637 0.13130221; … ; -0.049381886 0.029322255 … 0.045809846 -0.07097598; 0.1915356 0.19932209 … -0.08131214 0.0051467177], bias = Float32[-0.08884362, 0.07088833, 0.0277831, 0.021123549, 0.016318064, -0.043439887, 0.08031963, 0.140418, -0.10423956, 0.0069166576 … 0.05462606, -0.0761803, -0.057218213, -0.019335356, -0.028755564, -0.049729366, 0.044231486, 0.0056941365, -0.014222011, -0.01736464]), layer_3 = (weight = Float32[-0.42494422 0.061309695 … -0.25036478 0.14553227], bias = Float32[0.15820777])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()))
# Sanity-check distributions at the per-sample level.
z_test = randn(rng, Float32, latent_dim, 256)
fake_data, _ = generator(z_test, ps_g, st_g)
@printf "Real mean / std: %.3f / %.3f\n" mean(real_data) std(real_data)
@printf "Fake mean / std: %.3f / %.3f\n" mean(fake_data) std(fake_data)
# The real test: visual similarity between real and generated logs.
fig = Figure(size = (760, 460))
Label(fig[0, 1], "Real well logs", fontsize = 13)
Label(fig[0, 2], "GAN-generated logs", fontsize = 13)
depth = collect(1:n_depth)
n_show = 4
for k in 1:n_show
ax_r = Axis(fig[k, 1], xlabel = (k == n_show ? "depth sample" : ""),
ylabel = "GR", yticks = ([0, 0.5, 1], ["0", "0.5", "1"]))
lines!(ax_r, depth, real_data[:, k], color = :black)
ylims!(ax_r, 0, 1)
ax_f = Axis(fig[k, 2], xlabel = (k == n_show ? "depth sample" : ""),
ylabel = "", yticks = ([0, 0.5, 1], ["", "", ""]))
lines!(ax_f, depth, fake_data[:, k], color = :coral)
ylims!(ax_f, 0, 1)
end
figReal mean / std: 0.498 / 0.245
Fake mean / std: 0.466 / 0.230
A successful LSGAN here is recognisable on two grounds:
The means and standard deviations are reported only as a sanity check; as the previous chapters warn, scalar summaries are not sufficient to declare GAN training successful.
GANs excel at generating realistic samples from a learned distribution. They are particularly useful when:
GANs can be harder to train than other generative models (VAEs, diffusion models). Mode collapse (the generator produces only a few types of output) and training instability are common challenges; switching from the original log loss to LSGAN, WGAN, or WGAN-GP almost always helps when problems arise.