Skip to content

s11b Vision Transformer — vit_demo.py 代码详解

Download vit_demo.py

运行方式

bash
cd s11b_vit/code
python vit_demo.py

代码逐段详解

第1步:导入库 —— 每个库是做什么的

python
import torch
import torch.nn as nn           # Transformer 需要 Linear, LayerNorm, Dropout
import torch.nn.functional as F  # softmax, cross_entropy 等
import torch.optim as optim      # AdamW 优化器
import torchvision               # 数据集(Imagenette/CIFAR-100)、预训练 ViT
import torchvision.transforms as transforms
在此 demo 中的角色
torch.nn构建 PatchEmbedding、MultiHeadAttention、TransformerBlock
torch.optimAdamW(解耦权重衰减的 Adam)+ 余弦退火调度
torchvisionImagenette 下载 + 预训练 ViT-B/16 加载

设备自适应:CPU 模式下使用轻量 ViT(embed_dim=192, depth=4),3 个 epoch 快速演示。GPU 模式下使用较大模型(embed_dim=384, depth=8),10 个 epoch 得到有意义的结果。

第2步:数据加载 —— Imagenette 数据集

python
def get_imagenette_loaders(batch_size: int, img_size: int = 224):

为什么用 Imagenette 而不是 CIFAR-10? ViT 的 patch embedding 需要输入分辨率能被 patch_size 整除。原始 ViT 设计为 224×224 输入,切分成 14×14=196 个 patch(每个 16×16)。CIFAR-10 的 32×32 太小,需要大改架构。Imagenette 是 ImageNet 的 10 类子集(tench, springer, cassette player 等),图像原始尺寸足够大,resize 到 224×224 后可以直接适配标准 ViT。

三级回退策略

  1. Imagenette($\sim$1.5GB,自动下载)—— 首选
  2. CIFAR-100 —— 回退方案1
  3. 合成随机数据 —— 回退方案2(仅演示流程)

数据增强

python
transform_train = transforms.Compose([
    transforms.Resize((224, 224)),            # 统一尺寸
    transforms.RandomHorizontalFlip(),         # 水平翻转
    transforms.ColorJitter(0.2, 0.2),         # 颜色抖动:亮度±20%,对比度±20%
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],  # ImageNet 统计值
                         std=[0.229, 0.224, 0.225]),
])

颜色抖动(ColorJitter)的关键性:ViT 没有 CNN 的色彩归纳偏置,color jitter 作为数据增强强迫 ViT 学习颜色不变的表示,在小数据集上尤为重要。

第3步:Patch Embedding —— 图像变成词序列

这是 ViT 中唯一与 NLP 不同的组件,是连接两个世界的"桥梁"。

python
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        # Conv2d 的 kernel_size 和 stride 都设为 patch_size
        # 这等价于"均匀切分 + 线性投影"两个操作
        self.proj = nn.Conv2d(in_chans, embed_dim,
                              kernel_size=patch_size, stride=patch_size)
        self.num_patches = (img_size // patch_size) ** 2

    def forward(self, x):  # x: (B, 3, 224, 224)
        x = self.proj(x)           # (B, D, 14, 14) — 每个 16×16 patch → D维向量
        x = x.flatten(2)           # (B, D, 196)  — 展平空间维度
        x = x.transpose(1, 2)      # (B, 196, D) — 变成序列格式
        return x

为什么用 Conv2d 而不是手动切分? 数学上等价,但 Conv2d 更高效:

PatchEmbed(I)=Conv2d(Cin=3,Cout=D,kernel=P,stride=P)(I)

卷积核以步长 P 在图像上滑动,每次覆盖 P×P 的区域并输出一个 D 维向量。由于 stride = kernel_size,各窗口互不重叠,恰好实现了"切 patch + 线性投影"。

维度变化跟踪

输入:  (B, 3, 224, 224)     # 一张 224×224 RGB 图
 ↓ Conv2d(3→D, kernel=16, stride=16)
      (B, D, 14, 14)         # 14×14 个 patch,每个 D 维
 ↓ flatten(2)
      (B, D, 196)            # 196 个 token
 ↓ transpose(1,2)
      (B, 196, D)            # token 序列,shape 与 NLP 的 (B, seq_len, d_model) 一致

第4步:多头自注意力(MSA)—— ViT 的大脑

python
class MultiHeadAttention(nn.Module):
    def __init__(self, dim, num_heads=12, qkv_bias=False):
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5  # 缩放因子 1/√(d_k)

        # Q、K、V 合并在一个 Linear 中计算,比分别做三次 Linear 更高效
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.proj = nn.Linear(dim, dim)  # 多头拼接后的输出投影

为什么 Q、K、V 合并在一个 Linear 中? 矩阵乘法 XWqkv 的计算量是 O(ND3D)。分开做三次 XWQXWKXWV 需要三次 O(ND2)。合并后只需一次大矩阵乘法和 reshape,内存访问更连续,GPU 利用更充分。

前向传播 —— 注意力计算的代码实现

python
def forward(self, x):  # x: (B, N, D)
    B, N, C = x.shape

    # 1. 线性变换得到 Q、K、V
    qkv = self.qkv(x)                              # (B, N, 3D)
    qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads)
    qkv = qkv.permute(2, 0, 3, 1, 4)              # (3, B, num_heads, N, head_dim)
    q, k, v = qkv[0], qkv[1], qkv[2]              # 各 (B, num_heads, N, head_dim)

    # 2. 缩放点积注意力
    attn = (q @ k.transpose(-2, -1)) * self.scale  # (B, num_heads, N, N)
    attn = attn.softmax(dim=-1)                    # 沿 key 维度 softmax
    attn = self.attn_drop(attn)

    # 3. 加权求和 + 多头拼接
    x = (attn @ v).transpose(1, 2).reshape(B, N, C)
    x = self.proj(x)                               # (B, N, D)
    return x

