Skip to content

s13 图像生成 — exercise.py 练习指南

Download exercise.py

练习目标

通过手写 GAN 和 VAE 的核心损失函数与算法组件,深入理解两种生成范式:

  1. 实现 GAN 判别器/生成器损失 —— 理解对抗博弈的优化目标
  2. 实现 VAE 重参数化技巧 —— 理解如何让采样可微
  3. 实现 KL 散度 —— 理解潜空间正则化的数学
  4. 分析模式坍塌 —— 理解 GAN 最常见的失败模式
  5. 对比 GAN vs VAE 的设计理念 —— 两种范式各有何优势和代价

预备知识

  • GAN 目标minGmaxDE[logD(x)]+E[log(1D(G(z)))]
  • VAE 目标:最大化 ELBO = E[logp(x|z)]DKL(q(z|x)p(z))
  • BCE 损失BCE(y^,y)=[ylogy^+(1y)log(1y^)]
  • KL 散度(高斯)12[σj2+μj21logσj2]
  • 重参数化z=μ+σε,εN(0,1)

任务清单

练习 1:实现 GAN 判别器和生成器的损失函数

1a. 判别器损失

任务:实现 gan_discriminator_loss(d_real_pred, d_fake_pred)

数学目标

  • 对真实图像:D(x) 应接近 1,损失为 logD(x)
  • 对生成图像:D(G(z)) 应接近 0,损失为 log(1D(G(z)))
  • 总损失:LD=12[ElogD(x)+Elog(1D(G(z)))]

代码框架

python
def gan_discriminator_loss(d_real_pred, d_fake_pred):
    real_labels = torch.ones_like(d_real_pred)   # 全 1
    fake_labels = torch.zeros_like(d_fake_pred)  # 全 0

    real_loss = F.binary_cross_entropy(d_real_pred, real_labels)
    fake_loss = F.binary_cross_entropy(d_fake_pred, fake_labels)

    return (real_loss + fake_loss) / 2  # 取平均

BCE 的数学

BCE(y^,1)=log(y^),BCE(y^,0)=log(1y^)

因此 BCE(D(x), 1) + BCE(D(G(z)), 0) 等价于 logD(x)log(1D(G(z)))

1b. 生成器损失

任务:实现 gan_generator_loss(d_fake_pred)

数学目标:让判别器认为生成的图像是真的,即 D(G(z))1

为什么不是 log(1D(G(z))) 原始 GAN 论文确实使用这个公式,但实践中改用 logD(G(z))

  • D(G(z))0(生成器还很差),log(1D(G(z)))0,梯度几乎为零——梯度消失
  • 使用 logD(G(z)) 时,D(G(z))0 给出非常大的梯度,帮助生成器快速改进
  • 这个修改被称为 non-saturating GAN loss,是现代 GAN 训练的标准做法

代码框架

python
def gan_generator_loss(d_fake_pred):
    target = torch.ones_like(d_fake_pred)  # 全 1
    loss = F.binary_cross_entropy(d_fake_pred, target)
    return loss

测试用例

python
d_real = torch.tensor([[0.9], [0.8], [0.95]])  # D 正确识别真图
d_fake = torch.tensor([[0.1], [0.2], [0.05]])  # D 正确识别假图

d_loss = gan_discriminator_loss(d_real, d_fake)  # D 做得好 → loss 小
g_loss = gan_generator_loss(d_fake)               # G 做得差 → loss 大(~2.3)

练习 2:实现 VAE 的重参数化技巧

任务:实现 reparameterize(mu, logvar)

为什么需要重参数化? 如果写 z = torch.normal(mu, std),这个采样操作不可微——mustd 是模型参数,但我们无法计算 zμ 用于反向传播。重参数化将随机性"外包"给 ε

z=μ+σε,εN(0,I)

此时 zμ=1zσ=ε,梯度可以正常传播。

代码框架

