Skip to content

s17 预训练范式 — demo.py 代码详解

Download demo.py

运行方式

bash
cd s17_pretrained_models/code
python demo.py

依赖torch, transformers, matplotlib

首次运行:会自动从 HuggingFace Hub 下载模型文件。使用的模型包括:

  • prajjwal1/bert-tiny(最小 BERT,约 4MB,2 层 128 维)
  • bert-base-chinese(BERT 中文模型,用于 MLM 演示)
  • uer/gpt2-chinese-cluecorpussmall(中文 GPT-2)

代码逐段详解

第1步:核心导入 — transformers 库的关键类

python
from transformers import (
    AutoTokenizer,                    # 自动选择对应模型的 tokenizer
    AutoModelForSequenceClassification,  # BERT + 分类头
    AutoModelForMaskedLM,             # BERT MLM(掩码预测)
    AutoModelForCausalLM,             # GPT 自回归语言模型
    Trainer,                          # HuggingFace 的高层训练 API
    TrainingArguments,                # 训练超参数配置
    pipeline,                         # 一行代码完成推理的便捷 API
)

AutoTokenizer vs 专用 TokenizerAutoTokenizer.from_pretrained("bert-base-chinese") 会自动识别模型类型并加载对应的 tokenizer。不需要手动指定 BertTokenizerGPT2Tokenizer

AutoModelFor* 系列:HuggingFace 的 "Auto" 类会根据 checkpoint 名称自动选择正确的模型架构。AutoModelForSequenceClassification 会在预训练 BERT 顶部自动添加一个分类头(pooler + dropout + linear)。


第2步:BERT 文本分类微调

2.1 数据格式

python
train_data = [
    ("这个产品质量非常好,我很满意", 1),  # 正面评论 → 标签 1
    ("客服态度恶劣,完全不解决问题", 0),  # 负面评论 → 标签 0
    ...
]

每条数据是一个 (文本, 标签) 对。标签 1=正面,0=负面。训练集 24 条,验证集 6 条——这是典型的微调场景:标注数据很少,但预训练模型已经"懂"语言,只需少量数据即可适应特定任务

2.2 Tokenizer 编码

python
class SentimentDataset(Dataset):
    def __getitem__(self, idx):
        encoding = self.tokenizer(
            text,
            truncation=True,           # 超过 max_len 则截断
            padding='max_length',      # 不足 max_len 则填充到 max_len
            max_length=128,
            return_tensors='pt',       # 返回 PyTorch 张量
        )
        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': torch.tensor(label, dtype=torch.long),
        }

tokenizer 输出的三个关键字段

字段形状含义
input_ids(batch, max_len)每个 token 的词汇表索引,包括 [CLS], [SEP], [PAD]
attention_mask(batch, max_len)1=真实 token,0=填充 token。Attention 计算时忽略 0 的位置
token_type_ids(batch, max_len)0=句子 A,1=句子 B(单句分类时全为 0,此处未使用)

truncation=True, padding='max_length':保证所有样本的输入长度一致(都是 max_len),这是批处理的要求。截断丢弃超出部分,填充补齐不足部分。

2.3 HuggingFace Trainer — 高层训练 API

python
training_args = TrainingArguments(
    output_dir="./bert_sentiment_checkpoints",
    num_train_epochs=4,                        # 微调只需少量 epoch
    per_device_train_batch_size=4,
    eval_strategy="epoch",                     # 每个 epoch 评估一次
    load_best_model_at_end=True,               # 训练结束后加载最佳模型
    metric_for_best_model="eval_loss",
    report_to="none",                          # 不上传到 wandb
)

trainer = Trainer(
    model=bert_cls,                            # 预训练 BERT + 分类头
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics,           # 自定义评估函数
)
trainer.train()

为什么微调只需 2-4 个 epoch? 预训练模型已经学到了通用的语言知识(语法、语义、常识),微调只需将这些知识"调整"到特定任务。epoch 数过多反而会导致过拟合——模型会"忘记"预训练学到的通用知识(catastrophic forgetting)。

Trainer 的设计哲学:HuggingFace 的 Trainer 封装了训练循环、梯度累积、混合精度训练、分布式训练、日志记录、模型保存等底层细节。对于标准的微调任务,只需配置参数即可,无需手写训练循环。

