【训练与微调篇03】全量微调实践:从数据到部署的SFT完整指南

🎯 全量微调实践:从数据到部署的SFT完整指南

预训练给了模型"语言能力",微调给了模型"办事能力"。2026 年,全量微调已经从"全参数更新"演进为"分层 + 分阶段"的精细操作——日活千万的AI产品背后,都有一套精心设计的微调策略。

📑 目录


一、为什么需要全量微调?

1.1 全量微调 vs LoRA 的抉择

维度全量微调 (FFT)LoRA/QLoRA
更新参数全部(100%)少量(0.1-1%)
知识注入深度 — 改变模型内部表示低 — 学习"回答风格"
学习新知识✅ 可注入新事实❌ 几乎不可能
领域适应能力优秀 — 如法律、医疗⚠️ 有限
训练显存高(70B需~400GB)低(70B仅需~48GB)
训练时间周级小时级
模型质量上限更高受限于秩®
多任务鲁棒性✅ 好❌ 容易过拟合

决策规则

你的需求是?
├── 学习新知识/新格式 → 全量微调(FFT)
├── 适配回答风格/格式 → LoRA(高效)
├── 领域深度适应(医疗/法律/金融) → FFT(全量)
├── 已有7B模型,希望更快迭代 → LoRA
└── 用做基座,后续做RLHF → FFT(全量)

1.2 2026年共识:混合策略最佳

最先进的微调方案不再是"全量"或"LoRA"二选一,而是分阶段混合

Stage 1: 全量微调(新知识注入) → 8×H200, 3天
  学习新领域数据、新格式、新任务

Stage 2: LoRA 适配(风格对齐) → 1×RTX 5090, 2小时
  微调回答风格、格式偏好

Stage 3: DPO/GRPO(偏好对齐) → 8×H200, 1天
  人类偏好对齐


二、SFT的三种模式

2.1 标准 SFT(Next Token Prediction)

最经典的方法:输入"问题+答案",让模型预测下一个 token。

用户: 什么是注意力机制?
助手: 注意力机制是...

Loss = CrossEntropy(predicted_token, ground_truth_token)
     只计算"助手"部分的loss,不计算"用户"部分的loss

2.2 对话式 SFT(Multi-turn)

包含多轮对话,训练模型维持上下文:

用户: 什么是注意力机制?
助手: 注意力机制是...【计算loss】
用户: 那它和Transformer有什么关系?【不计算loss】
助手: Transformer的核心就是...【计算loss】

关键:只对助手回复计算 loss,用户输入部分 mask 掉。

2.3 多任务 SFT(Multi-task)

在同一个模型中训练多种任务类型:

任务1: [分类] 这封邮件是垃圾邮件吗?是【计算loss】
任务2: [摘要] 将以下文章总结为三句话...【计算loss】
任务3: [QA] 什么是...【计算loss】
任务4: [代码] 写一个快排...【计算loss】
任务5: [翻译] 将中文翻译成英文...【计算loss】


三、SFT数据构建方法论

3.1 高质量SFT数据的"黄金标准"

2026 年各实验室的 SFT 数据配方 [1][2]:

维度标准说明
多样性50+ 任务类型问答、写作、代码、翻译、摘要、分类…
难度渐进简单→中等→困难从"帮我写一封邮件"到"推导黎曼猜想"
格式统一一致的对话模板ChatML、Qwen、LLaMA 三种主流格式
质量优先DataFlow 验证10K 高质量 ≈ 50K 粗筛
答案长度128-2048 tokens短答案学格式,长答案学推理
拒绝回答5-10%“我无法回答这个问题”——防幻觉

3.2 数据格式标准化

"""三种主流对话模板"""
from typing import List, Dict

class ChatTemplate:
    """对话模板格式化"""
    
    @staticmethod
    def format_chatml(messages: List[Dict]) -> str:
        """ChatML 格式(OpenAI/DeepSeek 使用)"""
        formatted = ""
        for msg in messages:
            role = msg["role"]
            content = msg["content"]
            formatted += f"<|im_start|>{role}\n{content}<|im_end|>\n"
        return formatted
    
    @staticmethod
    def format_qwen(messages: List[Dict]) -> str:
        """Qwen 格式"""
        formatted = ""
        for msg in messages:
            role = msg["role"]
            content = msg["content"]
            if role == "system":
                formatted += f"<|system|>\n{content}\n"
            elif role == "user":
                formatted += f"<|user|>\n{content}\n"
            elif role == "assistant":
                formatted += f"<|assistant|>\n{content}\n"
        formatted += "<|assistant|>\n"
        return formatted
    
    @staticmethod
    def format_llama(messages: List[Dict]) -> str:
        """LLaMA 4 格式"""
        formatted = "<|begin_of_text|>"
        for msg in messages:
            role = msg["role"]
            content = msg["content"]
            if role == "system":
                formatted += f"<|start_header_id|>system<|end_header_id|>\n\n{content}<|eot_id|>"
            elif role == "user":
                formatted += f"<|start_header_id|>user<|end_header_id|>\n\n{content}<|eot_id|>"
            elif role == "assistant":
                formatted += f"<|start_header_id|>assistant<|end_header_id|>\n\n{content}<|eot_id|>"
        formatted += "<|start_header_id|>assistant<|end_header_id|>\n\n"
        return formatted

