Skip to content

s11 经典CNN架构演进 — demo.py 代码详解

Download demo.py

运行方式

bash
cd s11_cnn_architectures/code
python demo.py

代码逐段详解

第1步:导入库 —— 每个库是做什么的

python
import torch                          # PyTorch 核心框架,提供张量计算和自动求导
import torch.nn as nn                 # 神经网络模块(Linear, Conv2d, BatchNorm等)
import torch.nn.functional as F       # 函数式接口(relu, cross_entropy等,无状态)
import torch.optim as optim           # 优化器(SGD, Adam)和学习率调度器
import torchvision                    # 视觉库,提供预训练模型和数据集(CIFAR-10)
import torchvision.transforms as transforms  # 数据预处理/增强管道

# GPU 自动检测:CUDA(NVIDIA) > MPS(Apple Silicon) > CPU
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
角色
torch张量计算引擎 + 自动求导
torch.nn有状态(含可学习参数)的网络层
torch.nn.functional无状态的函数(激活函数、池化等)
torch.optim优化器(SGD/Adam)+ 学习率衰减策略
torchvision视觉数据集 + 预训练模型
torchvision.transforms图像增广管道(裁剪、翻转、归一化)

第2步:数据准备 —— CIFAR-10 数据集

python
def get_cifar10_loaders(batch_size: int = 128):
    # 训练数据增强:随机裁剪+填充(抗过拟合)、水平翻转、归一化
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),     # 32x32 图像先填充到 40x40,再随机裁回 32x32
        transforms.RandomHorizontalFlip(),         # 50% 概率水平翻转
        transforms.ToTensor(),                     # PIL Image → Tensor,值域 [0,1]
        transforms.Normalize((0.4914, 0.4822, 0.4465),  # CIFAR-10 各通道均值
                             (0.2023, 0.1994, 0.2010)),  # CIFAR-10 各通道标准差
    ])
    # 测试数据:仅归一化,不做增强
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(...),
    ])

为什么使用数据增强? CIFAR-10 只有 50000 张训练图。随机裁剪和翻转在不增加存储的前提下增大了有效数据量,迫使模型学习平移/翻转不变性,是 AlexNet 时代延续至今的标准做法。

为什么用这些均值和标准差? 它们是整个 CIFAR-10 训练集各通道的统计量。归一化后每个通道的输入分布接近 N(0,1),梯度更稳定,收敛更快。

回退机制:如果 CIFAR-10 下载失败(如网络问题),代码自动创建合成数据确保 demo 可运行。CPU 模式下将训练集缩减到 1000 样本以在合理时间内完成演示。

第3步:ResNet BasicBlock —— 残差学习的核心

这是 ResNet 最关键的组件。让我们逐行剖析:

python
class BasicBlock(nn.Module):
    expansion = 1  # BasicBlock 不改变通道数倍数(对比 Bottleneck 的 expansion=4)

    def __init__(self, in_planes: int, planes: int, stride: int = 1):
        super(BasicBlock, self).__init__()

        # 第一个 3×3 卷积:输入 in_planes 通道 → 输出 planes 通道
        # padding=1 保持空间尺寸不变(输入 H×W → 输出 H×W,stride=1 时)
        # bias=False 因为后面紧跟 BatchNorm(BN 自带 bias 项,卷积的 bias 浪费)
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)   # 批归一化:稳定训练、加速收敛

        # 第二个 3×3 卷积:planes → planes(同维度卷积)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        # 跳跃连接(Skip Connection / Identity Shortcut)
        # 如果输入输出维度不匹配,用 1×1 卷积对齐
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size=1,
                          stride=stride, bias=False),  # 1×1 卷积仅用于维度对齐
                nn.BatchNorm2d(planes),
            )

为什么 bias=False? 卷积层的 bias 为每个输出通道增加一个可学习偏置。但 BatchNorm 的计算中包含 xμσγ+β,其中 β 本身就是偏置项。卷积后紧接 BN 时,卷积的 bias 会被 BN 的 β 完全替代,设置 bias=False 减少无意义的参数。

为什么 shortcut 用 1×1 卷积?stride != 1(空间尺寸减半)或 in_planes != planes(通道数变化)时,identity x 的维度和主路径输出 F(x) 不一致,无法直接相加。1×1 卷积在保持空间信息的前提下完成通道和尺寸的对齐,计算开销极小。

前向传播 —— 残差公式的实现

