🎯 全量微调实践:从数据到部署的SFT完整指南
预训练给了模型"语言能力",微调给了模型"办事能力"。2026 年,全量微调已经从"全参数更新"演进为"分层 + 分阶段"的精细操作——日活千万的AI产品背后,都有一套精心设计的微调策略。
📑 目录
一、为什么需要全量微调?
1.1 全量微调 vs LoRA 的抉择
| 维度 | 全量微调 (FFT) | LoRA/QLoRA |
|---|---|---|
| 更新参数 | 全部(100%) | 少量(0.1-1%) |
| 知识注入深度 | 高 — 改变模型内部表示 | 低 — 学习"回答风格" |
| 学习新知识 | ✅ 可注入新事实 | ❌ 几乎不可能 |
| 领域适应能力 | ✅ 优秀 — 如法律、医疗 | ⚠️ 有限 |
| 训练显存 | 高(70B需~400GB) | 低(70B仅需~48GB) |
| 训练时间 | 周级 | 小时级 |
| 模型质量上限 | 更高 | 受限于秩® |
| 多任务鲁棒性 | ✅ 好 | ❌ 容易过拟合 |
决策规则:
你的需求是?
├── 学习新知识/新格式 → 全量微调(FFT)
├── 适配回答风格/格式 → LoRA(高效)
├── 领域深度适应(医疗/法律/金融) → FFT(全量)
├── 已有7B模型,希望更快迭代 → LoRA
└── 用做基座,后续做RLHF → FFT(全量)
1.2 2026年共识:混合策略最佳
最先进的微调方案不再是"全量"或"LoRA"二选一,而是分阶段混合:
Stage 1: 全量微调(新知识注入) → 8×H200, 3天
学习新领域数据、新格式、新任务
Stage 2: LoRA 适配(风格对齐) → 1×RTX 5090, 2小时
微调回答风格、格式偏好
Stage 3: DPO/GRPO(偏好对齐) → 8×H200, 1天
人类偏好对齐
二、SFT的三种模式
2.1 标准 SFT(Next Token Prediction)
最经典的方法:输入"问题+答案",让模型预测下一个 token。
用户: 什么是注意力机制?
助手: 注意力机制是...
Loss = CrossEntropy(predicted_token, ground_truth_token)
只计算"助手"部分的loss,不计算"用户"部分的loss
2.2 对话式 SFT(Multi-turn)
包含多轮对话,训练模型维持上下文:
用户: 什么是注意力机制?
助手: 注意力机制是...【计算loss】
用户: 那它和Transformer有什么关系?【不计算loss】
助手: Transformer的核心就是...【计算loss】
关键:只对助手回复计算 loss,用户输入部分 mask 掉。
2.3 多任务 SFT(Multi-task)
在同一个模型中训练多种任务类型:
任务1: [分类] 这封邮件是垃圾邮件吗?是【计算loss】
任务2: [摘要] 将以下文章总结为三句话...【计算loss】
任务3: [QA] 什么是...【计算loss】
任务4: [代码] 写一个快排...【计算loss】
任务5: [翻译] 将中文翻译成英文...【计算loss】
三、SFT数据构建方法论
3.1 高质量SFT数据的"黄金标准"
2026 年各实验室的 SFT 数据配方 [1][2]:
| 维度 | 标准 | 说明 |
|---|---|---|
| 多样性 | 50+ 任务类型 | 问答、写作、代码、翻译、摘要、分类… |
| 难度渐进 | 简单→中等→困难 | 从"帮我写一封邮件"到"推导黎曼猜想" |
| 格式统一 | 一致的对话模板 | ChatML、Qwen、LLaMA 三种主流格式 |
| 质量优先 | DataFlow 验证 | 10K 高质量 ≈ 50K 粗筛 |
| 答案长度 | 128-2048 tokens | 短答案学格式,长答案学推理 |
| 拒绝回答 | 5-10% | “我无法回答这个问题”——防幻觉 |
3.2 数据格式标准化
"""三种主流对话模板"""
from typing import List, Dict
class ChatTemplate:
"""对话模板格式化"""
@staticmethod
def format_chatml(messages: List[Dict]) -> str:
"""ChatML 格式(OpenAI/DeepSeek 使用)"""
formatted = ""
for msg in messages:
role = msg["role"]
content = msg["content"]
formatted += f"<|im_start|>{role}\n{content}<|im_end|>\n"
return formatted
@staticmethod
def format_qwen(messages: List[Dict]) -> str:
"""Qwen 格式"""
formatted = ""
for msg in messages:
role = msg["role"]
content = msg["content"]
if role == "system":
formatted += f"<|system|>\n{content}\n"
elif role == "user":
formatted += f"<|user|>\n{content}\n"
elif role == "assistant":
formatted += f"<|assistant|>\n{content}\n"
formatted += "<|assistant|>\n"
return formatted
@staticmethod
def format_llama(messages: List[Dict]) -> str:
"""LLaMA 4 格式"""
formatted = "<|begin_of_text|>"
for msg in messages:
role = msg["role"]
content = msg["content"]
if role == "system":
formatted += f"<|start_header_id|>system<|end_header_id|>\n\n{content}<|eot_id|>"
elif role == "user":
formatted += f"<|start_header_id|>user<|end_header_id|>\n\n{content}<|eot_id|>"
elif role == "assistant":
formatted += f"<|start_header_id|>assistant<|end_header_id|>\n\n{content}<|eot_id|>"
formatted += "<|start_header_id|>assistant<|end_header_id|>\n\n"
return formatted
3.3 数据构建流程
"""高质量SFT数据构建 Pipeline"""
import json
import random
from typing import List, Dict, Generator
class SFTDataBuilder:
"""SFT训练数据构建器"""
def __init__(self, base_model: str = "Qwen3-72B"):
self.base_model = base_model
self.task_templates = self._load_task_templates()
def _load_task_templates(self) -> Dict:
"""加载任务模板"""
return {
"qa": {
"system": "你是一个知识渊博的助手,请准确回答用户问题。",
"count": 0.3, # 30%
"difficulty": ["简单", "中等", "困难"]
},
"writing": {
"system": "你是一个写作助手,请根据要求完成各类写作任务。",
"count": 0.2, # 20%
"difficulty": ["简单", "中等"]
},
"code": {
"system": "你是一个编程助手,请写出正确高效的代码。",
"count": 0.2, # 20%
"difficulty": ["中等", "困难"]
},
"reasoning": {
"system": "请一步步思考,给出详细的推理过程。",
"count": 0.15, # 15%
"difficulty": ["中等", "困难"]
},
"translation": {
"system": "你是一个翻译助手,请准确翻译以下内容。",
"count": 0.1, # 10%
"difficulty": ["简单", "中等"]
},
"safe_refusal": {
"system": "你是一个负责任的AI助手。",
"count": 0.05, # 5% - 拒绝回答
"difficulty": ["简单"]
}
}
def build_dataset(
self,
total_count: int = 50000,
output_path: str = "sft_data.jsonl"
):
"""构建完整数据集"""
with open(output_path, 'w', encoding='utf-8') as f:
for task, config in self.task_templates.items():
n = int(total_count * config["count"])
for _ in range(n):
sample = self._generate_sample(task, config)
f.write(json.dumps(sample, ensure_ascii=False) + '\n')
print(f"数据集已生成: {output_path}")
print(f"总样本数: {total_count}")
for task, config in self.task_templates.items():
n = int(total_count * config["count"])
print(f" {task}: {n} ({config['count']*100:.0f}%)")
def _generate_sample(self, task: str, config: Dict) -> Dict:
"""生成单个训练样本"""
# 从真实数据或合成数据中采样
# 实际中应该用高质量的标注数据
return {
"messages": [
{"role": "system", "content": config["system"]},
{"role": "user", "content": f"这是一个{task}类型的{random.choice(config['difficulty'])}问题"},
{"role": "assistant", "content": "这是一个高质量的回复。"}
],
"task": task,
"difficulty": random.choice(config["difficulty"])
}
# ===== 数据质量检查 =====
class SFTDataValidator:
"""SFT数据质量检查器"""
@staticmethod
def validate_sample(sample: Dict) -> List[str]:
"""检查单个样本的质量"""
issues = []
messages = sample.get("messages", [])
# 检查对话结构
if len(messages) < 2:
issues.append("对话轮次不足")
if messages[-1]["role"] != "assistant":
issues.append("最后一条必须是assistant回复")
# 检查内容质量
for msg in messages:
content = msg.get("content", "")
if len(content) < 5:
issues.append(f"{msg['role']}回复过短: {content}")
# 检查重复
if len(set(content.split())) / max(len(content.split()), 1) < 0.3:
issues.append(f"{msg['role']}回复重复过多")
# 检查无用字符
if content.count('\u0000') > 0:
issues.append("包含空字符")
return issues
@staticmethod
def dataset_stats(dataset: List[Dict]) -> Dict:
"""统计数据集"""
stats = {
"total": len(dataset),
"avg_user_len": 0,
"avg_assistant_len": 0,
"task_distribution": {},
"issues": []
}
user_lens = []
assistant_lens = []
for sample in dataset:
for msg in sample["messages"]:
if msg["role"] == "user":
user_lens.append(len(msg["content"]))
elif msg["role"] == "assistant":
assistant_lens.append(len(msg["content"]))
# 任务分布
task = sample.get("task", "unknown")
stats["task_distribution"][task] = stats["task_distribution"].get(task, 0) + 1
# 质量检查
issues = SFTDataValidator.validate_sample(sample)
if issues:
stats["issues"].extend(issues)
stats["avg_user_len"] = sum(user_lens) / max(len(user_lens), 1)
stats["avg_assistant_len"] = sum(assistant_lens) / max(len(assistant_lens), 1)
return stats
3.4 数据质量 vs 模型性能
DataFlow 实验的关键发现 [1]:
| 数据集 | 规模 | 模型 | 数学得分 | 代码得分 | 备注 |
|---|---|---|---|---|---|
| 原始 Alpaca | 52K | 7B | 基线 | 基线 | 原始 |
| DataFlow SFT | 15K | 7B | +9.3 | +5.1 | 15K > 52K |
| DataFlow 数学 | 10K | 32B | 55.7 | — | 超过 Open-R1 |
| DataFlow 代码 | 10K | 7B | — | 46.2 | 10K 代码数据 |
| DataFlow Chat | 15K | 7B | — | — | AlpacaEval 7.05→10.11 |
核心结论:15K 精心构造的 DataFlow 数据 > 52K 原始 Alpaca 数据。数据质量比数量重要 3 倍以上。
四、损失函数与训练策略
4.1 标准 Cross-Entropy Loss
"""SFT 损失函数详解"""
import torch
import torch.nn.functional as F
def sft_loss(
logits: torch.Tensor, # (batch, seq_len, vocab_size)
labels: torch.Tensor, # (batch, seq_len) -100 表示忽略
attention_mask: torch.Tensor = None
) -> torch.Tensor:
"""
SFT 交叉熵损失
- 只计算 assistant 部分(labels 中 != -100 的位置)
- 忽略 user 部分和 padding 部分
"""
batch_size, seq_len, vocab_size = logits.shape
# 移除最后一个 logit(不需要预测最后一个 token 之后的内容)
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = labels[:, 1:].contiguous()
# Flatten
shift_logits = shift_logits.view(-1, vocab_size)
shift_labels = shift_labels.view(-1)
# 计算 loss,忽略 -100 的位置
loss = F.cross_entropy(
shift_logits,
shift_labels,
ignore_index=-100, # 关键!忽略 user 输入和 padding
reduction='mean'
)
return loss
4.2 Loss Masking 详解
![Loss Masking 示意图]
输入: "<|user|>\n什么是AI?<|end|>\n<|assistant|>\n人工智能是..."
labels: [-100, -100, -100, ..., -100, | "人", "工", "智", "能", ...]
↑ user部分全部ignore ↑ 助手部分正常计算loss
为什么这么做?
- user输入是"指令",模型不应该从"提问方式"中学习
- 只从"助手的回答"中学习
- 防止模型学会"如何提问",只学会"如何回答"
"""构建正确的 labels(mask user 部分)"""
from transformers import DataCollatorForSeq2Seq
class SFTDataCollator(DataCollatorForSeq2Seq):
"""SFT专用数据整理器:自动 mask user 输入"""
def __call__(self, features):
batch = super().__call__(features)
# labels 中 tokenizer.pad_token_id 的位置设为 -100
batch["labels"] = torch.where(
batch["labels"] == self.tokenizer.pad_token_id,
-100,
batch["labels"]
)
return batch
# 更精细的 mask:根据 message 来源
def mask_user_tokens(
input_ids: torch.Tensor,
tokenizer,
assistant_token_id: int = 77091 # Qwen3 的 <|assistant|> token
) -> torch.Tensor:
"""
将 assistant 之前的所有 token 都 mask 掉
只保留 assistant 回复部分的 loss
"""
labels = input_ids.clone()
for i in range(input_ids.shape[0]):
# 找到第一个 assistant token 的位置
assistant_positions = (input_ids[i] == assistant_token_id).nonzero()
if len(assistant_positions) > 0:
# 找到最后一个 assistant token(多轮对话中每个助手回复都保留")
last_assistant = assistant_positions[-1].item()
# 之前的所有内容 mask 掉
labels[i, :last_assistant] = -100
return labels
4.3 辅助 Loss:提高训练稳定性
"""SFT 训练辅助损失"""
class SFTLossWithAux:
"""带辅助损失的全量微调"""
def __init__(self, alpha_ce=1.0, alpha_kl=0.1, alpha_embed=0.01):
self.alpha_ce = alpha_ce # 主损失权重
self.alpha_kl = alpha_kl # KL散度(防止偏离原始模型)
self.alpha_embed = alpha_embed # Embedding 正则
def compute_loss(
self,
model,
student_logits: torch.Tensor, # 微调模型输出
teacher_logits: torch.Tensor, # 原始模型输出(冻结)
labels: torch.Tensor, # 标签
) -> torch.Tensor:
"""主损失 + KL散度 + Embedding正则"""
# 1. 主损失:Cross-Entropy
loss_ce = F.cross_entropy(
student_logits.view(-1, student_logits.size(-1)),
labels.view(-1),
ignore_index=-100,
reduction='mean'
)
# 2. KL散度损失:防止模型偏离原始模型过远
# 只在非 -100 的位置计算
mask = (labels != -100).unsqueeze(-1).expand_as(student_logits)
student_log_probs = F.log_softmax(student_logits, dim=-1)
teacher_probs = F.softmax(teacher_logits, dim=-1)
kl_div = F.kl_div(
student_log_probs[mask].view(-1, student_logits.size(-1)),
teacher_probs[mask].view(-1, student_logits.size(-1)),
reduction='batchmean',
log_target=False
)
# 3. Embedding 正则(可选)
embed_norm = sum(
p.norm(2) for name, p in model.named_parameters()
if 'embed' in name
)
total_loss = (
self.alpha_ce * loss_ce +
self.alpha_kl * kl_div +
self.alpha_embed * embed_norm
)
return total_loss
五、超参数调优指南
5.1 核心超参数
| 参数 | 推荐值 | 范围 | 说明 |
|---|---|---|---|
| learning_rate | 1e-5 | 5e-6 ~ 5e-5 | 全量微调比预训练低10-100倍 |
| batch_size | 128 | 64 ~ 512 | 尽量大,不足靠梯度累积 |
| num_epochs | 2-3 | 1 ~ 5 | 多了过拟合,少了学不够 |
| warmup_ratio | 0.03 | 0.01 ~ 0.1 | 前3%步预热 |
| weight_decay | 0.1 | 0.01 ~ 0.1 | 防止过拟合 |
| max_grad_norm | 1.0 | 0.5 ~ 2.0 | 梯度裁剪 |
| lr_scheduler | cosine | cosine, linear | cosine更平滑 |
| optim | AdamW | AdamW, AdamW-fused | fused版本快30% |
5.2 学习率策略详解
"""全量微调学习率调度策略"""
import math
def get_scheduler(
optimizer,
num_training_steps: int,
warmup_ratio: float = 0.03,
scheduler_type: str = "cosine"
):
"""SFT 学习率调度器"""
warmup_steps = int(num_training_steps * warmup_ratio)
def lr_lambda(current_step: int) -> float:
if current_step < warmup_steps:
# 线性预热
return float(current_step) / float(max(1, warmup_steps))
progress = float(current_step - warmup_steps) / float(
max(1, num_training_steps - warmup_steps)
)
if scheduler_type == "cosine":
# 余弦退火
return 0.5 * (1.0 + math.cos(math.pi * progress))
elif scheduler_type == "linear":
# 线性衰减
return 1.0 - progress
elif scheduler_type == "constant":
# 恒定
return 1.0
else:
return 1.0
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
不同策略效果对比:
学习率
^
| [预热] [------训练过程------]
1e-5 | /\
| / \______ cosine (推荐)
1e-5 | / \__
|
0 +-------------------------------> steps
| |
| |
1e-5 | /\___________ constant (简单,但收敛不佳)
1e-5 | / \__
|
0 +------------------------------->
5.3 Batch Size 选择
"""自动计算最优 batch size"""
import torch.cuda as cuda
def auto_batch_size(model_size_b: int = 7, target_memory_gb: float = 80):
"""
根据显存自动计算最大 batch size
"""
# 估算每 token 的显存需求
param_memory = model_size_b * 2 # FP16: 2 bytes per param
optimizer_memory = param_memory * 3 # Adam: 动量+方差+参数
gradient_memory = param_memory
# 激活内存(主要消耗)
activation_memory_per_token = model_size_b * 0.5 # 估算
# 剩余显存
total_memory = target_memory_gb * 1024 # MB
model_memory = (param_memory + optimizer_memory + gradient_memory) * 1024 / 8 # 转MB
available = total_memory - model_memory
# 每 token 激活约 X MB
max_tokens = available / (activation_memory_per_token * 1024 / 8)
# 假设 seq_len = 4096
seq_len = 4096
max_batch = max_tokens / seq_len
return max(1, int(max_batch))
经验法则:
- 7B 模型 + 8×H200:batch_size=128(8×16 grad_accum)
- 13B 模型 + 8×H200:batch_size=64
- 70B 模型 + 16×H200:batch_size=32-64
六、全量微调实战:完整代码
6.1 完整训练脚本
"""Qwen3-7B 全量微调完整脚本"""
import os
import torch
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
Trainer,
DataCollatorForSeq2Seq,
Qwen2ForCausalLM, # Qwen3 也基于 Qwen2 架构
)
from trl import SFTTrainer # 2026年更推荐用 TRL
import wandb
# ===== 0. 配置 =====
MODEL_NAME = "Qwen/Qwen3-7B"
DATA_PATH = "sft_data.jsonl"
OUTPUT_DIR = "qwen3-7b-sft"
# ===== 1. 加载模型和 tokenizer =====
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
attn_implementation="flash_attention_2", # Flash Attention 加速
device_map="auto", # 自动分配到多GPU
)
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
padding_side="right",
)
# 设置 pad_token
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# ===== 2. 准备数据集 =====
dataset = load_dataset("json", data_files=DATA_PATH, split="train")
def apply_chat_template(examples):
"""应用对话模板"""
texts = []
for messages in examples["messages"]:
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False,
)
texts.append(text)
return {"text": texts}
dataset = dataset.map(
apply_chat_template,
batched=True,
remove_columns=dataset.column_names,
)
def tokenize_fn(examples):
return tokenizer(
examples["text"],
truncation=True,
max_length=4096,
padding=False,
return_tensors=None,
)
tokenized_dataset = dataset.map(tokenize_fn, batched=True, remove_columns=["text"])
# 划分训练/验证
split_dataset = tokenized_dataset.train_test_split(test_size=0.01, seed=42)
# ===== 3. 训练参数 =====
training_args = TrainingArguments(
output_dir=OUTPUT_DIR,
# 核心训练参数
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
gradient_accumulation_steps=8, # 等效 batch = 2×8×GPU = 128 (8GPU)
num_train_epochs=3,
learning_rate=1e-5,
warmup_ratio=0.03,
lr_scheduler_type="cosine",
# 优化
optim="adamw_torch_fused", # fused AdamW,快30%
bf16=True, # BF16 比 FP16 更稳定
gradient_checkpointing=True, # 减少显存,增加计算
gradient_checkpointing_kwargs={"use_reentrant": False},
# 分布式
ddp_find_unused_parameters=False,
ddp_bucket_cap_mb=100,
# 保存与评估
save_strategy="steps",
save_steps=200,
save_total_limit=3,
evaluation_strategy="steps",
eval_steps=200,
logging_steps=10,
report_to=["wandb"],
# 其他
max_grad_norm=1.0,
weight_decay=0.1,
seed=42,
data_seed=42,
dataloader_num_workers=4,
remove_unused_columns=False,
)
# ===== 4. 数据整理器 =====
collator = DataCollatorForSeq2Seq(
tokenizer=tokenizer,
padding=True,
max_length=4096,
)
# ===== 5. 启动训练 =====
trainer = Trainer(
model=model,
args=training_args,
train_dataset=split_dataset["train"],
eval_dataset=split_dataset["test"],
data_collator=collator,
tokenizer=tokenizer,
)
# 开始训练
trainer.train()
# ===== 6. 保存 =====
trainer.save_model()
tokenizer.save_pretrained(OUTPUT_DIR)
print(f"模型已保存到: {OUTPUT_DIR}")
6.2 使用 TRL SFTTrainer(2026推荐)
"""使用 TRL 的 SFTTrainer — 更简洁的方式"""
from trl import SFTTrainer
trainer = SFTTrainER(
model=MODEL_NAME,
train_dataset=dataset,
args=TrainingArguments(
output_dir=OUTPUT_DIR,
per_device_train_batch_size=2,
gradient_accumulation_steps=8,
num_train_epochs=3,
learning_rate=1e-5,
bf16=True,
logging_steps=10,
save_steps=200,
report_to=["wandb"],
),
tokenizer=tokenizer,
max_seq_length=4096,
dataset_text_field="text", # 直接从text字段训练
packing=False, # 是否packing多个样本
)
trainer.train()
TRL SFTTrainer vs Trainer:TRL 自动处理了 loss masking、chat template 等 SFT 专属逻辑,2026 年 SFT 首选 TRL [3]。
6.3 分布式启动命令
# 8卡全量微调 Qwen3-7B
deepspeed --num_gpus=8 train_sft.py
# 16卡跨节点
deepspeed --num_gpus=8 --num_nodes=2 \
--master_addr=master_node \
train_sft.py
# DeepSpeed ZeRO-3 配置
cat << 'EOF' > ds_config.json
{
"zero_optimization": {
"stage": 3,
"overlap_comm": true,
"contiguous_gradients": true
},
"bf16": {"enabled": true},
"gradient_accumulation_steps": 8,
"gradient_clipping": 1.0
}
EOF
deepspeed --num_gpus=8 train_sft.py \
--deepspeed ds_config.json
七、分阶段训练策略
7.1 三阶段 SFT
第一阶段:知识注入(Epoch 1)
lr=5e-5, unfreeze all
目标:让模型学习新领域的知识
风险:有可能灾难性遗忘
第二阶段:格式对齐(Epoch 2)
lr=1e-5, unfreeze all
目标:让模型学会目标格式和风格
注意:降低学习率以免破坏原始能力
第三阶段:精调(Epoch 3)
lr=5e-6, layer-wise decay
目标:精细调整,提升输出质量
注意:只调整顶层
7.2 Layer-wise Learning Rate Decay
"""分层学习率:底层小、顶层大"""
def layer_wise_lr(model, base_lr=5e-5, decay_rate=0.9):
"""
底层(embedding, early layers)→ 低学习率
顶层(last layers, head)→ 高学习率
"""
param_groups = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
# 计算层级
if "embed" in name or "norm" in name:
lr_scale = 0.3 # 底层
elif "layers.0" in name or "layers.1" in name:
lr_scale = 0.5
elif "layers" in name:
# 提取层号
import re
match = re.search(r'layers\.(\d+)', name)
if match:
layer_idx = int(match.group(1))
total_layers = model.config.num_hidden_layers or 32
lr_scale = decay_rate ** (total_layers - layer_idx)
else:
lr_scale = 1.0
elif "lm_head" in name or "embed_tokens" in name:
lr_scale = 2.0 # 输出层
else:
lr_scale = 1.0
param_groups.append({
"params": param,
"lr": base_lr * lr_scale,
})
return torch.optim.AdamW(param_groups, lr=base_lr, betas=(0.9, 0.95))
7.3 Curriculum Learning(课程学习)
"""课程学习:从简单到困难"""
class CurriculumDataset(torch.utils.data.Dataset):
"""课程学习数据集"""
def __init__(self, dataset, difficulty_fn):
self.dataset = dataset
self.difficulty_fn = difficulty_fn
self.current_step = 0
self.total_steps = 10000
# 按难度排序
self.sorted_indices = sorted(
range(len(dataset)),
key=lambda i: difficulty_fn(dataset[i])
)
def update_step(self, step: int):
self.current_step = step
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
# 根据训练进度采样不同难度
progress = self.current_step / self.total_steps
if progress < 0.3:
# 简单:前30% 的数据
max_idx = int(len(self.dataset) * 0.3)
real_idx = self.sorted_indices[
np.random.randint(0, max_idx)
]
elif progress < 0.7:
# 中等:30%-70%
start = int(len(self.dataset) * 0.3)
end = int(len(self.dataset) * 0.7)
real_idx = self.sorted_indices[
np.random.randint(start, end)
]
else:
# 困难:最后30%
start = int(len(self.dataset) * 0.7)
real_idx = self.sorted_indices[
np.random.randint(start, len(self.dataset))
]
return self.dataset[real_idx]
八、灾难性遗忘与解决方案
8.1 遗忘的本质
全量微调的本质:
- 更新模型的所有参数
- 学习新数据 → 修改权重
- 修改权重 → 可能覆盖原始知识
灾难性遗忘:模型"忘记"了预训练中学会的知识
→ MMLU 得分下降 5-15%
→ 通用能力退化
8.2 六大防遗忘策略
| 策略 | 实现 | 效果 | 额外成本 |
|---|---|---|---|
| 1. EWC | 对重要参数加正则 | 好 | 2x 显存 |
| 2. Replay | 混合5-10%原始数据 | 最好 | 1x 训练时间 |
| 3. KL散度 | 与原始模型做KL约束 | 好 | 2x 显存 |
| 4. 分层训练 | 冻结底层,只更新顶层 | 中等 | 无 |
| 5. 渐进式LR | 底层LR小,顶层LR大 | 中等 | 无 |
| 6. 混合数据 | 10% 通用+90% 领域 | 很好 | 1x 数据 |
"""防遗忘:混合训练数据"""
def build_anti_forget_dataset(
domain_data: List[Dict], # 领域SFT数据
general_data: List[Dict], # 通用SFT数据
domain_ratio: float = 0.9 # 领域数据比例
) -> List[Dict]:
"""
混合领域 + 通用数据
领域:通用 = 90:10 ~ 既能学习领域,又不遗忘通用
"""
n_domain = len(domain_data)
n_general = int(n_domain * (1 - domain_ratio) / domain_ratio)
if len(general_data) > n_general:
general_data = random.sample(general_data, n_general)
combined = domain_data + general_data
random.shuffle(combined)
print(f"领域数据: {len(domain_data)}")
print(f"通用数据: {len(general_data)}")
print(f"总计: {len(combined)}")
return combined
8.3 遗忘检测
def evaluate_forgetting(
sft_model_path: str,
base_model_name: str = "Qwen/Qwen3-7B",
eval_tasks: List[str] = None
):
"""评估 SFT 后的遗忘程度"""
from lm_eval import simple_evaluate # 2026 lm-eval
if eval_tasks is None:
eval_tasks = [
"mmlu", # 世界知识
"gsm8k", # 数学
"humaneval", # 代码
"bbh", # 推理
"ifeval", # 指令遵循
]
results = {}
for eval_type, model_path in [
("sft", sft_model_path),
("base", base_model_name),
]:
scores = simple_evaluate(
model=f"hf:{model_path}",
tasks=eval_tasks,
num_fewshot=0,
batch_size=4,
)
results[eval_type] = scores["results"]
# 对比
for task in eval_tasks:
base_score = results["base"].get(task, {}).get("acc,none", 0)
sft_score = results["sft"].get(task, {}).get("acc,none", 0)
diff = sft_score - base_score
emoji = "🟢" if diff >= 0 else "🔴"
print(f"{emoji} {task}: base={base_score:.1%}, sft={sft_score:.1%} ({diff:+.1%})")
return results
九、评估与迭代
9.1 SFT 评估金字塔
Level 1: Loss 监控(训练中)
├── train loss: 持续下降 → 正常
├── eval loss: 先降后升 → 过拟合
└── gradient norm: 异常peak → 数据问题
Level 2: 自动化评估(训练后)
├── MMLU, GSM8K, HumanEval → 通用能力
├── 领域特定benchmark → 领域能力
└── IFEval → 指令遵循能力
Level 3: 人工评估(最关键)
├── 抽样 200-500 条
├── 对比 base vs sft 输出
└── 统计:更好/持平/更差
9.2 快速评估
"""SFT 质量快速评估"""
class SFTEvaluator:
"""SFT 模型质量快速评估"""
def __init__(self, sft_model, tokenizer, base_model=None):
self.sft_model = sft_model
self.base_model = base_model
self.tokenizer = tokenizer
def generate(self, prompt: str, **kwargs) -> str:
"""生成回复"""
inputs = self.tokenizer(
prompt, return_tensors="pt"
).to(self.sft_model.device)
with torch.no_grad():
outputs = self.sft_model.generate(
**inputs,
max_new_tokens=512,
temperature=0.7,
top_p=0.9,
do_sample=True,
**kwargs
)
return self.tokenizer.decode(
outputs[0][inputs.input_ids.shape[1]:],
skip_special_tokens=True
)
def compare_with_base(self, prompts: List[str]):
"""对比 SFT 前后"""
print("=" * 80)
for prompt in prompts:
sft_resp = self.generate(prompt)
if self.base_model:
base_resp = self.generate_with_base(prompt)
print(f"[Prompt]: {prompt[:50]}...")
print(f"[Base]: {base_resp[:100]}...")
print(f"[SFT]: {sft_resp[:100]}...")
print("-" * 40)
def evaluate_format(self, eval_samples: List[Dict]) -> Dict:
"""评估格式遵循能力"""
stats = {"total": 0, "correct_format": 0}
for sample in eval_samples:
stats["total"] += 1
prompt = sample["prompt"]
expected_format = sample["expected_format"]
response = self.generate(prompt)
# 格式检查
if expected_format == "json":
import json
try:
json.loads(response)
stats["correct_format"] += 1
except:
pass
elif expected_format == "list":
if response.strip().startswith("-") or response.strip().startswith("1."):
stats["correct_format"] += 1
stats["accuracy"] = stats["correct_format"] / max(stats["total"], 1)
return stats
9.3 迭代策略
Round 1: 基础SFT
- 5K 高质量数据
- 评估:通用能力 + 领域能力
- 结果:如果通用能力下降 → 加入replay
Round 2: 针对性增强
- AI评估:哪里不好→生成补充数据
- 加入 2K 问题领域的数据
- 评估:目标领域的top-k accuracy
Round 3: 格式对齐
- 500 条格式样本(JSON/代码/tables)
- 降低学习率 5e-6
- 评估:IFEval score
十、2026年微调新趋势
10.1 趋势一览
| 趋势 | 2024 | 2026 | 说明 |
|---|---|---|---|
| 主力框架 | Transformers Trainer | TRL SFTTrainer | TRL 接管SFT |
| 损失函数 | CE | CE + KL + Aux | 多损失组合 |
| 训练策略 | 统一LR | Layer-wise Decay | 分层学习率 |
| 数据规模 | 50K+ | 15K 高质量 | DataFlow效应 |
| 防遗忘 | 无 | Replay + KL | 标配 |
| 自动化 | 手动 | AI-in-the-loop | LLM参与评估/增强 |
| 硬件效率 | BF16 | BF16 + FP8 | FP8 正在推广 |
| 评估 | Human Only | Auto + Human | AI + 人工 |
10.2 FP8 训练(2026最新)
2026 年,NVIDIA H200/B200 开始支持 FP8 训练:
# FP8 训练配置(Transformers 4.50+)
training_args = TrainingArguments(
fp8=True, # 新增!
fp8_amax_history_length=1024,
fp8_amax_compute_algo="most_recent",
)
FP8 的优势:显存减半,训练速度提升 1.5-2x
FP8 的风险:精度损失,需要 careful scaling
10.3 Mathematical Framework for SFT
class MathematicalSFT:
"""SFT 的数学框架"""
@staticmethod
def supervised_fine_tuning_loss(
D: List[Tuple[str, str]], # (x, y): instruction, response
theta: Dict, # model parameters
theta_0: Dict = None # original parameters (for KL)
) -> torch.Tensor:
"""
L(θ) = -E_{(x,y)~D}[log P_θ(y|x)] + λ·KL(P_θ||P_θ₀)
First term: 最大化目标回复的概率 (MLE)
Second term: 限制偏离原始模型的程度 (KL regularization)
"""
# MLE term
nll = compute_nll(D, theta) # -log P(y|x)
# KL term (optional)
kl = 0
if theta_0 is not None:
kl = compute_kl(theta, theta_0)
return nll + 0.1 * kl
@staticmethod
def information_bottleneck_view(
data_ratio: float = 0.9,
forgetting_tradeoff: float = 0.95
):
"""
Information Bottleneck view of SFT:
min I(Z; X_old) - β·I(Z; X_new)
Z = fine-tuned representations
X_old = pre-training data
X_new = SFT data
β > 1: 更关注新任务 → 可能遗忘
β < 1: 保持原始能力 → 学习不足
β ≈ 0.95: 推荐
"""
pass
总结
核心要点
- 全量微调 vs LoRA:FFT 适合注入新知识/新领域,LoRA 适合风格/格式适配。2026 最佳实践:FFT + LoRA + DPO 三阶段
- 10K 高质量 > 50K 粗筛:DataFlow 验证,数据质量是第一优先级
- Loss masking 是核心:只计算 assistant 部分的 loss,user 部分 mask 掉
- 防遗忘是最大挑战:Replay + KL + Layer-wise LR 三管齐下
- TRL SFTTrainer 成为2026首选:自动处理 chat template、loss masking
- 评估三角形:Loss + AutoEval + Human Eval 三者缺一不可
最佳实践速查
| 需求 | 推荐方案 |
|---|---|
| 微调 7B | TRL SFTTrainer + DeepSpeed ZeRO-3 |
| 微调 70B | Megatron TP=4 + DeepSpeed ZeRO-1 |
| 数据量不足 | 合成数据 (DataFlow) + KL regularization |
| 通用能力下降 | 10% replay + 5e-6 LR |
| 格式要求高 | 第三阶段 500 条格式样本 + low LR |
| 多语言 | 按 language ratio 采样 + replay |
面试高频问题
Q: SFT 为什么只计算 assistant 部分的 loss?
A: user 部分是"指令",模型不需要学习"如何提问"。只学习"如何回答"可以防止模型学到不相关的模式,并使梯度集中在真正需要优化的部分。
Q: 全量微调后模型变笨了(MMLU 下降),怎么办?
A: 1) 加入 5-10% replay data 2) 使用 KL divergence loss 约束模型 3) 降低学习率 4) 使用 Layer-wise LR decay(底层低、顶层高)
Q: SFT 数据中是否需要包含"拒绝回答"样本?
A: 强烈建议!包含 5-10% 的"拒绝回答"样本可以有效降低幻觉,让模型在面对不知道的问题时说"不知道"而不是胡编乱造。
下一篇预告:训练与微调篇第 4 篇——指令微调与 SFT 数据构建技巧。
参考资料
[1] Peking University DCAI. DataFlow: Systematic Data Engineering for LLMs. arXiv:2512.16676. 2026
[2] HuggingFace TRL. Transformer Reinforcement Learning. 2024-2026
[3] Qwen Team. Qwen3 Technical Report. 2026
[4] DataComp-LM. DCLM: A Framework for LLM Data-Centric AI. 2024
[5] Kirkpatrick et al. Overcoming catastrophic forgetting in neural networks. PNAS 2017

1823

被折叠的 条评论
为什么被折叠?