3.3 数据构建流程

"""高质量SFT数据构建 Pipeline"""

import json
import random
from typing import List, Dict, Generator

class SFTDataBuilder:
    """SFT训练数据构建器"""
    
    def __init__(self, base_model: str = "Qwen3-72B"):
        self.base_model = base_model
        self.task_templates = self._load_task_templates()
    
    def _load_task_templates(self) -> Dict:
        """加载任务模板"""
        return {
            "qa": {
                "system": "你是一个知识渊博的助手,请准确回答用户问题。",
                "count": 0.3,  # 30%
                "difficulty": ["简单", "中等", "困难"]
            },
            "writing": {
                "system": "你是一个写作助手,请根据要求完成各类写作任务。",
                "count": 0.2,  # 20%
                "difficulty": ["简单", "中等"]
            },
            "code": {
                "system": "你是一个编程助手,请写出正确高效的代码。",
                "count": 0.2,  # 20%
                "difficulty": ["中等", "困难"]
            },
            "reasoning": {
                "system": "请一步步思考,给出详细的推理过程。",
                "count": 0.15,  # 15%
                "difficulty": ["中等", "困难"]
            },
            "translation": {
                "system": "你是一个翻译助手,请准确翻译以下内容。",
                "count": 0.1,  # 10%
                "difficulty": ["简单", "中等"]
            },
            "safe_refusal": {
                "system": "你是一个负责任的AI助手。",
                "count": 0.05,  # 5% - 拒绝回答
                "difficulty": ["简单"]
            }
        }
    
    def build_dataset(
        self, 
        total_count: int = 50000,
        output_path: str = "sft_data.jsonl"
    ):
        """构建完整数据集"""
        with open(output_path, 'w', encoding='utf-8') as f:
            for task, config in self.task_templates.items():
                n = int(total_count * config["count"])
                for _ in range(n):
                    sample = self._generate_sample(task, config)
                    f.write(json.dumps(sample, ensure_ascii=False) + '\n')
        
        print(f"数据集已生成: {output_path}")
        print(f"总样本数: {total_count}")
        for task, config in self.task_templates.items():
            n = int(total_count * config["count"])
            print(f"  {task}: {n} ({config['count']*100:.0f}%)")
    
    def _generate_sample(self, task: str, config: Dict) -> Dict:
        """生成单个训练样本"""
        # 从真实数据或合成数据中采样
        # 实际中应该用高质量的标注数据
        return {
            "messages": [
                {"role": "system", "content": config["system"]},
                {"role": "user", "content": f"这是一个{task}类型的{random.choice(config['difficulty'])}问题"},
                {"role": "assistant", "content": "这是一个高质量的回复。"}
            ],
            "task": task,
            "difficulty": random.choice(config["difficulty"])
        }


# ===== 数据质量检查 =====
class SFTDataValidator:
    """SFT数据质量检查器"""
    
    @staticmethod
    def validate_sample(sample: Dict) -> List[str]:
        """检查单个样本的质量"""
        issues = []
        
        messages = sample.get("messages", [])
        
        # 检查对话结构
        if len(messages) < 2:
            issues.append("对话轮次不足")
        
        if messages[-1]["role"] != "assistant":
            issues.append("最后一条必须是assistant回复")
        
        # 检查内容质量
        for msg in messages:
            content = msg.get("content", "")
            if len(content) < 5:
                issues.append(f"{msg['role']}回复过短: {content}")
            
            # 检查重复
            if len(set(content.split())) / max(len(content.split()), 1) < 0.3:
                issues.append(f"{msg['role']}回复重复过多")
            
            # 检查无用字符
            if content.count('\u0000') > 0:
                issues.append("包含空字符")
        
        return issues
    
    @staticmethod
    def dataset_stats(dataset: List[Dict]) -> Dict:
        """统计数据集"""
        stats = {
            "total": len(dataset),
            "avg_user_len": 0,
            "avg_assistant_len": 0,
            "task_distribution": {},
            "issues": []
        }
        
        user_lens = []
        assistant_lens = []
        
        for sample in dataset:
            for msg in sample["messages"]:
                if msg["role"] == "user":
                    user_lens.append(len(msg["content"]))
                elif msg["role"] == "assistant":
                    assistant_lens.append(len(msg["content"]))
            
            # 任务分布
            task = sample.get("task", "unknown")
            stats["task_distribution"][task] = stats["task_distribution"].get(task, 0) + 1
            
            # 质量检查
            issues = SFTDataValidator.validate_sample(sample)
            if issues:
                stats["issues"].extend(issues)
        
        stats["avg_user_len"] = sum(user_lens) / max(len(user_lens), 1)
        stats["avg_assistant_len"] = sum(assistant_lens) / max(len(assistant_lens), 1)
        
        return stats

