AIGC专栏9——Scalable Diffusion Models with Transformers (DiT)结构解析

开发板推荐:天空星STM32F407VET6开发板

超高性价比 STM32主控 | 超高主频 | 一板兼容百芯 | 比赛神器 | 沉金彩色丝印

网络构建
一、什么是Diffusion Transformer (DiT)
二、DiT的组成
三、生成流程
1、采样流程
a、生成初始噪声
b、对噪声进行N次采样
c、单次采样解析
I、预测噪声
II、施加噪声
d、预测噪声过程中的网络结构解析
i、adaLN-Zero结构解析
ii、patch分块处理
iii、Transformer特征提取
iv、上采样
3、隐空间解码生成图片
类别到图像预测过程代码
学习前言
近期Sora大火,它底层是Diffusion Transformer,本质上是使用Transformer结构代替原本的Unet进行噪声预测,好处是统一了文本生成与视频生成的结构。这训练优化和预测优化而言是个好事,因为只需要优化一种结构就够了。虽然觉得OpenAI是大力出奇迹,但还是得学!


源码下载地址
https://github.com/bubbliiiing/DiT-pytorch

喜欢的可以点个star噢。

网络构建
一、什么是Diffusion Transformer (DiT)
DiT基于扩散模型,所以不免包含不断去噪的过程,如果是图生图的话,还有不断加噪的过程,此时离不开DDPM那张老图,如下:

DiT相比于DDPM,使用了更快的采样器,也使用了更大的分辨率,与Stable Diffusion一样使用了隐空间的扩散,但可能更偏研究性质一些,没有使用非常大的数据集进行预训练,只使用了imagenet进行预训练。

与Stable Diffusion不同的是,DiT的网络结构完全由Transformer组成,没有Unet中大量的上下采样,结构更为简单清晰。

本文主要是解析一下整个DiT模型的结构组成,并简单一次扩散,多次扩散的流程。本文代码来自于Diffusers,Diffusers代码较为简单清晰,是一个非常好的仓库,学习起来也比较快。

二、DiT的组成
DiT由三大部分组成。
1、Sampler采样器。
2、Variational Autoencoder (VAE) 变分自编码器。
3、UNet 主网络,噪声预测器。

每一部分都很重要,由于DiT的官方版本并没有在 大规模文本图片 的 数据集上训练,只使用了imagenet进行预训练。所以它并没有文本输入,而是以标签作为输入。因此,DiT只能按照类别进行图片生成,可以生成imagenet中的1000类

三、生成流程

生成流程分为两个部分:
1、生成正态分布向量后进行若干次采样。
2、进行解码。

由于DiT只能按照类别进行图片生成,所以无需对文本进行编码,直接传入类别的对应的id(0-1000)即可指定类别。

# --------------------------------- #
#   前处理
# --------------------------------- #
# 生成latent
latents = randn_tensor(
    shape=(batch_size, latent_channels, latent_size, latent_size),
    generator=generator,
    device=self._execution_device,
    dtype=self.transformer.dtype,
)
latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1 else latents

# 将输入的label 与 null label进行concat,null label是负向提示类。
class_labels = torch.tensor(class_labels, device=self._execution_device).reshape(-1)
class_null = torch.tensor([1000] * batch_size, device=self._execution_device)
class_labels_input = torch.cat([class_labels, class_null], 0) if guidance_scale > 1 else class_labels

# 设置生成的步数
self.scheduler.set_timesteps(num_inference_steps)

# --------------------------------- #
#   扩散生成
# --------------------------------- #
# 开始N步扩散的循环
for t in self.progress_bar(self.scheduler.timesteps):
    if guidance_scale > 1:
        half = latent_model_input[: len(latent_model_input) // 2]
        latent_model_input = torch.cat([half, half], dim=0)
    latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
    
    # 处理timesteps
    timesteps = t
    if not torch.is_tensor(timesteps):
        is_mps = latent_model_input.device.type == "mps"
        if isinstance(timesteps, float):
            dtype = torch.float32 if is_mps else torch.float64
        else:
            dtype = torch.int32 if is_mps else torch.int64
        timesteps = torch.tensor([timesteps], dtype=dtype, device=latent_model_input.device)
    elif len(timesteps.shape) == 0:
        timesteps = timesteps[None].to(latent_model_input.device)
    # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
    timesteps = timesteps.expand(latent_model_input.shape[0])

    # 将隐含层特征、时间步和种类输入传入到transformers中
    noise_pred = self.transformer(
        latent_model_input, timestep=timesteps, class_labels=class_labels_input
    ).sample

    # perform guidance
    if guidance_scale > 1:
        # 在通道上做分割,取出生图部分的通道
        eps, rest = noise_pred[:, :latent_channels], noise_pred[:, latent_channels:]
        cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)

        half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
        eps = torch.cat([half_eps, half_eps], dim=0)

        noise_pred = torch.cat([eps, rest], dim=1)

    # 对结果进行分割,取出生图部分的通道
    if self.transformer.config.out_channels // 2 == latent_channels:
        model_output, _ = torch.split(noise_pred, latent_channels, dim=1)
    else:
        model_output = noise_pred

    # 通过采样器将这一步噪声施加到隐含层
    latent_model_input = self.scheduler.step(model_output, t, latent_model_input).prev_sample

if guidance_scale > 1:
    latents, _ = latent_model_input.chunk(2, dim=0)
else:
    latents = latent_model_input

# --------------------------------- #
#   后处理
# --------------------------------- #
# 通过vae进行解码
latents = 1 / self.vae.config.scaling_factor * latents
samples = self.vae.decode(latents).sample

samples = (samples / 2 + 0.5).clamp(0, 1)

# 转化为float32类别
samples = samples.cpu().permute(0, 2, 3, 1).float().numpy()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
4

开发板推荐:天空星STM32F407VET6开发板

超高性价比 STM32主控 | 超高主频 | 一板兼容百芯 | 比赛神器 | 沉金彩色丝印

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

AI周红伟

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值