【生成模型之十】Scalable Diffusion Models with Transformers

论文:https://arxiv.org/pdf/2212.09748

代码:GitHub - bubbliiiing/DiT-pytorch: 这是一个DiT-pytorch的代码,主要用于学习DiT结构。

支持CV算法简历修改、面试辅导

相关推荐:AIGC专栏9——Scalable Diffusion Models with Transformers (DiT)结构解析_transformer_Bubbliiiing-GitCode 开源社区

一、背景

我们探索了一类基于Transformer架构的新型扩散模型。我们训练图像的潜在扩散模型,用一个在latent patch上运行的Transformer代替常用的U-Net骨干网。我们通过Gflops测量的前向复杂性来分析扩散变换器(DiTs)的可扩展性。我们发现,通过增加Transformer深度/宽度或增加token的数量,具有较高Gflops的DiT始终具有较低的FID。除了具有良好的可扩展性特性外,我们最大的DiT XL/2模型在类条件ImageNet 512×512和256×256基准上的表现优于所有先前的扩散模型,在后者上实现了2.27的最新FID。

DDPM中的开创性工作首次引入了用于扩散模型的U-Net骨干网。最初在像素级自回归模型和条件GAN[23]中取得了成功,U-Net是从pixel-CNN++[52,58]继承而来的,但有一些变化。 该模型是卷积形式,主要由ResNet[15]块组成。与标准U-Net[49]相比,额外的spatial self-attention是变压器中的重要组成部分,它们以较低的分辨率散布。Dhariwal和Nichol[9]消除了U-Net的几种架构选择,例如使用自适应归一化层[40]为卷积层注入条件信息和channel counts。

我们表明,U-Net归纳偏执对扩散模型的性能并不重要,它们可以很容易地用Transformer等标准设计替换。本文主要研究一类新的基于Transformer的扩散模型。我们称之为Diffusion Transform- ers,简称DiTs。DiTs遵循视觉变换器(ViTs)的最佳实践[10],已被证明比传统卷积网络(例如ResNet[15])更有效地扩展视觉识别。更具体地说,我们研究了Transformer在network complexity vs. sample quality的缩放行为。通过简单地扩展DiT并使用高容量骨干网(118.6 Gflops)训练LDM,我们能够在类条件256×256 ImageNet生成基准上实现2.27 FID的最新结果。

Architecture complexity. 在评估图像生成文献中的架构复杂性时,使用参数量是相当常见的做法。一般来说,参数量不能很好地反映图像模型的复杂性,因为它们没有考虑到图像分辨率等对性能有重大影响的因素。相反,本文中的大部分模型复杂性分析都是通过理论Gflops的视角进行的。

二、Method

在本文中,我们将DiTs应用于潜在空间,尽管它们也可以应用于像素空间而无需修改。这使得我们的图像生成pipeline成为一种基于混合的方法;我们使用现成的卷积VAE和transformer-based DDPM。

Patchify

DiT的输入是一个空间表示z(对于256×256×3的图像,z的形状为32×32×4)。DiT的第一层是“patchify”,它通过将每个patch线性嵌入输入中,将空间输入转换为一系列T标记,每个标记的维度为d。 遵循patchify,我们将标准的基于ViT频率的位置嵌入(正弦余弦位置编码)应用于所有输入token。patchify创建token的数量T由patch大小超参数p决定。We add p = 2, 4, 8 to the DiT design space.

DiT block design

在patchify之后,input token由一系列Transformer block进行处理。除了噪声图像输入外,扩散模型有时还会处理额外的条件信息,如噪声时间步长t、类标签c、自然语言等。我们探索了四种不同处理条件输入的Transformer blocks变体。这些设计对标准ViT块设计进行了小但重要的修改。所有模块的设计如图3所示。

In-context conditioning。我们只是将t和c的向量嵌入作为两个额外的标记附加到输入序列中,将它们与图像标记区别对待。这类似于ViT中的cls token,它允许我们使用标准ViT块而无需修改。在最后一个块之后,我们从序列中删除条件标记。这种方法在模型中引入了可以忽略不计的新Gflops。

