Vision Transformer 从零实现:理解 ViT 的核心机制

Vision Transformer 从零实现:理解 ViT 的核心机制

1. 引言

Vision Transformer (ViT) 在 2020 年由 Google 提出,首次证明纯 Transformer 架构可以在图像分类任务上超越 CNN。ViT 的核心思想是将图像切分为固定大小的 patch,每个 patch 视为一个 “token”,然后用标准 Transformer Encoder 处理。

本文目标: 用 PyTorch 从零实现 ViT,并在 CIFAR-10 上训练验证。

2. ViT 架构总览

输入图像 (224×224×3)
    ↓
Patch Embedding (16×16 patches → 196 tokens × 768 维)
    ↓
[CLS] Token + Position Embedding
    ↓
Transformer Encoder × 12
    ├── LayerNorm → Multi-Head Self-Attention → Residual
    └── LayerNorm → FFN (MLP) → Residual
    ↓
MLP Head → 分类输出

3. 核心实现

3.1 Patch Embedding

import torch
import torch.nn as nn

class PatchEmbedding(nn.Module):
    """将图像切分为 patches 并映射到嵌入空间"""

    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2

        # 用卷积实现 patch 切分 + 线性投影(等价操作,但更高效)
        self.projection = nn.Conv2d(
            in_channels, embed_dim,
            kernel_size=patch_size, stride=patch_size
        )

    def forward(self, x):
        # x: (B, C, H, W)
        x = self.projection(x)       # (B, embed_dim, H/P, W/P)
        x = x.flatten(2)             # (B, embed_dim, num_patches)
        x = x.transpose(1, 2)        # (B, num_patches, embed_dim)
        return x

3.2 Multi-Head Self-Attention

class MultiHeadAttention(nn.Module):
    """多头自注意力机制"""

    def __init__(self, embed_dim=768, num_heads=12, dropout=0.0):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.attn_drop = nn.Dropout(dropout)
        self.proj_drop = nn.Dropout(dropout)

    def forward(self, x):
        B, N, C = x.shape

        # 生成 Q, K, V
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, B, heads, N, head_dim)
        q, k, v = qkv.unbind(0)

        # 缩放点积注意力
        attn = (q @ k.transpose(-2, -1)) * self.scale  # (B, heads, N, N)
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        # 加权聚合
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)  # (B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

3.3 Transformer Encoder Block

class TransformerBlock(nn.Module):
    """单个 Transformer Encoder 块"""

    def __init__(self, embed_dim=768, num_heads=12, mlp_ratio=4.0, dropout=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadAttention(embed_dim, num_heads, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)

        mlp_hidden = int(embed_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_hidden),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden, embed_dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x))   # Pre-Norm + Residual
        x = x + self.mlp(self.norm2(x))
        return x

3.4 完整 ViT 模型

class VisionTransformer(nn.Module):
    """完整的 Vision Transformer"""

    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_channels=3,
        num_classes=10,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4.0,
        dropout=0.1,
    ):
        super().__init__()

        # Patch Embedding
        self.patch_embed = PatchEmbedding(
            img_size, patch_size, in_channels, embed_dim
        )
        num_patches = self.patch_embed.num_patches

        # CLS Token 和 Position Embedding
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(
            torch.zeros(1, num_patches + 1, embed_dim)
        )
        self.pos_drop = nn.Dropout(dropout)

        # Transformer Encoder 堆叠
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])

        # 分类头
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

        # 初始化
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)

    def forward(self, x):
        B = x.shape[0]

        # Patch Embedding
        x = self.patch_embed(x)  # (B, num_patches, embed_dim)

        # 拼接 CLS Token
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)  # (B, num_patches+1, embed_dim)

        # 加入位置编码
        x = x + self.pos_embed
        x = self.pos_drop(x)

        # Transformer Encoder
        for block in self.blocks:
            x = block(x)

        # 分类:取 CLS Token 的输出
        x = self.norm(x)
        cls_output = x[:, 0]
        logits = self.head(cls_output)
        return logits