2.4 预测新样本

python
for text in test_texts:
    inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=128)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        logits = bert_cls(**inputs).logits          # (1, 2) — 正/负类的未归一化得分
    probs = F.softmax(logits, dim=-1)               # (1, 2) — 转为概率
    pred = torch.argmax(logits, dim=-1).item()      # 0 或 1

**inputs 字典解包:将 {'input_ids': ..., 'attention_mask': ...} 作为关键字参数传入模型。等价于 bert_cls(input_ids=..., attention_mask=...)

softmax 将 logits 转为概率probs = F.softmax(logits, dim=-1) 得到 [P(负面), P(正面)]


第3步:BERT MLM — 掩码预测演示

MLM(Masked Language Model)是 BERT 预训练的核心任务——随机遮盖部分 token,让模型从上下文预测被遮盖的词。

3.1 使用 Pipeline API

python
mlm_pipeline = pipeline(
    "fill-mask",                    # 任务类型:掩码填充
    model=mlm_model,                # 预训练的 BERT MLM 模型
    tokenizer=mlm_tokenizer,
    device=0 if DEVICE.type == 'cuda' else -1,
)

pipeline("fill-mask", ...) 将模型加载、tokenization、前向传播、结果解析封装为一个函数调用。输入带 [MASK] 的文本,输出最可能的填充词。

3.2 预测被遮盖的词

python
mlm_examples = [
    "今天天气真[MASK],适合出去郊游。",
    "这个手机拍照效果很[MASK],我非常满意。",
    "深度学习是人工智能的一个重要[MASK]。",
]
for text in mlm_examples:
    results = mlm_pipeline(text, top_k=3)
    # results[0] = {'score': 0.85, 'token_str': '好', 'sequence': '今天天气真好...'}

MLM 的威力:BERT 能准确预测不同上下文中的 [MASK]——在"天气真[MASK]"的上下文中填"好",在"效果很[MASK]"的上下文中填"好"或"棒"。这展示了 BERT 双向理解的能力——它能同时利用左右两侧的上下文信息。

注意 [MASK] token 的特殊性[MASK] 是 BERT 词表中一个特殊的 token(id=103)。预训练期间模型学会了:当看到 [MASK] 时,需要预测其原始词汇。但在微调阶段,输入中没有 [MASK]——BERT 使用了 80%-10%-10% 的替换策略来弥合这个 gap。


第4步:GPT-2 文本生成 — 对比 BERT

4.1 为什么 BERT 不能生成文本?

BERT 是 Encoder-only 架构,使用双向自注意力——每个 token 可以同时看到左右的 token。这意味着 BERT 无法按顺序一个接一个地生成 token(因为它天然需要"看到全部"才能做预测)。

GPT 是 Decoder-only 架构,使用因果自注意力(causal mask)——每个 token 只能看到它之前的 token。这让 GPT 天然支持自回归生成:给定前文,预测下一个 token,然后将其追加到序列中,重复此过程。

4.2 GPT-2 生成参数

python
outputs = gpt_model.generate(
    **inputs,
    max_new_tokens=30,            # 最多生成 30 个新 token
    temperature=0.8,              # 温度:<1 更确定,>1 更随机
    do_sample=True,               # 采样而非贪心解码
    top_p=0.9,                    # nucleus sampling:累积概率阈值
    repetition_penalty=1.1,       # >1 抑制重复,<1 鼓励重复
    pad_token_id=gpt_tokenizer.pad_token_id,
)

top_p(Nucleus Sampling):从概率最高的 token 开始累加,当累积概率达到 top_p(如 0.9)时停止,只从这组 token 中采样。与 top-k 采样相比,top-p 能根据概率分布动态调整候选集大小。

repetition_penalty:对已经出现过的 token 施加惩罚(logits 降低),防止模型陷入重复循环(如"我爱你我爱你我爱你...")。值 >1 表示惩罚重复。

do_sample=True:使用概率采样而非贪心解码。如果 do_sample=False,等价于 temperature=0(每次选概率最高的 token),生成结果是确定性的,缺乏多样性。


第5步:上下文嵌入 — BERT vs Word2Vec