3.4 数据质量 vs 模型性能

DataFlow 实验的关键发现 [1]:

数据集规模模型数学得分代码得分备注
原始 Alpaca52K7B基线基线原始
DataFlow SFT15K7B+9.3+5.115K > 52K
DataFlow 数学10K32B55.7超过 Open-R1
DataFlow 代码10K7B46.210K 代码数据
DataFlow Chat15K7BAlpacaEval 7.05→10.11

核心结论:15K 精心构造的 DataFlow 数据 > 52K 原始 Alpaca 数据。数据质量比数量重要 3 倍以上


四、损失函数与训练策略

4.1 标准 Cross-Entropy Loss

"""SFT 损失函数详解"""
import torch
import torch.nn.functional as F

def sft_loss(
    logits: torch.Tensor,       # (batch, seq_len, vocab_size)
    labels: torch.Tensor,       # (batch, seq_len)  -100 表示忽略
    attention_mask: torch.Tensor = None
) -> torch.Tensor:
    """
    SFT 交叉熵损失
    - 只计算 assistant 部分(labels 中 != -100 的位置)
    - 忽略 user 部分和 padding 部分
    """
    batch_size, seq_len, vocab_size = logits.shape
    
    # 移除最后一个 logit(不需要预测最后一个 token 之后的内容)
    shift_logits = logits[:, :-1, :].contiguous()
    shift_labels = labels[:, 1:].contiguous()
    
    # Flatten
    shift_logits = shift_logits.view(-1, vocab_size)
    shift_labels = shift_labels.view(-1)
    
    # 计算 loss,忽略 -100 的位置
    loss = F.cross_entropy(
        shift_logits,
        shift_labels,
        ignore_index=-100,  # 关键!忽略 user 输入和 padding
        reduction='mean'
    )
    
    return loss

4.2 Loss Masking 详解

![Loss Masking 示意图]

输入: "<|user|>\n什么是AI?<|end|>\n<|assistant|>\n人工智能是..."

labels: [-100, -100, -100, ..., -100,     |     "人", "工", "智", "能", ...]
         ↑ user部分全部ignore               ↑ 助手部分正常计算loss

为什么这么做?
- user输入是"指令",模型不应该从"提问方式"中学习
- 只从"助手的回答"中学习
- 防止模型学会"如何提问",只学会"如何回答"

"""构建正确的 labels(mask user 部分)"""
from transformers import DataCollatorForSeq2Seq

class SFTDataCollator(DataCollatorForSeq2Seq):
    """SFT专用数据整理器:自动 mask user 输入"""
    
    def __call__(self, features):
        batch = super().__call__(features)
        
        # labels 中 tokenizer.pad_token_id 的位置设为 -100
        batch["labels"] = torch.where(
            batch["labels"] == self.tokenizer.pad_token_id,
            -100,
            batch["labels"]
        )
        
        return batch


# 更精细的 mask:根据 message 来源
def mask_user_tokens(
    input_ids: torch.Tensor,
    tokenizer,
    assistant_token_id: int = 77091  # Qwen3 的 <|assistant|> token
) -> torch.Tensor:
    """
    将 assistant 之前的所有 token 都 mask 掉
    只保留 assistant 回复部分的 loss
    """
    labels = input_ids.clone()
    
    for i in range(input_ids.shape[0]):
        # 找到第一个 assistant token 的位置
        assistant_positions = (input_ids[i] == assistant_token_id).nonzero()
        
        if len(assistant_positions) > 0:
            # 找到最后一个 assistant token(多轮对话中每个助手回复都保留")
            last_assistant = assistant_positions[-1].item()
            # 之前的所有内容 mask 掉
            labels[i, :last_assistant] = -100
    
    return labels

4.3 辅助 Loss:提高训练稳定性

"""SFT 训练辅助损失"""