4. CIFAR-10 训练

import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# 数据预处理(ViT 需要 224×224 输入)
transform_train = transforms.Compose([
    transforms.Resize(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2470, 0.2435, 0.2616)),
])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train
)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=4)

# 创建模型(小版本 ViT,适合 CIFAR-10)
model = VisionTransformer(
    img_size=224,
    patch_size=16,
    num_classes=10,
    embed_dim=384,    # 缩小嵌入维度
    depth=6,          # 减少层数
    num_heads=6,
    mlp_ratio=4.0,
    dropout=0.1,
)

# 训练配置
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.05)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
criterion = nn.CrossEntropyLoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# 训练循环
for epoch in range(100):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for images, labels in trainloader:
        images, labels = images.to(device), labels.to(device)

        outputs = model(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    scheduler.step()
    acc = 100. * correct / total
    print(f"Epoch {epoch+1}/100 | Loss: {total_loss/len(trainloader):.4f} | Acc: {acc:.2f}%")

5. 注意力可视化

import matplotlib.pyplot as plt
import numpy as np

def visualize_attention(model, image_tensor, device):
    """可视化 ViT 的注意力图"""
    model.eval()
    hooks = []
    attn_weights = []

    # 注册 hook 提取注意力权重
    def hook_fn(module, input, output):
        # 重新计算注意力权重
        B, N, C = input[0].shape
        qkv = module.qkv(input[0]).reshape(B, N, 3, module.num_heads, -1)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, _ = qkv.unbind(0)
        attn = (q @ k.transpose(-2, -1)) * module.scale
        attn_weights.append(attn.softmax(dim=-1).detach().cpu())

    # 注册到最后一个注意力层
    last_attn = model.blocks[-1].attn
    hooks.append(last_attn.register_forward_hook(hook_fn))

    with torch.no_grad():
        model(image_tensor.unsqueeze(0).to(device))

    # 可视化 [CLS] token 对所有 patch 的注意力
    attn = attn_weights[0][0]  # (heads, N, N)
    attn = attn.mean(dim=0)     # 平均所有头
    cls_attn = attn[0, 1:]      # CLS 对各 patch 的注意力
    grid_size = int(cls_attn.shape[0] ** 0.5)
    attn_map = cls_attn.reshape(grid_size, grid_size).numpy()

    plt.imshow(attn_map, cmap='hot', interpolation='nearest')
    plt.colorbar()
    plt.title("CLS Token Attention Map")
    plt.savefig("attention_map.png", dpi=150)

    for h in hooks:
        h.remove()

6. ViT 变体对比

模型参数量嵌入维度深度头数ImageNet Top-1
ViT-Ti5.7M19212372.7%
ViT-S22M38412679.4%
ViT-B86M768121281.8%
ViT-L307M1024241685.2%
ViT-H632M1280321688.6%

7. ViT vs CNN 深度对比

特性CNNViT
归纳偏置局部性、平移不变性无(需大数据学习)
全局建模需要深层堆叠第一层就能全局交互
计算复杂度O(n·k²·c)O(n²·d)(n 为 token 数)
小数据表现优秀较差(需预训练)
大数据表现饱和持续提升
可解释性较弱注意力图可可视化

8. 总结

ViT 的核心创新在于将图像 patch 化后直接用 Transformer 处理,打破了 CNN 在视觉领域的垄断。关键理解:

  1. Patch Embedding = 切图 + 线性投影,等价于无重叠卷积
  2. CLS Token 是一个可学习的分类标记,聚合全局信息
  3. Position Embedding 对 ViT 至关重要(否则丧失空间信息)
  4. Pre-Norm(先 LayerNorm 再 Attention)比 Post-Norm 训练更稳定
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

大鱼>

一分也是爱

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值