创建一个数据集,并加载数据集,一篇文章让小白也能看懂

1.自定义一个数据集

1.1加载依赖项

import torch
from torch.utils.data import Dataset

torch.utils.data 是 PyTorch 中提供 数据加载与处理工具 的模块,主要用来方便地管理和批量读取数据。

主要作用有3个:定义数据集 Dataset 的基类、数据加载、数据切分。

1.2自定义一个数据集

# 自定义数据集类
class MyDataset(Dataset):
    def __init__(self, X_data, Y_data):
        """
        初始化数据集,X_data 和 Y_data 是两个列表或数组
        X_data: 输入特征
        Y_data: 目标标签
        """
        self.X_data = X_data
        self.Y_data = Y_data

    def __len__(self):
        """返回数据集的大小"""
        return len(self.X_data)

    def __getitem__(self, idx):
        """返回指定索引的数据"""
        x = torch.tensor(self.X_data[idx], dtype=torch.float32)  # 转换为 Tensor
        y = torch.tensor(self.Y_data[idx], dtype=torch.float32)
        return x, y

1.2.1 创建一个类

class MyDataset(Dataset)

创建一个类,继承自Dataset类

1.2.2 初始化函数

def __init__(self, X_data, Y_data):

        """

        初始化数据集,X_data 和 Y_data 是两个列表或数组

        X_data: 输入特征

        Y_data: 目标标签

        """

        self.X_data = X_data

        self.Y_data = Y_data

这边主要是进行数据的初始化,使用这个函数,把类实例化的对象进行一个初始化。

self.X_data是类MyDataset的属性,X_data是一个局部变量,变量外部传入。

MyDataset._init_(X_data,Y_data)

使用了这个语句以后,MyDataset.X_data的值与传入的X_data的值一致,

比如X_data = [[1, 2], [3, 4], [5, 6], [7, 8]]的话,初始化以后,MyDataset.X_data的值也是为[[1, 2], [3, 4], [5, 6], [7, 8]]

1.2.3 其他函数

    def __len__(self):

        """返回数据集的大小"""

        return len(self.X_data)

返回X_data的长度,X_data = [[1, 2], [3, 4], [5, 6], [7, 8]],使用了这个函数,返回的结果就是4

    def __getitem__(self, idx):

        """返回指定索引的数据"""

        x = torch.tensor(self.X_data[idx], dtype=torch.float32)  # 转换为 Tensor

        y = torch.tensor(self.Y_data[idx], dtype=torch.float32)

        return x, y

返回对应的特征与标签

X_data = [[1, 2], [3, 4], [5, 6], [7, 8]]  # 输入特征

Y_data = [1, 0, 1, 0]  # 目标标签

比如print(dataset.__getitem__(3)),返回的值为

(tensor([7., 8.]), tensor(0.))

对应的X_data索引为3的特征为[7,8]

对应的Y_data索引为3的标签为0

1.2.4 创建一个数据集实例

# 创建数据集实例
dataset = MyDataset(X_data, Y_data)

在执行这条的时候,会自动调用_init_函数,等效于下面,但是不建议这么写

dataset = MyDataset([], [])

dataset.__init__(X_data, Y_data)

加载数据

from torch.utils.data import DataLoader

# 创建 DataLoader 实例,batch_size 设置每次加载的样本数量
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

# 打印加载的数据
for epoch in range(1):
    for batch_idx, (inputs, labels) in enumerate(dataloader):
        print(f'Batch {batch_idx + 1}:')
        print(f'Inputs: {inputs}')
        print(f'Labels: {labels}')

数据加载模块

from torch.utils.data import DataLoader

导入 PyTorch 中的数据加载器 DataLoader。它可以按批次读取 Dataset 中的数据。

# 创建 DataLoader 实例,batch_size 设置每次加载的样本数量

dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

dataset是数据集,batch_size=2可以一次提取两个样本,shuffle是这种在训练的时候是否需要打乱样本,如果不打乱,每次的结果会是一样的,通常需要打乱样本,获得更好的训练效果

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值