class SFTLossWithAux:
    """带辅助损失的全量微调"""
    
    def __init__(self, alpha_ce=1.0, alpha_kl=0.1, alpha_embed=0.01):
        self.alpha_ce = alpha_ce    # 主损失权重
        self.alpha_kl = alpha_kl    # KL散度(防止偏离原始模型)
        self.alpha_embed = alpha_embed  # Embedding 正则
    
    def compute_loss(
        self,
        model,
        student_logits: torch.Tensor,    # 微调模型输出
        teacher_logits: torch.Tensor,    # 原始模型输出(冻结)
        labels: torch.Tensor,            # 标签
    ) -> torch.Tensor:
        """主损失 + KL散度 + Embedding正则"""
        
        # 1. 主损失:Cross-Entropy
        loss_ce = F.cross_entropy(
            student_logits.view(-1, student_logits.size(-1)),
            labels.view(-1),
            ignore_index=-100,
            reduction='mean'
        )
        
        # 2. KL散度损失:防止模型偏离原始模型过远
        # 只在非 -100 的位置计算
        mask = (labels != -100).unsqueeze(-1).expand_as(student_logits)
        
        student_log_probs = F.log_softmax(student_logits, dim=-1)
        teacher_probs = F.softmax(teacher_logits, dim=-1)
        
        kl_div = F.kl_div(
            student_log_probs[mask].view(-1, student_logits.size(-1)),
            teacher_probs[mask].view(-1, student_logits.size(-1)),
            reduction='batchmean',
            log_target=False
        )
        
        # 3. Embedding 正则(可选)
        embed_norm = sum(
            p.norm(2) for name, p in model.named_parameters()
            if 'embed' in name
        )
        
        total_loss = (
            self.alpha_ce * loss_ce +
            self.alpha_kl * kl_div +
            self.alpha_embed * embed_norm
        )
        
        return total_loss


五、超参数调优指南

5.1 核心超参数

参数推荐值范围说明
learning_rate1e-55e-6 ~ 5e-5全量微调比预训练低10-100倍
batch_size12864 ~ 512尽量大,不足靠梯度累积
num_epochs2-31 ~ 5多了过拟合,少了学不够
warmup_ratio0.030.01 ~ 0.1前3%步预热
weight_decay0.10.01 ~ 0.1防止过拟合
max_grad_norm1.00.5 ~ 2.0梯度裁剪
lr_schedulercosinecosine, linearcosine更平滑
optimAdamWAdamW, AdamW-fusedfused版本快30%

5.2 学习率策略详解

"""全量微调学习率调度策略"""
import math

def get_scheduler(
    optimizer,
    num_training_steps: int,
    warmup_ratio: float = 0.03,
    scheduler_type: str = "cosine"
):
    """SFT 学习率调度器"""
    
    warmup_steps = int(num_training_steps * warmup_ratio)
    
    def lr_lambda(current_step: int) -> float:
        if current_step < warmup_steps:
            # 线性预热
            return float(current_step) / float(max(1, warmup_steps))
        
        progress = float(current_step - warmup_steps) / float(
            max(1, num_training_steps - warmup_steps)
        )
        
        if scheduler_type == "cosine":
            # 余弦退火
            return 0.5 * (1.0 + math.cos(math.pi * progress))
        elif scheduler_type == "linear":
            # 线性衰减
            return 1.0 - progress
        elif scheduler_type == "constant":
            # 恒定
            return 1.0
        else:
            return 1.0
    
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

不同策略效果对比

学习率
  ^
  |   [预热] [------训练过程------]
1e-5 |  /\
     | /  \______   cosine (推荐)
1e-5 | /        \__
     |
 0   +-------------------------------> steps
     |                                      |    
     |                                      |
1e-5 |  /\___________   constant (简单,但收敛不佳)
1e-5 | /            \__
     |
 0   +------------------------------->

5.3 Batch Size 选择

"""自动计算最优 batch size"""
import torch.cuda as cuda

def auto_batch_size(model_size_b: int = 7, target_memory_gb: float = 80):
    """
    根据显存自动计算最大 batch size
    """
    # 估算每 token 的显存需求
    param_memory = model_size_b * 2  # FP16: 2 bytes per param
    optimizer_memory = param_memory * 3  # Adam: 动量+方差+参数
    gradient_memory = param_memory
    
    # 激活内存(主要消耗)
    activation_memory_per_token = model_size_b * 0.5  # 估算
    
    # 剩余显存
    total_memory = target_memory_gb * 1024  # MB
    model_memory = (param_memory + optimizer_memory + gradient_memory) * 1024 / 8  # 转MB
    
    available = total_memory - model_memory
    
    # 每 token 激活约 X MB
    max_tokens = available / (activation_memory_per_token * 1024 / 8)
    
    # 假设 seq_len = 4096
    seq_len = 4096
    max_batch = max_tokens / seq_len
    
    return max(1, int(max_batch))

经验法则

  • 7B 模型 + 8×H200:batch_size=128(8×16 grad_accum)
  • 13B 模型 + 8×H200:batch_size=64
  • 70B 模型 + 16×H200:batch_size=32-64