python
def forward(self, x: torch.Tensor) -> torch.Tensor:
    identity = self.shortcut(x)          # 跳跃连接分支

    out = self.conv1(x)                  # Conv 3×3
    out = self.bn1(out)                  # BatchNorm
    out = F.relu(out)                    # ReLU 激活

    out = self.conv2(out)                # Conv 3×3
    out = self.bn2(out)                  # BatchNorm —— 注意:这里还没加 ReLU!

    out += identity                      # 🔑 残差加法 H(x) = F(x) + x
    out = F.relu(out)                    # 加法之后才做 ReLU

    return out

关键细节:残差加法在第二个 BN 之后、第二个 ReLU 之前。 为什么?如果把 ReLU 放在加法之前(即 F(x) 做完 ReLU 再加 x),F(x) 的输出被钳制在 0,限制了残差映射的表达能力。加法后再 ReLU 让 F(x)+x 整体通过非线性,保留了更丰富的特征。

数学对应

y=ReLU(BN(Conv(BN(Conv(x))))+x)

即:

y=σ(F(x,{Wi})+x)

其中 σ 是 ReLU,F 是两层卷积+BN 组成的残差函数。

梯度高速公路效应

Lx=Lyyx=Ly(1+Fx)

那个 "+1" 是残差连接对梯度传播的核心贡献:即使 Fx 非常小(梯度衰减),梯度也能通过恒等路径无损传播。这就是为什么 ResNet 可以训练 152 层而不退化。

第4步:Bottleneck —— 深层 ResNet 的效率秘诀

python
class Bottleneck(nn.Module):
    expansion = 4  # 输出通道 = planes * 4

    def __init__(self, in_planes: int, planes: int, stride: int = 1):
        # 第1层: 1×1 降维
        #   in_planes → planes(如 256 → 64),大幅削减计算量
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)

        # 第2层: 3×3 卷积(主要特征提取)
        #   planes → planes,在低维度下计算
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        # 第3层: 1×1 升维
        #   planes → planes * 4(如 64 → 256),恢复高维度
        self.conv3 = nn.Conv2d(planes, planes * self.expansion,
                               kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)

为什么 Bottleneck 先降维再升维? 直接做 2562563×3 卷积,FLOPs 为 9×256×256=589,824。先 1×1 降到 64,做 3×3 卷积 9×64×64=36,864,再 1×1 升回 25664×256=16,384。总计 1×1256×64×2=32,768,加上 3×336,864,约 70,000 —— 节省了 88% 的计算量,而表达能力几乎不损失(1×1 卷积学习通道间的线性组合)。

Bottleneck 的形状像一个沙漏:两头粗(高通道数),中间细(低通道数)。这被用于 ResNet-50/101/152。

第5步:完整 ResNet —— 组装所有组件

python
class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        self.in_planes = 64  # 初始通道数

        # 初始卷积:3通道 → 64通道,stride=1 适配 CIFAR-10(32×32 无需大下采样)
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)

        # 4个残差层,通道逐层翻倍,空间逐层减半
        self.layer1 = self._make_layer(block, 64,  num_blocks[0], stride=1)  # 64→64,  32×32保持
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)  # 64→128, 32→16
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)  # 128→256, 16→8
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)  # 256→512, 8→4

        # 全局平均池化:将 4×4 特征图压缩为 1×1,替代传统的大全连接层
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

ResNet 的通道-空间 trade-off:每经过一个 layer,通道数翻倍、空间尺寸减半,总信息量基本守恒(64×32×32256×8×8)。这与 CNN 的通用设计哲学一致:浅层关注"哪里"(大空间尺寸、少通道),深层关注"是什么"(小空间尺寸、多通道)。

_make_layer 方法

python
def _make_layer(self, block, planes, num_blocks, stride):
    layers = []
    # 第一个 block 可能需要下采样
    layers.append(block(self.in_planes, planes, stride))
    self.in_planes = planes * block.expansion  # 更新全局通道数
    # 后续 block 保持尺寸
    for _ in range(1, num_blocks):
        layers.append(block(self.in_planes, planes, stride=1))
    return nn.Sequential(*layers)

Kaiming 初始化

python
def _initialize_weights(self):
    for m in self.modules():
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out',
                                    nonlinearity='relu')

为什么用 Kaiming 初始化而非 Xavier? Xavier(Glorot)初始化假设激活函数是线性的(或 tanh),在 ReLU 下会导致输出方差逐层衰减。Kaiming(He)初始化专门针对 ReLU 设计——它考虑了 ReLU 将一半输入置零的特性,使得正向传播的方差保持稳定:

Var(W)=2fan_out

mode='fan_out' 表示用输出神经元数(而非输入)来缩放,适合反向传播梯度的稳定。

第6步:Plain CNN —— 对照模型

python
class PlainBlock(nn.Module):
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        # ⚠️ 这里没有 out += x —— 缺少了跳跃连接!
        out = F.relu(out)
        return out

PlainCNN 的结构与 ResNet-18 完全一致,唯一的区别是:去掉了所有的跳跃连接(shortcut)。参数量相同,但优化难度截然不同。这个对照实验展示了:

  • Plain CNN:深度增加 梯度逐渐消失 训练误差上升(退化问题)
  • ResNet-18:梯度通过 shortcut 无损传播 深度越大效果越好

第7步:训练流程 —— 梯度范数监控

python
def train_one_epoch(model, train_loader, optimizer, criterion, device):
    model.train()
    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()          # 清零梯度(PyTorch 默认累积梯度)
        outputs = model(inputs)        # 前向传播
        loss = criterion(outputs, targets)  # 交叉熵损失
        loss.backward()                # 反向传播,计算梯度
        optimizer.step()               # 更新参数

为什么用 SGD + Momentum 而不是 Adam? 对于 ResNet 训练,SGD with Momentum(momentum=0.9)配合 CosineAnnealing 学习率调度是经过验证的最佳实践。Adam 虽然收敛快,但在图像分类任务的最优精度上通常不如精心调参的 SGD。

梯度范数计算

python
def compute_gradient_norms(model):
    grad_norms = {}
    for name, param in model.named_parameters():
        if param.grad is not None:
            grad_norms[name] = param.grad.norm().item()  # L2 范数
    return grad_norms

梯度范数 WL2 是衡量梯度流动的指标。ResNet 的梯度范数显著大于 Plain CNN,尤其是在深层——这是"梯度高速公路"的数值证据。

第8步:可视化 —— 对比分析

代码生成三张图:

  1. 训练 Loss 曲线:ResNet vs Plain CNN 的 loss 下降速度
  2. 测试准确率曲线:残差连接带来的精度提升
  3. 梯度范数对比:ResNet 梯度范数更大,证明 shortcut 维持了有效梯度

关键概念速查表

概念公式代码对应
残差块y=F(x)+xout = bn2(conv2(...)); out += identity
梯度高速公路Lx=Ly(1+Fx)loss.backward() 自动传播
BatchNormxμσγ+βnn.BatchNorm2d(planes)
Kaiming 初始化Var(W)=2fan_outnn.init.kaiming_normal_
Bottleneck1×1降维 → 3×3卷积 → 1×1升维Bottleneck
退化问题层数增加但训练误差不降反升PlainCNN 的糟糕表现
网络退化不是过拟合(测试误差也上升),是优化困难ResNet 解决了这个问题

完整代码

py
# -*- coding: utf-8 -*-
"""
s11 经典架构演进 demo:ResNet-18 从零实现与训练
================================================
使用 PyTorch 从零构建 ResNet-18,在 CIFAR-10 上训练,
与同深度的普通 CNN(无跳跃连接)进行对比。

运行方式:python demo.py(从 s11_cnn_architectures/code/ 目录运行)
依赖:torch, torchvision, matplotlib
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

# GPU 自动检测
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print(f"使用设备: {DEVICE}")
if DEVICE.type == 'cpu':
    print("(未检测到 GPU,使用 CPU 运行。如有 GPU,请安装 CUDA 版 PyTorch 以获得加速)")

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
matplotlib.rcParams['axes.unicode_minus'] = False
import numpy as np
import os
import time

# 图片保存目录:固定为本章节的 images/ 目录(相对于本脚本的 ../images/)
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
_IMAGES_DIR = os.path.join(_SCRIPT_DIR, '..', 'images')
os.makedirs(_IMAGES_DIR, exist_ok=True)
from collections import defaultdict

# ============================================================
# 第 1 部分:ResNet 组件实现
# ============================================================

class BasicBlock(nn.Module):
    """
    ResNet 基本残差块(BasicBlock)

    结构: Conv3x3 → BN → ReLU → Conv3x3 → BN → + skip → ReLU

    用于 ResNet-18 和 ResNet-34。
    每个卷积后都有 BatchNorm,通过 skip connection 缓解梯度消失。
    """

    expansion = 1  # BasicBlock 不扩展通道数

    def __init__(self, in_planes: int, planes: int, stride: int = 1):
        """
        初始化 BasicBlock

        参数:
            in_planes: 输入通道数
            planes: 输出通道数
            stride: 步长(用于下采样)
        """
        super(BasicBlock, self).__init__()

        # ---------- 第一个 3×3 卷积 ----------
        # padding=1 保持空间尺寸,stride 控制下采样
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)  # 批归一化,加速收敛

        # ---------- 第二个 3×3 卷积 ----------
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        # ---------- 跳跃连接(Skip Connection)----------
        # 当输入输出维度不匹配时(stride != 1 或通道变化),
        # 需要 1×1 卷积调整 shortcut 的维度和尺寸
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size=1,
                          stride=stride, bias=False),
                nn.BatchNorm2d(planes),
            )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        前向传播: 主路径 F(x) + 跳跃连接 x

        参数:
            x: 输入张量,形状 (N, in_planes, H, W)
        返回:
            out: 输出张量,形状 (N, planes, H/stride, W/stride)
        """
        # ---------- 主路径:Conv → BN → ReLU → Conv → BN ----------
        identity = self.shortcut(x)  # 跳跃连接(恒等或投影)

        out = self.conv1(x)          # 3×3 卷积
        out = self.bn1(out)          # 批归一化
        out = F.relu(out)            # ReLU 激活

        out = self.conv2(out)        # 3×3 卷积
        out = self.bn2(out)          # 批归一化

        # ---------- 残差连接:F(x) + x ----------
        out += identity               # H(x) = F(x) + x
        out = F.relu(out)             # 最后的 ReLU

        return out


