s13 图像生成 — demo.py 代码详解
运行方式
cd s13_image_generation/code
python demo.py代码逐段详解
第1步:导入库与全局配置 —— 每个库是做什么的
import torch
import torch.nn as nn # GAN 的 Linear+BN、VAE 的 Encoder/Decoder
import torch.nn.functional as F # BCE、MSE、relu、sigmoid 等损失和激活函数
import torch.optim as optim # Adam 优化器(GAN 和 VAE 都用 Adam)
import torchvision # MNIST 数据集
import torchvision.transforms as transforms
from torch.utils.data import DataLoader| 库 | 在此 demo 中的角色 |
|---|---|
torch.nn | 构建 Generator、Discriminator、VAE Encoder/Decoder |
torch.optim | Adam 优化器(GAN 用 lr=0.0002, betas=(0.5, 0.999)) |
torchvision | MNIST 数据集(28x28 手写数字) |
为什么 MNIST? 手写数字生成是图像生成模型的"Hello World"。28x28 灰度图维度低、结构简单,可以让 GAN 和 VAE 在 CPU 上也能快速训练出可见效果。同时 MNIST 的 10 个类别(0-9)提供了清晰的聚类结构,便于分析 VAE 潜空间的 t-SNE 分布。
数据归一化策略:
transforms.Normalize((0.5,), (0.5,)) # [0,1] → [-1,1]| 模型 | 输出激活 | 输出范围 | 目标范围 | 匹配? |
|---|---|---|---|---|
| GAN Generator | Tanh | 完美匹配 | ||
| VAE Decoder | Sigmoid | 需要转换! |
VAE 的特别注意:VAE 解码器用 Sigmoid 输出
,但 MNIST 数据被归一化到 。在计算 VAE 损失时,代码将目标反归一化回 再用 BCE 损失——详见第4步。
第2步:GAN —— 生成对抗网络
GAN 由两个网络组成:生成器
GAN 的数学形式:
- 判别器
想最大化 :真图为 1,假图为 0 - 生成器
想最小化 :让 接近 1(以假乱真)
2.1 生成器 Generator —— 从噪声到图像
class Generator(nn.Module):
def __init__(self, latent_dim=128):
self.model = nn.Sequential(
# Block 1: 128 → 256
nn.Linear(latent_dim, 256),
nn.BatchNorm1d(256), # BN 稳定训练
nn.ReLU(inplace=True),
# Block 2: 256 → 512
nn.Linear(256, 512),
nn.BatchNorm1d(512),
nn.ReLU(inplace=True),
# Block 3: 512 → 1024
nn.Linear(512, 1024),
nn.BatchNorm1d(1024),
nn.ReLU(inplace=True),
# Block 4: 1024 → 784 (28×28)
nn.Linear(1024, 784),
nn.Tanh(), # 输出 [-1, 1]
)为什么用全连接而不是转置卷积? 对于 MNIST 这种 28x28 的小图像,全连接层足以生成合理的结果(参数约 1M)。转置卷积(DCGAN)在大图像上效果更好,但全连接版本更简洁,适合教学。反卷积版本通常能生成更平滑的纹理。
为什么 Generator 用 BatchNorm 而 Discriminator 不用? BatchNorm 在生成器中非常关键——它防止生成器的输出分布漂移,确保各层的激活保持在合理范围。判别器不需要 BN(甚至 BN 可能有害),因为判别器的任务是对单张图像做判断,BN 引入的 mini-batch 统计依赖会干扰逐样本判断。
为什么最后一层用 Tanh? Tanh 输出范围
2.2 判别器 Discriminator —— 真假判断
class Discriminator(nn.Module):
def __init__(self):
self.model = nn.Sequential(
# Block 1: 784 → 512
nn.Linear(784, 512),
nn.LeakyReLU(0.2), # LeakyReLU 防止死神经元
# Block 2: 512 → 256
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
# Block 3: 256 → 1
nn.Linear(256, 1),
nn.Sigmoid(), # 输出概率 [0, 1]
)为什么判别器用 LeakyReLU 而不是 ReLU? LeakyReLU 在负半轴保留了微小斜率(
2.3 GAN 训练循环 —— 交替优化
for epoch in range(n_epochs):
for imgs, _ in dataloader:
batch_size = imgs.size(0)
real_imgs = imgs.to(device)
real_labels = torch.ones(batch_size, 1, device=device) # 真实=1
fake_labels = torch.zeros(batch_size, 1, device=device) # 虚假=0
# ===== 1. 训练判别器 D =====
optimizer_D.zero_grad()
# D(真实图像) → 应该接近 1
real_pred = discriminator(real_imgs)
d_real_loss = adversarial_loss(real_pred, real_labels)
# D(生成图像) → 应该接近 0
z = torch.randn(batch_size, latent_dim, device=device)
fake_imgs = generator(z)
fake_pred = discriminator(fake_imgs.detach()) # ⚠️ detach() 防止梯度传回 G
d_fake_loss = adversarial_loss(fake_pred, fake_labels)
d_loss = (d_real_loss + d_fake_loss) / 2
d_loss.backward()
optimizer_D.step()
# ===== 2. 训练生成器 G =====
optimizer_G.zero_grad()
z = torch.randn(batch_size, latent_dim, device=device)
gen_imgs = generator(z)
gen_pred = discriminator(gen_imgs) # 注意:这次不用 detach()
g_loss = adversarial_loss(gen_pred, real_labels) # G 的目标:让 D 判为真
g_loss.backward()
optimizer_G.step()关键代码细节:
fake_imgs.detach():训练判别器时,生成的假图必须从计算图中切断。否则梯度会通过fake_imgs流回生成器,导致我们在训练 D 的同时误改了 G 的参数。生成器的损失目标:
。注意这里用的目标标签是 1(真实)——生成器想"骗过"判别器,让判别器认为生成图像是真的。 交替训练:每轮先训 D 再训 G。理论上应该是 D 训 k 步 G 训 1 步(原始 GAN 论文的推荐),但实践中 1:1 交替训练通常效果不错。
GAN 训练的直观理解:
| 角色 | 目标 | 策略 |
|---|---|---|
| 判别器 D | 区分真假 | 给真图高分,给假图低分 |
| 生成器 G | 以假乱真 | 让 D 给生成图高分 |
| 理想平衡点 | D 完全无法区分,G 生成完美 |
第3步:VAE —— 变分自编码器
VAE 的学习目标与 GAN 完全不同:它不是对抗博弈,而是最大化数据的证据下界(ELBO)。
VAE 的数学核心:
3.1 VAE 网络结构
class VAE(nn.Module):
def __init__(self, latent_dim=20):
# --- 编码器:x → μ 和 log(σ²) ---
self.encoder = nn.Sequential(
nn.Linear(784, 512), nn.ReLU(),
nn.Linear(512, 256), nn.ReLU(),
)
self.fc_mu = nn.Linear(256, latent_dim) # 均值 μ
self.fc_logvar = nn.Linear(256, latent_dim) # log(σ²) —— 为什么是 log?
# --- 解码器:z → 重建 x̂ ---
self.decoder = nn.Sequential(
nn.Linear(latent_dim, 256), nn.ReLU(),
nn.Linear(256, 512), nn.ReLU(),
nn.Linear(512, 784), nn.Sigmoid(), # Sigmoid 输出 [0, 1]
)为什么编码器输出 logvar 而不是
这在数学上比约束式输出更干净,训练更稳定。
3.2 重参数化技巧 —— VAE 的关键创新
如果直接从
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar) # σ = e^(0.5 * log(σ²))
eps = torch.randn_like(std) # ε ~ N(0, 1)
z = mu + std * eps # z = μ + σ ⊙ ε
return z梯度流分析:
3.3 VAE 损失函数
def vae_loss(x_recon, x, mu, logvar):
# ---- 重构损失:二元交叉熵 ----
x_target = (x.view(x.size(0), -1) + 1) / 2.0 # [-1,1] → [0,1]
recon_loss = F.binary_cross_entropy(x_recon, x_target, reduction='sum') / x.size(0)
# ---- KL 散度(解析解)----
# 对于 q = N(μ, σ²), p = N(0, 1):
# D_KL(q || p) = -0.5 * Σ(1 + log(σ²) - μ² - σ²)
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / x.size(0)
total_loss = recon_loss + kl_loss
return total_loss, recon_loss, kl_lossKL 散度解析公式的推导:
对于两个
KL 散度的作用:
- 如果
且 (后验 = 先验),KL = 0 - 如果
远离 0 或 远离 1,KL >> 0 - KL 正则化迫使潜空间保持结构化和平滑——两个相似的
映射到相近的
3.4 VAE 训练循环
for imgs, _ in dataloader:
imgs = imgs.to(device)
optimizer.zero_grad()
x_recon, mu, logvar = model(imgs) # 前向传播
total_loss, recon_loss, kl_loss = vae_loss( # 计算损失
x_recon, imgs, mu, logvar
)
total_loss.backward() # 梯度可以穿过重参数化
optimizer.step()VAE 训练比 GAN 稳定得多——只有一个优化器,一个损失函数,没有博弈过程。
第4步:可视化 —— 四种对比图
GAN 生成样本:从随机噪声生成 16 张数字图像,排列成 4x4 网格。观察:
- 生成质量:数字是否清晰可辨
- 多样性:是否覆盖了 0-9 全部数字(还是模式坍塌只生成少数几个)
VAE 重建对比:上方原始图像 vs 下方 VAE 重建。典型情况是重建比原始模糊——这是 VAE 的"均值化效应":逐像素 BCE 损失倾向于预测像素值的条件期望,导致模糊。
VAE 潜空间 t-SNE:将测试集图像通过 VAE 编码器得到的
训练曲线对比:三张子图:
- GAN 的 G Loss 和 D Loss 曲线 —— 理想状态是两个 loss 在合理范围内波动,不收敛也不发散
- VAE 的总损失 / 重构损失 / KL 散度 —— 重构损失下降、KL 散度上升后稳定
- 文字对比总结 —— GAN vs VAE 的特性
关键概念速查表
| 概念 | 公式 | 代码对应 |
|---|---|---|
| GAN 极小极大 | 交替训练 D 和 G | |
| D 的损失 | BCE(d_real_pred, 1) + BCE(d_fake_pred, 0) | |
| G 的损失 | BCE(gen_pred, 1) | |
| VAE ELBO | recon_loss + kl_loss | |
| 重参数化 | mu + std * eps | |
| KL 散度(高斯) | -0.5 * torch.sum(1 + logvar - mu^2 - exp(logvar)) | |
| 模式坍塌 | G 只生成少数模式 | 生成样本多样性低 |
| LeakyReLU | nn.LeakyReLU(0.2) |
完整代码
# -*- coding: utf-8 -*-
"""
s13 图像生成 demo:从零实现 GAN 和 VAE 用于 MNIST 数字生成
===========================================================
使用 PyTorch 实现简单的 GAN 和 VAE,在 MNIST 上训练,
对比两种生成方法的效果。
运行方式:python demo.py(从 s13_image_generation/code/ 目录运行)
依赖:torch, torchvision, matplotlib, numpy
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
# GPU 自动检测
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print(f"使用设备: {DEVICE}")
if DEVICE.type == 'cpu':
print("(未检测到 GPU,使用 CPU 运行。如有 GPU,请安装 CUDA 版 PyTorch 以获得加速)")
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
matplotlib.rcParams['axes.unicode_minus'] = False
import numpy as np
import os
import time
# 图片保存目录:固定为本章节的 images/ 目录(相对于本脚本的 ../images/)
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
_IMAGES_DIR = os.path.join(_SCRIPT_DIR, '..', 'images')
os.makedirs(_IMAGES_DIR, exist_ok=True)
# ============================================================
# 第 0 部分:通用工具
# ============================================================
def load_mnist(batch_size: int = 128) -> DataLoader:
"""
加载 MNIST 数据集
参数:
batch_size: 批大小
返回:
train_loader: 训练数据加载器(仅含图像,不需要标签)
"""
transform = transforms.Compose([
transforms.ToTensor(),
# MNIST 图像是 [0,1],GAN 的 tanh 输出是 [-1,1]
# 因此将图像归一化到 [-1, 1]
transforms.Normalize((0.5,), (0.5,)),
])
try:
train_set = torchvision.datasets.MNIST(
root='../data', train=True, download=True,
transform=transform
)
except Exception as e:
print(f"[警告] MNIST 下载失败 ({e}),使用合成数据")
# 回退:合成 28x28 单通道图像
from torch.utils.data import TensorDataset
np.random.seed(42)
synth_X = torch.rand(10000, 1, 28, 28) * 2 - 1 # [-1, 1]
synth_y = torch.randint(0, 10, (10000,))
train_set = TensorDataset(synth_X, synth_y)
train_loader = DataLoader(
train_set, batch_size=batch_size, shuffle=True,
num_workers=0, drop_last=True
)
return train_loader
def to_image(tensor: torch.Tensor) -> np.ndarray:
"""
将张量转换为可显示的图像数组
参数:
tensor: 形状 (C, H, W) 或 (N, C, H, W),值范围 [-1, 1]
返回:
img: numpy 数组,值范围 [0, 1]
"""
# 反归一化: [-1, 1] → [0, 1]
img = (tensor.detach().cpu().numpy() + 1) / 2.0
img = np.clip(img, 0, 1)
if img.ndim == 4 and img.shape[1] == 1:
img = img[:, 0, :, :] # (N, 1, H, W) → (N, H, W)
elif img.ndim == 3 and img.shape[0] == 1:
img = img[0] # (1, H, W) → (H, W)
return img
# ============================================================
# 第 1 部分:GAN —— 生成对抗网络
# ============================================================
class Generator(nn.Module):
"""
GAN 生成器
将随机噪声 z 映射为一张 28×28 的 MNIST 图像。
架构: FC(128→256)→BN→ReLU → FC(256→512)→BN→ReLU → FC(512→784)→Tanh
最后 reshape 为 (1, 28, 28),Tanh 输出范围 [-1, 1] 与归一化的图像匹配。
"""
def __init__(self, latent_dim: int = 128):
"""
初始化生成器
参数:
latent_dim: 输入噪声 z 的维度
"""
super(Generator, self).__init__()
self.latent_dim = latent_dim
# ---------- 构建全连接网络:逐步放大维度 ----------
self.model = nn.Sequential(
# Block 1: latent_dim → 256
nn.Linear(latent_dim, 256),
nn.BatchNorm1d(256), # BN 稳定训练,加速收敛
nn.ReLU(inplace=True),
# Block 2: 256 → 512
nn.Linear(256, 512),
nn.BatchNorm1d(512),
nn.ReLU(inplace=True),
# Block 3: 512 → 1024
nn.Linear(512, 1024),
nn.BatchNorm1d(1024),
nn.ReLU(inplace=True),
# Block 4: 1024 → 784 (MNIST 像素数)
nn.Linear(1024, 784),
nn.Tanh(), # 输出范围 [-1, 1],与 MNIST 归一化一致
)
def forward(self, z: torch.Tensor) -> torch.Tensor:
"""
前向传播:噪声 → 图像
参数:
z: 随机噪声,形状 (N, latent_dim)
返回:
img: 生成的图像,形状 (N, 1, 28, 28),值范围 [-1, 1]
"""
img = self.model(z) # (N, 784)
img = img.view(-1, 1, 28, 28) # reshape 为图像形状
return img
class Discriminator(nn.Module):
"""
GAN 判别器
判断输入图像是真实图像(来自 MNIST)还是生成器伪造的假图像。
架构: FC(784→512)→LeakyReLU → FC(512→256)→LeakyReLU → FC(256→1)→Sigmoid
输出一个 [0, 1] 的标量,表示图像为真的概率。
"""
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
# Block 1: 784 → 512
nn.Linear(784, 512),
nn.LeakyReLU(0.2, inplace=True), # 使用 LeakyReLU 防止 dead neurons
# Block 2: 512 → 256
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
# Block 3: 256 → 1
nn.Linear(256, 1),
nn.Sigmoid(), # 输出概率 [0, 1]
)
def forward(self, img: torch.Tensor) -> torch.Tensor:
"""
前向传播:图像 → 真实性概率
参数:
img: 输入图像,形状 (N, 1, 28, 28) 或 (N, 784)
返回:
validity: 图像为真的概率,形状 (N, 1)
"""
# 展平图像: (N, 1, 28, 28) → (N, 784)
img_flat = img.view(img.size(0), -1)
validity = self.model(img_flat)
return validity
def train_gan(dataloader: DataLoader, device: torch.device,
n_epochs: int = 50, latent_dim: int = 128) -> dict:
"""
训练 GAN
参数:
dataloader: MNIST 数据加载器
device: 计算设备
n_epochs: 训练轮数
latent_dim: 潜变量维度
返回:
history: 包含每 epoch 的 G loss 和 D loss
"""
print(f"\n {'='*50}")
print(f" 训练 GAN (epochs={n_epochs})")
print(f" {'='*50}")
# ---------- 初始化模型 ----------
generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)
print(f" Generator 参数: {sum(p.numel() for p in generator.parameters()):,}")
print(f" Discriminator 参数: {sum(p.numel() for p in discriminator.parameters()):,}")
# ---------- 损失函数和优化器 ----------
adversarial_loss = nn.BCELoss() # 二元交叉熵损失
# 两个独立的优化器(交替训练)
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
# ---------- 训练循环 ----------
history = {"g_loss": [], "d_loss": []}
fixed_noise = torch.randn(16, latent_dim, device=device) # 用于定期可视化
for epoch in range(1, n_epochs + 1):
epoch_g_loss = 0.0
epoch_d_loss = 0.0
n_batches = 0
for i, (imgs, _) in enumerate(dataloader):
batch_size = imgs.size(0)
real_imgs = imgs.to(device)
# 创建标签(真实=1,假=0)
real_labels = torch.ones(batch_size, 1, device=device)
fake_labels = torch.zeros(batch_size, 1, device=device)
# ========== 训练判别器 D ==========
optimizer_D.zero_grad()
# 真实图像的损失:D(real_img) → 1
real_pred = discriminator(real_imgs)
d_real_loss = adversarial_loss(real_pred, real_labels)
# 假图像的损失:D(G(z)) → 0
z = torch.randn(batch_size, latent_dim, device=device)
fake_imgs = generator(z) # 生成假图像
fake_pred = discriminator(fake_imgs.detach()) # detach() 防止梯度传回 G
d_fake_loss = adversarial_loss(fake_pred, fake_labels)
# 判别器的总损失 = 真实损失 + 假损失
d_loss = (d_real_loss + d_fake_loss) / 2
d_loss.backward()
optimizer_D.step()
# ========== 训练生成器 G ==========
optimizer_G.zero_grad()
# 生成器的目标:让判别器认为假图像是真的 D(G(z)) → 1
z = torch.randn(batch_size, latent_dim, device=device)
gen_imgs = generator(z)
gen_pred = discriminator(gen_imgs) # 注意:这里不用 detach()
g_loss = adversarial_loss(gen_pred, real_labels) # 目标是"真实"
g_loss.backward()
optimizer_G.step()
epoch_g_loss += g_loss.item()
epoch_d_loss += d_loss.item()
n_batches += 1
# 记录 epoch 平均损失
avg_g_loss = epoch_g_loss / n_batches
avg_d_loss = epoch_d_loss / n_batches
history["g_loss"].append(avg_g_loss)
history["d_loss"].append(avg_d_loss)
if epoch % 5 == 0 or epoch == 1:
print(f" Epoch {epoch:3d}/{n_epochs} | "
f"D Loss: {avg_d_loss:.4f} | G Loss: {avg_g_loss:.4f}")
return history, generator, discriminator
# ============================================================
# 第 2 部分:VAE —— 变分自编码器
# ============================================================
class VAE(nn.Module):
"""
变分自编码器(VAE)
包含:
- 编码器: 输入 x → μ 和 log(σ²)
- 重参数化: z = μ + σ ⊙ ε, ε ~ N(0, 1)
- 解码器: z → 重建图像 x̂
损失 = 重构损失 (MSE/BCE) + KL 散度 (D_KL(q(z|x) || p(z)))
"""
def __init__(self, latent_dim: int = 20):
"""
初始化 VAE
参数:
latent_dim: 潜变量 z 的维度
"""
super(VAE, self).__init__()
self.latent_dim = latent_dim
# ---------- 编码器: x (784) → μ (latent_dim), logvar (latent_dim) ----------
self.encoder = nn.Sequential(
nn.Linear(784, 512),
nn.ReLU(inplace=True),
nn.Linear(512, 256),
nn.ReLU(inplace=True),
)
# μ 和 log(σ²) 分别由两个独立的全连接层预测
self.fc_mu = nn.Linear(256, latent_dim) # 均值 μ
self.fc_logvar = nn.Linear(256, latent_dim) # log(σ²),用 log 保证 σ² > 0
# ---------- 解码器: z (latent_dim) → 重建 x̂ (784) ----------
self.decoder = nn.Sequential(
nn.Linear(latent_dim, 256),
nn.ReLU(inplace=True),
nn.Linear(256, 512),
nn.ReLU(inplace=True),
nn.Linear(512, 784),
nn.Sigmoid(), # 输出 [0, 1],对应归一化的像素值
)
def encode(self, x: torch.Tensor) -> tuple:
"""
编码:输入图像 → μ 和 log(σ²)
参数:
x: 输入图像展平,形状 (N, 784)
返回:
(mu, logvar): 均值和 log 方差,形状均为 (N, latent_dim)
"""
h = self.encoder(x) # 共享的特征提取
mu = self.fc_mu(h) # 预测均值
logvar = self.fc_logvar(h) # 预测 log(σ²)
return mu, logvar
def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
"""
重参数化技巧: z = μ + σ ⊙ ε
这是 VAE 的核心创新。直接从 N(μ, σ²) 采样 z 是不可微的,
通过将随机性"外包"给 ε ~ N(0,1),使得 z 对 μ 和 σ 可微。
参数:
mu: 均值,形状 (N, latent_dim)
logvar: log(σ²),形状 (N, latent_dim)
返回:
z: 采样后的潜变量,形状 (N, latent_dim)
"""
std = torch.exp(0.5 * logvar) # σ = exp(0.5 * log(σ²))
eps = torch.randn_like(std) # ε ~ N(0, 1)
z = mu + std * eps # 重参数化: z = μ + σ ⊙ ε
return z
def decode(self, z: torch.Tensor) -> torch.Tensor:
"""
解码:潜变量 z → 重建图像 x̂
参数:
z: 潜变量,形状 (N, latent_dim)
返回:
x_recon: 重建图像,形状 (N, 784),值范围 [0, 1]
"""
return self.decoder(z)
def forward(self, x: torch.Tensor) -> tuple:
"""
VAE 完整前向传播
参数:
x: 输入图像,形状 (N, 1, 28, 28)
返回:
(x_recon, mu, logvar):
- x_recon: 重建图像,形状 (N, 784)
- mu: 编码均值
- logvar: 编码 log 方差
"""
# 展平图像: (N, 1, 28, 28) → (N, 784)
x_flat = x.view(x.size(0), -1)
# 编码 → μ, log(σ²)
mu, logvar = self.encode(x_flat)
# 重参数化采样 z
z = self.reparameterize(mu, logvar)
# 解码 → 重建
x_recon = self.decode(z)
return x_recon, mu, logvar
def vae_loss(x_recon: torch.Tensor, x: torch.Tensor,
mu: torch.Tensor, logvar: torch.Tensor) -> tuple:
"""
计算 VAE 的损失函数
L_VAE = 重构损失 + KL 散度
KL 散度的解析形式(高斯分布):
D_KL( N(μ,σ²) || N(0,1) ) = -0.5 * sum(1 + log(σ²) - μ² - σ²)
参数:
x_recon: 重建图像,形状 (N, 784)
x: 原始图像展平,形状 (N, 784)
mu: 编码均值,形状 (N, latent_dim)
logvar: 编码 log 方差,形状 (N, latent_dim)
返回:
(total_loss, recon_loss, kl_loss)
"""
# ---------- 重构损失:二元交叉熵(适用于 [0,1] 范围的图像)----------
# MNIST 图像经过 Normalize((0.5,),(0.5,)) 后值域为 [-1, 1],
# 而 VAE 解码器的 Sigmoid 输出为 [0, 1],BCE 要求 target ∈ [0,1],
# 因此需要将目标图像从 [-1, 1] 反归一化回 [0, 1]
x_target = (x.view(x.size(0), -1) + 1) / 2.0 # [-1, 1] → [0, 1]
recon_loss = F.binary_cross_entropy(x_recon, x_target,
reduction='sum') / x.size(0)
# ---------- KL 散度(解析解)----------
# KL = -0.5 * Σ(1 + log(σ²) - μ² - σ²)
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / x.size(0)
# ---------- 总损失 ----------
total_loss = recon_loss + kl_loss
return total_loss, recon_loss, kl_loss
def train_vae(dataloader: DataLoader, device: torch.device,
n_epochs: int = 30, latent_dim: int = 20) -> dict:
"""
训练 VAE
参数:
dataloader: MNIST 数据加载器
device: 计算设备
n_epochs: 训练轮数
latent_dim: 潜变量维度
返回:
history: 包含每 epoch 的损失
"""
print(f"\n {'='*50}")
print(f" 训练 VAE (epochs={n_epochs}, latent_dim={latent_dim})")
print(f" {'='*50}")
model = VAE(latent_dim).to(device)
print(f" VAE 参数: {sum(p.numel() for p in model.parameters()):,}")
optimizer = optim.Adam(model.parameters(), lr=1e-3)
history = {"total_loss": [], "recon_loss": [], "kl_loss": []}
for epoch in range(1, n_epochs + 1):
model.train()
epoch_total = 0.0
epoch_recon = 0.0
epoch_kl = 0.0
n_batches = 0
for imgs, _ in dataloader:
imgs = imgs.to(device)
optimizer.zero_grad()
x_recon, mu, logvar = model(imgs)
total_loss, recon_loss, kl_loss = vae_loss(x_recon, imgs, mu, logvar)
total_loss.backward()
optimizer.step()
epoch_total += total_loss.item()
epoch_recon += recon_loss.item()
epoch_kl += kl_loss.item()
n_batches += 1
history["total_loss"].append(epoch_total / n_batches)
history["recon_loss"].append(epoch_recon / n_batches)
history["kl_loss"].append(epoch_kl / n_batches)
if epoch % 5 == 0 or epoch == 1:
print(f" Epoch {epoch:3d}/{n_epochs} | "
f"Total: {epoch_total/n_batches:.4f} | "
f"Recon: {epoch_recon/n_batches:.4f} | "
f"KL: {epoch_kl/n_batches:.4f}")
return history, model
# ============================================================
# 第 3 部分:可视化工具
# ============================================================
def visualize_generated_samples(generator, device, latent_dim,
save_path: str, n_samples: int = 16):
"""
可视化 GAN 生成的图像样本
参数:
generator: 训练好的 GAN 生成器
device: 计算设备
latent_dim: 潜变量维度
save_path: 保存路径
n_samples: 生成的样本数
"""
generator.eval()
with torch.no_grad():
z = torch.randn(n_samples, latent_dim, device=device)
samples = generator(z)
samples = to_image(samples) # (N, H, W)
# 排列为网格
ncols = 4
nrows = int(np.ceil(n_samples / ncols))
fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 2, nrows * 2))
axes = axes.flatten()
for i, ax in enumerate(axes):
if i < n_samples:
ax.imshow(samples[i], cmap='gray')
ax.axis('off')
plt.suptitle('GAN Generated Digits', fontsize=14)
plt.tight_layout()
plt.savefig(save_path, dpi=120, bbox_inches='tight')
plt.close()
print(f" [可视化] GAN 生成的样本已保存到 {save_path}")
def visualize_vae_reconstructions(model, test_loader, device,
save_path: str, n_samples: int = 8):
"""
可视化 VAE 重建结果(原始图像 vs 重建图像)
参数:
model: 训练好的 VAE 模型
test_loader: 测试数据加载器
device: 计算设备
save_path: 保存路径
n_samples: 显示的样本数
"""
model.eval()
imgs, _ = next(iter(test_loader))
imgs = imgs[:n_samples].to(device)
with torch.no_grad():
x_recon, mu, logvar = model(imgs)
x_recon = x_recon.view(n_samples, 1, 28, 28)
originals = to_image(imgs)
recons = to_image(x_recon)
fig, axes = plt.subplots(2, n_samples, figsize=(n_samples * 1.5, 3))
for i in range(n_samples):
axes[0, i].imshow(originals[i], cmap='gray')
axes[0, i].axis('off')
if i == 0:
axes[0, i].set_ylabel('Original', fontsize=10)
axes[1, i].imshow(recons[i], cmap='gray')
axes[1, i].axis('off')
if i == 0:
axes[1, i].set_ylabel('Reconstructed', fontsize=10)
plt.suptitle('VAE Reconstruction (Top: Original, Bottom: Reconstructed)', fontsize=14)
plt.tight_layout()
plt.savefig(save_path, dpi=120, bbox_inches='tight')
plt.close()
print(f" [可视化] VAE 重建结果已保存到 {save_path}")
def visualize_vae_latent_space(model, test_loader, device,
save_path: str, n_samples: int = 1000):
"""
可视化 VAE 的潜空间(2D t-SNE 投影)
参数:
model: 训练好的 VAE
test_loader: 测试数据加载器
device: 计算设备
save_path: 保存路径
n_samples: 采样的潜变量数量
"""
try:
from sklearn.manifold import TSNE
except ImportError:
print(" [跳过] 潜空间可视化需要 scikit-learn: pip install scikit-learn")
return
model.eval()
latent_vectors = []
labels = []
with torch.no_grad():
for imgs, targets in test_loader:
imgs = imgs.to(device)
x_flat = imgs.view(imgs.size(0), -1)
mu, logvar = model.encode(x_flat)
latent_vectors.append(mu.cpu().numpy())
labels.append(targets.numpy())
if len(latent_vectors) * imgs.size(0) >= n_samples:
break
latent_vectors = np.concatenate(latent_vectors)[:n_samples]
labels = np.concatenate(labels)[:n_samples]
# t-SNE 降维到 2D
print(" 正在运行 t-SNE 降维(可能需要几秒)...")
tsne = TSNE(n_components=2, random_state=42, perplexity=30)
latent_2d = tsne.fit_transform(latent_vectors)
# 绘制
fig, ax = plt.subplots(figsize=(8, 6))
scatter = ax.scatter(latent_2d[:, 0], latent_2d[:, 1],
c=labels, cmap='tab10', alpha=0.6, s=10)
plt.colorbar(scatter, ticks=range(10), label='Digit Class')
ax.set_title('t-SNE Projection of VAE Latent Space', fontsize=14)
ax.set_xlabel('t-SNE Dimension 1')
ax.set_ylabel('t-SNE Dimension 2')
plt.tight_layout()
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.close()
print(f" [可视化] VAE 潜空间 t-SNE 已保存到 {save_path}")
def plot_training_curves(gan_history: dict, vae_history: dict,
save_dir: str):
"""
绘制训练曲线对比图
参数:
gan_history: GAN 训练历史(含 g_loss, d_loss)
vae_history: VAE 训练历史(含 total_loss, recon_loss, kl_loss)
save_dir: 保存目录
"""
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
# ---- GAN 损失曲线 ----
ax = axes[0]
epochs = range(1, len(gan_history["g_loss"]) + 1)
ax.plot(epochs, gan_history["g_loss"], 'b-', label='G Loss', linewidth=1.5)
ax.plot(epochs, gan_history["d_loss"], 'r-', label='D Loss', linewidth=1.5)
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('GAN Training Curves')
ax.legend()
ax.grid(True, alpha=0.3)
# ---- VAE 损失曲线 ----
ax = axes[1]
epochs = range(1, len(vae_history["total_loss"]) + 1)
ax.plot(epochs, vae_history["total_loss"], 'purple', label='Total', linewidth=1.5)
ax.plot(epochs, vae_history["recon_loss"], 'orange', label='Recon', linewidth=1)
ax.plot(epochs, vae_history["kl_loss"], 'green', label='KL', linewidth=1)
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('VAE Training Curves')
ax.legend()
ax.grid(True, alpha=0.3)
# ---- 生成质量直观对比说明 ----
ax = axes[2]
ax.axis('off')
comparison_text = (
"GAN vs VAE Comparison\n\n"
"GAN:\n"
" - Sharper images (adversarial training optimizes visual quality)\n"
" - No explicit latent space structure\n"
" - Unstable training (needs to balance G and D)\n"
" - May suffer from mode collapse\n\n"
"VAE:\n"
" - Blurrier images (pixel-wise MSE/BCE averaging effect)\n"
" - Smooth structured latent space (enables interpolation)\n"
" - Stable training (explicit optimization objective)\n"
" - Better coverage of data distribution"
)
ax.text(0.05, 0.95, comparison_text, transform=ax.transAxes,
fontsize=9, verticalalignment='top', fontfamily='monospace',
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
plt.tight_layout()
save_path = os.path.join(save_dir, 'training_curves.png')
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.close()
print(f" [可视化] 训练曲线已保存到 {save_path}")
# ============================================================
# 第 4 部分:主函数
# ============================================================
def main():
"""主函数:训练 GAN 和 VAE,生成对比可视化"""
print("=" * 60)
print("s13 图像生成 Demo")
print("GAN vs VAE: MNIST 数字生成对比")
print("=" * 60)
# ---------- 设备选择 ----------
device = DEVICE
print(f"\n计算设备: {device}")
# ---------- 加载数据 ----------
print("\n[1/5] 加载 MNIST 数据集...")
batch_size = 128
train_loader = load_mnist(batch_size)
# 测试集(VAE 重建可视化用)
try:
test_set = torchvision.datasets.MNIST(
root='../data', train=False, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
])
)
except Exception as e:
print(f"[警告] MNIST 测试集下载失败 ({e}),使用合成数据")
from torch.utils.data import TensorDataset
synth_X = torch.rand(1000, 1, 28, 28) * 2 - 1
synth_y = torch.randint(0, 10, (1000,))
test_set = TensorDataset(synth_X, synth_y)
test_loader = DataLoader(test_set, batch_size=16, shuffle=True, num_workers=0)
# ---------- 训练 GAN ----------
print("\n[2/5] 训练 GAN...")
if device.type == 'cpu':
n_gan_epochs = 2
n_train_subset = 1000
print("[配置] CPU 模式:使用轻量参数快速演示(GAN 2 epochs, 1000 样本)。GPU 模式下将使用完整训练配置。")
# 缩减训练集(仅当数据集有 .data 属性时,如标准 MNIST)
if hasattr(train_loader.dataset, 'data'):
train_loader.dataset.data = train_loader.dataset.data[:n_train_subset]
if hasattr(train_loader.dataset, 'targets'):
train_loader.dataset.targets = train_loader.dataset.targets[:n_train_subset]
else:
n_gan_epochs = 30
print(f"[配置] 训练 {n_gan_epochs} 个 epoch")
gan_history, generator, discriminator = train_gan(
train_loader, device, n_epochs=n_gan_epochs, latent_dim=128
)
# ---------- 训练 VAE ----------
print("\n[3/5] 训练 VAE...")
# 创建新的 train_loader(因为之前的被消耗了)
train_loader_vae = load_mnist(batch_size)
# VAE 使用 [0,1] 范围的图像(Sigmoid 输出),需要重新处理
# 简化处理:VAE 内部接受 [-1,1] 输入但输出 Sigmoid [0,1],损失用 BCE
if device.type == 'cpu':
n_vae_epochs = 2
if hasattr(train_loader_vae.dataset, 'data'):
train_loader_vae.dataset.data = train_loader_vae.dataset.data[:n_train_subset]
if hasattr(train_loader_vae.dataset, 'targets'):
train_loader_vae.dataset.targets = train_loader_vae.dataset.targets[:n_train_subset]
print(f"[配置] VAE 训练 {n_vae_epochs} 个 epoch(CPU 模式时长较短)")
else:
n_vae_epochs = 30
print(f"[配置] VAE 训练 {n_vae_epochs} 个 epoch")
vae_history, vae_model = train_vae(
train_loader_vae, device, n_epochs=n_vae_epochs, latent_dim=20
)
# ---------- 可视化 ----------
print("\n[4/5] 生成可视化...")
output_dir = _IMAGES_DIR
# GAN 生成的样本
visualize_generated_samples(generator, device, latent_dim=128,
save_path=os.path.join(output_dir, "gan_samples.png"))
# VAE 重建结果
visualize_vae_reconstructions(vae_model, test_loader, device,
save_path=os.path.join(output_dir, "vae_reconstructions.png"))
# VAE 潜空间可视化
visualize_vae_latent_space(vae_model, test_loader, device,
save_path=os.path.join(output_dir, "vae_latent_space.png"))
# 训练曲线对比
plot_training_curves(gan_history, vae_history, output_dir)
# ---------- 总结 ----------
print("\n[5/5] 总结")
print("=" * 60)
print("GAN:")
print(f" 最终 G Loss: {gan_history['g_loss'][-1]:.4f}")
print(f" 最终 D Loss: {gan_history['d_loss'][-1]:.4f}")
print(f" (理想状态: D Loss ≈ 0.693, G Loss 适中 — D 无法分辨真伪)")
print(f"\nVAE:")
print(f" 最终 Total Loss: {vae_history['total_loss'][-1]:.4f}")
print(f" 重构损失: {vae_history['recon_loss'][-1]:.4f}")
print(f" KL 散度: {vae_history['kl_loss'][-1]:.4f}")
print(f" (KL 散度越小,潜空间越接近标准正态分布)")
print("=" * 60)
print(f"\nDemo 完成!查看 {_IMAGES_DIR} 目录下的可视化结果。")
print("=" * 60)
if __name__ == "__main__":
main()