六、全量微调实战:完整代码

6.1 完整训练脚本

"""Qwen3-7B 全量微调完整脚本"""
import os
import torch
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorForSeq2Seq,
    Qwen2ForCausalLM,  # Qwen3 也基于 Qwen2 架构
)
from trl import SFTTrainer  # 2026年更推荐用 TRL
import wandb

# ===== 0. 配置 =====
MODEL_NAME = "Qwen/Qwen3-7B"
DATA_PATH = "sft_data.jsonl"
OUTPUT_DIR = "qwen3-7b-sft"

# ===== 1. 加载模型和 tokenizer =====
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
    attn_implementation="flash_attention_2",  # Flash Attention 加速
    device_map="auto",                         # 自动分配到多GPU
)

tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True,
    padding_side="right",
)

# 设置 pad_token
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# ===== 2. 准备数据集 =====
dataset = load_dataset("json", data_files=DATA_PATH, split="train")

def apply_chat_template(examples):
    """应用对话模板"""
    texts = []
    for messages in examples["messages"]:
        text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=False,
        )
        texts.append(text)
    return {"text": texts}

dataset = dataset.map(
    apply_chat_template,
    batched=True,
    remove_columns=dataset.column_names,
)

def tokenize_fn(examples):
    return tokenizer(
        examples["text"],
        truncation=True,
        max_length=4096,
        padding=False,
        return_tensors=None,
    )

tokenized_dataset = dataset.map(tokenize_fn, batched=True, remove_columns=["text"])

# 划分训练/验证
split_dataset = tokenized_dataset.train_test_split(test_size=0.01, seed=42)

# ===== 3. 训练参数 =====
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    
    # 核心训练参数
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=8,         # 等效 batch = 2×8×GPU = 128 (8GPU)
    num_train_epochs=3,
    learning_rate=1e-5,
    warmup_ratio=0.03,
    lr_scheduler_type="cosine",
    
    # 优化
    optim="adamw_torch_fused",            # fused AdamW,快30%
    bf16=True,                             # BF16 比 FP16 更稳定
    gradient_checkpointing=True,           # 减少显存,增加计算
    gradient_checkpointing_kwargs={"use_reentrant": False},
    
    # 分布式
    ddp_find_unused_parameters=False,
    ddp_bucket_cap_mb=100,
    
    # 保存与评估
    save_strategy="steps",
    save_steps=200,
    save_total_limit=3,
    evaluation_strategy="steps",
    eval_steps=200,
    logging_steps=10,
    report_to=["wandb"],
    
    # 其他
    max_grad_norm=1.0,
    weight_decay=0.1,
    seed=42,
    data_seed=42,
    dataloader_num_workers=4,
    remove_unused_columns=False,
)

# ===== 4. 数据整理器 =====
collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    padding=True,
    max_length=4096,
)

# ===== 5. 启动训练 =====
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=split_dataset["train"],
    eval_dataset=split_dataset["test"],
    data_collator=collator,
    tokenizer=tokenizer,
)

# 开始训练
trainer.train()

# ===== 6. 保存 =====
trainer.save_model()
tokenizer.save_pretrained(OUTPUT_DIR)
print(f"模型已保存到: {OUTPUT_DIR}")

6.2 使用 TRL SFTTrainer(2026推荐)

"""使用 TRL 的 SFTTrainer — 更简洁的方式"""
from trl import SFTTrainer

trainer = SFTTrainER(
    model=MODEL_NAME,
    train_dataset=dataset,
    args=TrainingArguments(
        output_dir=OUTPUT_DIR,
        per_device_train_batch_size=2,
        gradient_accumulation_steps=8,
        num_train_epochs=3,
        learning_rate=1e-5,
        bf16=True,
        logging_steps=10,
        save_steps=200,
        report_to=["wandb"],
    ),
    tokenizer=tokenizer,
    max_seq_length=4096,
    dataset_text_field="text",        # 直接从text字段训练
    packing=False,                     # 是否packing多个样本
)

trainer.train()

TRL SFTTrainer vs Trainer:TRL 自动处理了 loss masking、chat template 等 SFT 专属逻辑,2026 年 SFT 首选 TRL [3]。

6.3 分布式启动命令

# 8卡全量微调 Qwen3-7B
deepspeed --num_gpus=8 train_sft.py

# 16卡跨节点
deepspeed --num_gpus=8 --num_nodes=2 \
    --master_addr=master_node \
    train_sft.py

# DeepSpeed ZeRO-3 配置
cat << 'EOF' > ds_config.json
{
    "zero_optimization": {
        "stage": 3,
        "overlap_comm": true,
        "contiguous_gradients": true
    },
    "bf16": {"enabled": true},
    "gradient_accumulation_steps": 8,
    "gradient_clipping": 1.0
}
EOF