数学对应

Attention(Q,K,V)=softmax(QKdk)V
  • q @ k.transpose(-2, -1) 计算的是 QK,产生 (N,N) 的注意力矩阵
  • * self.scale 除以 dk —— 为什么要缩放?dk 很大时(如 64),QK 的元素值可能很大,softmax 会进入饱和区(梯度几乎为零)。缩放使输入 softmax 的值保持在合理范围。
  • softmax(dim=-1) 对每一行(每个 query 对所有 key)归一化,jαij=1
  • attn @ v 用注意力权重对 value 加权求和

ViT 中自注意力的意义:与 CNN 不同,ViT 的每个 patch 在第一层就能"看到"所有其他 patch(全局感受野)。注意力矩阵 αij 编码了"第 i 个 patch 应该关注第 j 个 patch 的程度"。

第5步:Transformer 编码器块

python
class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4.):
        self.norm1 = nn.LayerNorm(dim, eps=1e-6)   # Pre-Norm: MSA 前归一化
        self.attn = MultiHeadAttention(dim, num_heads)
        self.norm2 = nn.LayerNorm(dim, eps=1e-6)   # Pre-Norm: MLP 前归一化
        # MLP: D → 4D → D,GELU 激活
        self.mlp = nn.Sequential(
            nn.Linear(dim, int(dim * mlp_ratio)),
            nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(int(dim * mlp_ratio), dim),
            nn.Dropout(drop),
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x))   # Pre-Norm 残差: MSA
        x = x + self.mlp(self.norm2(x))    # Pre-Norm 残差: MLP
        return x

Pre-Norm vs Post-Norm:原始 Transformer(Vaswani et al., 2017)使用 Post-Norm(残差后再 LayerNorm)。ViT 使用 Pre-Norm(LayerNorm 放在子层之前),这在训练更深的 Transformer 时更稳定,梯度流动更好。

为什么用 GELU 而不是 ReLU? GELU (Gaussian Error Linear Unit) 是 BERT 和 ViT 的标准激活函数:

GELU(x)=xΦ(x)0.5x(1+tanh(2π(x+0.044715x3)))

GELU 在 x=0 附近是光滑的,不像 ReLU 那样有尖锐的不可导点。在 Transformer 中,GELU 通常比 ReLU 带来更好的收敛性和最终精度。

为什么 MLP 有 4 倍扩展率? mlp_ratio=4 意味着 MLP 的隐藏层是 dim * 4。这提供了足够的容量让 Transformer 在注意力聚合全局信息后,对每个 token 的表示做非线性变换。这个比例在原始 Transformer 论文中确定,在 ViT 和 BERT 中保持了一致。

第6步:完整 SimpleViT —— 组装所有组件

python
class SimpleViT(nn.Module):
    def __init__(self, img_size=224, patch_size=16, num_classes=1000,
                 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.):
        # 1. Patch Embedding
        self.patch_embed = PatchEmbedding(img_size, patch_size, 3, embed_dim)

        # 2. CLS Token:可学习的分类 token,放在序列开头
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

        # 3. Position Embedding:可学习的 1D 位置编码
        num_patches = (img_size // patch_size) ** 2
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))

        # 4. Transformer 编码器(L 层)
        self.blocks = nn.Sequential(*[
            TransformerBlock(embed_dim, num_heads, mlp_ratio)
            for _ in range(depth)
        ])

        # 5. 分类头
        self.norm = nn.LayerNorm(embed_dim, eps=1e-6)
        self.head = nn.Linear(embed_dim, num_classes)

CLS Token 的工作原理[CLS] 是一个可学习的参数向量,形状为 (D,)。在每一层自注意力中,它与所有 patch token 交互,逐步聚合整个图像的信息。最终分类时只取 [CLS] 的输出。这个设计与 BERT 的 [CLS] 完全相同。

为什么 1D 位置编码而不是 2D? 直觉上图像有行列结构,应该用 2D 编码。但 ViT 的实验表明 1D 和 2D 效果几乎一样——Transformer 足够强大,可以从 1D 序列中学会推 2D 空间关系。训练完成后,相近位置的编码向量确实更相似,说明模型自动"领悟"了 2D 结构。

完整前向传播

python
def forward(self, x):  # x: (B, 3, 224, 224)
    B = x.shape[0]

    # 1. Patch Embedding: (B, C, H, W) → (B, N, D)
    x = self.patch_embed(x)

    # 2. 追加 CLS Token: (B, 1, D) + (B, N, D) → (B, N+1, D)
    cls_tokens = self.cls_token.expand(B, -1, -1)  # 复制到 batch 维度
    x = torch.cat((cls_tokens, x), dim=1)          # 在序列维度拼接

    # 3. 加上位置编码: (B, N+1, D) + (1, N+1, D) → (B, N+1, D)
    x = x + self.pos_embed  # 广播加法

    # 4. Dropout → Transformer 编码器
    x = self.pos_drop(x)
    x = self.blocks(x)

    # 5. LayerNorm → 取 CLS Token → 分类
    x = self.norm(x)
    x = self.head(x[:, 0])  # x[:, 0] 是 CLS token 的最终输出

    return x