代码通过一个巧妙的实验展示 BERT 的核心优势——上下文相关的嵌入

python
test_sentences = [
    "我喜欢吃苹果,特别是红富士苹果",       # 两句中的"苹果"都是水果→相似
    "苹果公司发布了最新的iPhone手机",       # 这句中的"苹果"是公司→与水果不同
    "我在超市买了三个苹果",
    "苹果的股价今天上涨了百分之五",
]

Word2Vec 的问题:无论"苹果"出现在什么上下文中,其词向量完全相同。模型无法区分"吃苹果(水果)"和"苹果公司(科技公司)"。

BERT 的解决方案:BERT 的嵌入是上下文相关的——同一个词在不同句子(或同一句的不同位置)中有不同的向量表示。代码通过计算同一句中两个"苹果"的余弦相似度来验证:当其处于相同语义上下文时(都是水果),嵌入相似度高;跨语义上下文时,嵌入会不同。

计算余弦相似度

cosine_similarity(v1,v2)=v1v2v1v2

F.cosine_similarity(v1.unsqueeze(0), v2.unsqueeze(0)) 计算两个向量的夹角余弦,值域 [-1, 1]。1 表示完全相同,0 表示正交,-1 表示完全相反。


第6步:回退机制 — 当模型下载失败时

代码中包含了健壮的回退逻辑:如果 HuggingFace 模型下载失败(如无网络),会创建一个 TinyFallbackClassifier(微型 Transformer 分类器),确保 demo 在任何环境下都能运行。

python
class TinyFallbackClassifier(nn.Module):
    def __init__(self, vocab_size=1000, num_labels=2):
        self.embedding = nn.Embedding(vocab_size, 32)
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=32, nhead=2, ...),
            num_layers=2
        )
        self.classifier = nn.Linear(32, num_labels)

这是一个简单的 Encoder-only 模型(类似于微型 BERT),用于演示微调流程。虽然效果远不如真正的 BERT,但它能让你理解"预训练模型 + 分类头 → 微调"的完整 pipeline。


关键概念速查表

概念公式/描述关键点
MLM (掩码语言模型)预测 [MASK] 位置的原词BERT 的双向理解能力来源
CLM (因果语言模型)$P(x_tx_{<t})$
Tokenizer文本→input_ids, attention_mask分词+编码+填充+截断
[CLS] token句子级别的聚合表示BERT 分类任务用它的向量
[SEP] token句子分隔符分隔不同的句子/段落
[MASK] token被遮盖的 tokenMLM 的预测目标
attention_mask1=真实 token, 0=padding让注意力忽略填充位置
微调 epoch 数通常 2-4预训练模型只需少量调整
top-p sampling累积概率阈值动态确定候选 token 集合
repetition_penalty降低重复 token 的 logits防止生成循环重复文本
上下文嵌入同一词在不同上下文中向量不同BERT 解决多义词问题
Pipeline APIpipeline("fill-mask", model)一行代码完成推理

完整代码

py
# -*- coding: utf-8 -*-
"""
s17 预训练范式 demo:BERT 微调 + GPT-2 生成对比
================================================
本文件使用 HuggingFace transformers 库,展示预训练模型的核心用法:
  任务1:BERT 文本分类微调(中文情感分析)
  任务2:BERT 掩码预测(MLM 能力展示)
  任务3:GPT-2 文本生成(对比 BERT 无法生成)

运行方式:在 s17_pretrained_models 目录下执行 `python code/demo.py`
依赖:torch, transformers, datasets, matplotlib
注意:首次运行会在 models/ 目录下下载模型文件(约400MB)
"""

import numpy as np
from typing import List, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# GPU 自动检测
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print(f"使用设备: {DEVICE}")
if DEVICE.type == 'cpu':
    print("(未检测到 GPU,使用 CPU 运行。如有 GPU,请安装 CUDA 版 PyTorch 以获得加速)")

from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    AutoModelForMaskedLM,
    AutoModelForCausalLM,
    Trainer,
    TrainingArguments,
    pipeline,
)

import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams['axes.unicode_minus'] = False

import os
_HERE = os.path.dirname(os.path.abspath(__file__))
_IMAGES = os.path.join(_HERE, '..', 'images')
os.makedirs(_IMAGES, exist_ok=True)