Cross-attention block。我们将t和c的嵌入连接成一个长度为2的序列,与图像标记序列分开。Transformer block经过修改在Multi-head self-attention之后增加了一个额外的Multi-head cross-attention,类似于Vaswani等人[60]的原始设计,也类似于LDM用于在类别标签上进行调节的设计。cross-attention为模型增加了最多的Gflops,大约15%的开销。

–Adaptive layer norm。在GAN[2,28]中广泛使用自适应归一化层[40]和具有U-Net骨干网的扩散模型[9]之后,我们探索用自适应层范数(adaLN)替换Transformer block中的标准层范数层。我们不是直接学习缩放尺度和移位参数γ和β,而是从t和c的嵌入向量之和中回归它们。在我们探索的三个块设计中,adaLN添加的Gflops最少,因此计算效率最高。它也是唯一被限制对所有token应用相同功能的调节机制。

adaLN-Zero block。之前对ResNets的研究发现,将每个残差块初始化为恒等函数是有益的。例如,Goyal等人发现,在每个块中初始化the final batch norm scale factor γ为零,可以加速支持监督学习设置中的大规模训练。扩散U-Net模型使用类似的初始化策略,在任何残差连接之前,对每个块中的最终卷积层进行零初始化。我们探索了adaLN-DiT块的一种修改,它也做了同样的事情。除了回归γ和β外,我们还回归了在DiT块内任何残差连接之前立即应用的维度缩放参数α。

Model size。我们应用了一系列N个DiT块,每个块都以隐藏的尺寸d运行。在ViT之后,我们使用标准Transformer配置来联合缩放N、d和注意头。四种配置DiT-S, DiT-B, DiT-L and DiT-XL.

Transformer decoder。在最后的DiT块之后,我们需要将图像token序列解码为输出噪声预测和输出对角协方差预测。这两个输出的形状都等于原始空间输入。我们使用标准的线性解码器来实现这一点;我们应用最后一层范数(如果使用adaLN,则自适应),并将每个标记解码为p×p×2C张量,其中C是DiT空间输入中的通道数。最后,我们将解码的令牌重新排列到它们的原始空间布局中,以获得预测的噪声和协方差。

三、Experiments

我们的模型根据其配置和latent patch size p命名;例如,DiT XL/2是指XLarge配置,p=2。我们没有发现warm-up或regularization是训练DiTs达到高性能所必需的。

我们使用Stable Diffusion中现成的预训练变分自动编码器(VAE)模型。VAE编码器的下采样因子为8——给定形状为256×256×3的RGBimex,z=E(x)的形状为32×32×4。在本节的所有实验中,我们的扩散模型都在这个Z空间中运行。 从我们的扩散模型中采样一个新的latent,我们使用VAE解码器x=D(z)将其解码为像素。我们从ADM保留扩散超参数;具体而言,我们使用tmax=1000 linear variance schedule ranging from 1×10−4到2×10‐2,ADM的协方差参数化∑θ及其嵌入输入timestep and label的方法。

ADM:Diffusion models beat gans on image synthesis. In NeurIPS, 2021

DiT block design. 我们训练了四种最高Gflop的DiT XL/2模型,每种模型都使用不同的块设计——in-context(119.4 Gflops)、cross-attention(137.6 Gflop)、adaptive layer norm(adaLN,118.6 Gflops”)或adaLN-zero(118.6 Gflups)。我们在培训过程中测量FID。

Scaling model size and patch size.  在所有四个配置中,通过使Transformer更深更宽,FID在training的各个阶段都得到了显著改善。同样,图6(底部)显示了当补丁大小减小而模型大小保持不变时的FID。我们再次观察到,在整个训练过程中,通过简单地缩放DiT处理的token数量,保持参数大致固定,FID得到了显著改善。

DiT Gflops are critical to improving performance. 这些结果表明,缩放模型Gflops实际上是提高性能的关键。

Larger DiT models are more compute-efficient. 我们发现,与训练步骤较少的较大DiT模型相比,即使训练时间较长,小DiT模型最终也会变得计算效率低下。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Jeremy-Sky

你的鼓励是我的动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值