写在前面的话
这是我们研发的用于 消费决策的AI助理 ,我们会持续优化,欢迎体验与反馈。微信扫描二维码,添加即可。
官方链接:https://ailab.smzdm.com/
本教程将讲述 如何在Smithsonian Butterflies数据集的子集上,从头开始训练UNet2DModel,最终训练个【无条件图片生成模型】,就是不能进行文生图的啊,我觉得比较适合垂直领域的数据训练。
下载数据
训练的数据集在这个:https://huggingface.co/datasets/huggan/smithsonian_butterflies_subset。可以使用代码进行下载。
完整的代码在最后,因为网络的原因,调代码花了一些时间(官网默认上传hugging face,我没上传),所以要运行的话,copy最后的全部代码。我的显卡是3050,8G显存。
from datasets import load_dataset
dataset = load_dataset("huggan/smithsonian_butterflies_subset")
代码运行完成后,它的默认下载路径在:
/Users/用户名/.cache/huggingface/datasets
进入该目录后,可以看见下载的文件夹。
模型配置文件
为了方便起见,训练一个包含超参数的配置文件:
from dataclasses import dataclass
@dataclass
class TrainingConfig:
image_size = 128 # the generated image resolution
train_batch_size = 16
eval_batch_size = 16 # how many images to sample during evaluation
num_epochs = 50
gradient_accumulation_steps = 1
learning_rate = 1e-4
lr_warmup_steps = 500
save_image_epochs = 10
save_model_epochs = 30
mixed_precision = "fp16" # `no` for float32, `fp16` for automatic mixed precision
output_dir = "ddpm-butterflies-128" # the model name locally and on the HF Hub
push_to_hub = True # whether to upload the saved model to the HF Hub
hub_private_repo = False
overwrite_output_dir = True # overwrite the old model when re-running the notebook
seed = 0
config = TrainingConfig()
加载数据
from datasets import load_dataset
config.dataset_name = "huggan/smithsonian_butterflies_subset"
dataset = load_dataset(config.dataset_name, split="train")
大家也可以添加一下,Smithsonian Butterflies 数据集中一些其他数据(创建一个ImageFolder文件夹),但是在 配置文件中 要进行添加对应的变量 imagefolder。当然,也可以使用自己的数据。
import matplotlib.pyplot as plt
fig, axs = plt.subplots(1, 4, figsize=(16, 4))
for i, image in enumerate(dataset[:4]["image"]):
axs[i].imshow(image)
axs[i].set_axis_off()
fig.show()

不过,这些图像的大小都不一样,所以你需要先对它们进行预处理:
- 统一图像尺寸:缩放到配置文件中的指定尺寸;
- 数据增强:通过裁剪、翻转等方法
- 标准化:将像素值的范围控制在[-1, 1]
from torchvision import transforms
preprocess = transforms.Compose(
[
transforms.Resize((config.image_size, config.image_size)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
对图像进行预处理,将图像通道转化为RGB
def transform(examples):
images = [preprocess(image.convert("RGB")) for image in examples["image"]]
return {
"images": images}
dataset.set_transform(transform)
可以再次可视化图像,以确认它们是否已经被调整。之后就可以将数据集打包到DataLoader中进行训练了!
import torch
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)
创建一个UNet2DModel
from diffusers import UNet2DModel
model = UNet2DModel(
sample_size=config.image_size, # the target image resolution
in_channels=3, # the number of input channels, 3 for RGB images
out_channels=3, # the number of output channels
layers_per_block=2, # how many ResNet layers to use per UNet block
block_out_channels=(128, 128, 256, 256, 512, 512), # the number of output channels for each UNet block
down_block_types=(
"DownBlock2D", # a regular ResNet downsampling block
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention
"DownBlock2D",
),
up_block_types=(
"UpBlock2D", # a regular ResNet upsampling block
"AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
),
)
还有一个方法,快速检查样本图像的形状是否与模型输出形状匹配。
sample_image = dataset[0]["images"].unsqueeze(0)
print("Input shape:", sample_image.shape)
print("Output shape:", model(sample_image, timestep=0).sample.shape)
还需要一个调度器来为图像添加一些噪声。
创建一个调度器
调度器的作用在不同的场景下会生成不同的作用,这取决于您是使用模型进行训练还是推理。
在推理过程中,调度器从噪声中生成图像。
在训练过程中,调度器从图像上生成噪声。
可以看下DDPMScheduler调度器给图像增加噪声的效果:
import torch
from PIL import Image
from diffusers import DDPMScheduler
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
noise = torch.randn(sample_image.shape)
timesteps = torch.LongTensor([50])
noisy_image = noise_scheduler.add_noise(sample_image, noise, timesteps)
Image.fromarray(

本教程介绍如何在Smithsonian Butterflies数据集子集上从头训练UNet2DModel,生成无条件图片生成模型。内容包括下载数据、配置模型文件、加载数据、创建模型与调度器、训练模型等步骤,还给出了完整版代码,训练完成后可查看模型生成图片的效果。

2112

被折叠的 条评论
为什么被折叠?