python
def reparameterize(mu, logvar):
    std = torch.exp(0.5 * logvar)    # σ = e^(0.5 * log(σ²))
    eps = torch.randn_like(std)      # ε ~ N(0, 1)
    z = mu + std * eps               # z = μ + σ ⊙ ε
    return z

测试用例:给定 μlog(σ2),采样得到的 z 应该在 μ 附近随机波动,波动幅度由 σ 控制。

练习 3:计算 VAE 的 KL 散度

任务:实现 compute_kl_divergence(mu, logvar)

公式(两个高斯分布的 KL 散度解析解):

DKL(N(μ,σ2I)N(0,I))=12j=1d(1+logσj2μj2σj2)

代码框架

python
def compute_kl_divergence(mu, logvar):
    # 逐元素计算: 1 + log(σ²) - μ² - σ²
    kl_element = 1 + logvar - mu.pow(2) - logvar.exp()
    # 对 latent_dim 求和,对 batch 取平均,乘以 -0.5
    kl = -0.5 * torch.sum(kl_element, dim=1).mean()
    return kl

测试用例

μlog(σ2)σ2KL散度解释
0010后验 = 先验 N(0,1)
21e12.72>0后验远离先验

KL 散度的直觉:它是正则化项,防止编码器将所有输入映射到截然不同的 z 区域(这会导致潜空间支离破碎,插值毫无意义)。KL 项越大,潜空间越接近标准正态分布,越平滑和结构良好。

练习 4:解释 GAN 训练中的模式坍塌

任务:用文字回答三个问题,写入 explain_mode_collapse() 返回的字符串。

1. 什么是模式坍塌(给出具体例子)?

模式坍塌(Mode Collapse)指 GAN 的生成器学会了"作弊"——不管输入什么 z,都输出几乎相同的少数几张图像。例如:训练一个生成 MNIST 数字的 GAN,最终无论输入什么噪声,都只生成"1"和"7"(或更极端——只生成某一种"1")。生成的分布只覆盖了真实分布(0-9)的少数模式。

2. 从优化角度,为什么 GAN 容易发生模式坍塌?

GAN 的优化是 minGmaxDV(D,G)。如果 G 发现生成某几种模式足够"骗过"当前 D,它就没有动力去探索其他模式。D 学会识别这些模式后,G 跳转到另一组模式——而不是学习覆盖所有模式。这就产生了模式间的"旋转门"效应。本质原因是:GAN 的损失只惩罚"生成质量差"(不够真),不直接惩罚"多样性不足"(覆盖不全)。

3. 至少两种缓解方法

a) Minibatch Discrimination:让判别器在同一 batch 内比较不同样本的相似度——如果所有样本都很像,给出惩罚信号。 b) Wasserstein GAN(WGAN):用 Wasserstein 距离替代 JS 散度,提供更平滑的梯度信号,大幅缓解模式坍塌。WGAN-GP 通过梯度惩罚实现了稳定的训练。 c) Unrolled GAN:展开优化步骤——生成器考虑判别器"未来会怎样学",选择更具前瞻性的策略。

练习 5:对比 GAN 和 VAE 的损失函数设计理念

任务:分析 compare_gan_vae_objectives() 中的两个问题。

1. 为什么 GAN 的损失导致锐利但可能有模式坍塌的图像?

GAN 的判别器学习区分真假——这是一个感知质量的判断。生成器不需要逐像素匹配训练数据,只需要"看起来真"。因此 GAN 能生成锐利、细节丰富的图像。但生成器可以"偷懒"——只生成最能骗过当前判别器的少数模式。GAN 的损失中没有显式的"多样性"惩罚,模式坍塌是结构性问题。

2. 为什么 VAE 的损失导致模糊但覆盖完整的图像?