# 设置模型下载目录
os.environ.setdefault('HF_HOME', os.path.join(os.path.dirname(__file__), '..', 'models'))

# ============================================================
# 第一部分:BERT 文本分类微调
# ============================================================

# 中文情感分析数据集(少量示例数据,用于演示微调流程)
train_data = [
    ("这个产品质量非常好,我很满意", 1),
    ("客服态度特别好,问题很快就解决了", 1),
    ("物流很快,包装很精美,好评", 1),
    ("性价比超级高,推荐大家购买", 1),
    ("已经用了好几个月了,质量很稳定", 1),
    ("颜色很好看,大小也合适,非常满意", 1),
    ("功能强大,操作也很简单方便", 1),
    ("比实体店便宜多了,正品无疑,会回购", 1),
    ("材质很好,手感不错,下次还会来买", 1),
    ("发货很快,商品完好无损,很满意", 1),
    ("款式新颖,做工也很好,值得购买", 1),
    ("味道很好,生产日期很新鲜,好评", 1),
    ("这个产品质量太差了,用了两天就坏了", 0),
    ("客服态度恶劣,完全不解决问题", 0),
    ("发货特别慢,等了一周才到,非常失望", 0),
    ("和描述完全不符,图片严重美化,上当了", 0),
    ("用了不到一个月就出问题,质量太差", 0),
    ("包装简陋,收到的时候已经破损了", 0),
    ("一点都不好用,操作复杂,不值这个价钱", 0),
    ("有异味,不敢用,退货还特别麻烦", 0),
    ("刚买完就降价了,还不给保价差评", 0),
    ("做工粗糙,细节处理不到位,不满意", 0),
    ("安装说明书写得太烂,完全看不懂", 0),
    ("噪音很大,严重影响使用体验,差评", 0),
]

# 测试数据
eval_data = [
    ("这个东西真的很好用,我太喜欢了", 1),
    ("物流太慢了,商品还有破损,不推荐", 0),
    ("整体还不错,价格合理,推荐购买", 1),
    ("完全不如描述的那么好,不建议买", 0),
    ("质量没话说,会推荐给朋友的", 1),
    ("售后态度太差了,再也不会买了", 0),
]


class SentimentDataset(Dataset):
    """情感分析 PyTorch 数据集"""

    def __init__(self, data: List[Tuple[str, int]], tokenizer, max_len: int = 128):
        """
        参数:
            data: (文本, 标签) 列表
            tokenizer: BERT tokenizer
            max_len: 最大序列长度
        """
        self.texts = [item[0] for item in data]
        self.labels = [item[1] for item in data]
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        # 使用 tokenizer 编码文本
        encoding = self.tokenizer(
            text,
            truncation=True,           # 超过 max_len 则截断
            padding='max_length',      # 不足 max_len 则填充
            max_length=self.max_len,
            return_tensors='pt',
        )
        return {
            'input_ids': encoding['input_ids'].squeeze(0),       # (max_len,)
            'attention_mask': encoding['attention_mask'].squeeze(0),  # (max_len,)
            'labels': torch.tensor(label, dtype=torch.long),
        }


print("=" * 60)
print("[BERT 文本分类] 加载 bert-base-chinese 模型...")
print("=" * 60)

# 使用最小的可用模型以加速演示(CPU 友好)
model_name = "prajjwal1/bert-tiny"  # 最小 BERT 变体(约 4MB,2层,128维)
_HAS_REAL_MODEL = False

try:
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    bert_cls = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
    bert_cls = bert_cls.to(DEVICE)
    _HAS_REAL_MODEL = True
    print(f"[模型] 成功加载: {model_name}")
