批量读取 CIFAR10 数据集
本文重点介绍通过拆分原始数据集来加载和训练神经网络模型。 当整个数据集对于本地 RAM 来说太大并且必须在使用“model.fit”训练模型之前拆分
背景 Background
我们通常这样加载 CIFAR10 图像数据集
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
(train_data, train_labels), (test_data, test_labels) = datasets.cifar10.load_data()
load_data() 函数会自动下载数据集并将其存储在 ~/.keras/datasets/cifar-10-batches-py/ 中,这就是您可以在文件夹中找到的内容。 查看 load_data() 的源代码,它会加载所有批次并让我们将数据很好地返回到 train_data, train_labels), (test_data, test_labels),但是如果您想一一阅读它们怎么办? 你会怎么做,这就是我们今天要探讨的。

Pickle
我们用python的pickle模块来加载数据。首先定义这个函数
def load_pickle(filename):
""" load correct version of pickle """
version = platform.python_version_tuple()
if version[0] == '2':
return pickle.load(filename)
elif version[0] == '3':
return pickle.load(filename, encoding='latin1')
raise ValueError("invalid python version: {}".format(version))

本文介绍了如何批量读取CIFAR10数据集,以适应内存限制,构建并训练神经网络模型。通过Pickle模块处理数据,建立简单的卷积神经网络模型,进行训练和评估。这种方法适用于处理大型图像数据集,避免一次性加载所有数据导致内存溢出。

1730

被折叠的 条评论
为什么被折叠?