VAE 的逐像素重构损失(BCE/MSE)是一个"保守"的损失——当输入存在多种可能的重构时,最优预测是它们的期望值(即平均值)。例如一个像素在训练集中有时是白色有时是黑色,VAE 最优输出是灰色——这就是 VAE 图像模糊的根源。但 VAE 的 KL 散度正则化强制潜空间保持连续和平滑,确保了采样时能覆盖整个数据分布(多样性好)。

特性GANVAE
图像质量锐利(对抗性目标优化视觉质量)模糊(逐像素损失的均值化效应)
多样性低(模式坍塌风险)高(KL 约束覆盖分布)
训练稳定性差(博弈可能不收敛)好(单一优化目标)
潜空间无可解释结构平滑、可插值、结构良好

完整代码

py
# -*- coding: utf-8 -*-
"""
s13 图像生成 练习
==================
完成以下 TODO 练习来加深对 GAN 和 VAE 核心算法的理解。
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple


# ============================================================
# 练习 1:实现 GAN 判别器的损失函数
# ============================================================

def gan_discriminator_loss(d_real_pred: torch.Tensor,
                            d_fake_pred: torch.Tensor) -> torch.Tensor:
    """
    TODO: 实现 GAN 判别器的损失函数

    GAN 的判别器 D 有两个目标:
    1. 对于真实图像 x,D(x) 应该接近 1(判定为真)
    2. 对于生成图像 G(z),D(G(z)) 应该接近 0(判定为假)

    因此判别器的损失为:
        L_D = -E[log D(x)] - E[log(1 - D(G(z)))]

    或者等价的 BCE 形式:
        L_D = BCE(D(x), 1) + BCE(D(G(z)), 0)

    参数:
        d_real_pred: 判别器对真实图像的输出,形状 (N, 1),期望接近 1
        d_fake_pred: 判别器对假图像的输出,形状 (N, 1),期望接近 0

    返回:
        loss: 判别器的总损失(标量)

    提示:
    1. 使用 F.binary_cross_entropy 或手动计算
    2. real_target 应为全 1 张量,fake_target 应为全 0 张量
    3. 两种实现可选:
       a) loss = 0.5 * (BCE(d_real, ones) + BCE(d_fake, zeros))
       b) loss = -torch.mean(torch.log(d_real) + torch.log(1 - d_fake))
    """
    # TODO: 创建目标标签
    # real_labels = torch.ones_like(???)  # 全 1,形状与 d_real_pred 相同
    # fake_labels = torch.zeros_like(???) # 全 0,形状与 d_fake_pred 相同

    # TODO: 计算二元交叉熵损失
    # real_loss = F.binary_cross_entropy(???, ???)  # 真实图像损失
    # fake_loss = F.binary_cross_entropy(???, ???)  # 假图像损失

    # TODO: 返回总损失(两者的平均或求和)
    # return (real_loss + fake_loss) / 2

    return torch.tensor(0.0)  # 替换为你的实现


def gan_generator_loss(d_fake_pred: torch.Tensor) -> torch.Tensor:
    """
    TODO: 实现 GAN 生成器的损失函数

    生成器 G 的目标: 让判别器认为生成的图像是真的
    D(G(z)) 应该接近 1

    标准 GAN 损失:
        L_G = -E[log D(G(z))]

    实际实现中常用:
        L_G = BCE(D(G(z)), 1)  # 让 D(G(z)) 接近 1

    参数:
        d_fake_pred: 判别器对生成图像的输出,形状 (N, 1)

    返回:
        loss: 生成器的损失(标量)

    提示:
    1. 目标标签是全 1 张量
    2. 使用 F.binary_cross_entropy
    """
    # TODO: 创建全 1 的目标标签
    # target = torch.ones_like(???)

    # TODO: 计算 BCE 损失
    # loss = F.binary_cross_entropy(???, ???)

    return torch.tensor(0.0)  # 替换为你的实现


# ============================================================
# 练习 2:实现 VAE 的重参数化技巧
# ============================================================

def reparameterize(mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
    """
    TODO: 实现 VAE 的重参数化技巧

    重参数化将采样操作变为可微:
        z = μ + σ ⊙ ε
    其中:
        - σ = exp(0.5 * logvar) (因为 logvar = log(σ²))
        - ε ~ N(0, 1)

    为什么需要重参数化?
    如果直接从 N(μ, σ²) 采样 z,这个采样操作是不可微的,
    梯度无法从 decoder 传回 encoder 的 μ 和 σ。
    重参数化将随机性"外包"给 ε,使 μ 和 σ 成为可微的部分。

    参数:
        mu: 编码器预测的均值,形状 (N, latent_dim)
        logvar: 编码器预测的 log(σ²),形状 (N, latent_dim)

    返回:
        z: 采样后的潜变量,形状 (N, latent_dim)

    提示:
    1. std = torch.exp(0.5 * logvar)   # σ = e^(0.5*log(σ²))
    2. eps = torch.randn_like(std)     # ε ~ N(0, 1)
    3. z = mu + std * eps              # 重参数化
    """
    # TODO: 计算标准差 σ
    # std = ???

    # TODO: 从标准正态分布采样 ε
    # eps = ???

    # TODO: 重参数化: z = μ + σ ⊙ ε
    # z = ???

    return None  # 替换为你的实现


# ============================================================
# 练习 3:计算 VAE 的 KL 散度
# ============================================================

def compute_kl_divergence(mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
    """
    TODO: 计算 VAE 的 KL 散度 D_KL(q(z|x) || p(z))

    假设:
    - q(z|x) = N(z; μ, σ²)  (编码器的后验分布)
    - p(z) = N(z; 0, 1)     (先验分布,标准正态)

    两个高斯分布之间的 KL 散度有解析解:
        D_KL(N(μ,σ²) || N(0,1)) = -0.5 * Σ_j (1 + log(σ_j²) - μ_j² - σ_j²)

    参数:
        mu: 编码器输出的均值,形状 (N, latent_dim)
        logvar: 编码器输出的 log(σ²),形状 (N, latent_dim)

    返回:
        kl: KL 散度值(标量,对 batch 取平均)

    提示:
    1. 逐元素计算: 1 + logvar - mu^2 - exp(logvar)
    2. 对 latent_dim 维度求和
    3. 对 batch 取平均
    4. 乘以 -0.5

    KL 散度的直觉:
    - 如果 μ=0 且 σ²=1(后验 = 先验),KL = 0
    - 如果 μ 远离 0 或 σ² 远不等于 1,KL 变大
    - KL 项起正则化作用:让潜空间保持平滑,接近标准正态
    """
    # TODO: 计算每个样本的 KL 散度
    # 逐元素: 1 + logvar - mu^2 - exp(logvar)
    # kl_element = 1 + logvar - mu.pow(2) - logvar.exp()
    # 对 latent_dim 求和,对 batch 取平均
    # kl = -0.5 * torch.sum(kl_element, dim=1).mean()

    return torch.tensor(0.0)  # 替换为你的实现


# ============================================================
# 练习 4:解释 GAN 训练中的模式坍塌
# ============================================================

def explain_mode_collapse():
    """
    TODO: 用文字结合代码逻辑解释 GAN 的模式坍塌(Mode Collapse)问题

    请在下方写出你对模式坍塌的理解(中文回答以下问题):

    1. 什么是模式坍塌?给出一个具体的例子
    2. 从优化的角度,为什么 GAN 容易发生模式坍塌?
    3. 至少列举两种缓解模式坍塌的方法

    将你的回答写在下方字符串中并返回。
    """
    explanation = """
    请在此处写下你对模式坍塌的理解,回答上述三个问题。

    (提示示例)
    1. 模式坍塌是指...
    2. 从优化角度看...
    3. 缓解方法:
       a) ...
       b) ...
    """
    return explanation.strip()


# ============================================================
# 练习 5:对比 GAN 和 VAE 的损失函数设计理念
# ============================================================

def compare_gan_vae_objectives():
    """
    TODO: 从数学和直觉两个角度对比 GAN 和 VAE 的损失函数设计

    GAN 的目标: min_G max_D V(D,G) = E[log D(x)] + E[log(1-D(G(z)))]
    VAE 的目标: max E[log p(x|z)] - D_KL(q(z|x) || p(z))

    请分析:
    1. 为什么 GAN 的损失会导致锐利但可能有模式坍塌的图像?
    2. 为什么 VAE 的损失会导致模糊但覆盖完整的图像?
    """
    comparison = """
    请在此处写出你的分析。
    """
    return comparison.strip()


# ============================================================
# 测试代码
# ============================================================

if __name__ == "__main__":
    print("=" * 50)
    print("s13 图像生成 — 练习测试")
    print("=" * 50)

    # ---- 测试练习 1:GAN 损失函数 ----
    print("\n[练习 1] GAN 损失函数测试:")

    # 模拟输出
    d_real = torch.tensor([[0.9], [0.8], [0.95]])  # 较好的真实判别
    d_fake = torch.tensor([[0.1], [0.2], [0.05]])  # 较好的假判别

    d_loss = gan_discriminator_loss(d_real, d_fake)
    g_loss = gan_generator_loss(d_fake)

    print(f"  判别器损失: {d_loss.item():.4f}")
    print(f"  生成器损失: {g_loss.item():.4f}")
    print(f"  (判别器损失小 = D 能区分真假; 生成器损失大 = G 还需改进)")
    print(f"  (期望 d_loss ≈ 0.5~1.0, g_loss ≈ 2~3)")

    # ---- 测试练习 2:重参数化 ----
    print("\n[练习 2] 重参数化技巧测试:")

    mu = torch.tensor([[0.5, -0.3], [0.0, 0.8]])
    logvar = torch.tensor([[0.1, 0.2], [-0.5, 0.0]])

    z = reparameterize(mu, logvar)

    if z is not None:
        print(f"  μ:\n{mu}")
        print(f"  logvar:\n{logvar}")
        print(f"  σ:\n{torch.exp(0.5 * logvar)}")  # 期望标准差
        print(f"  采样 z:\n{z}")
        print(f"  z 的形状: {z.shape}")
        print(f"  (z 应该在 μ 附近随机波动)")
    else:
        print("  请完成 reparameterize 实现")

    # ---- 测试练习 3:KL 散度 ----
    print("\n[练习 3] KL 散度测试:")

    # 测试 1: 后验 = 先验 → KL ≈ 0
    mu1 = torch.zeros(10, 5)    # μ = 0
    logvar1 = torch.zeros(10, 5)  # log(σ²) = 0 → σ² = 1
    kl1 = compute_kl_divergence(mu1, logvar1)
    print(f"  μ=0, σ²=1 (后验=先验): KL = {kl1.item():.6f} (期望 ≈ 0)")

    # 测试 2: 后验偏离先验 → KL > 0
    mu2 = torch.ones(10, 5) * 2.0  # μ 远离 0
    logvar2 = torch.ones(10, 5) * 1.0  # σ² = e¹ ≈ 2.72
    kl2 = compute_kl_divergence(mu2, logvar2)
    print(f"  μ=2, σ²≈2.72 (后验≠先验): KL = {kl2.item():.4f} (期望 > 0)")

    # ---- 练习 4 和 5:概念题 ----
    print("\n[练习 4] 模式坍塌解释:")
    print(explain_mode_collapse())

    print("\n[练习 5] GAN vs VAE 损失对比:")
    print(compare_gan_vae_objectives())

    print("\n" + "=" * 50)
    print("完成所有练习后,运行 demo.py 查看完整的生成对比实验。")
    print("=" * 50)