实践 | 基于ResNet-50实现CIFAR10数据集分类

本文通过PaddlePaddle库实现CIFAR10数据集上的猫狗分类任务,涉及自定义数据集、ResNet18网络配置、基础API训练与验证,以及高层API的模型训练和预测。

在这里插入图片描述
在这里插入图片描述
图像分类是根据图像的语义信息将不同类别图像区分开来,是计算机视觉中重要的基本问题

猫狗分类属于图像分类中的粗粒度分类问题

Step1:准备数据

自定义数据集
(1)数据集介绍

我们使用CIFAR10数据集。CIFAR10数据集包含60,000张32x32的彩色图片,10个类别,每个类包含6,0000张。其中50,000张图片作为训练集,10000张作为验证集。这次我们只对其中的猫和狗两类进行预测。

在这里插入图片描述
(2)train_dataset和eval_dataset

自定义读取器处理训练集和测试集

paddle.reader.shuffle()表示每次缓存BUF_SIZE个数据项,并进行打乱

paddle.batch()表示每BATCH_SIZE组成一个batch

# 导入需要的包
import paddle
# import os 
# import numpy as np
# from PIL import Image
# import matplotlib.pyplot as plt
# import sys
# import pickle
# from paddle.vision.transforms import ToTensor
import paddle.nn as nn
import paddle.nn.functional as F
print("本教程基于Paddle的版本号为:"+paddle.__version__)
'''
参数配置
'''
train_parameters = {
   
   
    "input_size": [3, 32, 32],                           #输入图片的shape
    "src_path":"/home/aistudio/data/data9154/cifar-10-python.tar.gz",       #原始数据集路径
    "target_path":"/home/aistudio/cifar-10-batches-py",        #要解压的路径 
    "num_epochs": 1,                                    #训练轮数
    "train_batch_size": 64,                             #批次的大小
    "learning_strategy": {
   
                                 #优化函数相关的配置
        "lr": 0.001                                     #超参数学习率
    } 
}
def unzip_data(src_path,target_path):

    '''
    解压原始数据集,将src_path路径下的zip包解压至/home/aistudio/目录下
    '''

    if(not os.path.isdir(target_path)):    
        import tarfile
        tar = tarfile.open(src_path,'r')
        tar.extractall(PATH=target_path)
        tar.close()
    else:
        print("文件已解压")
'''
参数初始化
'''
src_path=train_parameters['src_path']
target_path=train_parameters['target_path']
batch_size=train_parameters['train_batch_size']
image_size=train_parameters['input_size']
epoch_num=train_parameters['num_epochs']
lr=train_parameters['learning_strategy']['lr']
'''
解压原始数据到指定路径
'''
unzip_data(src_path,target_path)
#定义数据序列化函数
def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

print(unpickle("cifar-10-batches-py/data_batch_1").keys())
print(unpickle("cifar-10-batches-py/test_batch").keys())

dict_keys([b’batch_label’, b’labels’, b’data’, b’filenames’])
dict_keys([b’batch_label’, b’labels’, b’data’, b’filenames’])

自定义数据集

'''
自定义数据集
'''
from paddle.io import Dataset
class MyDataset(paddle.io.Dataset)
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值