本章将详细介绍如何使用libtorch自带的数据加载模块。
自定义数据集
简介
要自定义数据加载模块,需要继承torch::data::Dataset这个基类实现派生类。
与pytorch中需要实现初始化函数init,获取函数getitem以及数据集大小函数len类似的是,在libtorch中同样需要处理好初始化函数,get()函数和size()函数。
例程的代码结构
例程中使用了一个图像分类任务来进行介绍,使用pytorch官网提供的昆虫分类数据集
遍历图像文件
例程中使用了io.h来遍历文件夹。
首先实现遍历文件夹的函数:
接受数据集文件夹路径image_dir和图片类型image_type,将遍历到的图片路径和其类别分别存储到list_images和list_labels,最后lable变量用于表示类别计数。
通过该函数,会得到所有图像的绝对地址,通过这些地址就可以获得图像。
#include <io.h>
void load_data_from_folder(std::string path, std::string type, std::vector<std::string> &list_images, std::vector<int> &list_labels, int label);
void load_data_from_folder(std::string path, std::string type, std::vector<std::string> &list_images, std::vector<int> &list_labels, int label)
{
/*
* path:文件夹地址
* type:图片类型
* list_images:所有图片的名称
* list_label:各个图片的标签,也就是所属的类
* label:类别的个数
*/
long long hFile = 0; //句柄
struct _finddata_t fileInfo;// 记录读取到文件的信息
std::string pathName;
// 调用_findfirst函数,其第一个参数为遍历的文件夹路径,*代表任意文件。注意路径最后,需要添加通配符
// 如果失败,返回-1,否则,就会返回文件句柄,并且将找到的第一个文件信息放在_finddata_t结构体变量中
if ((hFile = _findfirst(pathName.assign(path).

:数据加载模块&spm=1001.2101.3001.5002&articleId=140994030&d=1&t=3&u=9d67b14bb749447aa7000bd9321f10d8)
4255

被折叠的 条评论
为什么被折叠?



