网络构建
一、什么是Diffusion Transformer (DiT)
二、DiT的组成
三、生成流程
1、采样流程
a、生成初始噪声
b、对噪声进行N次采样
c、单次采样解析
I、预测噪声
II、施加噪声
d、预测噪声过程中的网络结构解析
i、adaLN-Zero结构解析
ii、patch分块处理
iii、Transformer特征提取
iv、上采样
3、隐空间解码生成图片
类别到图像预测过程代码
学习前言
近期Sora大火,它底层是Diffusion Transformer,本质上是使用Transformer结构代替原本的Unet进行噪声预测,好处是统一了文本生成与视频生成的结构。这训练优化和预测优化而言是个好事,因为只需要优化一种结构就够了。虽然觉得OpenAI是大力出奇迹,但还是得学!
源码下载地址
https://github.com/bubbliiiing/DiT-pytorch
喜欢的可以点个star噢。
网络构建
一、什么是Diffusion Transformer (DiT)
DiT基于扩散模型,所以不免包含不断去噪的过程,如果是图生图的话,还有不断加噪的过程,此时离不开DDPM那张老图,如下:
DiT相比于DDPM,使用了更快的采样器,也使用了更大的分辨率,与Stable Diffusion一样使用了隐空间的扩散,但可能更偏研究性质一些,没有使用非常大的数据集进行预训练,只使用了imagenet进行预训练。
与Stable Diffusion不同的是,DiT的网络结构完全由Transformer组成,没有Unet中大量的上下采样,结构更为简单清晰。
本文主要是解析一下整个DiT模型的结构组成,并简单一次扩散,多次扩散的流程。本文代码来自于Diffusers,Diffusers代码较为简单清晰,是一个非常好的仓库,学习起来也比较快。
二、DiT的组成
DiT由三大部分组成。
1、Sampler采样器。
2、Variational Autoencoder (VAE) 变分自编码器。
3、UNet 主网络,噪声预测器。
每一部分都很重要,由于DiT的官方版本并没有在 大规模文本图片 的 数据集上训练,只使用了imagenet进行预训练。所以它并没有文本输入,而是以标签作为输入。因此,DiT只能按照类别进行图片生成,可以生成imagenet中的1000类
三、生成流程
生成流程分为两个部分:
1、生成正态分布向量后进行若干次采样。
2、进行解码。
由于DiT只能按照类别进行图片生成,所以无需对文本进行编码,直接传入类别的对应的id(0-1000)即可指定类别。
# --------------------------------- #
# 前处理
# --------------------------------- #
# 生成latent
latents = randn_tensor(
shape=(batch_size, latent_channels, latent_size, latent_size),
generator=generator,
device=self._execution_device,
dtype=self.transformer.dtype,
)
latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1 else latents
# 将输入的label 与 null label进行concat,null label是负向提示类。
class_labels = torch.tensor(class_labels, device=self._execution_device).reshape(-1)
class_null = torch.tensor([1000] * batch_size, device=self._execution_device)
class_labels_input = torch.cat([class_labels, class_null], 0) if guidance_scale > 1 else class_labels
# 设置生成的步数
self.scheduler.set_timesteps(num_inference_steps)
# --------------------------------- #
# 扩散生成
# --------------------------------- #
# 开始N步扩散的循环
for t in self.progress_bar(self.scheduler.timesteps):
if guidance_scale > 1:
half = latent_model_input[: len(latent_model_input) // 2]
latent_model_input = torch.cat([half, half], dim=0)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# 处理timesteps
timesteps = t
if not torch.is_tensor(timesteps):
is_mps = latent_model_input.device.type == "mps"
if isinstance(timesteps, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=latent_model_input.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(latent_model_input.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(latent_model_input.shape[0])
# 将隐含层特征、时间步和种类输入传入到transformers中
noise_pred = self.transformer(
latent_model_input, timestep=timesteps, class_labels=class_labels_input
).sample
# perform guidance
if guidance_scale > 1:
# 在通道上做分割,取出生图部分的通道
eps, rest = noise_pred[:, :latent_channels], noise_pred[:, latent_channels:]
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
eps = torch.cat([half_eps, half_eps], dim=0)
noise_pred = torch.cat([eps, rest], dim=1)
# 对结果进行分割,取出生图部分的通道
if self.transformer.config.out_channels // 2 == latent_channels:
model_output, _ = torch.split(noise_pred, latent_channels, dim=1)
else:
model_output = noise_pred
# 通过采样器将这一步噪声施加到隐含层
latent_model_input = self.scheduler.step(model_output, t, latent_model_input).prev_sample
if guidance_scale > 1:
latents, _ = latent_model_input.chunk(2, dim=0)
else:
latents = latent_model_input
# --------------------------------- #
# 后处理
# --------------------------------- #
# 通过vae进行解码
latents = 1 / self.vae.config.scaling_factor * latents
samples = self.vae.decode(latents).sample
samples = (samples / 2 + 0.5).clamp(0, 1)
# 转化为float32类别
samples = samples.cpu().permute(0, 2, 3, 1).float().numpy()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
4

结构解析&spm=1001.2101.3001.5002&articleId=136559418&d=1&t=3&u=3fef5ff3bd754ed4a6d725b59a9cb599)
3171

被折叠的 条评论
为什么被折叠?