class Bottleneck(nn.Module):
    """
    ResNet 瓶颈残差块(Bottleneck)

    结构: Conv1x1 → BN → ReLU → Conv3x3 → BN → ReLU → Conv1x1 → BN → + skip → ReLU

    用于 ResNet-50/101/152。通过 1×1 卷积先降维再升维,降低计算量。
    """

    expansion = 4  # Bottleneck 将通道扩展 4 倍

    def __init__(self, in_planes: int, planes: int, stride: int = 1):
        """
        初始化 Bottleneck Block

        参数:
            in_planes: 输入通道数
            planes: 瓶颈通道数(中间层的通道数,不算 expansion)
            stride: 步长
        """
        super(Bottleneck, self).__init__()

        # 1×1 降维: in_planes → planes
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1,
                               stride=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)

        # 3×3 卷积: planes → planes(主要计算发生在这里)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        # 1×1 升维: planes → planes*expansion
        self.conv3 = nn.Conv2d(planes, planes * self.expansion,
                               kernel_size=1, stride=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)

        # 跳跃连接
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes * self.expansion:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes * self.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * self.expansion),
            )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """前向传播: 瓶颈 + 跳跃连接"""
        identity = self.shortcut(x)

        out = F.relu(self.bn1(self.conv1(x)))  # 1×1 降维
        out = F.relu(self.bn2(self.conv2(out)))  # 3×3 卷积
        out = self.bn3(self.conv3(out))  # 1×1 升维(不加 ReLU)

        out += identity
        out = F.relu(out)

        return out


# ============================================================
# 第 2 部分:ResNet 完整模型
# ============================================================