except Exception as e:
    print(f"[警告] 加载 {model_name} 失败: {e}")
    # 尝试备用:bert-base-chinese
    try:
        model_name = "bert-base-chinese"
        print("[备用] 尝试加载 bert-base-chinese...")
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        bert_cls = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
        bert_cls = bert_cls.to(DEVICE)
        _HAS_REAL_MODEL = True
        print(f"[模型] 成功加载: {model_name}")
    except Exception as e2:
        print(f"[警告] 所有模型下载均失败 ({e2})")
        print("[回退] 使用本地简化模型进行演示...")
        # 回退:创建一个微型随机 BERT 风格模型
        class TinyFallbackClassifier(nn.Module):
            def __init__(self, vocab_size=1000, num_labels=2):
                super().__init__()
                self.embedding = nn.Embedding(vocab_size, 32)
                self.encoder = nn.TransformerEncoder(
                    nn.TransformerEncoderLayer(d_model=32, nhead=2, dim_feedforward=64, batch_first=True),
                    num_layers=2
                )
                self.classifier = nn.Linear(32, num_labels)
                self.config = type('obj', (object,), {'hidden_size': 32})()
                self.device = DEVICE

            def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
                x = self.embedding(input_ids)
                x = self.encoder(x)
                x = x.mean(dim=1)
                logits = self.classifier(x)
                loss = F.cross_entropy(logits, labels) if labels is not None else None
                return type('obj', (object,), {'loss': loss, 'logits': logits})()

        try:
            tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
        except Exception:
            tokenizer = None
        if tokenizer is None:
            # 无 tokenizer 可用,创建一个最简的字符级编码器
            class SimpleCharTokenizer:
                def __init__(self):
                    self.vocab = {c: i for i, c in enumerate('abcdefghijklmnopqrstuvwxyz0123456789')}
                def __call__(self, text, truncation=True, padding='max_length', max_length=128, return_tensors='pt', **kwargs):
                    ids = [self.vocab.get(c, 0) for c in text.lower()[:max_length]]
                    ids = ids + [0] * (max_length - len(ids))
                    return {'input_ids': torch.tensor([ids]), 'attention_mask': torch.tensor([[1]*len(ids)])}
            tokenizer = SimpleCharTokenizer()
        bert_cls = TinyFallbackClassifier().to(DEVICE)
        print(f"[回退] 使用 TinyFallbackClassifier(仅用于演示流程,不是真实 BERT)")

print(f"[模型] 参数量: {sum(p.numel() for p in bert_cls.parameters()):,}")
print(f"[数据] 训练样本: {len(train_data)}, 验证样本: {len(eval_data)}")

# 准备数据集
train_dataset = SentimentDataset(train_data, tokenizer)
eval_dataset = SentimentDataset(eval_data, tokenizer)

# 训练参数(少量 epoch 演示微调)
training_args = TrainingArguments(
    output_dir="./bert_sentiment_checkpoints",  # 模型保存路径
    num_train_epochs=4,                         # 微调 epoch 数(预训练模型只需少量 epoch)
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    eval_strategy="epoch",                # 每个 epoch 评估一次
    save_strategy="epoch",
    logging_strategy="steps",
    logging_steps=5,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    report_to="none",                           # 不上报到 wandb 等平台
)

# 定义评估指标
def compute_metrics(eval_pred):
    """计算准确率"""
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    accuracy = (predictions == labels).mean()
    return {"accuracy": accuracy}


if _HAS_REAL_MODEL:
    trainer = Trainer(
        model=bert_cls,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        compute_metrics=compute_metrics,
    )

    print("\n[微调] 开始训练...(使用预训练 BERT + 少量情感标注数据)")
    trainer.train()

    # 评估
    eval_results = trainer.evaluate()
    print(f"\n[评估] 验证集准确率: {eval_results.get('eval_accuracy', 'N/A'):.4f}")
else:
    print("\n[微调] 跳过(使用回退模型,无需训练)")
    # 使用回退模型进行简单训练
    if DEVICE.type == 'cpu':
        n_fallback_epochs = 1
        n_fallback_samples = 8  # CPU 模式:仅用 8 条样本快速演示
        train_dataset_small = SentimentDataset(train_data[:n_fallback_samples], tokenizer)
        print("[配置] CPU 模式:使用轻量参数快速演示(回退模型 1 epoch, 8 样本)。GPU 模式下将使用完整训练配置。")
    else:
        n_fallback_epochs = 3
        train_dataset_small = train_dataset
    bert_cls.train()
    import torch.optim as optim
    optimizer_ft = optim.Adam(bert_cls.parameters(), lr=0.01)
    for epoch in range(n_fallback_epochs):
        total_loss = 0.0
        for batch in DataLoader(train_dataset_small, batch_size=4, shuffle=True):
            input_ids = batch['input_ids'].to(DEVICE)
            labels = batch['labels'].to(DEVICE)
            optimizer_ft.zero_grad()
            out = bert_cls(input_ids, labels=labels)
            out.loss.backward()
            optimizer_ft.step()
            total_loss += out.loss.item()
        print(f"  Epoch {epoch+1}: loss={total_loss/len(train_dataset_small)*4:.4f}")
    print("[回退训练] 完成(此训练无法达到预训练BERT的效果,仅用于演示流程)")