流程图示

Image (224×224)  →  PatchEmbed  →  196 Tokens
                                      + [CLS] Token  =  197 Tokens
                                      + Pos Embed

                                  × L 层 TransformerEncoder

                                  取 [CLS] Token  →  MLP Head  →  分类

第7步:预训练模型加载与微调

python
def load_pretrained_vit(num_classes=10):
    # 方案1: torchvision (≥0.13)
    from torchvision.models import vit_b_16, ViT_B_16_Weights
    model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)  # ImageNet-1K 预训练
    in_features = model.heads.head.in_features                  # 768
    model.heads.head = nn.Linear(in_features, num_classes)      # 替换分类头
    return model

为什么预训练 ViT 在 Imagenette 上需要微调? 预训练模型学到的特征来自 ImageNet-1K 的 1000 类。Imagenette 虽然只取了其中的 10 类,但分类头不同——需要替换最后一层并微调。微调时学习率要小(backbone lr×0.1,head lr),避免破坏预训练权重。

微调的参数分组策略

python
head_params = []  # 分类头:正常学习率
body_params = []  # backbone:1/10 学习率
for name, param in model.named_parameters():
    if 'head' in name or 'fc' in name:
        head_params.append(param)
    else:
        body_params.append(param)

optimizer = optim.AdamW([
    {'params': body_params, 'lr': lr * 0.1},  # backbone 慢速更新
    {'params': head_params, 'lr': lr},          # 分类头正常更新
], weight_decay=0.05)

为什么 AdamW 而不是 SGD? Transformer 对优化器敏感。SGD 在 Transformer 上通常不如自适应优化器。AdamW(Adam with decoupled Weight Decay)是当前训练 Transformer 的标准选择:

  • Adam 部分:自适应学习率,处理不同参数的梯度尺度差异
  • Weight Decay 解耦:正则化项直接作用于参数(θ=θηλθ),而不是通过梯度,效果更好

第8步:可视化 —— 准确率对比与训练曲线

代码生成两张图:

  1. 准确率柱状图:SimpleViT(从零)、ResNet-18(从零)、ViT-B/16(预训练+微调)的对比
  2. 训练曲线:Training Loss 和 Test Accuracy 随 epoch 的变化

预期结果解读

模型预期准确率原因
SimpleViT(从零)60-75%缺少归纳偏置 + 数据量不足
ResNet-18(从零)85-92%CNN 的局部性先验在小数据集上优势明显
ViT-B/16(预训练+微调)95%+预训练弥补了数据需求,微调适配目标域

这个对比验证了 ViT 论文的核心发现:归纳偏置 = 数据效率 vs 灵活性上限的 trade-off。 CNN 的局部性和平移等变性在小数据上是优势;ViT 抛弃这些偏置后,需要大规模预训练才能发挥潜力。

关键概念速查表

概念公式/说明代码对应
Patch EmbeddingConv2d(3,D,kernel=P,stride=P)PatchEmbedding
自注意力softmax(QKTdk)VMultiHeadAttention.forward()
CLS Token可学习参数 xclassRDself.cls_token = nn.Parameter(...)
Pre-Norm 残差x+MSA(LN(x))TransformerBlock.forward()
GELUxΦ(x)nn.GELU()
位置编码可学习 EposR(N+1)×Dself.pos_embed
AdamWAdam + 解耦权重衰减optim.AdamW(...)
微调参数分组backbone 用 lr×0.1,head 用 lrquick_finetune()

完整代码

