神经网络/深度学习
第二章 Python机器学习入门之EfficientNetV2的使用
前言
本文主要是复现efficientnetv2网络代码,训练自己的材质分类模型,学习记录下来。
大佬文章:https://blog.csdn.net/qq_37541097/article/details/116933569
大佬的讲解视频:https://www.bilibili.com/video/BV1Xy4y1g74u/?spm_id_from=333.1007.top_right_bar_window_history.content.click&vd_source=b9a1a486cbe5d7fe623135210f75aca8
论文下载地址:https://arxiv.org/abs/2104.00298
原论文提供代码:https://github.com/google/automl/tree/master/efficientnetv2

提示:以下是本篇文章正文内容,下面案例可供参考
一、EfficientNetV2是什么?
EfficientNetV2是由谷歌提出的一种新型神经网络架构,用于图像分类任务。它在EfficientNet的基础上进行了改进,通过优化模型的结构和训练过程,提高了模型的效率和性能。
EffNetV2-S(21k)(红色曲线)是一个EfficientNetV2家族的模型,使用21k个类别的数据进行预训练。该模型在较短的训练时间内(约0.5TPU天)达到了85%准确率,其准确率之高,模型大小之小,选为这次训练的基础模型(自己的小笔记本是4060labtap,感觉没啥问题)

模型可以去大佬的文章中找到代码链接,再从链接中找到百度网盘的模型下载链接。
二、EfficientNetV2代码的复现
首先给大家看一下整体的目录

我这里材质分类分了六种,分别是3D,玻璃,镜面 ,金属,平滑,纹理(当然这个是我人工定义的,大家可以根据自己的需求进行更改)
1.准备工作
在train文件夹下面设置好你所设定的种类,我这里六种,我就设置了六个文件夹并且以种类的名字命名,里面添加好各个种类的图片(图片根据自己的需求添加就行)
2.训练模型
train.py代码
import os
import math
import argparse
import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import torch.optim.lr_scheduler as lr_scheduler
from model import efficientnetv2_s as create_model
from my_dataset import MyDataSet
from utils import read_split_data, train_one_epoch, evaluate
def main(args):
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
print(args)
print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')
tb_writer = SummaryWriter()
if os.path.exists("./weights") is False:
os.makedirs("./weights")
train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)
img_size = {
"s": [300, 384], # train_size, val_size
"m": [384, 480],
"l": [384, 480]}
num_model = "s"
data_transform = {
"train": transforms.Compose([transforms.RandomResizedCrop(img_size[num_model][0]),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
"val": transforms.Compose([transforms.Resize(img_size[num_model][1]),
transforms.CenterCrop(img_size[num_model][1]),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]


1362

被折叠的 条评论
为什么被折叠?



