libtorch学习笔记(7)- VGG网络训练和测试

本文详细介绍VGG网络模型的构建、训练与测试过程。VGG网络是一种深度卷积神经网络,能有效提取图像特征,适用于大规模图像识别任务。文中通过实例展示了如何使用VGG16进行猫狗分类,包括数据集加载、图像预处理、网络训练及验证。

VGG网络训练和测试

简单介绍

VGG是卷积网络里面比较常见的网络模型,相比LeNet要复杂一些,但是都属于拓补结构简单直接的前置反馈网络,详细信息可参考论文VERY DEEP CONVOLUTIONAL NETWORKS FOR LARGE-SCALE IMAGE RECOGNITION,VGG网络能够提取更多的图像特征,最后输出的特诊向量信息量更丰富,所以可以进行更大规模的分类,前面介绍的LeNet5可以产生10个分类,分别对应0~9, VGG可以产生上万个分类,识别更多的类型。VGG也是Faster RCNN的基础,Faster RCNN在现实当中实用性更强,能在任意图像内进行目标定位,然后再进行目标识别。
下图是从论文中截取的一张网络配置图,并加上代码中对应的层:
在这里插入图片描述
这张表后面结合代码再做详细描述,与前面笔记中提到的LeNet5相比:

Network 网络层数 权重层数 参数个数
LeNet5 7 5 61706
VGG16(D) 39 16 138357544
可想而知VGG要比LeNet5复杂很多,运算量也大很多,训练时间更长,训练的网络状态所占空间也越大。
在我的机器上(MacBook Pro 2017), 用CPU训练,60000张MNIST训练图片(1x28x28)2轮学习花了10分钟左右,10000张测试图片花了10秒,但是8000张左右猫狗训练集(3x可变长宽)2轮学习花了6.7个小时, 2000张测试图片识别花了11分钟左右。GPU可能快很多,目前没试过。
从上表中也能看出一般网络模型命名规律:网络模型名 + 权重层数,所以有LeNet-5, VGG-11, VGG-16和VGG-19这些名称。

网络构建

根据上述论文,选择ConvNet Configuration D,也称作VGG16,基于c++ libtorch库用如下代码创建了它,在上图中也标出了每层对应的module名称,这些网络层的命令是,模型名称缩写+所在第几层,如C29,就是卷积层(Convolutional network, C)在本网络中位于第29层, FC38就是全连接层(FullConnection, FC)在此网络中位于第38层。
另外有些网络层就是做一个简单操作,比如RELU, MaxPool等,就不注册网络层,具体就在forward中当作function来in-place处理。

VGGNet::VGGNet(int num_classes)
	: C1  (register_module("C1",  Conv2d(Conv2dOptions(  3,  64, 3).padding(1))))
	, C3  (register_module("C3",  Conv2d(Conv2dOptions( 64,  64, 3).padding(1))))
	, C6  (register_module("C6",  Conv2d(Conv2dOptions( 64, 128, 3).padding(1))))
	, C8  (register_module("C8",  Conv2d(Conv2dOptions(128, 128, 3).padding(1))))
	, C11 (register_module("C11", Conv2d(Conv2dOptions(128, 256, 3).padding(1))))
	, C13 (register_module("C13", Conv2d(Conv2dOptions(256, 256, 3).padding(1))))
	, C15 (register_module("C15", Conv2d(Conv2dOptions(256, 256, 3).padding(1))))
	, C18 (register_module("C18", Conv2d(Conv2dOptions(256, 512, 3).padding(1))))
	, C20 (register_module("C20", Conv2d(Conv2dOptions(512, 512, 3).padding(1))))
	, C22 (register_module("C22", Conv2d(Conv2dOptions(512, 512, 3).padding(1))))
	, C25 (register_module("C25", Conv2d(Conv2dOptions(512, 512, 3).padding(1))))
	, C27 (register_module("C27", Conv2d(Conv2dOptions(512, 512, 3).padding(1))))
	, C29 (register_module("C29", Conv2d(Conv2dOptions(512, 512, 3).padding(1))))
	, FC32(register_module("FC32",Linear(512 * 7 * 7, 4096)))
	, FC35(register_module("FC35",Linear(4096, 4096)))
	, FC38(register_module("FC38",Linear(4096, num_classes)))
{
   
   
...
}

torch::Tensor VGGNet::forward(torch::Tensor input)
{
   
   
	namespace F = torch::nn::functional;
	// block#1
	auto x = F::max_pool2d(F::relu(C3(F::relu(C1(input)))), F::MaxPool2dFuncOptions(2));
	// block#2
	x = F::max_pool2d(F::relu(C8(F::relu(C6(x)))), F::MaxPool2dFuncOptions(2));
	// block#3
	x = F::max_pool2d(F::relu(C15(F::relu(C13(F::relu(C11(x)))))), F::MaxPool2dFuncOptions(2));
	// block#4
	x = F::max_pool2d(F::relu(C22(F::relu(C20(F
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值