手把手用PyTorch实现ResNeXt:比ResNet更强的分组卷积实战
如果你已经用PyTorch实现过ResNet,并且对它的残差连接和瓶颈结构(Bottleneck)了如指掌,那么当你第一次看到ResNeXt的论文时,可能会觉得它只是ResNet的一个“微调版”。但恰恰是这个看似微小的改动——将单一的3x3卷积替换为分组卷积(Group Convolution),并引入“基数(Cardinality)”这一新维度——让ResNeXt在ImageNet等基准测试上,以几乎相同的参数量和计算量,实现了显著的精度提升。今天,我们就抛开复杂的理论推导,直接从代码层面切入,用PyTorch一步步构建一个ResNeXt-32x4d模型,并深入剖析其核心设计“基数”是如何通过torch.nn.Conv2d的groups参数高效实现的。我们还会在CIFAR-10数据集上,将它与标准的ResNet-50进行训练对比,用实际数据说话,看看这个“更强的ResNet”究竟强在哪里。
1. 从ResNet到ResNeXt:理解“基数”这个新维度
在ResNet中,一个标准的Bottleneck残差块由三个卷积层构成:1x1卷积降维 -> 3x3卷积进行空间特征提取 -> 1x1卷积升维。整个网络性能的提升,主要依赖于堆叠更多的层(增加深度)或增加每层的通道数(增加宽度)。然而,ResNeXt的作者何恺明等人提出了一个新颖的观点:除了深度和宽度,变换路径的数量(Cardinality) 是提升模型表达能力的另一个同等重要、甚至更高效的维度。
提示:基数(Cardinality)在ResNeXt论文中被定义为“变换集的大小”。简单理解,它就是一个残差块内部,并行执行的、结构相同的变换路径的数量。
ResNeXt的核心思想借鉴了Inception模块的“拆分-变换-合并(Split-Transform-Merge)”策略,但做了关键性的简化:所有并行路径的拓扑结构完全相同。这带来了两个巨大优势:一是设计极其模块化,超参数极少;二是可以通过分组卷积这一标准操作高效实现,无需复杂的多分支管理。
为了直观感受ResNeXt的设计,我们对比一下ResNet-50和ResNeXt-50 (32x4d) 的一个Bottleneck块:
| 组件 | ResNet-50 Bottleneck | ResNeXt-50 (32x4d) Bottleneck |
|---|---|---|
| 路径1 (降维) | 1x1 Conv, 64-d | 1x1 Conv, 128-d |
| 路径2 (特征提取) | 3x3 Conv, 64-d | 3x3 Group Conv, 128-d, groups=32 |
| 路径3 (升维) | 1x1 Conv, 256-d | 1x1 Conv, 256-d |
| 关键变化 | 单一变换路径 | 32条并行变换路径,每条路径处理4维特征(128/32=4) |
可以看到,ResNeXt将中间那个3x3卷积的输入/输出通道数从64增加到了128,但同时将这个卷积操作分成了32组(groups=32)独立进行。这意味着,128个输入通道被均分到32个组里,每组4个通道,组与组之间的卷积计算是相互独立的。这32组卷积的结果最终会在通道维度上拼接(Concatenate)起来,形成最终的128维输出。这种设计在几乎不增加参数量的前提下,极大地丰富了特征的变换方式。
2. 核心实现:用PyTorch构建ResNeXt Bottleneck模块
理解了设计理念,实现起来就清晰了。我们将构建一个名为Bottleneck的类,它将是ResNeXt网络的基本构建块。与标准ResNet的Bottleneck相比,关键区别就在于中间那个3x3卷积层。
首先,我们导入必要的库,并定义一些基础配置。
import torch
import torch.nn as nn
import torch.nn.functional as F
# 基础配置:我们以ResNeXt-50 (32x4d)为例
CARDINALITY = 32 # 基数,即分组数
BASE_WIDTH = 4 # 每个分组的宽度(通道数),即“4d”中的4
接下来是Bottleneck类的完整实现。我会在代码中加入详细注释,解释每一行的作用。
class Bottleneck(nn.Module):
"""
ResNeXt的瓶颈残差块。
遵循 `1x1 conv -> 3x3 group conv -> 1x1 conv` 的结构,并带有快捷连接。
"""
expansion = 4 # 最后一个1x1卷积的输出通道扩展倍数
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
"""
参数:
in_channels: 输入特征图的通道数
out_channels: 第一个1x1卷积的输出通道数(也是中间3x3卷积的输入/输出通道数)
stride: 中间3x3卷积的步幅,用于下采样
downsample: 一个可选的nn.Module,用于调整快捷连接的维度以匹配残差分支
"""
super(Bottleneck, self).__init__()
# 计算分组卷积的实际宽度
width = int(o


6947

被折叠的 条评论
为什么被折叠?



