Skip to content

s13 图像生成 — demo.py 代码详解

Download demo.py

运行方式

bash
cd s13_image_generation/code
python demo.py

代码逐段详解

第1步:导入库与全局配置 —— 每个库是做什么的

python
import torch
import torch.nn as nn            # GAN 的 Linear+BN、VAE 的 Encoder/Decoder
import torch.nn.functional as F  # BCE、MSE、relu、sigmoid 等损失和激活函数
import torch.optim as optim      # Adam 优化器(GAN 和 VAE 都用 Adam)
import torchvision               # MNIST 数据集
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
在此 demo 中的角色
torch.nn构建 Generator、Discriminator、VAE Encoder/Decoder
torch.optimAdam 优化器(GAN 用 lr=0.0002, betas=(0.5, 0.999)
torchvisionMNIST 数据集(28x28 手写数字)

为什么 MNIST? 手写数字生成是图像生成模型的"Hello World"。28x28 灰度图维度低、结构简单,可以让 GAN 和 VAE 在 CPU 上也能快速训练出可见效果。同时 MNIST 的 10 个类别(0-9)提供了清晰的聚类结构,便于分析 VAE 潜空间的 t-SNE 分布。

数据归一化策略

python
transforms.Normalize((0.5,), (0.5,))  # [0,1] → [-1,1]
模型输出激活输出范围目标范围匹配?
GAN GeneratorTanh[1,1][1,1]完美匹配
VAE DecoderSigmoid[0,1][1,1] (输入)需要转换!

VAE 的特别注意:VAE 解码器用 Sigmoid 输出 [0,1],但 MNIST 数据被归一化到 [1,1]。在计算 VAE 损失时,代码将目标反归一化回 [0,1] 再用 BCE 损失——详见第4步。

第2步:GAN —— 生成对抗网络

GAN 由两个网络组成:生成器 G判别器 D,它们在一个极小极大博弈中对抗训练。

GAN 的数学形式

minGmaxDV(D,G)=Expdata[logD(x)]+Ezpz[log(1D(G(z)))]
  • 判别器 D 想最大化 V:真图为 1,假图为 0
  • 生成器 G 想最小化 V:让 D(G(z)) 接近 1(以假乱真)

2.1 生成器 Generator —— 从噪声到图像

python
class Generator(nn.Module):
    def __init__(self, latent_dim=128):
        self.model = nn.Sequential(
            # Block 1: 128 → 256
            nn.Linear(latent_dim, 256),
            nn.BatchNorm1d(256),          # BN 稳定训练
            nn.ReLU(inplace=True),
            # Block 2: 256 → 512
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            # Block 3: 512 → 1024
            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(inplace=True),
            # Block 4: 1024 → 784 (28×28)
            nn.Linear(1024, 784),
            nn.Tanh(),  # 输出 [-1, 1]
        )

为什么用全连接而不是转置卷积? 对于 MNIST 这种 28x28 的小图像,全连接层足以生成合理的结果(参数约 1M)。转置卷积(DCGAN)在大图像上效果更好,但全连接版本更简洁,适合教学。反卷积版本通常能生成更平滑的纹理。

为什么 Generator 用 BatchNorm 而 Discriminator 不用? BatchNorm 在生成器中非常关键——它防止生成器的输出分布漂移,确保各层的激活保持在合理范围。判别器不需要 BN(甚至 BN 可能有害),因为判别器的任务是对单张图像做判断,BN 引入的 mini-batch 统计依赖会干扰逐样本判断。

为什么最后一层用 Tanh? Tanh 输出范围 [1,1],与 MNIST 的归一化范围一致。如果用 Sigmoid(输出 [0,1]),需要改变数据预处理,且梯度在两端更平缓。

2.2 判别器 Discriminator —— 真假判断

python
class Discriminator(nn.Module):
    def __init__(self):
        self.model = nn.Sequential(
            # Block 1: 784 → 512
            nn.Linear(784, 512),
            nn.LeakyReLU(0.2),        # LeakyReLU 防止死神经元
            # Block 2: 512 → 256
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            # Block 3: 256 → 1
            nn.Linear(256, 1),
            nn.Sigmoid(),  # 输出概率 [0, 1]
        )

为什么判别器用 LeakyReLU 而不是 ReLU? LeakyReLU 在负半轴保留了微小斜率(α=0.2),避免了 ReLU 的"死神经元"问题——当 ReLU 输入恒为负时,梯度为零,神经元永远无法恢复。在 GAN 训练中,判别器的梯度质量直接影响生成器的学习,死神经元是 GAN 训练失败的常见原因之一。

LeakyReLU(x)={xif x>00.2xif x0

2.3 GAN 训练循环 —— 交替优化

python
for epoch in range(n_epochs):
    for imgs, _ in dataloader:
        batch_size = imgs.size(0)
        real_imgs = imgs.to(device)

        real_labels = torch.ones(batch_size, 1, device=device)   # 真实=1
        fake_labels = torch.zeros(batch_size, 1, device=device)   # 虚假=0

        # ===== 1. 训练判别器 D =====
        optimizer_D.zero_grad()

        # D(真实图像) → 应该接近 1
        real_pred = discriminator(real_imgs)
        d_real_loss = adversarial_loss(real_pred, real_labels)

        # D(生成图像) → 应该接近 0
        z = torch.randn(batch_size, latent_dim, device=device)
        fake_imgs = generator(z)
        fake_pred = discriminator(fake_imgs.detach())  # ⚠️ detach() 防止梯度传回 G
        d_fake_loss = adversarial_loss(fake_pred, fake_labels)

        d_loss = (d_real_loss + d_fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        # ===== 2. 训练生成器 G =====
        optimizer_G.zero_grad()

        z = torch.randn(batch_size, latent_dim, device=device)
        gen_imgs = generator(z)
        gen_pred = discriminator(gen_imgs)  # 注意:这次不用 detach()
        g_loss = adversarial_loss(gen_pred, real_labels)  # G 的目标:让 D 判为真

        g_loss.backward()
        optimizer_G.step()

关键代码细节

  1. fake_imgs.detach():训练判别器时,生成的假图必须从计算图中切断。否则梯度会通过 fake_imgs 流回生成器,导致我们在训练 D 的同时误改了 G 的参数。

  2. 生成器的损失目标LG=BCE(D(G(z)),1)。注意这里用的目标标签是 1(真实)——生成器想"骗过"判别器,让判别器认为生成图像是真的。

  3. 交替训练:每轮先训 D 再训 G。理论上应该是 D 训 k 步 G 训 1 步(原始 GAN 论文的推荐),但实践中 1:1 交替训练通常效果不错。

GAN 训练的直观理解

角色目标策略
判别器 D区分真假给真图高分,给假图低分
生成器 G以假乱真让 D 给生成图高分
理想平衡点D(x)=12D 完全无法区分,G 生成完美

第3步:VAE —— 变分自编码器

VAE 的学习目标与 GAN 完全不同:它不是对抗博弈,而是最大化数据的证据下界(ELBO)

VAE 的数学核心

LVAE=Eqϕ(z|x)[logpθ(x|z)]重构损失(越大越好)DKL(qϕ(z|x)p(z))KL 散度(越小越好)

3.1 VAE 网络结构

python
class VAE(nn.Module):
    def __init__(self, latent_dim=20):
        # --- 编码器:x → μ 和 log(σ²) ---
        self.encoder = nn.Sequential(
            nn.Linear(784, 512), nn.ReLU(),
            nn.Linear(512, 256), nn.ReLU(),
        )
        self.fc_mu = nn.Linear(256, latent_dim)       # 均值 μ
        self.fc_logvar = nn.Linear(256, latent_dim)   # log(σ²) —— 为什么是 log?

        # --- 解码器:z → 重建 x̂ ---
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 256), nn.ReLU(),
            nn.Linear(256, 512), nn.ReLU(),
            nn.Linear(512, 784), nn.Sigmoid(),        # Sigmoid 输出 [0, 1]
        )

为什么编码器输出 logvar 而不是 σ2 直接? 方差 σ2 必须为正,如果网络直接输出 σ2,我们需要加一个激活函数(如 softplus)来保证正值。输出 log(σ2) 不需要任何约束——它可以取任意实数值,然后通过 exp 恢复为正的 σ2

σ2=exp(log(σ2)),σ=exp(0.5log(σ2))

这在数学上比约束式输出更干净,训练更稳定。

3.2 重参数化技巧 —— VAE 的关键创新

如果直接从 N(μ,σ2) 采样 z,采样操作不可微,梯度无法从 decoder 传回 encoder。重参数化将采样分解为确定性部分 + 随机噪声:

z=μ+σε,εN(0,I)
python
def reparameterize(self, mu, logvar):
    std = torch.exp(0.5 * logvar)    # σ = e^(0.5 * log(σ²))
    eps = torch.randn_like(std)      # ε ~ N(0, 1)
    z = mu + std * eps               # z = μ + σ ⊙ ε
    return z

梯度流分析

zμ=1,zσ=ε

μσ 直接出现在 z 的计算中,梯度可以毫无障碍地传回 encoder。ε 是一个与模型参数无关的随机变量,反向传播时被当作常数。

3.3 VAE 损失函数

python
def vae_loss(x_recon, x, mu, logvar):
    # ---- 重构损失:二元交叉熵 ----
    x_target = (x.view(x.size(0), -1) + 1) / 2.0  # [-1,1] → [0,1]
    recon_loss = F.binary_cross_entropy(x_recon, x_target, reduction='sum') / x.size(0)

    # ---- KL 散度(解析解)----
    # 对于 q = N(μ, σ²), p = N(0, 1):
    # D_KL(q || p) = -0.5 * Σ(1 + log(σ²) - μ² - σ²)
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / x.size(0)

    total_loss = recon_loss + kl_loss
    return total_loss, recon_loss, kl_loss

KL 散度解析公式的推导

对于两个 d 维高斯分布 q=N(μ,Σ)p=N(0,I)

DKL(qp)=12[tr(Σ)+μμdlogdet(Σ)]=12j=1d[σj2+μj21log(σj2)]=12j=1d[1+log(σj2)μj2σj2]

KL 散度的作用

  • 如果 μ=0σ2=1(后验 = 先验),KL = 0
  • 如果 μ 远离 0 或 σ2 远离 1,KL >> 0
  • KL 正则化迫使潜空间保持结构化和平滑——两个相似的 x 映射到相近的 z

3.4 VAE 训练循环

python
for imgs, _ in dataloader:
    imgs = imgs.to(device)
    optimizer.zero_grad()

    x_recon, mu, logvar = model(imgs)             # 前向传播
    total_loss, recon_loss, kl_loss = vae_loss(   # 计算损失
        x_recon, imgs, mu, logvar
    )
    total_loss.backward()                          # 梯度可以穿过重参数化
    optimizer.step()

VAE 训练比 GAN 稳定得多——只有一个优化器,一个损失函数,没有博弈过程。

第4步:可视化 —— 四种对比图

GAN 生成样本:从随机噪声生成 16 张数字图像,排列成 4x4 网格。观察:

  • 生成质量:数字是否清晰可辨
  • 多样性:是否覆盖了 0-9 全部数字(还是模式坍塌只生成少数几个)

VAE 重建对比:上方原始图像 vs 下方 VAE 重建。典型情况是重建比原始模糊——这是 VAE 的"均值化效应":逐像素 BCE 损失倾向于预测像素值的条件期望,导致模糊。

VAE 潜空间 t-SNE:将测试集图像通过 VAE 编码器得到的 μ 向量做 t-SNE 投影到 2D。如果潜空间结构良好,不同数字类别应该在投影空间中形成清晰的聚类。

训练曲线对比:三张子图:

  1. GAN 的 G Loss 和 D Loss 曲线 —— 理想状态是两个 loss 在合理范围内波动,不收敛也不发散
  2. VAE 的总损失 / 重构损失 / KL 散度 —— 重构损失下降、KL 散度上升后稳定
  3. 文字对比总结 —— GAN vs VAE 的特性

关键概念速查表

概念公式代码对应
GAN 极小极大minGmaxDV(D,G)交替训练 D 和 G
D 的损失logD(x)log(1D(G(z)))BCE(d_real_pred, 1) + BCE(d_fake_pred, 0)
G 的损失logD(G(z))BCE(gen_pred, 1)
VAE ELBOEq[logp(x|z)]DKL(q|p)recon_loss + kl_loss
重参数化z=μ+σεmu + std * eps
KL 散度(高斯)12[1+logσ2μ2σ2]-0.5 * torch.sum(1 + logvar - mu^2 - exp(logvar))
模式坍塌G 只生成少数模式生成样本多样性低
LeakyReLUmax(0.2x,x)nn.LeakyReLU(0.2)

完整代码

py
# -*- coding: utf-8 -*-
"""
s13 图像生成 demo:从零实现 GAN 和 VAE 用于 MNIST 数字生成
===========================================================
使用 PyTorch 实现简单的 GAN 和 VAE,在 MNIST 上训练,
对比两种生成方法的效果。

运行方式:python demo.py(从 s13_image_generation/code/ 目录运行)
依赖:torch, torchvision, matplotlib, numpy
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# GPU 自动检测
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print(f"使用设备: {DEVICE}")
if DEVICE.type == 'cpu':
    print("(未检测到 GPU,使用 CPU 运行。如有 GPU,请安装 CUDA 版 PyTorch 以获得加速)")

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
matplotlib.rcParams['axes.unicode_minus'] = False
import numpy as np
import os
import time

# 图片保存目录:固定为本章节的 images/ 目录(相对于本脚本的 ../images/)
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
_IMAGES_DIR = os.path.join(_SCRIPT_DIR, '..', 'images')
os.makedirs(_IMAGES_DIR, exist_ok=True)


# ============================================================
# 第 0 部分:通用工具
# ============================================================

def load_mnist(batch_size: int = 128) -> DataLoader:
    """
    加载 MNIST 数据集

    参数:
        batch_size: 批大小
    返回:
        train_loader: 训练数据加载器(仅含图像,不需要标签)
    """
    transform = transforms.Compose([
        transforms.ToTensor(),
        # MNIST 图像是 [0,1],GAN 的 tanh 输出是 [-1,1]
        # 因此将图像归一化到 [-1, 1]
        transforms.Normalize((0.5,), (0.5,)),
    ])

    try:
        train_set = torchvision.datasets.MNIST(
            root='../data', train=True, download=True,
            transform=transform
        )
    except Exception as e:
        print(f"[警告] MNIST 下载失败 ({e}),使用合成数据")
        # 回退:合成 28x28 单通道图像
        from torch.utils.data import TensorDataset
        np.random.seed(42)
        synth_X = torch.rand(10000, 1, 28, 28) * 2 - 1  # [-1, 1]
        synth_y = torch.randint(0, 10, (10000,))
        train_set = TensorDataset(synth_X, synth_y)

    train_loader = DataLoader(
        train_set, batch_size=batch_size, shuffle=True,
        num_workers=0, drop_last=True
    )

    return train_loader


def to_image(tensor: torch.Tensor) -> np.ndarray:
    """
    将张量转换为可显示的图像数组

    参数:
        tensor: 形状 (C, H, W) 或 (N, C, H, W),值范围 [-1, 1]
    返回:
        img: numpy 数组,值范围 [0, 1]
    """
    # 反归一化: [-1, 1] → [0, 1]
    img = (tensor.detach().cpu().numpy() + 1) / 2.0
    img = np.clip(img, 0, 1)

    if img.ndim == 4 and img.shape[1] == 1:
        img = img[:, 0, :, :]  # (N, 1, H, W) → (N, H, W)
    elif img.ndim == 3 and img.shape[0] == 1:
        img = img[0]  # (1, H, W) → (H, W)
    return img


# ============================================================
# 第 1 部分:GAN —— 生成对抗网络
# ============================================================

class Generator(nn.Module):
    """
    GAN 生成器

    将随机噪声 z 映射为一张 28×28 的 MNIST 图像。

    架构: FC(128→256)→BN→ReLU → FC(256→512)→BN→ReLU → FC(512→784)→Tanh
    最后 reshape 为 (1, 28, 28),Tanh 输出范围 [-1, 1] 与归一化的图像匹配。
    """

    def __init__(self, latent_dim: int = 128):
        """
        初始化生成器

        参数:
            latent_dim: 输入噪声 z 的维度
        """
        super(Generator, self).__init__()

        self.latent_dim = latent_dim

        # ---------- 构建全连接网络:逐步放大维度 ----------
        self.model = nn.Sequential(
            # Block 1: latent_dim → 256
            nn.Linear(latent_dim, 256),
            nn.BatchNorm1d(256),          # BN 稳定训练,加速收敛
            nn.ReLU(inplace=True),
            # Block 2: 256 → 512
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            # Block 3: 512 → 1024
            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(inplace=True),
            # Block 4: 1024 → 784 (MNIST 像素数)
            nn.Linear(1024, 784),
            nn.Tanh(),  # 输出范围 [-1, 1],与 MNIST 归一化一致
        )

    def forward(self, z: torch.Tensor) -> torch.Tensor:
        """
        前向传播:噪声 → 图像

        参数:
            z: 随机噪声,形状 (N, latent_dim)
        返回:
            img: 生成的图像,形状 (N, 1, 28, 28),值范围 [-1, 1]
        """
        img = self.model(z)         # (N, 784)
        img = img.view(-1, 1, 28, 28)  # reshape 为图像形状
        return img


class Discriminator(nn.Module):
    """
    GAN 判别器

    判断输入图像是真实图像(来自 MNIST)还是生成器伪造的假图像。

    架构: FC(784→512)→LeakyReLU → FC(512→256)→LeakyReLU → FC(256→1)→Sigmoid
    输出一个 [0, 1] 的标量,表示图像为真的概率。
    """

    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            # Block 1: 784 → 512
            nn.Linear(784, 512),
            nn.LeakyReLU(0.2, inplace=True),  # 使用 LeakyReLU 防止 dead neurons
            # Block 2: 512 → 256
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            # Block 3: 256 → 1
            nn.Linear(256, 1),
            nn.Sigmoid(),  # 输出概率 [0, 1]
        )

    def forward(self, img: torch.Tensor) -> torch.Tensor:
        """
        前向传播:图像 → 真实性概率

        参数:
            img: 输入图像,形状 (N, 1, 28, 28) 或 (N, 784)
        返回:
            validity: 图像为真的概率,形状 (N, 1)
        """
        # 展平图像: (N, 1, 28, 28) → (N, 784)
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity


def train_gan(dataloader: DataLoader, device: torch.device,
              n_epochs: int = 50, latent_dim: int = 128) -> dict:
    """
    训练 GAN

    参数:
        dataloader: MNIST 数据加载器
        device: 计算设备
        n_epochs: 训练轮数
        latent_dim: 潜变量维度

    返回:
        history: 包含每 epoch 的 G loss 和 D loss
    """
    print(f"\n  {'='*50}")
    print(f"  训练 GAN (epochs={n_epochs})")
    print(f"  {'='*50}")

    # ---------- 初始化模型 ----------
    generator = Generator(latent_dim).to(device)
    discriminator = Discriminator().to(device)

    print(f"  Generator 参数: {sum(p.numel() for p in generator.parameters()):,}")
    print(f"  Discriminator 参数: {sum(p.numel() for p in discriminator.parameters()):,}")

    # ---------- 损失函数和优化器 ----------
    adversarial_loss = nn.BCELoss()  # 二元交叉熵损失

    # 两个独立的优化器(交替训练)
    optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

    # ---------- 训练循环 ----------
    history = {"g_loss": [], "d_loss": []}
    fixed_noise = torch.randn(16, latent_dim, device=device)  # 用于定期可视化

    for epoch in range(1, n_epochs + 1):
        epoch_g_loss = 0.0
        epoch_d_loss = 0.0
        n_batches = 0

        for i, (imgs, _) in enumerate(dataloader):
            batch_size = imgs.size(0)
            real_imgs = imgs.to(device)

            # 创建标签(真实=1,假=0)
            real_labels = torch.ones(batch_size, 1, device=device)
            fake_labels = torch.zeros(batch_size, 1, device=device)

            # ========== 训练判别器 D ==========
            optimizer_D.zero_grad()

            # 真实图像的损失:D(real_img) → 1
            real_pred = discriminator(real_imgs)
            d_real_loss = adversarial_loss(real_pred, real_labels)

            # 假图像的损失:D(G(z)) → 0
            z = torch.randn(batch_size, latent_dim, device=device)
            fake_imgs = generator(z)  # 生成假图像
            fake_pred = discriminator(fake_imgs.detach())  # detach() 防止梯度传回 G
            d_fake_loss = adversarial_loss(fake_pred, fake_labels)

            # 判别器的总损失 = 真实损失 + 假损失
            d_loss = (d_real_loss + d_fake_loss) / 2
            d_loss.backward()
            optimizer_D.step()

            # ========== 训练生成器 G ==========
            optimizer_G.zero_grad()

            # 生成器的目标:让判别器认为假图像是真的 D(G(z)) → 1
            z = torch.randn(batch_size, latent_dim, device=device)
            gen_imgs = generator(z)
            gen_pred = discriminator(gen_imgs)  # 注意:这里不用 detach()
            g_loss = adversarial_loss(gen_pred, real_labels)  # 目标是"真实"

            g_loss.backward()
            optimizer_G.step()

            epoch_g_loss += g_loss.item()
            epoch_d_loss += d_loss.item()
            n_batches += 1

        # 记录 epoch 平均损失
        avg_g_loss = epoch_g_loss / n_batches
        avg_d_loss = epoch_d_loss / n_batches
        history["g_loss"].append(avg_g_loss)
        history["d_loss"].append(avg_d_loss)

        if epoch % 5 == 0 or epoch == 1:
            print(f"  Epoch {epoch:3d}/{n_epochs} | "
                  f"D Loss: {avg_d_loss:.4f} | G Loss: {avg_g_loss:.4f}")

    return history, generator, discriminator


# ============================================================
# 第 2 部分:VAE —— 变分自编码器
# ============================================================

class VAE(nn.Module):
    """
    变分自编码器(VAE)

    包含:
    - 编码器: 输入 x → μ 和 log(σ²)
    - 重参数化: z = μ + σ ⊙ ε, ε ~ N(0, 1)
    - 解码器: z → 重建图像 x̂

    损失 = 重构损失 (MSE/BCE) + KL 散度 (D_KL(q(z|x) || p(z)))
    """

    def __init__(self, latent_dim: int = 20):
        """
        初始化 VAE

        参数:
            latent_dim: 潜变量 z 的维度
        """
        super(VAE, self).__init__()
        self.latent_dim = latent_dim

        # ---------- 编码器: x (784) → μ (latent_dim), logvar (latent_dim) ----------
        self.encoder = nn.Sequential(
            nn.Linear(784, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 256),
            nn.ReLU(inplace=True),
        )
        # μ 和 log(σ²) 分别由两个独立的全连接层预测
        self.fc_mu = nn.Linear(256, latent_dim)       # 均值 μ
        self.fc_logvar = nn.Linear(256, latent_dim)   # log(σ²),用 log 保证 σ² > 0

        # ---------- 解码器: z (latent_dim) → 重建 x̂ (784) ----------
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 784),
            nn.Sigmoid(),  # 输出 [0, 1],对应归一化的像素值
        )

    def encode(self, x: torch.Tensor) -> tuple:
        """
        编码:输入图像 → μ 和 log(σ²)

        参数:
            x: 输入图像展平,形状 (N, 784)
        返回:
            (mu, logvar): 均值和 log 方差,形状均为 (N, latent_dim)
        """
        h = self.encoder(x)               # 共享的特征提取
        mu = self.fc_mu(h)                # 预测均值
        logvar = self.fc_logvar(h)        # 预测 log(σ²)
        return mu, logvar

    def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        """
        重参数化技巧: z = μ + σ ⊙ ε

        这是 VAE 的核心创新。直接从 N(μ, σ²) 采样 z 是不可微的,
        通过将随机性"外包"给 ε ~ N(0,1),使得 z 对 μ 和 σ 可微。

        参数:
            mu: 均值,形状 (N, latent_dim)
            logvar: log(σ²),形状 (N, latent_dim)
        返回:
            z: 采样后的潜变量,形状 (N, latent_dim)
        """
        std = torch.exp(0.5 * logvar)     # σ = exp(0.5 * log(σ²))
        eps = torch.randn_like(std)       # ε ~ N(0, 1)
        z = mu + std * eps                # 重参数化: z = μ + σ ⊙ ε
        return z

    def decode(self, z: torch.Tensor) -> torch.Tensor:
        """
        解码:潜变量 z → 重建图像 x̂

        参数:
            z: 潜变量,形状 (N, latent_dim)
        返回:
            x_recon: 重建图像,形状 (N, 784),值范围 [0, 1]
        """
        return self.decoder(z)

    def forward(self, x: torch.Tensor) -> tuple:
        """
        VAE 完整前向传播

        参数:
            x: 输入图像,形状 (N, 1, 28, 28)
        返回:
            (x_recon, mu, logvar):
                - x_recon: 重建图像,形状 (N, 784)
                - mu: 编码均值
                - logvar: 编码 log 方差
        """
        # 展平图像: (N, 1, 28, 28) → (N, 784)
        x_flat = x.view(x.size(0), -1)

        # 编码 → μ, log(σ²)
        mu, logvar = self.encode(x_flat)

        # 重参数化采样 z
        z = self.reparameterize(mu, logvar)

        # 解码 → 重建
        x_recon = self.decode(z)

        return x_recon, mu, logvar


def vae_loss(x_recon: torch.Tensor, x: torch.Tensor,
             mu: torch.Tensor, logvar: torch.Tensor) -> tuple:
    """
    计算 VAE 的损失函数

    L_VAE = 重构损失 + KL 散度

    KL 散度的解析形式(高斯分布):
        D_KL( N(μ,σ²) || N(0,1) ) = -0.5 * sum(1 + log(σ²) - μ² - σ²)

    参数:
        x_recon: 重建图像,形状 (N, 784)
        x: 原始图像展平,形状 (N, 784)
        mu: 编码均值,形状 (N, latent_dim)
        logvar: 编码 log 方差,形状 (N, latent_dim)

    返回:
        (total_loss, recon_loss, kl_loss)
    """
    # ---------- 重构损失:二元交叉熵(适用于 [0,1] 范围的图像)----------
    # MNIST 图像经过 Normalize((0.5,),(0.5,)) 后值域为 [-1, 1],
    # 而 VAE 解码器的 Sigmoid 输出为 [0, 1],BCE 要求 target ∈ [0,1],
    # 因此需要将目标图像从 [-1, 1] 反归一化回 [0, 1]
    x_target = (x.view(x.size(0), -1) + 1) / 2.0  # [-1, 1] → [0, 1]
    recon_loss = F.binary_cross_entropy(x_recon, x_target,
                                         reduction='sum') / x.size(0)

    # ---------- KL 散度(解析解)----------
    # KL = -0.5 * Σ(1 + log(σ²) - μ² - σ²)
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / x.size(0)

    # ---------- 总损失 ----------
    total_loss = recon_loss + kl_loss

    return total_loss, recon_loss, kl_loss


def train_vae(dataloader: DataLoader, device: torch.device,
              n_epochs: int = 30, latent_dim: int = 20) -> dict:
    """
    训练 VAE

    参数:
        dataloader: MNIST 数据加载器
        device: 计算设备
        n_epochs: 训练轮数
        latent_dim: 潜变量维度

    返回:
        history: 包含每 epoch 的损失
    """
    print(f"\n  {'='*50}")
    print(f"  训练 VAE (epochs={n_epochs}, latent_dim={latent_dim})")
    print(f"  {'='*50}")

    model = VAE(latent_dim).to(device)
    print(f"  VAE 参数: {sum(p.numel() for p in model.parameters()):,}")

    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    history = {"total_loss": [], "recon_loss": [], "kl_loss": []}

    for epoch in range(1, n_epochs + 1):
        model.train()
        epoch_total = 0.0
        epoch_recon = 0.0
        epoch_kl = 0.0
        n_batches = 0

        for imgs, _ in dataloader:
            imgs = imgs.to(device)

            optimizer.zero_grad()
            x_recon, mu, logvar = model(imgs)
            total_loss, recon_loss, kl_loss = vae_loss(x_recon, imgs, mu, logvar)
            total_loss.backward()
            optimizer.step()

            epoch_total += total_loss.item()
            epoch_recon += recon_loss.item()
            epoch_kl += kl_loss.item()
            n_batches += 1

        history["total_loss"].append(epoch_total / n_batches)
        history["recon_loss"].append(epoch_recon / n_batches)
        history["kl_loss"].append(epoch_kl / n_batches)

        if epoch % 5 == 0 or epoch == 1:
            print(f"  Epoch {epoch:3d}/{n_epochs} | "
                  f"Total: {epoch_total/n_batches:.4f} | "
                  f"Recon: {epoch_recon/n_batches:.4f} | "
                  f"KL: {epoch_kl/n_batches:.4f}")

    return history, model


# ============================================================
# 第 3 部分:可视化工具
# ============================================================

def visualize_generated_samples(generator, device, latent_dim,
                                 save_path: str, n_samples: int = 16):
    """
    可视化 GAN 生成的图像样本

    参数:
        generator: 训练好的 GAN 生成器
        device: 计算设备
        latent_dim: 潜变量维度
        save_path: 保存路径
        n_samples: 生成的样本数
    """
    generator.eval()
    with torch.no_grad():
        z = torch.randn(n_samples, latent_dim, device=device)
        samples = generator(z)
        samples = to_image(samples)  # (N, H, W)

    # 排列为网格
    ncols = 4
    nrows = int(np.ceil(n_samples / ncols))
    fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 2, nrows * 2))
    axes = axes.flatten()

    for i, ax in enumerate(axes):
        if i < n_samples:
            ax.imshow(samples[i], cmap='gray')
        ax.axis('off')

    plt.suptitle('GAN Generated Digits', fontsize=14)
    plt.tight_layout()
    plt.savefig(save_path, dpi=120, bbox_inches='tight')
    plt.close()
    print(f"  [可视化] GAN 生成的样本已保存到 {save_path}")


def visualize_vae_reconstructions(model, test_loader, device,
                                   save_path: str, n_samples: int = 8):
    """
    可视化 VAE 重建结果(原始图像 vs 重建图像)

    参数:
        model: 训练好的 VAE 模型
        test_loader: 测试数据加载器
        device: 计算设备
        save_path: 保存路径
        n_samples: 显示的样本数
    """
    model.eval()
    imgs, _ = next(iter(test_loader))
    imgs = imgs[:n_samples].to(device)

    with torch.no_grad():
        x_recon, mu, logvar = model(imgs)
        x_recon = x_recon.view(n_samples, 1, 28, 28)

    originals = to_image(imgs)
    recons = to_image(x_recon)

    fig, axes = plt.subplots(2, n_samples, figsize=(n_samples * 1.5, 3))
    for i in range(n_samples):
        axes[0, i].imshow(originals[i], cmap='gray')
        axes[0, i].axis('off')
        if i == 0:
            axes[0, i].set_ylabel('Original', fontsize=10)

        axes[1, i].imshow(recons[i], cmap='gray')
        axes[1, i].axis('off')
        if i == 0:
            axes[1, i].set_ylabel('Reconstructed', fontsize=10)

    plt.suptitle('VAE Reconstruction (Top: Original, Bottom: Reconstructed)', fontsize=14)
    plt.tight_layout()
    plt.savefig(save_path, dpi=120, bbox_inches='tight')
    plt.close()
    print(f"  [可视化] VAE 重建结果已保存到 {save_path}")


def visualize_vae_latent_space(model, test_loader, device,
                                save_path: str, n_samples: int = 1000):
    """
    可视化 VAE 的潜空间(2D t-SNE 投影)

    参数:
        model: 训练好的 VAE
        test_loader: 测试数据加载器
        device: 计算设备
        save_path: 保存路径
        n_samples: 采样的潜变量数量
    """
    try:
        from sklearn.manifold import TSNE
    except ImportError:
        print("  [跳过] 潜空间可视化需要 scikit-learn: pip install scikit-learn")
        return

    model.eval()
    latent_vectors = []
    labels = []

    with torch.no_grad():
        for imgs, targets in test_loader:
            imgs = imgs.to(device)
            x_flat = imgs.view(imgs.size(0), -1)
            mu, logvar = model.encode(x_flat)
            latent_vectors.append(mu.cpu().numpy())
            labels.append(targets.numpy())

            if len(latent_vectors) * imgs.size(0) >= n_samples:
                break

    latent_vectors = np.concatenate(latent_vectors)[:n_samples]
    labels = np.concatenate(labels)[:n_samples]

    # t-SNE 降维到 2D
    print("    正在运行 t-SNE 降维(可能需要几秒)...")
    tsne = TSNE(n_components=2, random_state=42, perplexity=30)
    latent_2d = tsne.fit_transform(latent_vectors)

    # 绘制
    fig, ax = plt.subplots(figsize=(8, 6))
    scatter = ax.scatter(latent_2d[:, 0], latent_2d[:, 1],
                          c=labels, cmap='tab10', alpha=0.6, s=10)
    plt.colorbar(scatter, ticks=range(10), label='Digit Class')
    ax.set_title('t-SNE Projection of VAE Latent Space', fontsize=14)
    ax.set_xlabel('t-SNE Dimension 1')
    ax.set_ylabel('t-SNE Dimension 2')
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"  [可视化] VAE 潜空间 t-SNE 已保存到 {save_path}")


def plot_training_curves(gan_history: dict, vae_history: dict,
                          save_dir: str):
    """
    绘制训练曲线对比图

    参数:
        gan_history: GAN 训练历史(含 g_loss, d_loss)
        vae_history: VAE 训练历史(含 total_loss, recon_loss, kl_loss)
        save_dir: 保存目录
    """
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))

    # ---- GAN 损失曲线 ----
    ax = axes[0]
    epochs = range(1, len(gan_history["g_loss"]) + 1)
    ax.plot(epochs, gan_history["g_loss"], 'b-', label='G Loss', linewidth=1.5)
    ax.plot(epochs, gan_history["d_loss"], 'r-', label='D Loss', linewidth=1.5)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Loss')
    ax.set_title('GAN Training Curves')
    ax.legend()
    ax.grid(True, alpha=0.3)

    # ---- VAE 损失曲线 ----
    ax = axes[1]
    epochs = range(1, len(vae_history["total_loss"]) + 1)
    ax.plot(epochs, vae_history["total_loss"], 'purple', label='Total', linewidth=1.5)
    ax.plot(epochs, vae_history["recon_loss"], 'orange', label='Recon', linewidth=1)
    ax.plot(epochs, vae_history["kl_loss"], 'green', label='KL', linewidth=1)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Loss')
    ax.set_title('VAE Training Curves')
    ax.legend()
    ax.grid(True, alpha=0.3)

    # ---- 生成质量直观对比说明 ----
    ax = axes[2]
    ax.axis('off')
    comparison_text = (
        "GAN vs VAE Comparison\n\n"
        "GAN:\n"
        "  - Sharper images (adversarial training optimizes visual quality)\n"
        "  - No explicit latent space structure\n"
        "  - Unstable training (needs to balance G and D)\n"
        "  - May suffer from mode collapse\n\n"
        "VAE:\n"
        "  - Blurrier images (pixel-wise MSE/BCE averaging effect)\n"
        "  - Smooth structured latent space (enables interpolation)\n"
        "  - Stable training (explicit optimization objective)\n"
        "  - Better coverage of data distribution"
    )
    ax.text(0.05, 0.95, comparison_text, transform=ax.transAxes,
            fontsize=9, verticalalignment='top', fontfamily='monospace',
            bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

    plt.tight_layout()
    save_path = os.path.join(save_dir, 'training_curves.png')
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"  [可视化] 训练曲线已保存到 {save_path}")


# ============================================================
# 第 4 部分:主函数
# ============================================================

def main():
    """主函数:训练 GAN 和 VAE,生成对比可视化"""
    print("=" * 60)
    print("s13 图像生成 Demo")
    print("GAN vs VAE: MNIST 数字生成对比")
    print("=" * 60)

    # ---------- 设备选择 ----------
    device = DEVICE
    print(f"\n计算设备: {device}")

    # ---------- 加载数据 ----------
    print("\n[1/5] 加载 MNIST 数据集...")
    batch_size = 128
    train_loader = load_mnist(batch_size)

    # 测试集(VAE 重建可视化用)
    try:
        test_set = torchvision.datasets.MNIST(
            root='../data', train=False, download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5,), (0.5,)),
            ])
        )
    except Exception as e:
        print(f"[警告] MNIST 测试集下载失败 ({e}),使用合成数据")
        from torch.utils.data import TensorDataset
        synth_X = torch.rand(1000, 1, 28, 28) * 2 - 1
        synth_y = torch.randint(0, 10, (1000,))
        test_set = TensorDataset(synth_X, synth_y)
    test_loader = DataLoader(test_set, batch_size=16, shuffle=True, num_workers=0)

    # ---------- 训练 GAN ----------
    print("\n[2/5] 训练 GAN...")
    if device.type == 'cpu':
        n_gan_epochs = 2
        n_train_subset = 1000
        print("[配置] CPU 模式:使用轻量参数快速演示(GAN 2 epochs, 1000 样本)。GPU 模式下将使用完整训练配置。")
        # 缩减训练集(仅当数据集有 .data 属性时,如标准 MNIST)
        if hasattr(train_loader.dataset, 'data'):
            train_loader.dataset.data = train_loader.dataset.data[:n_train_subset]
            if hasattr(train_loader.dataset, 'targets'):
                train_loader.dataset.targets = train_loader.dataset.targets[:n_train_subset]
    else:
        n_gan_epochs = 30
        print(f"[配置] 训练 {n_gan_epochs} 个 epoch")
    gan_history, generator, discriminator = train_gan(
        train_loader, device, n_epochs=n_gan_epochs, latent_dim=128
    )

    # ---------- 训练 VAE ----------
    print("\n[3/5] 训练 VAE...")
    # 创建新的 train_loader(因为之前的被消耗了)
    train_loader_vae = load_mnist(batch_size)
    # VAE 使用 [0,1] 范围的图像(Sigmoid 输出),需要重新处理
    # 简化处理:VAE 内部接受 [-1,1] 输入但输出 Sigmoid [0,1],损失用 BCE

    if device.type == 'cpu':
        n_vae_epochs = 2
        if hasattr(train_loader_vae.dataset, 'data'):
            train_loader_vae.dataset.data = train_loader_vae.dataset.data[:n_train_subset]
            if hasattr(train_loader_vae.dataset, 'targets'):
                train_loader_vae.dataset.targets = train_loader_vae.dataset.targets[:n_train_subset]
        print(f"[配置] VAE 训练 {n_vae_epochs} 个 epoch(CPU 模式时长较短)")
    else:
        n_vae_epochs = 30
        print(f"[配置] VAE 训练 {n_vae_epochs} 个 epoch")
    vae_history, vae_model = train_vae(
        train_loader_vae, device, n_epochs=n_vae_epochs, latent_dim=20
    )

    # ---------- 可视化 ----------
    print("\n[4/5] 生成可视化...")
    output_dir = _IMAGES_DIR

    # GAN 生成的样本
    visualize_generated_samples(generator, device, latent_dim=128,
                                 save_path=os.path.join(output_dir, "gan_samples.png"))

    # VAE 重建结果
    visualize_vae_reconstructions(vae_model, test_loader, device,
                                   save_path=os.path.join(output_dir, "vae_reconstructions.png"))

    # VAE 潜空间可视化
    visualize_vae_latent_space(vae_model, test_loader, device,
                                save_path=os.path.join(output_dir, "vae_latent_space.png"))

    # 训练曲线对比
    plot_training_curves(gan_history, vae_history, output_dir)

    # ---------- 总结 ----------
    print("\n[5/5] 总结")
    print("=" * 60)
    print("GAN:")
    print(f"  最终 G Loss: {gan_history['g_loss'][-1]:.4f}")
    print(f"  最终 D Loss: {gan_history['d_loss'][-1]:.4f}")
    print(f"  (理想状态: D Loss ≈ 0.693, G Loss 适中 — D 无法分辨真伪)")
    print(f"\nVAE:")
    print(f"  最终 Total Loss: {vae_history['total_loss'][-1]:.4f}")
    print(f"  重构损失: {vae_history['recon_loss'][-1]:.4f}")
    print(f"  KL 散度: {vae_history['kl_loss'][-1]:.4f}")
    print(f"  (KL 散度越小,潜空间越接近标准正态分布)")
    print("=" * 60)
    print(f"\nDemo 完成!查看 {_IMAGES_DIR} 目录下的可视化结果。")
    print("=" * 60)


if __name__ == "__main__":
    main()