使用pytorch构建resnet50-v2

本文介绍了ResNetV2相对于原版的改进,包括将BN层和ReLU激活函数前置,以及在残差结构内的不同设计。作者在CIFAR-10数据集上对比了两种结构,结果显示V2模型的测试错误率显著降低。文章还展示了数据预处理、数据加载、模型构建和训练过程。

resnet-v2改进点以及和v1差别

在这里插入图片描述
🧲 改进点:

(a)original表示原始的ResNet的残差结构,(b)proposed表示新的ResNet的残差结构。

主要差别就是

(a)结构先卷积后进行BN和激活函数计算,最后执行addition后再进行ReLU计算;(b)结构先进性BN和激活函数计算后卷积,把addition后的ReLU计算放到了残差结构内部。

📌 改进结果:作者使用这两种不同的结构再CIFAR-10数据集上做测试,模型用的是1001层的ResNet模型。从图中的结果我们可以看出,(b)proposed的测试集错误率明显更低一些,达到了4.92%的错误率。(a)original的测试集错误率是7.61%

在相同的数据集来实现(之前是resnet50)

导入需要的包

import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import os, PIL
import numpy as np
from torch.utils.data import DataLoader,Subset
from torchvision import transforms
from torchvision.datasets import ImageFolder

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
plt.rcParams["font.sans-serif"] = ["SimHei"]
plt.rcParams["axes.unicode_minus"] = False

遍历并展示图片数量

path = "./data/bird_photos"
f = []
for root, dirs, files in os.walk(path):
    for name in files:
        f.append(os.path.join(root, name))
print("图片总数:",len(f))

图片总数: 565

导入数据

transform = transforms.Compose([
    transforms.Resize(224), #统一图片大小
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225]) #标准化
])
data = ImageFolder(path, transform = transform)
print(data.class_to_idx)
print(data.classes)

{‘Bananaquit’: 0, ‘Black Skimmer’: 1, ‘Black Throated Bushtiti’: 2, ‘Cockatoo’: 3}
[‘Bananaquit’, ‘Black Skimmer’, ‘Black Throated Bushtiti’, ‘Cockatoo’]

可视化数据

def imageshow(data, idx, norm = None, label = False):
    plt.figure(dpi=100,figsize=(12,4))
    for i in range(15):
        plt.subplot(3, 5, i + 1)
        img = data[idx[i]][0].numpy().transpose((1, 2, 0))
        if norm is not None:
            mean = norm[0]
            std = norm[1]
            img = img * std + mean
        img = np.clip(img, a_min = 0, a_max=1)
        plt.imshow(img)
    
        if label:
            plt.title(data.classes[data[idx[i]][1]])
        plt.axis('off')
        plt.tight_layout(pad=0.5)
    plt.show()

norm = [[0.485, 0.456, 0.406],[0.229, 0.224, 0.225]]
np.random.seed(22)
demo_img_ids = np.random.randint(564,size = 15)
imageshow(data, demo_img_ids, norm = norm
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

爱挠静香的下巴

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值