一、综述
Dataset :对数据进行抽象,将数据包装为Dataset类。
DataLoader:在 Dataset之上对数据进行进一步处理,包括进行乱序处理,获取一个batch_size的数据等。

二、Dataset
在Dataset类中必须重新 getitem()、len()两个方法。
- 创建数据
ss=np.linspace(1,100,100)
np.savetxt("sample_data.txt", ss.reshape(-1,4))
数据格式如下所示:

2. 创建自定义Dataset
import numpy as np
import torch as t
from torch.utils.data import Dataset
class MyDataSet(Dataset):
def __init__(self):
#使用numy读取数据
txt_data = np.loadtxt('sample_data.txt')
#取数据前三列为x
self._x = t.from_numpy(txt_data[:,:3])
#取数据最后一列为target值
self._y = t.from_numpy(txt_data[:,-1])
#获取数据的长度
self._len = len(txt_data)
def __getitem__(self,item):
#item对应的一条数据,可以是一张图,可以是一句话,总之 记住,一条数据。
return self._x[item],self._y[item]
def __len__(self):
#带训练数据的总长度, 如果是dataframe, 直接len(df)即可,或者在init的时候传入了长度,直接返回
return self._len
dataset = MyDataSet()
print(len(dataset))
data =next(iter(dataset))
print(data)

三、 DataLoader

关键参数:
- dataset :数据集
- batch_size : 一个批次的大小
- shuffle : 是否乱序处理
- sampler:非常简单的多线程方法, 只要设置为>=1, 就可以多线程预读数据啦.
- drop_last:如果数据集大小不能整除batch_size的话,是否删除最后一个batch
from torch.utils.data import DataLoader
data = MyDataSet()
dataloader = DataLoader(data,batch_size=4,shuffle=True,drop_last=True,num_workers=0)
for i,data in enumerate(dataloader):
print('batch---->',i+1)
inputs,labels=data
print(inputs)
print(labels)
print("*"*30)

四、random_split
pytorch中 random_split类似于 sklearn中的train_test_split类似的功能,将数据切分为训练集、测试集、验证集。
from torch.utils.data import random_split
all_length =len(dataset)
train_size =int(0.8*all_length)
test_size = all_length - train_size
#切分数据集
train_dataset,test_dataset = random_split(dataset,[train_size,test_size])
train_loader = DataLoader(train_dataset, batch_size=3, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=3, shuffle=False, num_workers=0)
for i,curr_data in enumerate(train_loader):
print('batch---->',i+1)
inputs,labels=curr_data
print(inputs)
print(labels)
print("*"*30)
```

本文介绍如何使用PyTorch中的Dataset和DataLoader类来加载和处理数据,包括自定义数据集、批量加载、数据乱序及多线程预读等功能,并演示如何通过random_split方法划分训练集和测试集。

886

被折叠的 条评论
为什么被折叠?