# 预测新样本
print("\n[预测] 对新样本进行情感分类:")
test_texts = [
    "这个东西简直太好用了,完全超出预期",
    "质量好坏不说,光是等了一周就不想买了",
    "性价比还不错,但也算不上特别惊艳",
]
for text in test_texts:
    inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=128)
    device = bert_cls.device
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        logits = bert_cls(**inputs).logits
    probs = F.softmax(logits, dim=-1)
    pred = torch.argmax(logits, dim=-1).item()
    label_str = "正面 👍" if pred == 1 else "负面 👎"
    print(f"  文本: {text}")
    print(f"  预测: {label_str} (正面概率: {probs[0][1].item():.3f}, 负面概率: {probs[0][0].item():.3f})")
    print()

# ============================================================
# 第二部分:BERT MLM 掩码预测
# ============================================================

print("=" * 60)
print("[BERT MLM] 掩码预测能力展示")
print("=" * 60)

# 加载专用的 MLM 模型
print("[模型] 加载 bert-base-chinese MLM...")
try:
    mlm_tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
    mlm_model = AutoModelForMaskedLM.from_pretrained("bert-base-chinese")
    mlm_model = mlm_model.to(DEVICE)
    print("[模型] 加载成功")
except Exception as e:
    print(f"[警告] 加载失败: {e}, 跳过 MLM 演示")
    mlm_model = None

if mlm_model is not None:
    # 使用 pipeline 简化调用
    mlm_pipeline = pipeline(
        "fill-mask",
        model=mlm_model,
        tokenizer=mlm_tokenizer,
        device=0 if DEVICE.type == 'cuda' else -1,
    )

    # 测试 MLM:在不同上下文中预测 [MASK] 的词
    mlm_examples = [
        "今天天气真[MASK],适合出去郊游。",
        "这个手机拍照效果很[MASK],我非常满意。",
        "深度学习是人工智能的一个重要[MASK]。",
        "他每天坚持[MASK]身体,所以非常健康。",
        "这家餐厅的菜品味道[MASK],价格也合理。",
    ]

    for text in mlm_examples:
        results = mlm_pipeline(text, top_k=3)
        print(f"\n  原文: {text}")
        print(f"  Top-3 预测:")
        for r in results:
            print(f"    [{r['score']:.3f}] {r['token_str']}{r['sequence']}")


# ============================================================
# 第三部分:GPT-2 文本生成(对比 BERT)
# ============================================================

print("\n" + "=" * 60)
print("[GPT-2 文本生成] 对比 BERT 的生成能力")
print("=" * 60)

# 加载 GPT-2 中文模型
gpt_model_name = "uer/gpt2-chinese-cluecorpussmall"
print(f"[模型] 加载 {gpt_model_name}...")

try:
    gpt_tokenizer = AutoTokenizer.from_pretrained(gpt_model_name)
    gpt_model = AutoModelForCausalLM.from_pretrained(gpt_model_name)
    gpt_model = gpt_model.to(DEVICE)
    # 设置 pad_token
    if gpt_tokenizer.pad_token is None:
        gpt_tokenizer.pad_token = gpt_tokenizer.eos_token
    print("[模型] GPT-2 加载成功")
    print(f"[模型] GPT-2 参数量: {sum(p.numel() for p in gpt_model.parameters()):,}")

    # 文本生成
    prompts = [
        "人工智能的发展历程可以追溯到",
        "今天天气很好,我决定去",
        "在深度学习中,神经网络的训练过程",
    ]

    for prompt in prompts:
        inputs = gpt_tokenizer(prompt, return_tensors='pt')
        inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
        # 生成文本
        outputs = gpt_model.generate(
            **inputs,
            max_new_tokens=30,         # 最多生成 30 个新 token
            temperature=0.8,           # 温度参数控制随机性
            do_sample=True,             # 使用采样而非贪心解码
            top_p=0.9,                  # nucleus sampling
            repetition_penalty=1.1,     # 抑制重复
            pad_token_id=gpt_tokenizer.pad_token_id,
        )
        generated = gpt_tokenizer.decode(outputs[0], skip_special_tokens=True)
        print(f"\n  提示: {prompt}")
        print(f"  GPT-2: {generated}")

    # 对比:尝试让 BERT 生成文本
    print("\n[对比] BERT 能否生成文本?")
    print("  BERT 是 Encoder-only 架构,使用双向注意力,无法自回归生成。")
    print("  尝试将 BERT 用于生成 → 缺乏因果注意力+自回归训练 → 输出无意义的乱码。")
    print("  这就是为什么:理解任务用 BERT,生成任务用 GPT。")

