1.自定义一个数据集
1.1加载依赖项
import torch
from torch.utils.data import Dataset
torch.utils.data 是 PyTorch 中提供 数据加载与处理工具 的模块,主要用来方便地管理和批量读取数据。
主要作用有3个:定义数据集 Dataset 的基类、数据加载、数据切分。
1.2自定义一个数据集
# 自定义数据集类
class MyDataset(Dataset):
def __init__(self, X_data, Y_data):
"""
初始化数据集,X_data 和 Y_data 是两个列表或数组
X_data: 输入特征
Y_data: 目标标签
"""
self.X_data = X_data
self.Y_data = Y_data
def __len__(self):
"""返回数据集的大小"""
return len(self.X_data)
def __getitem__(self, idx):
"""返回指定索引的数据"""
x = torch.tensor(self.X_data[idx], dtype=torch.float32) # 转换为 Tensor
y = torch.tensor(self.Y_data[idx], dtype=torch.float32)
return x, y
1.2.1 创建一个类
class MyDataset(Dataset)
创建一个类,继承自Dataset类
1.2.2 初始化函数
def __init__(self, X_data, Y_data):
"""
初始化数据集,X_data 和 Y_data 是两个列表或数组
X_data: 输入特征
Y_data: 目标标签
"""
self.X_data = X_data
self.Y_data = Y_data
这边主要是进行数据的初始化,使用这个函数,把类实例化的对象进行一个初始化。
self.X_data是类MyDataset的属性,X_data是一个局部变量,变量外部传入。
MyDataset._init_(X_data,Y_data)
使用了这个语句以后,MyDataset.X_data的值与传入的X_data的值一致,
比如X_data = [[1, 2], [3, 4], [5, 6], [7, 8]]的话,初始化以后,MyDataset.X_data的值也是为[[1, 2], [3, 4], [5, 6], [7, 8]]
1.2.3 其他函数
def __len__(self):
"""返回数据集的大小"""
return len(self.X_data)
返回X_data的长度,X_data = [[1, 2], [3, 4], [5, 6], [7, 8]],使用了这个函数,返回的结果就是4
def __getitem__(self, idx):
"""返回指定索引的数据"""
x = torch.tensor(self.X_data[idx], dtype=torch.float32) # 转换为 Tensor
y = torch.tensor(self.Y_data[idx], dtype=torch.float32)
return x, y
返回对应的特征与标签
X_data = [[1, 2], [3, 4], [5, 6], [7, 8]] # 输入特征
Y_data = [1, 0, 1, 0] # 目标标签
比如print(dataset.__getitem__(3)),返回的值为
(tensor([7., 8.]), tensor(0.))
对应的X_data索引为3的特征为[7,8]
对应的Y_data索引为3的标签为0
1.2.4 创建一个数据集实例
# 创建数据集实例
dataset = MyDataset(X_data, Y_data)
在执行这条的时候,会自动调用_init_函数,等效于下面,但是不建议这么写
dataset = MyDataset([], [])
dataset.__init__(X_data, Y_data)
加载数据
from torch.utils.data import DataLoader
# 创建 DataLoader 实例,batch_size 设置每次加载的样本数量
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
# 打印加载的数据
for epoch in range(1):
for batch_idx, (inputs, labels) in enumerate(dataloader):
print(f'Batch {batch_idx + 1}:')
print(f'Inputs: {inputs}')
print(f'Labels: {labels}')
数据加载模块
from torch.utils.data import DataLoader
导入 PyTorch 中的数据加载器 DataLoader。它可以按批次读取 Dataset 中的数据。
# 创建 DataLoader 实例,batch_size 设置每次加载的样本数量
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
dataset是数据集,batch_size=2可以一次提取两个样本,shuffle是这种在训练的时候是否需要打乱样本,如果不打乱,每次的结果会是一样的,通常需要打乱样本,获得更好的训练效果


4997

被折叠的 条评论
为什么被折叠?