py
# -*- coding: utf-8 -*-
"""
s11b Vision Transformer (ViT) Demo:从零实现 ViT,对比预训练 ViT 与 CNN
=====================================================================
使用 PyTorch 从零构建一个 Vision Transformer,在 Imagenette 上训练,
与预训练的 ViT-B/16 和 ResNet-18 进行准确率对比。

核心目标:
  1. 理解 ViT 的 Patch Embedding → Transformer Encoder → Classification Head 流程
  2. 对比 ViT(从零训练)、预训练 ViT(微调)、CNN(ResNet-18)的性能
  3. 直观感受 ViT 需要更多数据的特性

运行方式:
  cd s11b_vit/code
  python vit_demo.py

依赖:torch, torchvision, matplotlib, numpy
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
matplotlib.rcParams['axes.unicode_minus'] = False

import numpy as np
import os
import sys
import time

# ============================================================
# 全局配置
# ============================================================

# 设备检测:GPU > MPS (Mac) > CPU
DEVICE = torch.device(
    'cuda' if torch.cuda.is_available()
    else 'mps' if torch.backends.mps.is_available()
    else 'cpu'
)
print(f"使用设备: {DEVICE}")
if DEVICE.type == 'cuda':
    print(f"  GPU 型号: {torch.cuda.get_device_name(0)}")
elif DEVICE.type == 'cpu':
    print("  (未检测到 GPU,使用 CPU 轻量模式)")

# 图片保存路径
_HERE = os.path.dirname(os.path.abspath(__file__))
_IMAGES_DIR = os.path.join(_HERE, '..', 'images')
os.makedirs(_IMAGES_DIR, exist_ok=True)

# ---- 根据设备选择模型规模 ----
if DEVICE.type == 'cpu':
    # CPU 轻量配置:小模型 + 少 epoch + 跳过大模型
    CONFIG = {
        'img_size': 224,
        'patch_size': 16,
        'embed_dim': 192,
        'depth': 4,
        'num_heads': 3,
        'mlp_ratio': 2.0,
        'epochs': 3,
        'batch_size': 32,
        'lr': 1e-3,
        'use_pretrained': False,
    }
else:
    # GPU 全量配置
    CONFIG = {
        'img_size': 224,
        'patch_size': 16,
        'embed_dim': 384,
        'depth': 8,
        'num_heads': 6,
        'mlp_ratio': 4.0,
        'epochs': 10,
        'batch_size': 64,
        'lr': 3e-4,
        'use_pretrained': True,
    }

NUM_CLASSES = 10  # Imagenette 类别数(ImageNet 的 10 类子集)

# Imagenette 类别名(ImageNet 中的 10 个易分类类别)
IMAGENETTE_CLASSES = (
    'tench', 'English springer', 'cassette player', 'chain saw',
    'church', 'French horn', 'garbage truck', 'gas pump',
    'golf ball', 'parachute'
)

IMAGENETTE_URL = 'https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz'


# ============================================================
# 第 1 部分:数据加载
# ============================================================

def download_imagenette(data_dir: str) -> str:
    """下载并解压 Imagenette(ImageNet 的 10 类子集,~1.5GB)"""
    import tarfile, urllib.request
    tgz_path = os.path.join(data_dir, 'imagenette2.tgz')
    extracted_dir = os.path.join(data_dir, 'imagenette2')

    # 如果已解压,直接返回
    if os.path.exists(extracted_dir):
        print(f"  Imagenette 已存在于 {extracted_dir}")
        return extracted_dir

    # 下载
    print(f"  正在下载 Imagenette (ImageNet 10 类子集,~1.5GB)...")
    print(f"  URL: {IMAGENETTE_URL}")
    try:
        urllib.request.urlretrieve(IMAGENETTE_URL, tgz_path)
    except Exception as e:
        raise RuntimeError(f"Imagenette 下载失败: {e}")

    # 解压
    print(f"  正在解压...")
    with tarfile.open(tgz_path) as tar:
        tar.extractall(data_dir)

    # 清理压缩包
    os.remove(tgz_path)
    return extracted_dir


def get_imagenette_loaders(batch_size: int, img_size: int = 224):
    """
    加载 Imagenette 数据集(ImageNet 的 10 类子集)。
    自动下载 ~1.5GB 数据。如果下载失败,回退到 CIFAR-100。
    """
    # 数据增强(训练集)
    transform_train = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.2, contrast=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

    transform_test = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

    data_dir = os.path.join(_HERE, '..', '..', 'data')

    # ---- 方案 1:Imagenette(ImageNet 子集)----
    try:
        extracted = download_imagenette(data_dir)
        train_path = os.path.join(extracted, 'train')
        val_path = os.path.join(extracted, 'val')

        train_set = torchvision.datasets.ImageFolder(
            train_path, transform=transform_train
        )
        test_set = torchvision.datasets.ImageFolder(
            val_path, transform=transform_test
        )
        class_names = train_set.classes
        print(f"  Imagenette 加载成功!类别: {class_names}")
        print(f"  训练集: {len(train_set)} 张, 验证集: {len(test_set)} 张")
        tl, vl = _make_loaders(train_set, test_set, batch_size)
        return tl, vl, class_names

    except Exception as e:
        print(f"  Imagenette 加载失败 ({type(e).__name__}: {e})")
        print(f"  你也可以手动下载: {IMAGENETTE_URL}")
        print(f"  解压到: {data_dir}/imagenette2/")

    # ---- 方案 2:CIFAR-100 ----
    print("  回退方案 1: 尝试 CIFAR-100...")
    try:
        train_set = torchvision.datasets.CIFAR100(
            root=data_dir, train=True, download=True,
            transform=transform_train
        )
        test_set = torchvision.datasets.CIFAR100(
            root=data_dir, train=False, download=True,
            transform=transform_test
        )
        global NUM_CLASSES
        NUM_CLASSES = 100
        print(f"  CIFAR-100 加载成功!({NUM_CLASSES} 类)")
        tl, vl = _make_loaders(train_set, test_set, batch_size)
        return tl, vl, None
    except Exception as e2:
        print(f"  CIFAR-100 也失败了 ({type(e2).__name__})")

    # ---- 方案 3:合成数据 ----
    print("  回退方案 2: 使用合成数据(仅供演示流程)")
    from torch.utils.data import TensorDataset
    np.random.seed(42)
    n_train, n_test = 2000, 400
    synth_X_train = torch.randn(n_train, 3, img_size, img_size)
    synth_y_train = torch.randint(0, NUM_CLASSES, (n_train,))
    synth_X_test = torch.randn(n_test, 3, img_size, img_size)
    synth_y_test = torch.randint(0, NUM_CLASSES, (n_test,))
    train_set = TensorDataset(synth_X_train, synth_y_train)
    test_set = TensorDataset(synth_X_test, synth_y_test)
    tl, vl = _make_loaders(train_set, test_set, batch_size)
    return tl, vl, None


def _make_loaders(train_set, test_set, batch_size):
    """创建 DataLoader"""
    train_loader = torch.utils.data.DataLoader(
        train_set, batch_size=batch_size, shuffle=True, num_workers=0
    )
    test_loader = torch.utils.data.DataLoader(
        test_set, batch_size=batch_size, shuffle=False, num_workers=0
    )
    return train_loader, test_loader


# ============================================================
# 第 2 部分:从零实现 Vision Transformer (ViT)
# ============================================================

class PatchEmbedding(nn.Module):
    """
    Patch Embedding:将图像切分成固定大小的 Patch,并映射为 Token 序列

    输入图像 (B, C, H, W) → 使用 stride=patch_size 的 Conv2d 切分 →
    输出 (B, N, D),其中 N = (H/P) * (W/P) 为 Patch 数量,D 为嵌入维度
    """

    def __init__(self, img_size: int = 224, patch_size: int = 16,
                 in_chans: int = 3, embed_dim: int = 768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2

        # Conv2d 比 unfold + Linear 更高效,本质相同
        # (B, C, 224, 224) → (B, D, 14, 14)
        self.proj = nn.Conv2d(in_chans, embed_dim,
                              kernel_size=patch_size, stride=patch_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        参数:
            x: (B, C, H, W) 输入图像
        返回:
            (B, N, D) Patch Token 序列
        """
        x = self.proj(x)              # (B, D, H/P, W/P)
        x = x.flatten(2)              # (B, D, N)
        x = x.transpose(1, 2)         # (B, N, D)
        return x