except Exception as e:
    print(f"[警告] GPT-2 加载失败: {e}")
    print("  (这是正常的,可以跳过此部分)")
    print("  核心理解:BERT=biderectional→理解,GPT=causal→生成")

# ============================================================
# 第四部分:嵌入可视化(BERT 上下文嵌入对比 word2vec)
# ============================================================

print("\n" + "=" * 60)
print("[上下文嵌入] BERT vs word2vec — 多义词对比")
print("=" * 60)

if mlm_model is not None:
    # 展示 BERT 的上下文相关嵌入
    # 同一词在不同上下文中会有不同的向量表示
    test_sentences = [
        "我喜欢吃苹果,特别是红富士苹果",
        "苹果公司发布了最新的iPhone手机",
        "我在超市买了三个苹果",
        "苹果的股价今天上涨了百分之五",
    ]

    for sentence in test_sentences:
        inputs = mlm_tokenizer(sentence, return_tensors='pt')
        with torch.no_grad():
            # 获取 BERT 的隐藏状态(取最后一层的 [CLS] 或特定 token)
            outputs = mlm_model.base_model(**inputs)
            # 取最后一个隐藏层的所有 token 向量
            last_hidden = outputs.last_hidden_state  # (1, seq_len, 768)
            # 找到"苹果"token 的位置
            tokens = mlm_tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
            # 打印句子和"苹果"的 token 位置
            apple_positions = [i for i, t in enumerate(tokens) if '苹' in t]
            print(f"\n  句子: {sentence}")
            print(f"  Token 序列: {tokens}")
            if len(apple_positions) >= 2:
                # 计算两个"苹果"嵌入的余弦相似度
                v1 = last_hidden[0, apple_positions[0]]
                v2 = last_hidden[0, apple_positions[1]]
                sim = F.cosine_similarity(v1.unsqueeze(0), v2.unsqueeze(0)).item()
                print(f'  句子内"苹果"相似度: {sim:.4f}')
    print("\n  关键观察:同一句中两个'苹果'的上下文嵌入高度相似")
    print("  (因为它们都在'水果'上下文中)")

# ============================================================
# 第五部分:总结
# ============================================================

print("\n" + "=" * 60)
print("[总结] BERT 与 GPT 的核心差异")
print("=" * 60)
print("""
  架构差异:
    BERT: Encoder-only, 双向注意力, 适合"理解"任务
    GPT:  Decoder-only, 因果注意力, 适合"生成"任务

  训练目标:
    BERT: MLM (Masked Language Model) — 预测被遮盖的词
    GPT:  CLM (Causal Language Model) — 预测下一个词

  预训练-微调范式:
    一次大规模预训练 → 添加简单任务头 → 快速微调 → 部署

  BERT 的弱点:
    无法生成连贯文本 — 因为它不是自回归的
    固定长度输入 — 最大512 tokens (原始版本)
    [MASK] 在微调时不存在 — 预训练与微调之间有gap

  GPT 的优势:
    天然适合生成文本 — 从左到右逐个输出
    随着规模增大涌现新能力 — In-context learning, CoT
    统一的输入输出格式 — 所有任务都是"续写文本"

  下一章 s18: 大语言模型的 Scaling Law, 涌现能力, 以及对齐技术
""")
print("所有 demo 运行完成!")