deepspeed --num_gpus=8 train_sft.py \
    --deepspeed ds_config.json


七、分阶段训练策略

7.1 三阶段 SFT

第一阶段:知识注入(Epoch 1)
  lr=5e-5,  unfreeze all
  目标:让模型学习新领域的知识
  风险:有可能灾难性遗忘

第二阶段:格式对齐(Epoch 2)  
  lr=1e-5,  unfreeze all  
  目标:让模型学会目标格式和风格
  注意:降低学习率以免破坏原始能力

第三阶段:精调(Epoch 3)
  lr=5e-6,  layer-wise decay  
  目标:精细调整,提升输出质量
  注意:只调整顶层

7.2 Layer-wise Learning Rate Decay

"""分层学习率:底层小、顶层大"""
def layer_wise_lr(model, base_lr=5e-5, decay_rate=0.9):
    """
    底层(embedding, early layers)→ 低学习率
    顶层(last layers, head)→ 高学习率
    """
    param_groups = []
    
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue
        
        # 计算层级
        if "embed" in name or "norm" in name:
            lr_scale = 0.3  # 底层
        elif "layers.0" in name or "layers.1" in name:
            lr_scale = 0.5
        elif "layers" in name:
            # 提取层号
            import re
            match = re.search(r'layers\.(\d+)', name)
            if match:
                layer_idx = int(match.group(1))
                total_layers = model.config.num_hidden_layers or 32
                lr_scale = decay_rate ** (total_layers - layer_idx)
            else:
                lr_scale = 1.0
        elif "lm_head" in name or "embed_tokens" in name:
            lr_scale = 2.0  # 输出层
        else:
            lr_scale = 1.0
        
        param_groups.append({
            "params": param,
            "lr": base_lr * lr_scale,
        })
    
    return torch.optim.AdamW(param_groups, lr=base_lr, betas=(0.9, 0.95))

7.3 Curriculum Learning(课程学习)

"""课程学习:从简单到困难"""
class CurriculumDataset(torch.utils.data.Dataset):
    """课程学习数据集"""
    
    def __init__(self, dataset, difficulty_fn):
        self.dataset = dataset
        self.difficulty_fn = difficulty_fn
        self.current_step = 0
        self.total_steps = 10000
        
        # 按难度排序
        self.sorted_indices = sorted(
            range(len(dataset)),
            key=lambda i: difficulty_fn(dataset[i])
        )
        
    def update_step(self, step: int):
        self.current_step = step
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        # 根据训练进度采样不同难度
        progress = self.current_step / self.total_steps
        
        if progress < 0.3:
            # 简单:前30% 的数据
            max_idx = int(len(self.dataset) * 0.3)
            real_idx = self.sorted_indices[
                np.random.randint(0, max_idx)
            ]
        elif progress < 0.7:
            # 中等:30%-70%
            start = int(len(self.dataset) * 0.3)
            end = int(len(self.dataset) * 0.7)
            real_idx = self.sorted_indices[
                np.random.randint(start, end)
            ]
        else:
            # 困难:最后30%
            start = int(len(self.dataset) * 0.7)
            real_idx = self.sorted_indices[
                np.random.randint(start, len(self.dataset))
            ]
        
        return self.dataset[real_idx]


八、灾难性遗忘与解决方案

8.1 遗忘的本质

全量微调的本质:
- 更新模型的所有参数
- 学习新数据 → 修改权重
- 修改权重 → 可能覆盖原始知识

灾难性遗忘:模型"忘记"了预训练中学会的知识
  → MMLU 得分下降 5-15%
  → 通用能力退化

8.2 六大防遗忘策略

策略实现效果额外成本
1. EWC对重要参数加正则2x 显存
2. Replay混合5-10%原始数据最好1x 训练时间
3. KL散度与原始模型做KL约束2x 显存
4. 分层训练冻结底层,只更新顶层中等
5. 渐进式LR底层LR小,顶层LR大中等
6. 混合数据10% 通用+90% 领域很好1x 数据
"""防遗忘:混合训练数据"""
def build_anti_forget_dataset(
    domain_data: List[Dict],      # 领域SFT数据
    general_data: List[Dict],     # 通用SFT数据  
    domain_ratio: float = 0.9     # 领域数据比例
) -> List[Dict]:
    """
    混合领域 + 通用数据
    领域:通用 = 90:10  ~ 既能学习领域,又不遗忘通用
    """
    n_domain = len(domain_data)
    n_general = int(n_domain * (1 - domain_ratio) / domain_ratio)
    
    if len(general_data) > n_general:
        general_data = random.sample(general_data, n_general)
    
    combined = domain_data + general_data
    random.shuffle(combined)
    
    print(f"领域数据: {len(domain_data)}")
    print(f"通用数据: {len(general_data)}")
    print(f"总计: {len(combined)}")
    
    return combined

