Skip to content

s16 Attention与Transformer — demo.py 代码详解

Download demo.py

运行方式

bash
cd s16_attention_transformer/code
python demo.py

依赖numpy, torch, matplotlib, seaborn


代码逐段详解

第1步:导入库

python
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import seaborn as sns       # 注意力热力图可视化

第2步:缩放点积注意力 — 一切的基础

核心公式(Vaswani et al., 2017):

Attention(Q,K,V)=softmax(QKdk)V

这是所有 Transformer 架构的基石。无论 BERT、GPT 还是 ViT,其自注意力计算都是这个公式的某种变体。

python
def scaled_dot_product_attention(Q, K, V, mask=None):
    d_k = Q.size(-1)                                          # 每个头的维度
    scores = torch.matmul(Q, K.transpose(-2, -1))             # QK^T: (..., seq_q, seq_k)
    scores = scores / math.sqrt(d_k)                          # 除以 √d_k 缩放
    if mask is not None:
        scores = scores.masked_fill(mask, float('-inf'))      # 掩码位置设为 -∞
    attn_weights = F.softmax(scores, dim=-1)                  # softmax 归一化
    output = torch.matmul(attn_weights, V)                    # 加权求和
    return output

逐行解释

  1. K.transpose(-2, -1):将 K 的最后两维转置。如果 Q 是 (..., seq_q, d_k),K 转置后为 (..., d_k, seq_k),那么 Q @ K^T 的结果是 (..., seq_q, seq_k) —— 每个查询位置对所有键位置的得分。

  2. / math.sqrt(d_k):缩放因子的核心作用——当 dk 较大时(如 64 或 128),点积 qk=qiki 的方差约为 dk。如果不缩放,点积值过大,softmax 会输出几乎 one-hot 的分布(梯度接近 0),模型将无法学习。除以 dk 将方差控制在 1,保持 softmax 在梯度良好的区域。

  3. masked_fill(mask, float('-inf')):掩码位置填入 ,经过 softmax 后权重为 0。用在因果掩码(屏蔽未来位置)和 padding 掩码(屏蔽填充 token)。

  4. F.softmax(scores, dim=-1):沿最后一个维度(即 Key 维度)做 softmax。这保证了对于每个 Query 位置,它对所有 Key 位置的注意力权重之和为 1。

  5. torch.matmul(attn_weights, V):用注意力权重对 Value 做加权求和。直观上,这是"根据每个位置的关联程度,从 Value 中提取信息"。


第3步:多头自注意力 — MultiHeadSelfAttention

为什么要多头? 单头注意力只能捕捉一种关系模式(如句法依赖),但语言中有多种类型的关系需要同时建模——共指关系、语义关联、局部短语结构等。多头注意力通过并行运行 h 组独立的 Q/K/V 投影,让不同头关注不同类型的模式。

核心公式

headi=Attention(QWiQ,KWiK,VWiV)MultiHead(Q,K,V)=Concat(head1,,headh)WO

其中 WiQ,WiK,WiVRdmodel×dkWORhdv×dmodel,通常 dk=dv=dmodel/h

3.1 初始化:Q/K/V 投影矩阵

python
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d_model=512, num_heads=8, dropout=0.1):
        assert d_model % num_heads == 0, "d_model 必须能被 num_heads 整除"
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads          # 每个头的维度

        # Q、K、V 的线性投影(所有头合并在一个矩阵中)
        self.W_Q = nn.Linear(d_model, d_model, bias=False)
        self.W_K = nn.Linear(d_model, d_model, bias=False)
        self.W_V = nn.Linear(d_model, d_model, bias=False)
        self.W_O = nn.Linear(d_model, d_model, bias=False)   # 输出投影

设计选择

  • bias=False:原始 Transformer 中 QKV 投影不带偏置,简化了计算
  • d_k = d_model // num_heads:总维度 d_model 被均匀分配给 h 个头,保证多头和单头的计算量大致相当
  • W_O(输出投影):将所有头的输出拼接后,通过一个线性层融合,学习如何组合不同头的信息

3.2 前向传播:拆分-计算-合并

python
def forward(self, x, mask=None):
    batch_size, seq_len, _ = x.shape

    # 1. 线性投影:每个 token 的向量分别通过 Q、K、V 投影
    Q = self.W_Q(x)    # (batch, seq_len, d_model)
    K = self.W_K(x)
    V = self.W_V(x)

    # 2. 拆分为多头:reshape → transpose
    #    (batch, seq_len, d_model) → (batch, seq_len, num_heads, d_k) → (batch, num_heads, seq_len, d_k)
    Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
    K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
    V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)

    # 3. 缩放点积注意力(每个头独立计算)
    attn_output = scaled_dot_product_attention(Q, K, V, mask)
    # (batch, num_heads, seq_len, d_k)

    # 4. 合并多头:transpose → reshape
    attn_output = attn_output.transpose(1, 2).contiguous().view(
        batch_size, seq_len, self.d_model
    )
    # 5. 最终输出投影
    output = self.W_O(attn_output)
    return output, attn_weights

