s13 图像生成 — exercise.py 练习指南
练习目标
通过手写 GAN 和 VAE 的核心损失函数与算法组件,深入理解两种生成范式:
- 实现 GAN 判别器/生成器损失 —— 理解对抗博弈的优化目标
- 实现 VAE 重参数化技巧 —— 理解如何让采样可微
- 实现 KL 散度 —— 理解潜空间正则化的数学
- 分析模式坍塌 —— 理解 GAN 最常见的失败模式
- 对比 GAN vs VAE 的设计理念 —— 两种范式各有何优势和代价
预备知识
- GAN 目标:
- VAE 目标:最大化 ELBO =
- BCE 损失:
- KL 散度(高斯):
- 重参数化:
任务清单
练习 1:实现 GAN 判别器和生成器的损失函数
1a. 判别器损失
任务:实现 gan_discriminator_loss(d_real_pred, d_fake_pred)。
数学目标:
- 对真实图像:
应接近 1,损失为 - 对生成图像:
应接近 0,损失为 - 总损失:
代码框架:
def gan_discriminator_loss(d_real_pred, d_fake_pred):
real_labels = torch.ones_like(d_real_pred) # 全 1
fake_labels = torch.zeros_like(d_fake_pred) # 全 0
real_loss = F.binary_cross_entropy(d_real_pred, real_labels)
fake_loss = F.binary_cross_entropy(d_fake_pred, fake_labels)
return (real_loss + fake_loss) / 2 # 取平均BCE 的数学:
因此 BCE(D(x), 1) + BCE(D(G(z)), 0) 等价于
1b. 生成器损失
任务:实现 gan_generator_loss(d_fake_pred)。
数学目标:让判别器认为生成的图像是真的,即
为什么不是
- 当
(生成器还很差), ,梯度几乎为零——梯度消失 - 使用
时, 给出非常大的梯度,帮助生成器快速改进 - 这个修改被称为 non-saturating GAN loss,是现代 GAN 训练的标准做法
代码框架:
def gan_generator_loss(d_fake_pred):
target = torch.ones_like(d_fake_pred) # 全 1
loss = F.binary_cross_entropy(d_fake_pred, target)
return loss测试用例:
d_real = torch.tensor([[0.9], [0.8], [0.95]]) # D 正确识别真图
d_fake = torch.tensor([[0.1], [0.2], [0.05]]) # D 正确识别假图
d_loss = gan_discriminator_loss(d_real, d_fake) # D 做得好 → loss 小
g_loss = gan_generator_loss(d_fake) # G 做得差 → loss 大(~2.3)练习 2:实现 VAE 的重参数化技巧
任务:实现 reparameterize(mu, logvar)。
为什么需要重参数化? 如果写 z = torch.normal(mu, std),这个采样操作不可微——mu 和 std 是模型参数,但我们无法计算
此时
代码框架:
def reparameterize(mu, logvar):
std = torch.exp(0.5 * logvar) # σ = e^(0.5 * log(σ²))
eps = torch.randn_like(std) # ε ~ N(0, 1)
z = mu + std * eps # z = μ + σ ⊙ ε
return z测试用例:给定
练习 3:计算 VAE 的 KL 散度
任务:实现 compute_kl_divergence(mu, logvar)。
公式(两个高斯分布的 KL 散度解析解):
代码框架:
def compute_kl_divergence(mu, logvar):
# 逐元素计算: 1 + log(σ²) - μ² - σ²
kl_element = 1 + logvar - mu.pow(2) - logvar.exp()
# 对 latent_dim 求和,对 batch 取平均,乘以 -0.5
kl = -0.5 * torch.sum(kl_element, dim=1).mean()
return kl测试用例:
| KL散度 | 解释 | |||
|---|---|---|---|---|
| 0 | 0 | 1 | 后验 = 先验 | |
| 2 | 1 | 后验远离先验 |
KL 散度的直觉:它是正则化项,防止编码器将所有输入映射到截然不同的
练习 4:解释 GAN 训练中的模式坍塌
任务:用文字回答三个问题,写入 explain_mode_collapse() 返回的字符串。
1. 什么是模式坍塌(给出具体例子)?
模式坍塌(Mode Collapse)指 GAN 的生成器学会了"作弊"——不管输入什么
,都输出几乎相同的少数几张图像。例如:训练一个生成 MNIST 数字的 GAN,最终无论输入什么噪声,都只生成"1"和"7"(或更极端——只生成某一种"1")。生成的分布只覆盖了真实分布(0-9)的少数模式。
2. 从优化角度,为什么 GAN 容易发生模式坍塌?
GAN 的优化是
。如果 发现生成某几种模式足够"骗过"当前 ,它就没有动力去探索其他模式。 学会识别这些模式后, 跳转到另一组模式——而不是学习覆盖所有模式。这就产生了模式间的"旋转门"效应。本质原因是:GAN 的损失只惩罚"生成质量差"(不够真),不直接惩罚"多样性不足"(覆盖不全)。
3. 至少两种缓解方法
a) Minibatch Discrimination:让判别器在同一 batch 内比较不同样本的相似度——如果所有样本都很像,给出惩罚信号。 b) Wasserstein GAN(WGAN):用 Wasserstein 距离替代 JS 散度,提供更平滑的梯度信号,大幅缓解模式坍塌。WGAN-GP 通过梯度惩罚实现了稳定的训练。 c) Unrolled GAN:展开优化步骤——生成器考虑判别器"未来会怎样学",选择更具前瞻性的策略。
练习 5:对比 GAN 和 VAE 的损失函数设计理念
任务:分析 compare_gan_vae_objectives() 中的两个问题。
1. 为什么 GAN 的损失导致锐利但可能有模式坍塌的图像?
GAN 的判别器学习区分真假——这是一个感知质量的判断。生成器不需要逐像素匹配训练数据,只需要"看起来真"。因此 GAN 能生成锐利、细节丰富的图像。但生成器可以"偷懒"——只生成最能骗过当前判别器的少数模式。GAN 的损失中没有显式的"多样性"惩罚,模式坍塌是结构性问题。
2. 为什么 VAE 的损失导致模糊但覆盖完整的图像?
VAE 的逐像素重构损失(BCE/MSE)是一个"保守"的损失——当输入存在多种可能的重构时,最优预测是它们的期望值(即平均值)。例如一个像素在训练集中有时是白色有时是黑色,VAE 最优输出是灰色——这就是 VAE 图像模糊的根源。但 VAE 的 KL 散度正则化强制潜空间保持连续和平滑,确保了采样时能覆盖整个数据分布(多样性好)。
| 特性 | GAN | VAE |
|---|---|---|
| 图像质量 | 锐利(对抗性目标优化视觉质量) | 模糊(逐像素损失的均值化效应) |
| 多样性 | 低(模式坍塌风险) | 高(KL 约束覆盖分布) |
| 训练稳定性 | 差(博弈可能不收敛) | 好(单一优化目标) |
| 潜空间 | 无可解释结构 | 平滑、可插值、结构良好 |
完整代码
# -*- coding: utf-8 -*-
"""
s13 图像生成 练习
==================
完成以下 TODO 练习来加深对 GAN 和 VAE 核心算法的理解。
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple
# ============================================================
# 练习 1:实现 GAN 判别器的损失函数
# ============================================================
def gan_discriminator_loss(d_real_pred: torch.Tensor,
d_fake_pred: torch.Tensor) -> torch.Tensor:
"""
TODO: 实现 GAN 判别器的损失函数
GAN 的判别器 D 有两个目标:
1. 对于真实图像 x,D(x) 应该接近 1(判定为真)
2. 对于生成图像 G(z),D(G(z)) 应该接近 0(判定为假)
因此判别器的损失为:
L_D = -E[log D(x)] - E[log(1 - D(G(z)))]
或者等价的 BCE 形式:
L_D = BCE(D(x), 1) + BCE(D(G(z)), 0)
参数:
d_real_pred: 判别器对真实图像的输出,形状 (N, 1),期望接近 1
d_fake_pred: 判别器对假图像的输出,形状 (N, 1),期望接近 0
返回:
loss: 判别器的总损失(标量)
提示:
1. 使用 F.binary_cross_entropy 或手动计算
2. real_target 应为全 1 张量,fake_target 应为全 0 张量
3. 两种实现可选:
a) loss = 0.5 * (BCE(d_real, ones) + BCE(d_fake, zeros))
b) loss = -torch.mean(torch.log(d_real) + torch.log(1 - d_fake))
"""
# TODO: 创建目标标签
# real_labels = torch.ones_like(???) # 全 1,形状与 d_real_pred 相同
# fake_labels = torch.zeros_like(???) # 全 0,形状与 d_fake_pred 相同
# TODO: 计算二元交叉熵损失
# real_loss = F.binary_cross_entropy(???, ???) # 真实图像损失
# fake_loss = F.binary_cross_entropy(???, ???) # 假图像损失
# TODO: 返回总损失(两者的平均或求和)
# return (real_loss + fake_loss) / 2
return torch.tensor(0.0) # 替换为你的实现
def gan_generator_loss(d_fake_pred: torch.Tensor) -> torch.Tensor:
"""
TODO: 实现 GAN 生成器的损失函数
生成器 G 的目标: 让判别器认为生成的图像是真的
D(G(z)) 应该接近 1
标准 GAN 损失:
L_G = -E[log D(G(z))]
实际实现中常用:
L_G = BCE(D(G(z)), 1) # 让 D(G(z)) 接近 1
参数:
d_fake_pred: 判别器对生成图像的输出,形状 (N, 1)
返回:
loss: 生成器的损失(标量)
提示:
1. 目标标签是全 1 张量
2. 使用 F.binary_cross_entropy
"""
# TODO: 创建全 1 的目标标签
# target = torch.ones_like(???)
# TODO: 计算 BCE 损失
# loss = F.binary_cross_entropy(???, ???)
return torch.tensor(0.0) # 替换为你的实现
# ============================================================
# 练习 2:实现 VAE 的重参数化技巧
# ============================================================
def reparameterize(mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
"""
TODO: 实现 VAE 的重参数化技巧
重参数化将采样操作变为可微:
z = μ + σ ⊙ ε
其中:
- σ = exp(0.5 * logvar) (因为 logvar = log(σ²))
- ε ~ N(0, 1)
为什么需要重参数化?
如果直接从 N(μ, σ²) 采样 z,这个采样操作是不可微的,
梯度无法从 decoder 传回 encoder 的 μ 和 σ。
重参数化将随机性"外包"给 ε,使 μ 和 σ 成为可微的部分。
参数:
mu: 编码器预测的均值,形状 (N, latent_dim)
logvar: 编码器预测的 log(σ²),形状 (N, latent_dim)
返回:
z: 采样后的潜变量,形状 (N, latent_dim)
提示:
1. std = torch.exp(0.5 * logvar) # σ = e^(0.5*log(σ²))
2. eps = torch.randn_like(std) # ε ~ N(0, 1)
3. z = mu + std * eps # 重参数化
"""
# TODO: 计算标准差 σ
# std = ???
# TODO: 从标准正态分布采样 ε
# eps = ???
# TODO: 重参数化: z = μ + σ ⊙ ε
# z = ???
return None # 替换为你的实现
# ============================================================
# 练习 3:计算 VAE 的 KL 散度
# ============================================================
def compute_kl_divergence(mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
"""
TODO: 计算 VAE 的 KL 散度 D_KL(q(z|x) || p(z))
假设:
- q(z|x) = N(z; μ, σ²) (编码器的后验分布)
- p(z) = N(z; 0, 1) (先验分布,标准正态)
两个高斯分布之间的 KL 散度有解析解:
D_KL(N(μ,σ²) || N(0,1)) = -0.5 * Σ_j (1 + log(σ_j²) - μ_j² - σ_j²)
参数:
mu: 编码器输出的均值,形状 (N, latent_dim)
logvar: 编码器输出的 log(σ²),形状 (N, latent_dim)
返回:
kl: KL 散度值(标量,对 batch 取平均)
提示:
1. 逐元素计算: 1 + logvar - mu^2 - exp(logvar)
2. 对 latent_dim 维度求和
3. 对 batch 取平均
4. 乘以 -0.5
KL 散度的直觉:
- 如果 μ=0 且 σ²=1(后验 = 先验),KL = 0
- 如果 μ 远离 0 或 σ² 远不等于 1,KL 变大
- KL 项起正则化作用:让潜空间保持平滑,接近标准正态
"""
# TODO: 计算每个样本的 KL 散度
# 逐元素: 1 + logvar - mu^2 - exp(logvar)
# kl_element = 1 + logvar - mu.pow(2) - logvar.exp()
# 对 latent_dim 求和,对 batch 取平均
# kl = -0.5 * torch.sum(kl_element, dim=1).mean()
return torch.tensor(0.0) # 替换为你的实现
# ============================================================
# 练习 4:解释 GAN 训练中的模式坍塌
# ============================================================
def explain_mode_collapse():
"""
TODO: 用文字结合代码逻辑解释 GAN 的模式坍塌(Mode Collapse)问题
请在下方写出你对模式坍塌的理解(中文回答以下问题):
1. 什么是模式坍塌?给出一个具体的例子
2. 从优化的角度,为什么 GAN 容易发生模式坍塌?
3. 至少列举两种缓解模式坍塌的方法
将你的回答写在下方字符串中并返回。
"""
explanation = """
请在此处写下你对模式坍塌的理解,回答上述三个问题。
(提示示例)
1. 模式坍塌是指...
2. 从优化角度看...
3. 缓解方法:
a) ...
b) ...
"""
return explanation.strip()
# ============================================================
# 练习 5:对比 GAN 和 VAE 的损失函数设计理念
# ============================================================
def compare_gan_vae_objectives():
"""
TODO: 从数学和直觉两个角度对比 GAN 和 VAE 的损失函数设计
GAN 的目标: min_G max_D V(D,G) = E[log D(x)] + E[log(1-D(G(z)))]
VAE 的目标: max E[log p(x|z)] - D_KL(q(z|x) || p(z))
请分析:
1. 为什么 GAN 的损失会导致锐利但可能有模式坍塌的图像?
2. 为什么 VAE 的损失会导致模糊但覆盖完整的图像?
"""
comparison = """
请在此处写出你的分析。
"""
return comparison.strip()
# ============================================================
# 测试代码
# ============================================================
if __name__ == "__main__":
print("=" * 50)
print("s13 图像生成 — 练习测试")
print("=" * 50)
# ---- 测试练习 1:GAN 损失函数 ----
print("\n[练习 1] GAN 损失函数测试:")
# 模拟输出
d_real = torch.tensor([[0.9], [0.8], [0.95]]) # 较好的真实判别
d_fake = torch.tensor([[0.1], [0.2], [0.05]]) # 较好的假判别
d_loss = gan_discriminator_loss(d_real, d_fake)
g_loss = gan_generator_loss(d_fake)
print(f" 判别器损失: {d_loss.item():.4f}")
print(f" 生成器损失: {g_loss.item():.4f}")
print(f" (判别器损失小 = D 能区分真假; 生成器损失大 = G 还需改进)")
print(f" (期望 d_loss ≈ 0.5~1.0, g_loss ≈ 2~3)")
# ---- 测试练习 2:重参数化 ----
print("\n[练习 2] 重参数化技巧测试:")
mu = torch.tensor([[0.5, -0.3], [0.0, 0.8]])
logvar = torch.tensor([[0.1, 0.2], [-0.5, 0.0]])
z = reparameterize(mu, logvar)
if z is not None:
print(f" μ:\n{mu}")
print(f" logvar:\n{logvar}")
print(f" σ:\n{torch.exp(0.5 * logvar)}") # 期望标准差
print(f" 采样 z:\n{z}")
print(f" z 的形状: {z.shape}")
print(f" (z 应该在 μ 附近随机波动)")
else:
print(" 请完成 reparameterize 实现")
# ---- 测试练习 3:KL 散度 ----
print("\n[练习 3] KL 散度测试:")
# 测试 1: 后验 = 先验 → KL ≈ 0
mu1 = torch.zeros(10, 5) # μ = 0
logvar1 = torch.zeros(10, 5) # log(σ²) = 0 → σ² = 1
kl1 = compute_kl_divergence(mu1, logvar1)
print(f" μ=0, σ²=1 (后验=先验): KL = {kl1.item():.6f} (期望 ≈ 0)")
# 测试 2: 后验偏离先验 → KL > 0
mu2 = torch.ones(10, 5) * 2.0 # μ 远离 0
logvar2 = torch.ones(10, 5) * 1.0 # σ² = e¹ ≈ 2.72
kl2 = compute_kl_divergence(mu2, logvar2)
print(f" μ=2, σ²≈2.72 (后验≠先验): KL = {kl2.item():.4f} (期望 > 0)")
# ---- 练习 4 和 5:概念题 ----
print("\n[练习 4] 模式坍塌解释:")
print(explain_mode_collapse())
print("\n[练习 5] GAN vs VAE 损失对比:")
print(compare_gan_vae_objectives())
print("\n" + "=" * 50)
print("完成所有练习后,运行 demo.py 查看完整的生成对比实验。")
print("=" * 50)