8.3 遗忘检测

def evaluate_forgetting(
    sft_model_path: str,
    base_model_name: str = "Qwen/Qwen3-7B",
    eval_tasks: List[str] = None
):
    """评估 SFT 后的遗忘程度"""
    
    from lm_eval import simple_evaluate  # 2026 lm-eval
    
    if eval_tasks is None:
        eval_tasks = [
            "mmlu",              # 世界知识
            "gsm8k",             # 数学
            "humaneval",         # 代码
            "bbh",               # 推理
            "ifeval",            # 指令遵循
        ]
    
    results = {}
    for eval_type, model_path in [
        ("sft", sft_model_path),
        ("base", base_model_name),
    ]:
        scores = simple_evaluate(
            model=f"hf:{model_path}",
            tasks=eval_tasks,
            num_fewshot=0,
            batch_size=4,
        )
        results[eval_type] = scores["results"]
    
    # 对比
    for task in eval_tasks:
        base_score = results["base"].get(task, {}).get("acc,none", 0)
        sft_score = results["sft"].get(task, {}).get("acc,none", 0)
        diff = sft_score - base_score
        
        emoji = "🟢" if diff >= 0 else "🔴"
        print(f"{emoji} {task}: base={base_score:.1%}, sft={sft_score:.1%} ({diff:+.1%})")
    
    return results


九、评估与迭代

9.1 SFT 评估金字塔

Level 1: Loss 监控(训练中)
  ├── train loss: 持续下降 → 正常
  ├── eval loss: 先降后升 → 过拟合
  └── gradient norm: 异常peak → 数据问题

Level 2: 自动化评估(训练后)
  ├── MMLU, GSM8K, HumanEval → 通用能力
  ├── 领域特定benchmark → 领域能力
  └── IFEval → 指令遵循能力

Level 3: 人工评估(最关键)
  ├── 抽样 200-500 条
  ├── 对比 base vs sft 输出
  └── 统计:更好/持平/更差

9.2 快速评估

"""SFT 质量快速评估"""

class SFTEvaluator:
    """SFT 模型质量快速评估"""
    
    def __init__(self, sft_model, tokenizer, base_model=None):
        self.sft_model = sft_model
        self.base_model = base_model
        self.tokenizer = tokenizer
    
    def generate(self, prompt: str, **kwargs) -> str:
        """生成回复"""
        inputs = self.tokenizer(
            prompt, return_tensors="pt"
        ).to(self.sft_model.device)
        
        with torch.no_grad():
            outputs = self.sft_model.generate(
                **inputs,
                max_new_tokens=512,
                temperature=0.7,
                top_p=0.9,
                do_sample=True,
                **kwargs
            )
        
        return self.tokenizer.decode(
            outputs[0][inputs.input_ids.shape[1]:],
            skip_special_tokens=True
        )
    
    def compare_with_base(self, prompts: List[str]):
        """对比 SFT 前后"""
        print("=" * 80)
        for prompt in prompts:
            sft_resp = self.generate(prompt)
            
            if self.base_model:
                base_resp = self.generate_with_base(prompt)
                print(f"[Prompt]: {prompt[:50]}...")
                print(f"[Base]:   {base_resp[:100]}...")
                print(f"[SFT]:    {sft_resp[:100]}...")
                print("-" * 40)
    
    def evaluate_format(self, eval_samples: List[Dict]) -> Dict:
        """评估格式遵循能力"""
        stats = {"total": 0, "correct_format": 0}
        
        for sample in eval_samples:
            stats["total"] += 1
            prompt = sample["prompt"]
            expected_format = sample["expected_format"]
            
            response = self.generate(prompt)
            
            # 格式检查
            if expected_format == "json":
                import json
                try:
                    json.loads(response)
                    stats["correct_format"] += 1
                except:
                    pass
            elif expected_format == "list":
                if response.strip().startswith("-") or response.strip().startswith("1."):
                    stats["correct_format"] += 1
        
        stats["accuracy"] = stats["correct_format"] / max(stats["total"], 1)
        return stats

9.3 迭代策略

Round 1: 基础SFT
  - 5K 高质量数据
  - 评估:通用能力 + 领域能力
  - 结果:如果通用能力下降 → 加入replay

Round 2: 针对性增强
  - AI评估:哪里不好→生成补充数据
  - 加入 2K 问题领域的数据
  - 评估:目标领域的top-k accuracy

Round 3: 格式对齐
  - 500 条格式样本(JSON/代码/tables)
  - 降低学习率 5e-6
  - 评估:IFEval score


十、2026年微调新趋势

10.1 趋势一览