view + transpose 的奥秘:这是多头注意力的关键实现技巧。

  • Q.view(batch, seq_len, num_heads, d_k):将 d_model=512 的向量"折叠"为 num_heads=8d_k=64 的小向量。
  • .transpose(1, 2):交换 seq_len 和 num_heads 维度,使得后续矩阵乘法在每个头上独立进行
  • 合并时:.transpose(1, 2).contiguous().view(batch, seq_len, d_model)——将 8 个头的 64 维输出拼接回 512 维。

为什么这样设计? 如果不拆分维度而用 for 循环对每个头分别计算,代码更直观但速度慢数十倍。将"多头"信息编码到张量维度中,可以利用 GPU 的批量矩阵乘法一次性完成所有头的计算。


第4步:Feed-Forward Network(FFN)

Transformer Block 的第二个子层是 Position-wise FFN——对每个位置独立应用同一个两层全连接网络:

FFN(x)=GELU(xW1+b1)W2+b2

维度变化:dmodel4×dmodeldmodel

python
class FeedForward(nn.Module):
    def __init__(self, d_model=512, d_ff=2048, dropout=0.1):
        self.linear1 = nn.Linear(d_model, d_ff)       # 先升维 4x
        self.linear2 = nn.Linear(d_ff, d_model)       # 再降维回去
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.linear2(self.dropout(F.gelu(self.linear1(x))))

为什么先升维再降维? 4x 扩展比提供了更大的容量来存储"知识"。FFN 被比喻为 Transformer 的"知识存储"——注意力负责"查找"相关信息,FFN 负责对查找结果进行非线性变换。升维给了 FFN 足够的表达能力来学习复杂的特征变换。

GELU vs ReLU:原始 Transformer 用 ReLU,但 GPT-2 及之后的模型普遍用 GELU(Gaussian Error Linear Unit)。GELU 是平滑版的 ReLU,在零点附近不是硬截断而是平滑过渡,梯度流动更好。


第5步:Transformer Encoder Block — Pre-LN 风格

输入 x

LayerNorm(x)           ← Pre-LN: 归一化在注意力之前

Multi-Head Self-Attention

Dropout → Add(x)       ← 残差连接

LayerNorm(...)

FFN

Dropout → Add(...)     ← 残差连接

输出
python
class TransformerEncoderBlock(nn.Module):
    def forward(self, x, mask=None):
        # 子层 1: 自注意力 + 残差
        residual = x
        x_norm = self.norm1(x)                        # Pre-LN
        attn_out, attn_weights = self.self_attn(x_norm, mask)
        x = residual + self.dropout(attn_out)          # 残差连接

        # 子层 2: FFN + 残差
        residual = x
        x = residual + self.dropout(self.ffn(self.norm2(x)))
        return x, attn_weights

Post-LN vs Pre-LN:原始 Transformer 论文用的是 Post-LN(LayerNorm 在加法之后),但后来的实践(GPT-2+)发现 Pre-LN(LayerNorm 在子层之前)训练更稳定,梯度流动更顺畅。这是因为 Pre-LN 下残差路径没有经过 LayerNorm,梯度的反向传播路径更"干净"。

残差连接为什么关键? 残差连接(x+Sublayer(x))让梯度可以通过恒等路径直通底层。没有残差连接,训练几十层的 Transformer 几乎不可能——梯度会随着层数消散。有了残差连接,即使注意力或 FFN 部分的梯度消失,恒等路径仍然可以向底层传递梯度信号。


第6步:正弦位置编码 — SinusoidalPositionEncoding

自注意力天然对位置不敏感——打乱输入序列中词的顺序,注意力输出只会在顺序上不同,但值的集合完全相同。位置编码将位置信息注入序列。

公式

PE(pos,2i)=sin(pos100002i/dmodel)PE(pos,2i+1)=cos(pos100002i/dmodel)
python
class SinusoidalPositionEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)          # (max_len, 1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)    # 偶数维度: sin
        pe[:, 1::2] = torch.cos(position * div_term)    # 奇数维度: cos
        self.register_buffer('pe', pe.unsqueeze(0))      # (1, max_len, d_model)

关键设计分析

  • 不同频率的正弦波100002i/dmodel 产生从 110000 的波长范围。低维(小 i)对应高频(短波长),高维(大 i)对应低频(长波长)。这让模型能从多个粒度感知位置——高频维度区分相邻位置,低频维度捕捉远距离位置关系。

  • 为什么用 sin/cos 而不是可学习嵌入? 正弦编码可以外推到训练时未见过的序列长度(因为函数是确定性的),且相邻位置的编码具有线性关系:PE(pos+k) 可以表示为 PE(pos) 的线性函数,有助于模型学习相对位置。

  • register_buffer:将张量注册为 buffer(而非 Parameter),意味着它随模型保存/加载但不参与梯度更新。


第7步:Mini-GPT — Decoder-only Transformer