class ResNet(nn.Module):
    """
    ResNet 完整模型

    支持 ResNet-18, 34, 50, 101, 152 各变体。

    架构:
        - 初始卷积: 3×3 Conv, 64通道, stride=1 (适配 CIFAR)
        - 4 个 layer,每个 layer 包含若干个 block
        - 全局平均池化 → 全连接
    """

    def __init__(self, block: nn.Module, num_blocks: list,
                 num_classes: int = 10):
        """
        初始化 ResNet

        参数:
            block: 残差块类型(BasicBlock 或 Bottleneck)
            num_blocks: 每个 layer 的 block 数量,如 [2,2,2,2] 对应 ResNet-18
            num_classes: 分类类别数(CIFAR-10 为 10)
        """
        super(ResNet, self).__init__()
        self.in_planes = 64  # 初始通道数

        # ---------- 初始卷积层 ----------
        # CIFAR-10 图像较小 (32×32),使用 stride=1, padding=1,不做大幅下采样
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)

        # ---------- 4 个残差层 ----------
        # 每个 layer 的 planes 逐层翻倍,stride 在第二个 layer 开始使用 2 实现下采样
        self.layer1 = self._make_layer(block, 64,  num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)

        # ---------- 分类头 ----------
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # 全局平均池化 → 1×1
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        # ---------- 权重初始化 ----------
        self._initialize_weights()

    def _make_layer(self, block: nn.Module, planes: int,
                    num_blocks: int, stride: int) -> nn.Sequential:
        """
        构建一个残差层(包含多个 block)

        参数:
            block: 残差块类型
            planes: 该层的输出通道数(BasicBlock)或瓶颈通道数(Bottleneck)
            num_blocks: 该层包含的 block 数量
            stride: 第一个 block 的步长(用于下采样)
        返回:
            Sequential 包装的残差层
        """
        layers = []
        # 第一个 block 可能进行下采样
        layers.append(block(self.in_planes, planes, stride))
        self.in_planes = planes * block.expansion

        # 后续 block 保持尺寸
        for _ in range(1, num_blocks):
            layers.append(block(self.in_planes, planes, stride=1))

        return nn.Sequential(*layers)

    def _initialize_weights(self):
        """Kaiming 初始化:适用于 ReLU 的权重初始化策略"""
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out',
                                        nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        ResNet 前向传播

        参数:
            x: 输入图像,形状 (N, 3, 32, 32)
        返回:
            out: 分类 logits,形状 (N, num_classes)
        """
        # 初始卷积 + BN + ReLU
        x = F.relu(self.bn1(self.conv1(x)))

        # 4 个残差层
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        # 全局平均池化 + 分类
        x = self.avgpool(x)  # (N, 512*expansion, 1, 1)
        x = x.view(x.size(0), -1)  # 展平: (N, 512*expansion)
        x = self.fc(x)
        return x


def ResNet18(num_classes: int = 10) -> ResNet:
    """构建 ResNet-18(BasicBlock × 2×4=8 个,2+2+2+2)"""
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)


def ResNet34(num_classes: int = 10) -> ResNet:
    """构建 ResNet-34(BasicBlock × 3+4+6+3=16 个)"""
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes)


# ============================================================
# 第 3 部分:无跳跃连接的普通 CNN(对照模型)
# ============================================================

class PlainBlock(nn.Module):
    """
    普通卷积块(无跳跃连接)

    结构与 BasicBlock 相同但去掉了 skip connection。
    用于展示深度网络不加残差连接时的退化现象。
    """

    expansion = 1

    def __init__(self, in_planes: int, planes: int, stride: int = 1):
        super(PlainBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        # 无 shortcut!这就是与 BasicBlock 的唯一区别

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        前向传播(无残差连接)

        参数:
            x: 输入张量
        返回:
            out: 输出张量
        """
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        # 注意:这里没有 out += x(去掉了跳跃连接)
        out = F.relu(out)
        return out