趋势20242026说明
主力框架Transformers TrainerTRL SFTTrainerTRL 接管SFT
损失函数CECE + KL + Aux多损失组合
训练策略统一LRLayer-wise Decay分层学习率
数据规模50K+15K 高质量DataFlow效应
防遗忘Replay + KL标配
自动化手动AI-in-the-loopLLM参与评估/增强
硬件效率BF16BF16 + FP8FP8 正在推广
评估Human OnlyAuto + HumanAI + 人工

10.2 FP8 训练(2026最新)

2026 年,NVIDIA H200/B200 开始支持 FP8 训练:

# FP8 训练配置(Transformers 4.50+)
training_args = TrainingArguments(
    fp8=True,                     # 新增!
    fp8_amax_history_length=1024,
    fp8_amax_compute_algo="most_recent",
)

FP8 的优势:显存减半,训练速度提升 1.5-2x
FP8 的风险:精度损失,需要 careful scaling

10.3 Mathematical Framework for SFT

class MathematicalSFT:
    """SFT 的数学框架"""
    
    @staticmethod
    def supervised_fine_tuning_loss(
        D: List[Tuple[str, str]],     # (x, y): instruction, response
        theta: Dict,                   # model parameters
        theta_0: Dict = None          # original parameters (for KL)
    ) -> torch.Tensor:
        """
        L(θ) = -E_{(x,y)~D}[log P_θ(y|x)] + λ·KL(P_θ||P_θ₀)
        
        First term: 最大化目标回复的概率 (MLE)
        Second term: 限制偏离原始模型的程度 (KL regularization)
        """
        # MLE term
        nll = compute_nll(D, theta)  # -log P(y|x)
        
        # KL term (optional)
        kl = 0
        if theta_0 is not None:
            kl = compute_kl(theta, theta_0)
        
        return nll + 0.1 * kl
    
    @staticmethod
    def information_bottleneck_view(
        data_ratio: float = 0.9,
        forgetting_tradeoff: float = 0.95
    ):
        """
        Information Bottleneck view of SFT:
        min I(Z; X_old) - β·I(Z; X_new)
        
        Z = fine-tuned representations
        X_old = pre-training data
        X_new = SFT data
        
        β > 1: 更关注新任务 → 可能遗忘
        β < 1: 保持原始能力 → 学习不足
        β ≈ 0.95: 推荐
        """
        pass


总结

核心要点

  1. 全量微调 vs LoRA:FFT 适合注入新知识/新领域,LoRA 适合风格/格式适配。2026 最佳实践:FFT + LoRA + DPO 三阶段
  2. 10K 高质量 > 50K 粗筛:DataFlow 验证,数据质量是第一优先级
  3. Loss masking 是核心:只计算 assistant 部分的 loss,user 部分 mask 掉
  4. 防遗忘是最大挑战:Replay + KL + Layer-wise LR 三管齐下
  5. TRL SFTTrainer 成为2026首选:自动处理 chat template、loss masking
  6. 评估三角形:Loss + AutoEval + Human Eval 三者缺一不可

最佳实践速查

需求推荐方案
微调 7BTRL SFTTrainer + DeepSpeed ZeRO-3
微调 70BMegatron TP=4 + DeepSpeed ZeRO-1
数据量不足合成数据 (DataFlow) + KL regularization
通用能力下降10% replay + 5e-6 LR
格式要求高第三阶段 500 条格式样本 + low LR
多语言按 language ratio 采样 + replay

面试高频问题

Q: SFT 为什么只计算 assistant 部分的 loss?
A: user 部分是"指令",模型不需要学习"如何提问"。只学习"如何回答"可以防止模型学到不相关的模式,并使梯度集中在真正需要优化的部分。

Q: 全量微调后模型变笨了(MMLU 下降),怎么办?
A: 1) 加入 5-10% replay data 2) 使用 KL divergence loss 约束模型 3) 降低学习率 4) 使用 Layer-wise LR decay(底层低、顶层高)

Q: SFT 数据中是否需要包含"拒绝回答"样本?
A: 强烈建议!包含 5-10% 的"拒绝回答"样本可以有效降低幻觉,让模型在面对不知道的问题时说"不知道"而不是胡编乱造。


下一篇预告:训练与微调篇第 4 篇——指令微调与 SFT 数据构建技巧。


参考资料

[1] Peking University DCAI. DataFlow: Systematic Data Engineering for LLMs. arXiv:2512.16676. 2026
[2] HuggingFace TRL. Transformer Reinforcement Learning. 2024-2026
[3] Qwen Team. Qwen3 Technical Report. 2026
[4] DataComp-LM. DCLM: A Framework for LLM Data-Centric AI. 2024
[5] Kirkpatrick et al. Overcoming catastrophic forgetting in neural networks. PNAS 2017

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值