class MultiHeadAttention(nn.Module):
    """
    多头自注意力 (Multi-Head Self-Attention, MSA)

    让序列中每个 Patch 可以"看到"所有其他 Patch,计算全局依赖关系。
    公式: Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) V
    """

    def __init__(self, dim: int, num_heads: int = 12,
                 qkv_bias: bool = False, attn_drop: float = 0.,
                 proj_drop: float = 0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5  # 缩放因子 1/sqrt(d_k)

        # 将 Q、K、V 合并在一个 Linear 中计算,提升效率
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        参数:
            x: (B, N, D) 输入序列
        返回:
            (B, N, D) 注意力输出
        """
        B, N, C = x.shape

        # 计算 Q、K、V: (B, N, D) → (B, N, 3*D) → 拆分为 3 个 (B, num_heads, N, head_dim)
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, B, num_heads, N, head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]

        # 缩放点积注意力
        # Q @ K^T: (B, num_heads, N, N)
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)        # 沿最后一个维度做 softmax
        attn = self.attn_drop(attn)

        # 加权求和: (B, num_heads, N, head_dim)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class TransformerBlock(nn.Module):
    """
    ViT Transformer 编码器块

    结构: x → LayerNorm → MSA → +残差 → LayerNorm → MLP → +残差

    这与原始 Transformer 编码器完全一致。
    """

    def __init__(self, dim: int, num_heads: int, mlp_ratio: float = 4.,
                 qkv_bias: bool = False, drop: float = 0.,
                 attn_drop: float = 0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim, eps=1e-6)
        self.attn = MultiHeadAttention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias,
            attn_drop=attn_drop, proj_drop=drop
        )
        self.norm2 = nn.LayerNorm(dim, eps=1e-6)

        # MLP: D → 4D → D (GELU 激活)
        hidden_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(drop),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """前向传播:Pre-Norm 残差结构"""
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x


class SimpleViT(nn.Module):
    """
    从零实现的 Vision Transformer (ViT)

    架构:
      图像 → Patch Embedding → + CLS Token → + Position Embedding →
      ×L Transformer Blocks → LayerNorm → CLS Token → MLP Head → 分类

    这是 ViT-Base 的简化版,可通过参数控制规模。
    """

    def __init__(self, img_size: int = 224, patch_size: int = 16,
                 in_chans: int = 3, num_classes: int = 1000,
                 embed_dim: int = 768, depth: int = 12,
                 num_heads: int = 12, mlp_ratio: float = 4.,
                 qkv_bias: bool = False, drop_rate: float = 0.,
                 attn_drop_rate: float = 0.):
        super().__init__()
        self.num_classes = num_classes
        self.embed_dim = embed_dim

        # ---- Patch Embedding ----
        self.patch_embed = PatchEmbedding(
            img_size, patch_size, in_chans, embed_dim
        )
        num_patches = self.patch_embed.num_patches

        # ---- CLS Token:可学习的分类 token,放在序列开头 ----
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

        # ---- Position Embedding:可学习的 1D 位置编码 ----
        self.pos_embed = nn.Parameter(
            torch.zeros(1, num_patches + 1, embed_dim)
        )
        self.pos_drop = nn.Dropout(drop_rate)

        # ---- Transformer 编码器 ----
        self.blocks = nn.Sequential(*[
            TransformerBlock(
                dim=embed_dim, num_heads=num_heads,
                mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
                drop=drop_rate, attn_drop=attn_drop_rate,
            )
            for _ in range(depth)
        ])

        # ---- 分类头:取 CLS token 输出做分类 ----
        self.norm = nn.LayerNorm(embed_dim, eps=1e-6)
        self.head = nn.Linear(embed_dim, num_classes)

        # ---- 权重初始化 ----
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        """对 Linear 和 LayerNorm 进行初始化"""
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.weight, 1.0)
            nn.init.constant_(m.bias, 0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        前向传播

        参数:
            x: (B, C, H, W) 输入图像
        返回:
            logits: (B, num_classes) 分类 logits
        """
        B = x.shape[0]

        # 1. Patch Embedding: (B, C, H, W) → (B, N, D)
        x = self.patch_embed(x)

        # 2. 追加 CLS Token: (B, 1, D) + (B, N, D) → (B, N+1, D)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        # 3. 加上位置编码
        x = x + self.pos_embed
        x = self.pos_drop(x)

        # 4. Transformer 编码器
        x = self.blocks(x)

        # 5. LayerNorm → 取 CLS Token → 分类
        x = self.norm(x)
        x = self.head(x[:, 0])  # 只取位置 0 的 CLS token

        return x


# ============================================================
# 第 3 部分:预训练模型加载
# ============================================================

def load_pretrained_vit(num_classes: int = 10):
    """
    加载预训练的 ViT-B/16,替换分类头为 num_classes

    优先使用 torchvision 内置 ViT,其次尝试 timm 库,失败则返回 None。
    """
    # ---- 尝试 torchvision (>= 0.13) ----
    try:
        from torchvision.models import vit_b_16, ViT_B_16_Weights
        model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
        in_features = model.heads.head.in_features
        model.heads.head = nn.Linear(in_features, num_classes)
        print("  [预训练] ViT-B/16 (torchvision, ImageNet-1K 权重)")
        return model
    except (ImportError, AttributeError):
        pass

    # ---- 回退:尝试 timm 库 ----
    try:
        import timm
        model = timm.create_model('vit_base_patch16_224', pretrained=True,
                                  num_classes=num_classes)
        print("  [预训练] ViT-B/16 (timm, ImageNet-21K→1K 权重)")
        return model
    except ImportError:
        pass

    print("  [跳过] 预训练 ViT 不可用(需要 torchvision>=0.13 或 timm)")
    return None


def load_pretrained_resnet(num_classes: int = 10):
    """
    加载预训练的 ResNet-18,替换分类头

    预训练权重来自 ImageNet-1K,替换最后的全连接层适配 Imagenette。
    """
    try:
        from torchvision.models import resnet18, ResNet18_Weights
        model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
        in_features = model.fc.in_features
        model.fc = nn.Linear(in_features, num_classes)
        print("  [预训练] ResNet-18 (torchvision, ImageNet-1K 权重)")
        return model
    except (ImportError, AttributeError):
        pass

    # 回退:不用预训练权重
    from torchvision.models import resnet18
    model = resnet18(weights=None)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    print("  [从零] ResNet-18 (无预训练权重)")
    return model


# ============================================================
# 第 4 部分:训练与评估工具
# ============================================================

def count_parameters(model: nn.Module) -> int:
    """统计模型可训练参数量"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def train_one_epoch(model, train_loader, optimizer, criterion, device):
    """
    训练一个 epoch

    返回:
        avg_loss: 平均损失
        accuracy: 训练准确率 (%)
    """
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0

    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    avg_loss = total_loss / len(train_loader)
    accuracy = 100.0 * correct / total
    return avg_loss, accuracy


@torch.no_grad()
def evaluate(model, test_loader, criterion, device):
    """
    在测试集上评估模型

    返回:
        avg_loss: 平均损失
        accuracy: 测试准确率 (%)
    """
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0

    for inputs, targets in test_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    avg_loss = total_loss / len(test_loader)
    accuracy = 100.0 * correct / total
    return avg_loss, accuracy


def train_full(model, train_loader, test_loader, epochs, lr, device,
               model_name="Model"):
    """
    完整训练流程:训练 epochs 轮,记录 loss 和准确率曲线

    返回:
        history: { 'train_loss': [...], 'train_acc': [...], 'test_acc': [...] }
        final_test_acc: 最终测试准确率
    """
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.05)
    # 余弦退火学习率调度
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

    history = {'train_loss': [], 'train_acc': [], 'test_acc': []}
    best_acc = 0.0

    print(f"\n  [{model_name}] 开始训练 {epochs} epochs, lr={lr}")
    t0 = time.time()

    for epoch in range(1, epochs + 1):
        train_loss, train_acc = train_one_epoch(
            model, train_loader, optimizer, criterion, device
        )
        test_loss, test_acc = evaluate(
            model, test_loader, criterion, device
        )
        scheduler.step()

        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['test_acc'].append(test_acc)

        if test_acc > best_acc:
            best_acc = test_acc

        elapsed = time.time() - t0
        print(f"    Epoch {epoch:2d}/{epochs} | "
              f"Loss: {train_loss:.4f} | "
              f"Train Acc: {train_acc:.2f}% | "
              f"Test Acc: {test_acc:.2f}% | "
              f"Time: {elapsed:.0f}s")

    print(f"  [{model_name}] 完成!最佳测试准确率: {best_acc:.2f}%")
    return history, best_acc


def quick_finetune(model, train_loader, test_loader, epochs, lr, device,
                   model_name="Pretrained"):
    """
    快速微调预训练模型:主要训练分类头,fine-tune 最后几层

    适用于已在大数据集上预训练的 ViT-B/16。使用较小的学习率避免破坏
    预训练学到的良好特征表示。
    """
    criterion = nn.CrossEntropyLoss()

    # 对 backbone 使用更小的学习率
    # 分类头参数单独分组,使用较大学习率
    head_params = []
    body_params = []
    for name, param in model.named_parameters():
        if 'head' in name or 'fc' in name:
            head_params.append(param)
        else:
            body_params.append(param)

    optimizer = optim.AdamW([
        {'params': body_params, 'lr': lr * 0.1},   # backbone 慢速更新
        {'params': head_params, 'lr': lr},          # 分类头正常更新
    ], weight_decay=0.05)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

    history = {'train_loss': [], 'train_acc': [], 'test_acc': []}
    best_acc = 0.0

    print(f"\n  [{model_name}] 微调 {epochs} epochs, head_lr={lr}, body_lr={lr*0.1:.1e}")
    t0 = time.time()

    for epoch in range(1, epochs + 1):
        train_loss, train_acc = train_one_epoch(
            model, train_loader, optimizer, criterion, device
        )
        test_loss, test_acc = evaluate(
            model, test_loader, criterion, device
        )
        scheduler.step()

        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['test_acc'].append(test_acc)

        if test_acc > best_acc:
            best_acc = test_acc

        elapsed = time.time() - t0
        print(f"    Epoch {epoch:2d}/{epochs} | "
              f"Loss: {train_loss:.4f} | "
              f"Train Acc: {train_acc:.2f}% | "
              f"Test Acc: {test_acc:.2f}% | "
              f"Time: {elapsed:.0f}s")

    print(f"  [{model_name}] 完成!最佳测试准确率: {best_acc:.2f}%")
    return history, best_acc


# ============================================================
# 第 5 部分:可视化
# ============================================================

def plot_accuracy_comparison(results: dict, save_path: str):
    """
    绘制模型准确率对比柱状图

    参数:
        results: {model_name: accuracy}
        save_path: 图片保存路径
    """
    fig, ax = plt.subplots(figsize=(10, 6))

    names = list(results.keys())
    accs = list(results.values())

    # 使用不同颜色区分从零训练 vs 预训练
    colors = []
    for name in names:
        if '预训练' in name or 'Pretrained' in name:
            colors.append('#2196F3')  # 蓝色:预训练模型
        elif '微调' in name or 'Finetuned' in name:
            colors.append('#4CAF50')  # 绿色:微调模型
        else:
            colors.append('#FF9800')  # 橙色:从零训练

    bars = ax.bar(range(len(names)), accs, color=colors, alpha=0.85, edgecolor='white')

    # 在柱子上方标注数值
    for bar, acc in zip(bars, accs):
        ax.text(bar.get_x() + bar.get_width() / 2., bar.get_height() + 0.5,
                f'{acc:.2f}%', ha='center', va='bottom', fontweight='bold', fontsize=11)

    ax.set_xticks(range(len(names)))
    ax.set_xticklabels(names, fontsize=10)
    ax.set_ylabel('Test Accuracy (%)', fontsize=12)
    ax.set_title('Imagenette Classification Accuracy: ViT vs CNN', fontsize=14, fontweight='bold')
    ax.set_ylim(0, max(accs) * 1.15)
    ax.grid(True, alpha=0.3, axis='y')

    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"  [图表] 准确率对比已保存到 {save_path}")


def plot_training_curves(histories: dict, save_path: str):
    """
    绘制训练曲线对比图(Loss + Accuracy)

    参数:
        histories: {model_name: {'train_loss': [...], 'test_acc': [...]}}
        save_path: 图片保存路径
    """
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    colors = ['#FF9800', '#2196F3', '#4CAF50', '#F44336']

    # 子图 1: 训练 Loss
    ax = axes[0]
    for idx, (name, hist) in enumerate(histories.items()):
        epochs = range(1, len(hist['train_loss']) + 1)
        ax.plot(epochs, hist['train_loss'], marker='o', markersize=3,
                color=colors[idx % len(colors)], label=name, linewidth=2)
    ax.set_xlabel('Epoch', fontsize=11)
    ax.set_ylabel('Training Loss', fontsize=11)
    ax.set_title('Training Loss Comparison', fontsize=13, fontweight='bold')
    ax.legend(fontsize=9)
    ax.grid(True, alpha=0.3)

    # 子图 2: 测试准确率
    ax = axes[1]
    for idx, (name, hist) in enumerate(histories.items()):
        epochs = range(1, len(hist['test_acc']) + 1)
        ax.plot(epochs, hist['test_acc'], marker='s', markersize=3,
                color=colors[idx % len(colors)], label=name, linewidth=2)
    ax.set_xlabel('Epoch', fontsize=11)
    ax.set_ylabel('Test Accuracy (%)', fontsize=11)
    ax.set_title('Test Accuracy Comparison', fontsize=13, fontweight='bold')
    ax.legend(fontsize=9)
    ax.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"  [图表] 训练曲线已保存到 {save_path}")


# ============================================================
# 第 6 部分:主流程
# ============================================================

def main():
    """主函数:训练并对比 ViT (从零)、预训练 ViT、ResNet-18"""
    print("=" * 65)
    print("  s11b Vision Transformer Demo")
    print("  ViT (从零) vs 预训练 ViT-B/16 vs ResNet-18")
    print("=" * 65)

    cfg = CONFIG
    print(f"\n配置:")
    print(f"  设备: {DEVICE}")
    print(f"  输入尺寸: {cfg['img_size']}x{cfg['img_size']}")
    print(f"  Patch 大小: {cfg['patch_size']}x{cfg['patch_size']}")
    print(f"  ViT 嵌入维度: {cfg['embed_dim']}")
    print(f"  ViT 层数: {cfg['depth']}, 注意力头数: {cfg['num_heads']}")
    print(f"  Epochs: {cfg['epochs']}, Batch Size: {cfg['batch_size']}")
    print(f"  使用预训练模型: {cfg['use_pretrained']}")

    # ---- 1. 加载数据 ----
    print("\n[1/5] 加载 Imagenette 数据集 (resize 到 {}x{})..."
          .format(cfg['img_size'], cfg['img_size']))
    train_loader, test_loader, classes = get_imagenette_loaders(
        cfg['batch_size'], cfg['img_size']
    )
    print(f"  训练集: {len(train_loader.dataset)} 张")
    print(f"  测试集: {len(test_loader.dataset)} 张")
    print(f"  类别: {classes}")

    # ---- 2. 构建模型 ----
    print("\n[2/5] 构建模型...")

    # 2a. 从零实现的 SimpleViT
    simple_vit = SimpleViT(
        img_size=cfg['img_size'],
        patch_size=cfg['patch_size'],
        in_chans=3,
        num_classes=NUM_CLASSES,
        embed_dim=cfg['embed_dim'],
        depth=cfg['depth'],
        num_heads=cfg['num_heads'],
        mlp_ratio=cfg['mlp_ratio'],
    ).to(DEVICE)
    print(f"  [从零] SimpleViT 参数量: {count_parameters(simple_vit):,}")

    # 2b. ResNet-18 (从零训练)
    resnet = load_pretrained_resnet(num_classes=NUM_CLASSES)
    # 不加载预训练权重,从零训练以公平对比
    from torchvision.models import resnet18 as _resnet18_raw
    resnet = _resnet18_raw(weights=None)
    resnet.fc = nn.Linear(resnet.fc.in_features, NUM_CLASSES)
    resnet = resnet.to(DEVICE)
    print(f"  [从零] ResNet-18 参数量: {count_parameters(resnet):,}")

    # 2c. 预训练 ViT-B/16 (可选)
    pretrained_vit = None
    if cfg['use_pretrained'] and DEVICE.type != 'cpu':
        pretrained_vit = load_pretrained_vit(num_classes=NUM_CLASSES)
        if pretrained_vit is not None:
            pretrained_vit = pretrained_vit.to(DEVICE)
            print(f"  [预训练] ViT-B/16 参数量: {count_parameters(pretrained_vit):,}")

    # ---- 3. 训练模型 ----
    print("\n[3/5] 训练从零实现的 SimpleViT...")
    vit_history, vit_acc = train_full(
        simple_vit, train_loader, test_loader,
        epochs=cfg['epochs'], lr=cfg['lr'], device=DEVICE,
        model_name="SimpleViT (从零)"
    )

    print("\n[4/5] 训练 ResNet-18 (从零)...")
    # ResNet-18 使用稍小的学习率 (SGD 标准)
    rn_history, rn_acc = train_full(
        resnet, train_loader, test_loader,
        epochs=cfg['epochs'], lr=1e-3, device=DEVICE,
        model_name="ResNet-18 (从零)"
    )

    # 收集训练曲线用于绘图(只看从零训练的模型)
    training_histories = {
        'SimpleViT (从零)': vit_history,
        'ResNet-18 (从零)': rn_history,
    }

    # 收集最终准确率
    final_results = {
        'SimpleViT\n(从零训练)': vit_acc,
        'ResNet-18\n(从零训练)': rn_acc,
    }

    # ---- 4. 微调预训练模型 ----
    if pretrained_vit is not None:
        print("\n[4b/5] 微调预训练 ViT-B/16...")
        # 微调 epoch 数减半(预训练模型收敛更快)
        ft_epochs = max(2, cfg['epochs'] // 2)
        pt_history, pt_acc = quick_finetune(
            pretrained_vit, train_loader, test_loader,
            epochs=ft_epochs, lr=cfg['lr'] * 0.5, device=DEVICE,
            model_name="ViT-B/16 (预训练+微调)"
        )
        final_results['ViT-B/16\n(预训练+微调)'] = pt_acc
        training_histories['ViT-B/16 (预训练+微调)'] = pt_history

    # ---- 5. 可视化 ----
    print("\n[5/5] 生成对比图表...")

    # 图表 1: 准确率对比柱状图
    plot_accuracy_comparison(
        final_results,
        os.path.join(_IMAGES_DIR, 'vit_accuracy_comparison.png')
    )

    # 图表 2: 训练曲线
    plot_training_curves(
        training_histories,
        os.path.join(_IMAGES_DIR, 'vit_training_curves.png')
    )

    # ---- 总结 ----
    print("\n" + "=" * 65)
    print("  训练总结")
    print("=" * 65)
    for name, acc in final_results.items():
        name_clean = name.replace('\n', ' ')
        print(f"  {name_clean:<35} {acc:.2f}%")

    if pretrained_vit is not None:
        print(f"\n  关键发现:")
        print(f"  - ViT 从零训练在 Imagenette 上效果有限(需要更多数据)")
        print(f"  - 预训练 ViT 通过微调大幅超越从零训练的模型")
        print(f"  - ResNet 的卷积归纳偏置在小数据集上仍有优势")
    else:
        print(f"\n  提示: 在 GPU 上运行可启用预训练 ViT 对比,观察迁移学习效果")

    print(f"\n  图表已保存到: {_IMAGES_DIR}")
    print("=" * 65)
    print("  Demo 完成!")
    print("=" * 65)


if __name__ == '__main__':
    main()