【Diffusers库】第四篇 训练一个扩散模型(Unconditional)

本教程介绍如何在Smithsonian Butterflies数据集子集上从头训练UNet2DModel,生成无条件图片生成模型。内容包括下载数据、配置模型文件、加载数据、创建模型与调度器、训练模型等步骤,还给出了完整版代码,训练完成后可查看模型生成图片的效果。

写在前面的话

  这是我们研发的用于 消费决策的AI助理 ,我们会持续优化,欢迎体验与反馈。微信扫描二维码,添加即可。
  官方链接:https://ailab.smzdm.com/

************************************************************** 分割线 *******************************************************************

  本教程将讲述 如何在Smithsonian Butterflies数据集的子集上,从头开始训练UNet2DModel,最终训练个【无条件图片生成模型】,就是不能进行文生图的啊,我觉得比较适合垂直领域的数据训练。

下载数据

  训练的数据集在这个:https://huggingface.co/datasets/huggan/smithsonian_butterflies_subset。可以使用代码进行下载。
   完整的代码在最后,因为网络的原因,调代码花了一些时间(官网默认上传hugging face,我没上传),所以要运行的话,copy最后的全部代码。我的显卡是3050,8G显存。

from datasets import load_dataset
dataset = load_dataset("huggan/smithsonian_butterflies_subset")

  代码运行完成后,它的默认下载路径在:

/Users/用户名/.cache/huggingface/datasets

  进入该目录后,可以看见下载的文件夹。

模型配置文件

  为了方便起见,训练一个包含超参数的配置文件:

from dataclasses import dataclass


@dataclass
class TrainingConfig:
    image_size = 128  # the generated image resolution
    train_batch_size = 16
    eval_batch_size = 16  # how many images to sample during evaluation
    num_epochs = 50
    gradient_accumulation_steps = 1
    learning_rate = 1e-4
    lr_warmup_steps = 500
    save_image_epochs = 10
    save_model_epochs = 30
    mixed_precision = "fp16"  # `no` for float32, `fp16` for automatic mixed precision
    output_dir = "ddpm-butterflies-128"  # the model name locally and on the HF Hub

    push_to_hub = True  # whether to upload the saved model to the HF Hub
    hub_private_repo = False
    overwrite_output_dir = True  # overwrite the old model when re-running the notebook
    seed = 0


config = TrainingConfig()

加载数据

from datasets import load_dataset

config.dataset_name = "huggan/smithsonian_butterflies_subset"
dataset = load_dataset(config.dataset_name, split="train")

  大家也可以添加一下,Smithsonian Butterflies 数据集中一些其他数据(创建一个ImageFolder文件夹),但是在 配置文件中 要进行添加对应的变量 imagefolder。当然,也可以使用自己的数据。

import matplotlib.pyplot as plt

fig, axs = plt.subplots(1, 4, figsize=(16, 4))
for i, image in enumerate(dataset[:4]["image"]):
    axs[i].imshow(image)
    axs[i].set_axis_off()
fig.show()

在这里插入图片描述
  不过,这些图像的大小都不一样,所以你需要先对它们进行预处理:

  1. 统一图像尺寸:缩放到配置文件中的指定尺寸;
  2. 数据增强:通过裁剪、翻转等方法
  3. 标准化:将像素值的范围控制在[-1, 1]
from torchvision import transforms

preprocess = transforms.Compose(
    [
        transforms.Resize((config.image_size, config.image_size)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ]
)

  对图像进行预处理,将图像通道转化为RGB

def transform(examples):
    images = [preprocess(image.convert("RGB")) for image in examples["image"]]
    return {
   
   "images": images}


dataset.set_transform(transform)

  可以再次可视化图像,以确认它们是否已经被调整。之后就可以将数据集打包到DataLoader中进行训练了!

import torch
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)

创建一个UNet2DModel

from diffusers import UNet2DModel

model = UNet2DModel(
    sample_size=config.image_size,  # the target image resolution
    in_channels=3,  # the number of input channels, 3 for RGB images
    out_channels=3,  # the number of output channels
    layers_per_block=2,  # how many ResNet layers to use per UNet block
    block_out_channels=(128, 128, 256, 256, 512, 512),  # the number of output channels for each UNet block
    down_block_types=(
        "DownBlock2D",  # a regular ResNet downsampling block
        "DownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
        "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
        "DownBlock2D",
    ),
    up_block_types=(
        "UpBlock2D",  # a regular ResNet upsampling block
        "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
    ),
)

  还有一个方法,快速检查样本图像的形状是否与模型输出形状匹配。

sample_image = dataset[0]["images"].unsqueeze(0)

print("Input shape:", sample_image.shape)
print("Output shape:", model(sample_image, timestep=0).sample.shape)

  还需要一个调度器来为图像添加一些噪声。

创建一个调度器

  调度器的作用在不同的场景下会生成不同的作用,这取决于您是使用模型进行训练还是推理。
  在推理过程中,调度器从噪声中生成图像。
  在训练过程中,调度器从图像上生成噪声。
  可以看下DDPMScheduler调度器给图像增加噪声的效果:

import torch
from PIL import Image
from diffusers import DDPMScheduler

noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
noise = torch.randn(sample_image.shape)
timesteps = torch.LongTensor([50])
noisy_image = noise_scheduler.add_noise(sample_image, noise, timesteps)

Image.fromarray(
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值