class PlainCNN(nn.Module):
    """
    与 ResNet-18 同深度的普通 CNN(无跳跃连接)

    架构完全相同,只是将 BasicBlock 替换为 PlainBlock。
    """

    def __init__(self, num_blocks: list, num_classes: int = 10):
        super(PlainCNN, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)

        self.layer1 = self._make_layer(64,  num_blocks[0], stride=1)
        self.layer2 = self._make_layer(128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(512, num_blocks[3], stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)

        self._initialize_weights()

    def _make_layer(self, planes: int, num_blocks: int,
                    stride: int) -> nn.Sequential:
        """构建普通卷积层"""
        layers = []
        # 第一个 block 需要调整维度(因为去掉 shortcut 后需要手动匹配)
        if stride != 1 or self.in_planes != planes:
            layers.append(nn.Conv2d(self.in_planes, planes, kernel_size=1,
                                     stride=stride, bias=False))
            layers.append(nn.BatchNorm2d(planes))
            layers.append(nn.ReLU())
        layers.append(PlainBlock(planes, planes, stride=1))
        self.in_planes = planes
        for _ in range(1, num_blocks):
            layers.append(PlainBlock(self.in_planes, planes, stride=1))
        return nn.Sequential(*layers)

    def _initialize_weights(self):
        """权重初始化"""
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out',
                                        nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """前向传播"""
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x


# ============================================================
# 第 4 部分:训练与评估工具
# ============================================================

def count_parameters(model: nn.Module) -> int:
    """计算模型参数量"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


class TrainingLogger:
    """训练日志记录器:记录 loss、准确率、梯度范数"""

    def __init__(self):
        self.train_losses = []
        self.train_accs = []
        self.test_accs = []
        self.grad_norms = []  # 每层的梯度范数

    def update(self, train_loss: float, train_acc: float,
               test_acc: float, grad_norms: dict = None):
        """记录一轮训练的数据"""
        self.train_losses.append(train_loss)
        self.train_accs.append(train_acc)
        self.test_accs.append(test_acc)
        if grad_norms is not None:
            self.grad_norms.append(grad_norms)


def compute_gradient_norms(model: nn.Module) -> dict:
    """
    计算模型各层的梯度范数(用于诊断梯度消失/爆炸)

    返回:
        grad_norms: {层名: 梯度的 L2 范数}
    """
    grad_norms = {}
    for name, param in model.named_parameters():
        if param.grad is not None:
            grad_norms[name] = param.grad.norm().item()
    return grad_norms


def train_one_epoch(model: nn.Module, train_loader, optimizer,
                    criterion, device: torch.device) -> tuple:
    """
    训练一个 epoch

    参数:
        model: 模型
        train_loader: 训练数据加载器
        optimizer: 优化器
        criterion: 损失函数
        device: 计算设备

    返回:
        (平均 loss, 准确率)
    """
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0

    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    avg_loss = total_loss / len(train_loader)
    accuracy = 100.0 * correct / total
    return avg_loss, accuracy


def evaluate(model: nn.Module, test_loader, device: torch.device) -> float:
    """
    在测试集上评估模型

    参数:
        model: 模型
        test_loader: 测试数据加载器
        device: 计算设备

    返回:
        测试准确率(百分比)
    """
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    return 100.0 * correct / total


def get_cifar10_loaders(batch_size: int = 128):
    """
    加载 CIFAR-10 数据集

    返回:
        train_loader, test_loader, classes
    """
    # 训练数据增强
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),     # 随机裁剪 + 填充
        transforms.RandomHorizontalFlip(),         # 随机水平翻转
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),  # CIFAR-10 均值
                             (0.2023, 0.1994, 0.2010)),  # CIFAR-10 标准差
    ])

    # 测试数据:仅归一化
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    try:
        train_set = torchvision.datasets.CIFAR10(
            root='../data', train=True, download=True,
            transform=transform_train
        )
        test_set = torchvision.datasets.CIFAR10(
            root='../data', train=False, download=True,
            transform=transform_test
        )
    except Exception as e:
        print(f"[警告] CIFAR-10 下载失败 ({e}),使用合成数据")
        print("[回退] 创建合成 32x32 图像数据集用于演示 CNN 结构")
        # 回退:创建合成数据(32x32 RGB,10类)
        from torch.utils.data import TensorDataset
        np.random.seed(42)
        n_train, n_test = 5000, 1000
        synth_X_train = torch.randn(n_train, 3, 32, 32)
        synth_y_train = torch.randint(0, 10, (n_train,))
        synth_X_test = torch.randn(n_test, 3, 32, 32)
        synth_y_test = torch.randint(0, 10, (n_test,))
        train_set = TensorDataset(synth_X_train, synth_y_train)
        test_set = TensorDataset(synth_X_test, synth_y_test)

    train_loader = torch.utils.data.DataLoader(
        train_set, batch_size=batch_size, shuffle=True, num_workers=0
    )
    test_loader = torch.utils.data.DataLoader(
        test_set, batch_size=batch_size, shuffle=False, num_workers=0
    )

    classes = ('plane', 'car', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck')

    return train_loader, test_loader, classes


# ============================================================
# 第 5 部分:可视化工具
# ============================================================

def plot_training_comparison(logger_resnet: TrainingLogger,
                              logger_plain: TrainingLogger,
                              save_dir: str):
    """
    绘制 ResNet vs Plain CNN 的训练对比图

    生成三张子图:(1) 训练 Loss, (2) 测试准确率, (3) 梯度范数均值
    """
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))

    epochs = range(1, len(logger_resnet.train_losses) + 1)

    # 子图 1: 训练 Loss
    ax = axes[0]
    ax.plot(epochs, logger_resnet.train_losses, 'b-o', label='ResNet-18',
            linewidth=2, markersize=4)
    ax.plot(epochs, logger_plain.train_losses, 'r-s', label='Plain CNN',
            linewidth=2, markersize=4)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Training Loss')
    ax.set_title('Training Loss Comparison')
    ax.legend()
    ax.grid(True, alpha=0.3)

    # 子图 2: 测试准确率
    ax = axes[1]
    ax.plot(epochs, logger_resnet.test_accs, 'b-o', label='ResNet-18',
            linewidth=2, markersize=4)
    ax.plot(epochs, logger_plain.test_accs, 'r-s', label='Plain CNN',
            linewidth=2, markersize=4)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Test Accuracy (%)')
    ax.set_title('Test Accuracy Comparison')
    ax.legend()
    ax.grid(True, alpha=0.3)

    # 子图 3: 平均梯度范数(取所有层梯度的均值)
    ax = axes[2]
    if logger_resnet.grad_norms and logger_plain.grad_norms:
        resnet_avg_grads = [np.mean(list(g.values()))
                            for g in logger_resnet.grad_norms]
        plain_avg_grads = [np.mean(list(g.values()))
                           for g in logger_plain.grad_norms]
        ax.plot(epochs, resnet_avg_grads, 'b-o', label='ResNet-18',
                linewidth=2, markersize=4)
        ax.plot(epochs, plain_avg_grads, 'r-s', label='Plain CNN',
                linewidth=2, markersize=4)
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Average Gradient Norm')
        ax.set_title('Gradient Norm Comparison (Residual Connections Maintain Gradient Flow)')
        ax.legend()
        ax.grid(True, alpha=0.3)

    plt.tight_layout()
    save_path = os.path.join(save_dir, 'resnet_vs_plain.png')
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"  [可视化] 训练对比图已保存到 {save_path}")


def plot_gradient_distribution(grad_norms: dict, epoch: int,
                                save_dir: str, prefix: str = ""):
    """
    绘制单轮中各层梯度范数的分布图(诊断哪层梯度消失)

    参数:
        grad_norms: {层名: 范数}
        epoch: 当前 epoch
        save_dir: 保存目录
        prefix: 文件名前缀(区分 ResNet 和 Plain)
    """
    names = list(grad_norms.keys())
    values = list(grad_norms.values())

    if len(values) == 0:
        return

    fig, ax = plt.subplots(figsize=(14, 5))
    bars = ax.bar(range(len(names)), values, color='steelblue', alpha=0.8)
    ax.set_xticks(range(len(names)))
    # 简化层名(只显示最后一部分)
    short_names = [n.split('.')[-1] for n in names]
    ax.set_xticklabels(short_names, rotation=45, ha='right', fontsize=7)
    ax.set_ylabel('Gradient L2 Norm')
    ax.set_title(f'{prefix} Epoch {epoch} - Per-Layer Gradient Norm Distribution')
    ax.axhline(y=0, color='red', linewidth=0.5, linestyle='-')
    ax.grid(True, alpha=0.3, axis='y')

    plt.tight_layout()
    fname = os.path.join(save_dir, f'grads_{prefix}_epoch{epoch:02d}.png')
    plt.savefig(fname, dpi=100, bbox_inches='tight')
    plt.close()


# ============================================================
# 第 6 部分:主训练流程
# ============================================================

def main():
    """主函数:训练 ResNet-18 和 Plain CNN,对比分析"""
    print("=" * 60)
    print("s11 经典架构演进 Demo")
    print("ResNet-18 vs Plain CNN 对比实验")
    print("=" * 60)

    # ---------- 设备选择 ----------
    device = DEVICE
    print(f"\n计算设备: {device}")
    if device.type == 'cuda':
        print(f"  GPU: {torch.cuda.get_device_name(0)}")

    # ---------- 加载数据 ----------
    print("\n[1/5] 加载 CIFAR-10 数据集...")
    batch_size = 128
    train_loader, test_loader, classes = get_cifar10_loaders(batch_size)
    print(f"  训练集: 50000 张, 测试集: 10000 张")
    print(f"  类别: {classes}")

    # ---------- 构建模型 ----------
    print("\n[2/5] 构建模型...")
    # CPU 模式下使用大幅缩减的训练配置以在 60s 内完成演示
    if device.type == 'cpu':
        n_epochs = 1
        n_train_subset = 1000
        print("[配置] CPU 模式:使用轻量参数快速演示(1 epoch, 1000 训练样本)。GPU 模式下将使用完整训练配置。")
        # 缩减训练集:只取前 n_train_subset 个样本
        from torch.utils.data import Subset
        if hasattr(train_loader.dataset, 'data') and hasattr(train_loader.dataset, 'targets'):  # CIFAR-10 标准数据集
            train_loader.dataset.data = train_loader.dataset.data[:n_train_subset]
            train_loader.dataset.targets = train_loader.dataset.targets[:n_train_subset]
    else:
        n_epochs = 10
        print(f"[配置] 训练 {n_epochs} 个 epoch(GPU 模式可充分训练)")

    resnet = ResNet18(num_classes=10).to(device)
    plain_cnn = PlainCNN([2, 2, 2, 2], num_classes=10).to(device)

    print(f"  ResNet-18 参数量: {count_parameters(resnet):,}")
    print(f"  Plain CNN 参数量: {count_parameters(plain_cnn):,}")

    # ---------- 优化器和损失函数 ----------
    criterion = nn.CrossEntropyLoss()
    optimizer_resnet = optim.SGD(resnet.parameters(), lr=0.1,
                                 momentum=0.9, weight_decay=5e-4)
    optimizer_plain = optim.SGD(plain_cnn.parameters(), lr=0.1,
                                momentum=0.9, weight_decay=5e-4)

    # 学习率调度:每 30 epoch 降低到原来的 0.1(这里简化)
    scheduler_resnet = optim.lr_scheduler.CosineAnnealingLR(
        optimizer_resnet, T_max=n_epochs
    )
    scheduler_plain = optim.lr_scheduler.CosineAnnealingLR(
        optimizer_plain, T_max=n_epochs
    )

    # ---------- 训练 ----------
    print("\n[3/5] 训练 ResNet-18...")
    logger_resnet = TrainingLogger()

    for epoch in range(1, n_epochs + 1):
        train_loss, train_acc = train_one_epoch(
            resnet, train_loader, optimizer_resnet, criterion, device
        )
        test_acc = evaluate(resnet, test_loader, device)
        scheduler_resnet.step()

        # 记录梯度范数
        grad_norms = compute_gradient_norms(resnet)

        logger_resnet.update(train_loss, train_acc, test_acc, grad_norms)

        print(f"  Epoch {epoch:2d}/{n_epochs} | "
              f"Loss: {train_loss:.4f} | "
              f"Train Acc: {train_acc:.2f}% | "
              f"Test Acc: {test_acc:.2f}%")

    print("\n[4/5] 训练 Plain CNN(无跳跃连接)...")
    logger_plain = TrainingLogger()

    for epoch in range(1, n_epochs + 1):
        train_loss, train_acc = train_one_epoch(
            plain_cnn, train_loader, optimizer_plain, criterion, device
        )
        test_acc = evaluate(plain_cnn, test_loader, device)
        scheduler_plain.step()

        grad_norms = compute_gradient_norms(plain_cnn)
        logger_plain.update(train_loss, train_acc, test_acc, grad_norms)

        print(f"  Epoch {epoch:2d}/{n_epochs} | "
              f"Loss: {train_loss:.4f} | "
              f"Train Acc: {train_acc:.2f}% | "
              f"Test Acc: {test_acc:.2f}%")

    # ---------- 可视化 ----------
    print("\n[5/5] 生成可视化对比...")
    output_dir = _IMAGES_DIR

    # 训练过程对比图
    plot_training_comparison(logger_resnet, logger_plain, output_dir)

    # 最后一轮的梯度分布
    if logger_resnet.grad_norms:
        plot_gradient_distribution(
            logger_resnet.grad_norms[-1],
            epoch=n_epochs, save_dir=output_dir,
            prefix="resnet"
        )
    if logger_plain.grad_norms:
        plot_gradient_distribution(
            logger_plain.grad_norms[-1],
            epoch=n_epochs, save_dir=output_dir,
            prefix="plain"
        )

    # ---------- 总结 ----------
    print("\n" + "=" * 60)
    print("训练总结:")
    print(f"  ResNet-18    最终测试准确率: {logger_resnet.test_accs[-1]:.2f}%")
    print(f"  Plain CNN    最终测试准确率: {logger_plain.test_accs[-1]:.2f}%")
    print(f"  ResNet 优势:   {logger_resnet.test_accs[-1] - logger_plain.test_accs[-1]:.2f}%")

    if logger_resnet.grad_norms:
        resnet_avg_grad = np.mean(list(logger_resnet.grad_norms[-1].values()))
        plain_avg_grad = np.mean(list(logger_plain.grad_norms[-1].values()))
        print(f"  ResNet 平均梯度范数: {resnet_avg_grad:.6f}")
        print(f"  Plain 平均梯度范数: {plain_avg_grad:.6f}")
        print(f"  (ResNet 的梯度范数更大,说明残差连接有效)")

    print("=" * 60)
    print(f"Demo 完成!查看 {_IMAGES_DIR} 目录下的可视化结果。")
    print("=" * 60)


if __name__ == "__main__":
    main()