python
class MiniGPT(nn.Module):
    def __init__(self, vocab_size, d_model=128, num_heads=4, num_layers=4,
                 d_ff=512, max_len=128, dropout=0.1):
        # 词嵌入 + 位置编码
        self.token_embed = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = SinusoidalPositionEncoding(d_model, max_len)

        # 堆叠多个 Transformer Block
        self.blocks = nn.ModuleList([
            TransformerEncoderBlock(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])

        # 最终 LayerNorm + 语言模型头
        self.final_norm = nn.LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

        # 权重绑定(Weight Tying)
        self.lm_head.weight = self.token_embed.weight

        # 缓存因果掩码
        self.register_buffer(
            'causal_mask',
            torch.triu(torch.ones(max_len, max_len, dtype=torch.bool), diagonal=1)
        )

7.1 因果掩码 — causal_mask

因果掩码是一个上三角矩阵(对角线以上为 True/1),用于防止位置 i 关注位置 j(当 j>i 时):

对于 seq_len=4:
[[False, True,  True,  True],    ← 位置0只能看到自己
 [False, False, True,  True],    ← 位置1能看到0和1
 [False, False, False, True],    ← 位置2能看到0,1,2
 [False, False, False, False]]   ← 位置3能看到全部

torch.triu(..., diagonal=1) 保留对角线以上(不含对角线)的元素,对角线及以下置 0。mask 位置为 True 时会被 masked_fill 设为

7.2 前向传播

python
def forward(self, x, return_attn=False):
    batch_size, seq_len = x.shape
    # 获取因果掩码,塑形为 (1, 1, seq_len, seq_len) 以便广播到 (batch, num_heads, seq_len, seq_len)
    mask = self.causal_mask[:seq_len, :seq_len].view(1, 1, seq_len, seq_len)

    # 词嵌入(缩放)+位置编码
    x_emb = self.token_embed(x) * math.sqrt(self.d_model)
    x_emb = self.pos_encoding(x_emb)

    # 通过所有 Transformer Block
    hidden = x_emb
    for block in self.blocks:
        hidden, attn_weights = block(hidden, mask)

    # 最终 LayerNorm + LM Head
    hidden = self.final_norm(hidden)
    logits = self.lm_head(hidden)         # (batch, seq_len, vocab_size)
    return logits, attn_maps

* math.sqrt(self.d_model):这是原始 Transformer 论文中的一个小技巧。嵌入向量的初始方差很小(通常从 N(0,1) 初始化),乘以 dmodel 让嵌入的尺度与位置编码的尺度匹配,防止位置编码"淹没"了语义信息。

权重绑定 self.lm_head.weight = self.token_embed.weight:输入嵌入层的权重和输出投影层的权重共享。这在 GPT 系列中是标准做法——输入时将一个 token 映射为向量,输出时将向量映射回 token 概率分布,共享权重减少了参数量,且两边的语义空间是一致的。

7.3 自回归文本生成

python
def generate(self, seed_tokens, max_new_tokens=50, temperature=0.8):
    with torch.no_grad():
        for _ in range(max_new_tokens):
            # 取最后 max_len 个 token 作为输入(序列过长时截断)
            input_seq = torch.tensor([tokens[-128:]], device=device)
            logits, _ = self.forward(input_seq)
            # 取最后一个位置的 logits,除以温度
            next_logits = logits[0, -1, :] / temperature
            probs = F.softmax(next_logits, dim=-1)
            next_token = torch.multinomial(probs, 1).item()
            tokens.append(next_token)
    return tokens

逐 token 生成:每次生成一个新 token 后,将其追加到序列末尾,然后用更新后的序列预测下一个。这是标准的自回归(autoregressive)生成方式。

因果注意力的效率:虽然每次只预测最后一个位置,但前向传播时仍需计算所有位置的自注意力。不过因果掩码确保位置 i 只能看到 ji 的位置,这恰好符合自回归的需求。


第8步:dk 缩放实验

代码的最后一部分是一个关键的对照实验:对比不同 dk 下有无缩放的 softmax 分布差异。

核心指标

指标含义理想值
平均熵softmax 分布的均匀程度,熵越高分布越均匀有缩放时熵较高
最大注意力权重最受关注的位置权重,越高越集中(饱和)有缩放时较低

实验发现:当 dk=256 时,无缩放的 softmax 几乎完全饱和(某个位置的权重 1,其余 0),平均最大注意力权重接近 1.0。有缩放后,分布更均匀,多个位置都能获得有意义的注意力权重。

这个实验直接验证了 dk 缩放的数学原理:Var(qk)=dk,因此标准差 σ=dk,除以 dk 将标准差归一化为 1。


第9步:注意力热力图可视化

代码绘制了两类热力图:

  1. 各层平均注意力(所有头的平均):展示信息流如何随层加深而变化——浅层倾向于关注局部邻域,深层倾向于关注全局或语义相关的 token。
  2. 最后一层的多头对比:展示不同注意力头关注的不同模式——有的头可能关注相邻词,有的头可能关注语法结构(如主语-谓语)。

关键概念速查表

概念公式/描述作用
Scaled Dot-Product Attentionsoftmax(QK/dk)V全局信息聚合
Q (Query)Q=XWQ"我在找什么?"
K (Key)K=XWK"作为信息源,我是什么?"
V (Value)V=XWV"如果被选中,我提供什么信息?"
多头注意力Concat(head1,...,headh)WO并行捕捉多种关系模式
dk 缩放除以 dk防止 softmax 饱和
因果掩码上三角 防止看到未来
位置编码sin/cos 或可学习嵌入注入位置信息
Pre-LNLayerNorm 在子层之前训练更稳定
残差连接x+Sublayer(x)梯度直通底层
FFNGELU(xW1+b1)W2+b2位置独立非线性变换
权重绑定token_embed.weight=lm_head.weight减少参数,语义一致

完整代码

py
# -*- coding: utf-8 -*-
"""
s16 Attention 与 Transformer demo:从零实现注意力与 Mini-GPT
============================================================
本文件从零实现了 Transformer 的核心组件:
  1. 缩放点积注意力(Scaled Dot-Product Attention)
  2. 多头自注意力(Multi-Head Self-Attention)
  3. Transformer 编码器 Block(Attention + FFN + LayerNorm + 残差)
  4. Mini-GPT(decoder-only)用于字符级文本生成
  5. 注意力可视化(热力图)
  6. √d_k 缩放效果对比

运行方式:在 s16_attention_transformer 目录下执行 `python code/demo.py`
依赖:numpy, torch, matplotlib, seaborn
"""

import numpy as np
import math
from typing import Tuple, List, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# GPU 自动检测
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print(f"使用设备: {DEVICE}")
if DEVICE.type == 'cpu':
    print("(未检测到 GPU,使用 CPU 运行。如有 GPU,请安装 CUDA 版 PyTorch 以获得加速)")

import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams['axes.unicode_minus'] = False
import seaborn as sns

import os
_HERE = os.path.dirname(os.path.abspath(__file__))
_IMAGES = os.path.join(_HERE, '..', 'images')
os.makedirs(_IMAGES, exist_ok=True)

# ============================================================
# 第一部分:从零实现 Attention 核心组件
# ============================================================

def scaled_dot_product_attention(
    Q: torch.Tensor,
    K: torch.Tensor,
    V: torch.Tensor,
    mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """
    缩放点积注意力:Attention(Q, K, V) = softmax(QK^T / √d_k) V

    参数:
        Q: Query 矩阵 (..., seq_len_q, d_k)
        K: Key 矩阵 (..., seq_len_k, d_k)
        V: Value 矩阵 (..., seq_len_v, d_v),seq_len_v = seq_len_k
        mask: 注意力掩码 (..., seq_len_q, seq_len_k),True 的位置将被设为 -inf
    返回:
        output: 加权后的 Value (..., seq_len_q, d_v)
    """
    d_k = Q.size(-1)
    # QK^T: 计算 Query 和 Key 的点积
    scores = torch.matmul(Q, K.transpose(-2, -1))  # (..., seq_len_q, seq_len_k)
    # 缩放:除以 √d_k,防止点积过大导致 softmax 饱和
    scores = scores / math.sqrt(d_k)
    # 应用掩码(如因果掩码):掩码位置设为负无穷
    if mask is not None:
        scores = scores.masked_fill(mask, float('-inf'))
    # softmax 归一化
    attn_weights = F.softmax(scores, dim=-1)  # (..., seq_len_q, seq_len_k)
    # 加权求和 Value
    output = torch.matmul(attn_weights, V)    # (..., seq_len_q, d_v)
    return output


class MultiHeadSelfAttention(nn.Module):
    """
    多头自注意力(Multi-Head Self-Attention)。
    不使用 nn.MultiheadAttention,完全从零实现。

    参数:
        d_model: 模型维度(输入/输出维度)
        num_heads: 注意力头数
        dropout: Dropout 概率
    """

    def __init__(self, d_model: int = 512, num_heads: int = 8, dropout: float = 0.1):
        super().__init__()
        assert d_model % num_heads == 0, "d_model 必须能被 num_heads 整除"
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # 每个头的维度

        # Q, K, V 的线性投影(合并所有头到一个矩阵中)
        self.W_Q = nn.Linear(d_model, d_model, bias=False)  # (d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model, bias=False)
        self.W_V = nn.Linear(d_model, d_model, bias=False)
        # 输出投影
        self.W_O = nn.Linear(d_model, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(
        self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        前向传播。

        参数:
            x: 输入序列 (batch, seq_len, d_model)
            mask: 注意力掩码 (batch, 1, seq_len, seq_len) 或 broadcastable
        返回:
            output: 多头注意力输出 (batch, seq_len, d_model)
            attn_weights: 注意力权重(用于可视化,第一个头的平均)
        """
        batch_size, seq_len, _ = x.shape

        # 线性投影得到 Q, K, V
        Q = self.W_Q(x)  # (batch, seq_len, d_model)
        K = self.W_K(x)
        V = self.W_V(x)

        # 拆分为多头:reshape 为 (batch, seq_len, num_heads, d_k) → transpose
        Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        # 现在形状: (batch, num_heads, seq_len, d_k)

        # 缩放点积注意力(mask 已在 MiniGPT.forward 中塑形为 (1, 1, seq_len, seq_len),可直接广播)
        attn_output = scaled_dot_product_attention(Q, K, V, mask)
        # attn_output: (batch, num_heads, seq_len, d_k)

        # 合并多头输出
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        # 最终线性投影
        output = self.W_O(attn_output)
        output = self.dropout(output)

        # 为可视化保存第一个头的注意力权重
        # 计算用于可视化的注意力权重
        with torch.no_grad():
            scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
            if mask is not None:
                scores = scores.masked_fill(mask, float('-inf'))
            attn_weights = F.softmax(scores, dim=-1)  # (batch, num_heads, seq_len, seq_len)

        return output, attn_weights


class FeedForward(nn.Module):
    """
    Transformer 的 Position-wise Feed-Forward Network。
    FFN(x) = GELU(x @ W_1 + b_1) @ W_2 + b_2
    维度变化:d_model → 4*d_model → d_model
    """

    def __init__(self, d_model: int = 512, d_ff: int = 2048, dropout: float = 0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """前向传播"""
        return self.linear2(self.dropout(F.gelu(self.linear1(x))))


class TransformerEncoderBlock(nn.Module):
    """
    一个完整的 Transformer Encoder Block:
    Input → LayerNorm → Multi-Head Self-Attention → Add (残差)
    → LayerNorm → FFN → Add (残差) → Output
    """

    def __init__(self, d_model: int = 512, num_heads: int = 8, d_ff: int = 2048, dropout: float = 0.1):
        super().__init__()
        self.self_attn = MultiHeadSelfAttention(d_model, num_heads, dropout)
        self.ffn = FeedForward(d_model, d_ff, dropout)
        # LayerNorm(Pre-LN 风格)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None):
        """
        前向传播。

        参数:
            x: 输入 (batch, seq_len, d_model)
            mask: 注意力掩码
        返回:
            output: 输出 (batch, seq_len, d_model)
            attn_weights: 注意力权重(用于可视化)
        """
        # 子层 1: 自注意力 + 残差(Pre-LN: LN 在注意力之前)
        residual = x
        x_norm = self.norm1(x)
        attn_out, attn_weights = self.self_attn(x_norm, mask)
        x = residual + self.dropout(attn_out)

        # 子层 2: FFN + 残差
        residual = x
        x = residual + self.dropout(self.ffn(self.norm2(x)))
        return x, attn_weights


# ============================================================
# 第二部分:位置编码
# ============================================================

class SinusoidalPositionEncoding(nn.Module):
    """
    正弦位置编码。
    PE(pos, 2i)   = sin(pos / 10000^(2i/d_model))
    PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
    """

    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)  # (max_len, 1)
        # 计算分母: 10000^(2i/d_model)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )
        # 偶数维度用 sin,奇数维度用 cos
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        # 注册为 buffer(不参与训练,但随模型保存)
        self.register_buffer('pe', pe.unsqueeze(0))  # (1, max_len, d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """将位置编码加到输入上"""
        return x + self.pe[:, :x.size(1), :]


# ============================================================
# 第三部分:Mini-GPT(Decoder-only Transformer)
# ============================================================

class MiniGPT(nn.Module):
    """
    一个微型的 GPT 风格模型(Decoder-only Transformer)。
    用于字符级文本生成。

    参数:
        vocab_size: 词汇表大小
        d_model: 模型维度
        num_heads: 注意力头数
        num_layers: Transformer Block 层数
        max_len: 最大序列长度
    """

    def __init__(
        self,
        vocab_size: int,
        d_model: int = 128,
        num_heads: int = 4,
        num_layers: int = 4,
        d_ff: int = 512,
        max_len: int = 128,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.d_model = d_model
        # 词嵌入 + 位置编码
        self.token_embed = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = SinusoidalPositionEncoding(d_model, max_len)
        self.dropout_embed = nn.Dropout(dropout)
        # 堆叠 Transformer Block
        self.blocks = nn.ModuleList([
            TransformerEncoderBlock(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        # 最终 LayerNorm + 输出投影
        self.final_norm = nn.LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
        # 权重绑定:将 lm_head 的权重与 token_embed 共享
        self.lm_head.weight = self.token_embed.weight

        # 缓存因果掩码,避免每次前向都重新计算
        self.register_buffer(
            'causal_mask',
            torch.triu(torch.ones(max_len, max_len, dtype=torch.bool), diagonal=1)
        )

    def forward(
        self, x: torch.Tensor, return_attn: bool = False
    ) -> Tuple[torch.Tensor, Optional[List]]:
        """
        前向传播。

        参数:
            x: 输入 token 序列 (batch, seq_len)
            return_attn: 是否返回注意力权重(用于可视化)
        返回:
            logits: (batch, seq_len, vocab_size)
            attn_maps: 各层的注意力权重列表(可选)
        """
        batch_size, seq_len = x.shape
        # 获取因果掩码 — 塑形为 (1, 1, seq_len, seq_len) 以便与 (batch, num_heads, seq_len, seq_len) 广播
        mask = self.causal_mask[:seq_len, :seq_len].view(1, 1, seq_len, seq_len)
        # 词嵌入 + 位置编码
        x_emb = self.token_embed(x) * math.sqrt(self.d_model)  # 缩放嵌入
        x_emb = self.pos_encoding(x_emb)
        x_emb = self.dropout_embed(x_emb)
        # 通过所有 Transformer Block
        attn_maps = [] if return_attn else None
        hidden = x_emb
        for block in self.blocks:
            hidden, attn_weights = block(hidden, mask)
            if return_attn:
                attn_maps.append(attn_weights)
        # 最终 LayerNorm + LM Head
        hidden = self.final_norm(hidden)
        logits = self.lm_head(hidden)  # (batch, seq_len, vocab_size)
        return logits, attn_maps

    def generate(
        self, seed_tokens: torch.Tensor, max_new_tokens: int = 50, temperature: float = 0.8
    ) -> List[int]:
        """
        自回归生成文本。

        参数:
            seed_tokens: 初始 token 序列 (1, seed_len)
            max_new_tokens: 生成的最大 token 数
            temperature: 温度参数
        返回:
            generated: 生成的 token 序列
        """
        self.eval()
        device = next(self.parameters()).device
        tokens = seed_tokens.tolist()[0] if seed_tokens.dim() > 1 else seed_tokens.tolist()
        with torch.no_grad():
            for _ in range(max_new_tokens):
                # 取最后 max_len 个 token 作为输入
                input_seq = torch.tensor([tokens[-128:]], dtype=torch.long, device=device)
                logits, _ = self.forward(input_seq)
                # 取最后一个位置的 logits
                next_logits = logits[0, -1, :] / temperature
                probs = F.softmax(next_logits, dim=-1)
                next_token = torch.multinomial(probs, 1).item()
                tokens.append(next_token)
        return tokens


# ============================================================
# 第四部分:训练与可视化
# ============================================================

# 使用一个简单的中文语料训练字符级模型
TEXT = """
深度学习的核心是神经网络,它通过多层非线性变换来学习数据的层次化特征表示。
反向传播算法是训练神经网络的关键技术,它利用链式法则高效地计算损失函数对每个参数的梯度。
卷积神经网络擅长处理图像数据,它通过局部连接和权重共享来捕捉空间结构信息。
循环神经网络适合处理序列数据,它通过隐藏状态来记忆历史信息。
注意力机制是Transformer架构的核心,它允许模型在处理序列时动态地关注不同位置的信息。
Transformer架构完全基于注意力机制,抛弃了传统的循环和卷积结构。
自然语言处理是人工智能的重要领域,它致力于让计算机理解和生成人类语言。
计算机视觉使机器能够像人类一样看懂图像和视频。
强化学习通过试错和奖励机制让智能体学会最优决策策略。
生成对抗网络通过生成器和判别器的对抗训练来生成逼真的数据样本。
知识图谱以结构化的方式表示实体之间的语义关系。
迁移学习利用在大规模数据上预训练的模型来提升小样本任务的性能。
大语言模型通过在海量文本上预训练,展现出惊人的语言理解和生成能力。
"""


def build_dataset(text: str, seq_len: int = 30):
    """构建字符级训练数据集"""
    chars = sorted(set(text))
    char2idx = {ch: i for i, ch in enumerate(chars)}
    idx2char = {i: ch for i, ch in enumerate(chars)}
    data = [char2idx[ch] for ch in text]
    # 构建 (input_seq, target_seq) 对
    samples = []
    for i in range(0, len(data) - seq_len):
        input_seq = data[i:i + seq_len]
        target_seq = data[i + 1:i + seq_len + 1]  # 目标序列是输入右移一位
        samples.append((input_seq, target_seq))
    return samples, char2idx, idx2char, len(chars)


class MiniGPTDataset(Dataset):
    """Mini-GPT 训练数据集"""

    def __init__(self, samples):
        self.samples = samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        inp, tgt = self.samples[idx]
        return torch.tensor(inp, dtype=torch.long), torch.tensor(tgt, dtype=torch.long)


# 构建数据集
samples, char2idx, idx2char, vocab_size = build_dataset(TEXT, seq_len=30)
dataset = MiniGPTDataset(samples)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)
print(f"[数据集] 词汇表大小: {vocab_size}, 训练样本数: {len(dataset)}")

# 创建模型
model = MiniGPT(
    vocab_size=vocab_size,
    d_model=64,
    num_heads=4,
    num_layers=3,
    d_ff=256,
    max_len=128,
    dropout=0.1,
)
print(f"[模型] 参数量: {sum(p.numel() for p in model.parameters()):,}")
device = DEVICE
model = model.to(device)
print(f"[设备] {device}")

# 训练
print("\n[训练] 开始训练 Mini-GPT...")
optimizer = optim.AdamW(model.parameters(), lr=0.003, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()
loss_history = []

for epoch in range(80):
    model.train()
    total_loss = 0.0
    for inputs, targets in dataloader:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        logits, _ = model(inputs)
        # logits: (batch, seq_len, vocab_size), targets: (batch, seq_len)
        loss = criterion(logits.view(-1, vocab_size), targets.view(-1))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(dataloader)
    loss_history.append(avg_loss)
    if (epoch + 1) % 20 == 0:
        print(f"  Epoch {epoch+1}/80, Loss: {avg_loss:.4f}")

# 绘制训练损失
plt.figure(figsize=(8, 4))
plt.plot(loss_history, color='#1E88E5', linewidth=1.5)
plt.xlabel("Epoch", fontsize=12)
plt.ylabel("Loss", fontsize=12)
plt.title("Mini-GPT Character Language Model Training Loss", fontsize=13, fontweight='bold')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.join(_IMAGES, 'minigpt_loss_curve.png'), dpi=150, bbox_inches='tight')
plt.close()
print("[可视化] 训练损失图已保存")

# ============================================================
# 第五部分:注意力可视化
# ============================================================

# 取一个样本输入,可视化注意力权重
sample_text = "深度学习是人工智能的重要领域"
sample_tokens = torch.tensor([[char2idx.get(ch, 0) for ch in sample_text]], dtype=torch.long, device=device)

model.eval()
with torch.no_grad():
    _, attn_maps = model(sample_tokens, return_attn=True)

# 绘制各层的注意力热力图(取第一个头)
fig, axes = plt.subplots(1, len(attn_maps), figsize=(5 * len(attn_maps), 4))
if len(attn_maps) == 1:
    axes = [axes]

for layer_idx, attn in enumerate(attn_maps):
    # attn: (batch, num_heads, seq_len, seq_len)
    # 取第一个样本、所有头的平均值
    avg_attn = attn[0].mean(dim=0).cpu().numpy()  # (seq_len, seq_len)
    ax = axes[layer_idx]
    im = ax.imshow(avg_attn, cmap='YlOrRd', aspect='auto')
    # 标注因果掩码的下三角区域
    ax.set_xticks(range(len(sample_text)))
    ax.set_xticklabels(list(sample_text), fontsize=9, rotation=45)
    ax.set_yticks(range(len(sample_text)))
    ax.set_yticklabels(list(sample_text), fontsize=9)
    ax.set_title(f"Layer {layer_idx+1} Average Attention", fontsize=11)
plt.suptitle("Mini-GPT Attention Weight Heatmaps by Layer (Causal Mask Lower Triangle)", fontsize=13, fontweight='bold')
plt.tight_layout()
plt.savefig(os.path.join(_IMAGES, 'attention_heatmap.png'), dpi=150, bbox_inches='tight')
plt.close()
print("[可视化] 注意力热力图已保存至 images/attention_heatmap.png")

# 多头注意力对比(取最后一层)
fig, axes = plt.subplots(1, 4, figsize=(20, 4))
last_attn = attn_maps[-1][0].cpu().numpy()  # (num_heads, seq_len, seq_len)
for head_idx in range(min(4, last_attn.shape[0])):
    ax = axes[head_idx]
    im = ax.imshow(last_attn[head_idx], cmap='YlOrRd', aspect='auto')
    ax.set_xticks(range(len(sample_text)))
    ax.set_xticklabels(list(sample_text), fontsize=8, rotation=45)
    ax.set_yticks(range(len(sample_text)))
    ax.set_yticklabels(list(sample_text), fontsize=8)
    ax.set_title(f"Attention Head {head_idx+1}", fontsize=11)
plt.suptitle("Last Layer: 4 Attention Heads Showing Different Focus Patterns", fontsize=13, fontweight='bold')
plt.tight_layout()
plt.savefig(os.path.join(_IMAGES, 'multihead_attention_patterns.png'), dpi=150, bbox_inches='tight')
plt.close()
print("[可视化] 多头注意力模式图已保存")

# ============================================================
# 第六部分:√d_k 缩放效果对比
# ============================================================

print("\n" + "=" * 60)
print("[实验] √d_k 缩放对 softmax 的影响")
print("=" * 60)

# 模拟不同 d_k 下的点积得分分布
torch.manual_seed(42)
seq_len = 16
for d_k in [4, 16, 64, 256]:
    Q_test = torch.randn(1, seq_len, d_k)
    K_test = torch.randn(1, seq_len, d_k)
    # 不缩放
    scores_no_scale = torch.matmul(Q_test, K_test.transpose(-2, -1))
    # 缩放
    scores_scaled = scores_no_scale / math.sqrt(d_k)

    # 计算 softmax 分布的熵(熵越高,分布越均匀)
    probs_no_scale = F.softmax(scores_no_scale, dim=-1)
    probs_scaled = F.softmax(scores_scaled, dim=-1)

    # 计算每行的平均熵
    entropy_no = -(probs_no_scale * torch.log(probs_no_scale + 1e-9)).sum(dim=-1).mean().item()
    entropy_scaled = -(probs_scaled * torch.log(probs_scaled + 1e-9)).sum(dim=-1).mean().item()
    # 最大注意力权重(越大越集中)
    max_attn_no = probs_no_scale.max(dim=-1)[0].mean().item()
    max_attn_scaled = probs_scaled.max(dim=-1)[0].mean().item()

    print(f"\nd_k = {d_k:3d}:")
    print(f"  无缩放: 平均熵 = {entropy_no:.4f}, 最大注意力 = {max_attn_no:.4f}")
    print(f"  有缩放: 平均熵 = {entropy_scaled:.4f}, 最大注意力 = {max_attn_scaled:.4f}")

# 大 d_k 下的对比柱状图
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
d_k_values = [4, 16, 64, 256]
entropy_no_list, entropy_scaled_list = [], []
max_no_list, max_scaled_list = [], []

for d_k in d_k_values:
    Q_test = torch.randn(1, seq_len, d_k)
    K_test = torch.randn(1, seq_len, d_k)
    scores_no = torch.matmul(Q_test, K_test.transpose(-2, -1))
    scores_sc = scores_no / math.sqrt(d_k)
    probs_no = F.softmax(scores_no, dim=-1)
    probs_sc = F.softmax(scores_sc, dim=-1)
    entropy_no_list.append(-(probs_no * torch.log(probs_no + 1e-9)).sum(dim=-1).mean().item())
    entropy_scaled_list.append(-(probs_sc * torch.log(probs_sc + 1e-9)).sum(dim=-1).mean().item())
    max_no_list.append(probs_no.max(dim=-1)[0].mean().item())
    max_scaled_list.append(probs_sc.max(dim=-1)[0].mean().item())

x = np.arange(len(d_k_values))
w = 0.35
axes[0].bar(x - w/2, entropy_no_list, w, label='No Scaling', color='#E53935', alpha=0.8)
axes[0].bar(x + w/2, entropy_scaled_list, w, label='With Scaling 1/sqrt(d_k)', color='#1E88E5', alpha=0.8)
axes[0].set_xticks(x)
axes[0].set_xticklabels([f"d_k={v}" for v in d_k_values])
axes[0].set_ylabel("Softmax Average Entropy", fontsize=11)
axes[0].set_title("Entropy (Higher = More Uniform)", fontsize=12)
axes[0].legend()
axes[0].grid(True, alpha=0.3, axis='y')

axes[1].bar(x - w/2, max_no_list, w, label='No Scaling', color='#E53935', alpha=0.8)
axes[1].bar(x + w/2, max_scaled_list, w, label='With Scaling 1/sqrt(d_k)', color='#1E88E5', alpha=0.8)
axes[1].set_xticks(x)
axes[1].set_xticklabels([f"d_k={v}" for v in d_k_values])
axes[1].set_ylabel("Average Max Attention Weight", fontsize=11)
axes[1].set_title("Max Attention (Lower = Less Saturated)", fontsize=12)
axes[1].legend()
axes[1].grid(True, alpha=0.3, axis='y')

plt.suptitle("Effect of 1/sqrt(d_k) Scaling on Softmax Saturation: Without Scaling, Small d_k is More Uniform, Large d_k is Extremely Concentrated", fontsize=13, fontweight='bold')
plt.tight_layout()
plt.savefig(os.path.join(_IMAGES, 'dk_scaling_effect.png'), dpi=150, bbox_inches='tight')
plt.close()
print("\n[可视化] √d_k 缩放效果对比图已保存")

# ============================================================
# 第七部分:文本生成
# ============================================================

print("\n" + "=" * 60)
print("[文本生成] Mini-GPT 生成演示")
print("=" * 60)

seeds = ["深度", "自然", "注意", "计算"]
for seed in seeds:
    seed_tokens = torch.tensor([[char2idx.get(ch, 0) for ch in seed]], dtype=torch.long, device=device)
    generated = model.generate(seed_tokens, max_new_tokens=30, temperature=0.8)
    gen_text = ''.join([idx2char.get(t, '?') for t in generated])
    print(f"  种子「{seed}」→ {gen_text}")

print("\n所有 demo 运行完成!图表已保存至 images/ 目录。")
print("\n" + "=" * 60)
print("[核心要点]")
print("=" * 60)
print("""
  1. Self-Attention = softmax(QK^T / √d_k) V
  2. QKV: Query(查)、Key(被查)、Value(值) —— 字典查询范式
  3. Multi-Head: 并行多组注意力,捕捉不同类型的语言关系
  4. 因果掩码: 上三角为 -inf,防止模型'偷看'未来
  5. Pre-LN: LayerNorm 在子层之前,训练更稳定
  6. √d_k 缩放: 没有它,大 d_k 时 softmax 会饱和→梯度消失
  7. 残差连接: 让深层 Transformer 梯度直通底层
""")