深度学习之DAT

DAT是一种结合了可变形卷积的Vision Transformer模型,旨在解决Transformer在计算量和注意力聚焦上的问题。通过可变形注意力机制,DAT能够专注于图像中的重要区域,减少计算需求,同时保持对全局相关性的捕捉。DAT在目标检测任务上表现出色,通过在Swin-T基础上堆叠可变形注意力层,构建了一个金字塔网络结构。

在这里插入图片描述

这篇文章是2022年Vision-Transformer领域的CVPR论文。文章推出了一种新的Vision Transformer模型——Deformable Attention Transformer(DAT)。DAT将DCN运用到Transformer中,从而将注意力的运算集中在重要性区域上,为目标检测带来了一定程度的性能提升。DAT最大的价值在于其使得自注意力层可以聚焦于相关区域来捕获信息

参考目录:
源码
CVPR 2022 | 清华开源DAT:具有可变形注意力的视觉Transformer
DAT论文解读(暖风)

Abstract

  1. 相比CNN,Vision Transformer具有更大的感受野或者说具有建模长距离范围内相关性的能力,这也就为其产生更高表现力奠定了基础。但是目前的Vision Transformer仍然暴露出2大问题:①使用稠密注意力机制的模型,如ViT将会产生较大的计算消耗以及其无法聚焦于重要区域,要知道一幅图像中并不是所有的区域信息都是有用的,比如背景等;②稀疏注意力机制的模型,如Swin-T、PVT等不具备长距离建模的能力,其更适合于一定窗口内的相关性捕捉。
  2. 为了解决上述2大问题,作者推出了Deformable Self-Attention机制,即可变形自注意力。该注意力机制利用可变形卷积技术产生offset,并通过可导采样产生sampled-K和sampled-V,之后再进行常规的MHSA从而可以使得注意力计算集中于和 Q Q Q相关的区域。需要注意的是,DA可以产生较小的K和V,并且DCN天然可学习重要区域的特性可帮助改模型减少计算量以及避免不相关token的干扰;此外DA仍然可以捕捉全局相关性,只不过其捕捉的范围被合理缩小了(DCN)
  3. 利用Swin-T、DA进行堆叠,从而就产生了可变形注意力模型——DAT

Note:

  1. PVT、Swin-T属于数据不可知型模型(data-agnostic):当前窗口内的 Q Q Q无法在全局范围内搜索,他无法知道窗口外面的世界。
  2. DAT属于数据依赖型(data-dependent)模型:这是因为其 K , V K,V K,V的产生是依赖于 Q Q Q的,即 Q Q Q产生了 K , V K,V K,V。它的搜索范围也是全局,故其不属于数据不可知模型。
  3. 关于DCN如何聚焦于重要区域,这是分类等任务的loss驱动产生的,loss驱动去学习合适的offset,从而让卷积具有合理的采样区域。下图分别展示了DCN对采样区域缩小、变大、变大的作用:
    在这里插入图片描述
    有关DCN的知识可参考我的2篇文章:①深度学习之DCN;②深度学习之DCN-v2
  4. DCN本身和视频超分中各有不同作用,即DCN和超分中的DCN用法是不一样的,前者聚焦于重要区域,后者更像flow,侧重于对齐。
  5. 可变形注意力的本质:利用DCN的offset聚焦到目标的核心重要位置,然后进行放大,而对于背景区域则会缩小。放大后的重要目标会拿来做注意力,从而凸显了重要区域(或者说informative、relavant)的重要性。

1 Introduction

Vision Transformer是一把双刃剑:

  1. 优点:具有较大的感受野以及具备捕捉较长范围内相关性的能力。
  2. 缺点:由于搜索范围往往很大,所以相似度计算量较大,需要较大的计算资源。

Swin-T && PVT \colorbox{lightskyblue}{Swin-T \&\& PVT} Swin-T && PVT
为了减少Vision Transformer的计算量,Swin-T和PVT各自发表了自己的模型:Swin-T基于窗口的局部相关性建模将计算量缩小到了窗口而非全图;PVT则是将Key和Value下采样从而节约计算资源。这些手工设计的注意力模式虽然可以降低相似度计算次数,但是两者均属于数据不可知模型——显然这不是最优的方式。这是因为你可能丢弃了一些较好的token,而只在一些次优的token里来回计算,这样一定会陷入局部最优的。

因此下一步的发展就变成了如何既能降低计算量又可以使得 Q Q Q可以搜寻到最佳匹配的token


可变形卷积(DCN)具有可聚焦采样于重要区域,这重要区域的选择是loss驱动的结果,比如一个 3 × 3 3\times 3 3×3的卷积由于感受野不够产生0.1的loss,为了降低loss,DCN的offset就会想办法往外扩张使得新的卷积可以有更大的感受野从而将采样区域聚集于更好的地方。受启发于这样的思想,DAT作者将可变形卷积迁移到Vision-Transformer里来。
这里面有loss的驱动才是DCN有效的关键,关于Conv和DCN的对比如下:在这里插入图片描述


Deformable Attention Transformer \colorbox{tomato}{Deformable Attention Transformer} Deformable Attention Transformer
本文的核心就是可变形注意力——DA,以DA为核心的金字塔网络结构(分辨率逐渐降低)就组成了DAT模型,如下图所示:
在这里插入图片描述
输入是 224 × 224 224\times 224 224×224的图像,经过4个stage,最后输出特征向量进行分类等任务。
①S1:利用卷积使分辨率下降 1 4 \frac{1}{4} 41 56 × 56 56\times 56 56×56;特征通道不变;LA+Shift-LA,即Swin-T。
②S2:利用卷积使分辨率下降 1 2 \frac{1}{2} 21 28 × 28 28\times 28 28×28;特征通道翻倍;同样是一个Swin-T结构
③S3:利用卷积使分辨率下降 1 2 \frac{1}{2} 21 14 × 14 14\times 14 14×14;特征通道翻倍;LA+DA。
④S4:利用卷积使分辨率下降 1 2 \frac{1}{2} 21 7 × 7 7\times 7 7×7;特征通道翻倍;LA+DA。

Note:

  1. 这种卷积型的patch embedding(即有重叠区域)比无重叠区域的卷积性能要高出 0.5 % − 1 % 0.5\%-1\% 0.5%1%在这里插入图片描述
    但并非所有想要利用DA的任务都适合于卷积型patch embedding,尤其是需要细节恢复的任务要格外注意这一点。现在许多涉及Vision-Transformer的模型获取token的方式都是利用卷积来做的( s t r i d e < c o n v _ s i z e stride < conv\_size stride<conv_size),比如VRT;还有一部分利用Unfold-fold来做;剩余的都像DAT-baseline或者ViT取不重叠的token。
  2. 总的来说,DAT中分为3个模块:L、D、S,分别为局部local注意力、可变形注意力、shift注意力。根据源码来看,其中L、S都是swin-T的一部分。
  3. 卷积利用窗口内的局部相关性提取特征;Transformer利用全局(局部)相关性提取特征。
  4. 上下采样之后记得LN(层归一化)。

首先简单描述下DA的inference,具体的见第三节:
在这里插入图片描述
输入 x x x是token结构(源码是四维张量),首先要经过全连接层输出 Q Q Q;然后利用 Q Q Q通过轻量级网络学习offset(源码里只学了1对方向,类同于光流)以及产生和offset相同大小的网格点坐标reference points;接着利用offset、reference points、 x x x通过双线性插值输出sampled-K和sampled-V。最后 Q , s a m p l e d − K , s a m p l e d − V Q,sampled-K,sampled-V Q,sampledK,sampledV三者之间通过MHSA输出 z z z

Note:

  1. ★★★由于 K , V K,V K,V都是由offset产生的,因此其在loss的驱动下会聚焦于图像的重要区域;且由于它两都是反向采样生成的,因此重要区域往往会被放大,背景区域往往会缩小,从而展现了突出重要区域,弱化次要区域的视觉效果;具体如下图所示:在这里插入图片描述

  2. 这些集中的regions由offset网络从Query中学习到的多组Deformable sampling点确定。采用双线性插值对特征映射中的特征进行采样,然后将采样后的特征输入key投影得到Deformable Key。

  3. 坐标格点归一化为 [ − 1 , 1 ] [-1, 1] [1,1],flow(offset)也同样需要归一化为 [ − 1 , 1 ] [-1, 1] [1,1]之间;光流是图像坐标,而非矩阵坐标, f l o w ∈ R H × W × 2 flow\in\mathbb{R}^{H\times W\times 2} flowRH×W×2,这个2里面,默认[0]为x,[1]为y。比如TTVSR中的位图是按图像坐标来的,而非矩阵坐标。

  4. Offset的源码如下:

ksizes = [9, 7, 5, 3]
kk = ksizes[stage_idx]  # DA只有ksizes[2],ksizes[3]用到
# stride=1
self.conv_offset = nn.Sequential(
            nn.Conv2d(self.n_group_channels, self.n_group_channels, kk, stride, kk//2, groups=self.n_group_channels),
            LayerNormProxy(self.n_group_channels),
            nn.GELU(),
            nn.Conv2d(self.n_group_channels, 2, 1, 1, 0, bias=False)
        )

DAT的贡献:

  1. 这是首次将可变形注意力机制用于目标分类、目标识别任务。
  2. 在ImageNet、ADE20K、COCO数据集上的展现了非常不错的表现力!

2 Related Work

3 Deformable Attention Transformer

3.1 Preliminaries

首先回顾下经典的Vision-Transformer中的注意力模块。
设输入为shape为 ( B , N , C ) (B, N, C) (B,N,C),则对于每一张图片 x ∈ R N × C x\in\mathbb{R}^{N\times C} xRN×C,其MHSA模块的数学表达式如下(设head的总数为 M M M):
q = x W q , k = x W k , v = x W v . (1) q = x W_q, k=x W_k, v = x W_v.\tag{1} q=xWq,k=xWk,v=xWv.(1) z ( m ) = σ ( q ( m ) k ( m ) T d ) v ( m ) , m = 1 , ⋯   , M . (2) z^{(m)} = \sigma(\frac{q^{(m)} {k^{(m)}}^T}{\sqrt{d}}) v^{(m)}, m=1,\cdots,M.\tag{2} z(m)=σ(d q(m)k(m)T)v(m),m=1,,M.(2) z = C o n c a t ( z ( 1 ) , ⋯   , z ( M ) ) . (3) z = Concat(z^{(1)}, \cdots, z^{(M)}).\tag{3} z=Concat(z(1),,z(M)).(3)其中 W q , W k , W v ∈ R C × C W_q,W_k, W_v \in\mathbb{R}^{C\times C} Wq,Wk,WvRC×C q ( m ) ∈ R N × d q^{(m)}\in\mathbb{R}^{N\times d} q(m)RN×d d = C / M d=C/M d=C/M z ∈ R N × C z\in\mathbb{R}^{N\times C} zRN×C为最后输出的新token; σ ( ⋅ ) \sigma(\cdot) σ()为softmax函数。

M L P ( ⋅ ) MLP(\cdot) MLP()为前馈神经网络FFN,通常由2个线性层中间夹一个GELU非线性层组成;故整个Vision Transformer的Encoder部分如下:
z l ′ = M H S A ( L N ( z l − 1 ) ) + z l − 1 . (4) z'_l = MHSA(LN(z_{l-1})) + z_{l-1}.\tag{4} zl=MHSA(LN(zl1))+zl1.(4) z l = M L P ( L N ( L N ( z l ′ ) ) + z l ′ ) . (5) z_l = MLP(LN(LN(z'_l)) + z'_l).\tag{5} zl=MLP(LN(LN(zl))+zl).(5)

3.2 Deformable Attention

现有的层级Transformer模型都存在一定的缺陷,诸如PVT利用下采样技术会造成一定的信息丢失;Swin-T基于窗口的注意力机制使得感受野收到限制,虽然会有增加感受野的操作但是作者指出这种方式增长太慢。此外为了避免计算量过大,故要设计出一种sparse attention,因此Deformable Attention应运而生。


首先给出DA的两大优势,然后接下去分析为何DA具有这些优点:

  1. 降低计算量:不通过下采样的方式学习出 H r × W r \frac{H}{r} \times \frac{W}{r} rH×rW尺寸的Key和Value,从而降低了 Q Q Q K K K之间的相似度计算量。
  2. 利用可变形卷积机制将 K , V K,V K,V集中在重要性区域,从而让注意力计算聚焦在informative区域。

对于shape为 H × W × C H\times W\times C H×W×C的feature map,用 3 × 3 3\times 3 3×3的卷积核会造成计算复杂度为 9 H W C 9HWC 9HWC,如果将DCN直接用于Transformer会产生 N q N k C N_q N_k C NqNkC的计算复杂度。Deformable-DETR采取直接将 N k = 4 N_k=4 Nk=4的做法虽然会节省很多计算量,但会损失很多信息。GCNetDeepViT均指出不同的Query具有类似的注意力(这个应该指的是一个范围内),因此我们可以设计出共享移动的Key和Value,即让一个范围内的Query共享同一个Key和Value。由于DA种Key和Value通过offset产生,因此才说是“移动”的。

Note:

  1. 这种做法是一种在表现力(准确度)和计算效率之间的trade-off。
  2. DA下采样的是offset,而不是feature map
  3. DA的Key和Value是根据offset采样形成的,称之为sampled-key(value)或者deformable-key(value),其1个向量代表着1个区域的信息。这个区域的大小为 s × s s\times s s×s s s s是offset的最大偏移值。具体示意图如下:在这里插入图片描述

其中粉红色线代表着常规DA(offset_n=1)的做法(即不对offset下采样——DCN做法);蓝色线代表DA中利用可变形卷积的做法。我们可以借用DCN来理解DA是如何使用1个向量来代表1个区域的(如黑色线所示)。这样一来不仅节省了计算量又不会像直接对feature map降采样那样损失太多信息;此外DA实现了多个Query共享1个Key的场面,你可以这样理解,假如offset只有1个,那么最终就是所有 Q Q Q和同一个 K K K做attention。至于为何“火柴人”会被放大,这是因为可变形机制会聚焦重要区域,通过反向采样或者可变形卷积就可以实现放大的效果。
4. DA只是借用了DCN中的可变形机制,而之后的并没有根据offset做卷积,而是将offset看成flow来做反向采样实现warp。根据Understanding Deformable Alignment in VSR一文,我们可以增加offset网络的个数来实现多方向的offset。

Deformable attention module \colorbox{hotpink}{Deformable attention module} Deformable attention module
明白了DA的机制后,我们来分析DA的结构。DA的pipeline如下图所示:

在这里插入图片描述
设输入为 x ∈ R H × W × C x\in\mathbb{R}^{H\times W\times C} xRH×W×C(这也是token的一种形式)。
①首先使用线性层或者卷积层 W q W_q Wq输出空间尺寸不变的Query;
Q Q Q经过offset网络输出offset并归一化和tanh()处理—— Δ p = θ o f f s e t ( q ) \Delta p = \theta_{offset}(q) Δp=θoffset(q),这个网络会将空间尺寸缩小为原来的 1 r \frac{1}{r} r1,记为 H G 、 W G H_G、W_G HGWG,并用这个尺寸输出网格点坐标inference points并进行归一化至 [ − 1 , 1 ] [-1, 1] [1,1],其属于图像坐标(即宽度坐标在前,高度坐标在后);
③利用 o f f s e t + r e f e r e n c e offset+reference offset+reference在输入 x x x上进行可导采样(双线性插值)——F.grid_sample()输出 x ~ \tilde{x} x~
④对 x ~ \tilde{x} x~使用2个线性层 W k , W v W_k,W_v Wk,Wv输出 k ~ , v ~ \tilde{k},\tilde{v} k~,v~
⑤最后利用MHSA来输出最后的新token并进行投影 W 0 W_0 W0的结果—— z z z

上述过程的具体数学表达式如下:
q = x W q , k ~ = x ~ W k , v ~ = x ~ W v . (6) q = x W_q, \tilde{k} = \tilde{x} W_k, \tilde{v} = \tilde{x}W_v.\tag{6} q=xWq,k~=x~Wk,v~=x~Wv.(6) Δ p = θ o f f s e t ( q ) , x ~ = ϕ ( x ; p + Δ p ) . (7) \Delta p = \theta_{offset}(q), \tilde{x} = \phi(x;p+\Delta p).\tag{7} Δp=θoffset(q),x~=ϕ(x;p+Δp).(7)
ϕ ( ⋅ ) \phi(\cdot) ϕ()算子是双线性插值函数:
ϕ ( z ; ( p x , p y ) ) = ∑ ( r x , r y ) g ( p x , r x ) g ( p y , r y ) z [ r y , r x , : ] . (8) \phi(z;(p_x, p_y)) = \sum_{(r_x, r_y)} g(p_x, r_x) g(p_y, r_y) z[r_y, r_x, :] .\tag{8} ϕ(z;(px,py))=(rx,ry)g(px,rx)g(py,ry)z[ry,rx,:].(8)其中 g ( a , b ) = m a x ( 0 , 1 − ∣ a − b ∣ ) g(a,b) = max(0, 1- |a-b|) g(a,b)=max(0,1ab)
MHSA过程如下:
z ( m ) = σ ( q ( m ) k ~ ( m ) T d + ϕ ( B ^ ; R ) ) v ~ ( m ) . (9) z^{(m)} = \sigma(\frac{q^{(m)} {\tilde{k}^{(m)}}^T}{\sqrt{d}} + \phi(\hat{B};R)) \tilde{v}^{(m)}.\tag{9} z(m)=σ(d q(m)k~(m)T+ϕ(B^;R))v~(m).(9)其中 ϕ ( B ^ ; R ) ∈ R H W × H G W G \phi(\hat{B};R)\in\mathbb{R}^{HW\times H_GW_G} ϕ(B^;R)RHW×HGWG属于相对位置编码(RPE),其基本遵循Swin-T提出的RPE原理并作出恰当的修改; B ^ 、 R \hat{B}、R B^R分别表示相对位置编码表和相对位置索引。

Offset generation \colorbox{lightskyblue}{Offset generation} Offset generation
这个不多说了,就是一个轻量级网络 θ o f f s e t \theta_{offset} θoffset,用于产生offset的:
在这里插入图片描述
可以看出:

  1. 该网络需要将输出的空间尺寸下降 1 r \frac{1}{r} r1
  2. 最后的输出通道维或者特征维度为2,即DA将offset当作optical flow处理。

Offset groups \colorbox{gold}{Offset groups} Offset groups
这个其实就是DCN中的offset分组,比如分组 G = 4 , h e a d = 16 G=4,head=16 G=4,head=16,那么每4个head共享1个offset。
Note:

  1. G = 4 G=4 G=4代表每4个通道学习一组offset,比如当你的 C = 128 C=128 C=128,那么一共要学习4组offset,每 128 / 4 = 32 128/4=32 128/4=32个通道共同学习一组offset。

Deformable relative position bias \colorbox{lightseagreen}{Deformable relative position bias} Deformable relative position bias
位置编码主要分为绝对位置编码APE和相对位置编码RPE两种,相对而言后者更好,因为其可以表示token之间的相对位置关系。现在基本上需要用到Vision Transformer的任务都会加入位置编码,因为位置编码提供了另一种空间信息,诸如VRT中MSA使用相对位置编码,其和Swin-T几乎一样,只是多加时间维度;MMA使用绝对位置编码中的余弦位置编码。
Note:

  1. Swin-T中启示我们一般是使用相对位置编码后再进行shift-mask。VRT的缺陷在于其无法在空间上捕捉叫长距离的信息。

1、相对位置编码
对于输入shape为 R H × W 的图像, \mathbb{R}^{H\times W}的图像, RH×W的图像,相对位置编码包括2部分:①一个是相对位置编码查找表 B ^ ∈ R ( 2 H − 1 ) × ( 2 W − 1 ) \hat{B}\in\mathbb{R}^{(2H-1)\times (2W-1)} B^R(2H1)×(2W1),这是因为 x , y x,y x,y方向的位置偏差都在 [ − H , H ] , [ − W , W ] [-H, H],[-W, W] [H,H],[W,W]之内;这个表是可训练的,一般初始化为高斯分布。②:相对位置索引,这个张量是个固定值,一般都会存到buffer里,它记录了所有 Q Q Q K K K的相对位置偏差displacement,关于这个索引的制作源码如下(参考Swin-T,VRT):

window_size = [3, 3]
coords_h = torch.arange(window_size[0])  # tensor([0,1,2])
coords_w = torch.arange(window_size[1])  # tensor([0,1,2])
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww    # 生成网格3*3个x,3*3个y坐标
coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww    # 展开  49个x,49个y坐标
#print(f'coo is {coords_flatten}')
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
#(2,9,1)-(2,1,9)  得到每个格子和49个格子的相对位置
#print(f'relative_coords is \n{relative_coords}')
relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2    # 换维
relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0    # x都加2,变成非负
relative_coords[:, :, 1] += window_size[1] - 1  # y都加2,变成非负
#print(f'relative_coords is \n{relative_coords}')
relative_coords[:, :, 0] *= 2 * window_size[1] - 1  # 乘5
relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww   # 最后一维相加,生成0-24的数

上述代码可以自己跑着试一试,最后的输出结果如下(假设 3 × 3 3\times 3 3×3窗口):
在这里插入图片描述

解释一下这张相对位置索引为何是这样的:

  1. 首先每一个格点都代表了token和token之间的相对位置关系,比如右上角的“0”就是 Q Q Q中第一个token和 K K K中最后一个token之间的相对位置关系。
  2. 其值是这么来的:横向有5种位置关系: [ − 2 , − 1 , 0 , 1 , 2 ] [-2, -1, 0, 1, 2] [2,1,0,1,2];纵向也有5种位置关系 [ − 2 , − 1 , 0 , 1 , 2 ] [-2, -1, 0, 1, 2] [2,1,0,1,2],这就解释了①②为何是这样的关系。那么排列组合共有25种位置关系(包括了2个方向的12种,和自己和自己的1种):因此最远的距离在于左上和右下,其距离一定大于等于12,故有 ( 1 + 1 + x + x ) ≥ 12 (1+1+x+x)\ge12 (1+1+x+x)12,因此 x ≥ 5 x\ge 5 x5,这就解释了为何③④为何是这样的关系。

关于相对位置编码relative position embedding(RPE),DAT参照Swin-T中相对位置编码的做法,但做了一定的改动。因为在DA中有offset在,这是个天然的displacement,而且其描述的相对位置关系是连续的值,比上面的离散值要更好;相对位置编码表还是不变。DA中的相对位置编码最后需要通过F.grid_sample()来采样获取,而Swin-T中最后的相对位置编码值是通过用相对位置索引值在相对位置表中索引得到的
接下来我们分析下DA种的RPE,源码如下:

rpe_table = self.rpe_table
rpe_bias = rpe_table[None, ...].expand(B, -1, -1, -1)
q_grid = self._get_ref_points(H, W, B, dtype, device)

 # 偏移:q和k的相对位置偏移
displacement = (q_grid.reshape(B * self.n_groups, H * W, 2).unsqueeze(2) - pos.reshape(B * self.n_groups, n_sample, 2).unsqueeze(1)).mul(0.5)

attn_bias = F.grid_sample(
				input=rpe_bias.reshape(B * self.n_groups, self.n_group_heads, 2 * H - 1, 2 * W - 1), 
				grid=displacement[..., (1, 0)], 
				mode='bilinear', 
				align_corners=True) # B * g, h_g, HW, Ns

attn_bias = attn_bias.reshape(B * self.n_heads, H * W, n_sample
attn = attn + attn_bias

我们来看displacement即相对位置索引是什么样的(只研究 x x x方向):
我们假设 q _ g r i d q\_grid q_grid
在这里插入图片描述
p o s = o f f s e t + r e f e r e n c e pos=offset+reference pos=offset+reference
在这里插入图片描述
因此根据F.grid_sample()所得 K K K的位置信息大致如下:
在这里插入图片描述
Q Q Q K K K的相对位置为:
在这里插入图片描述

Note:

  1. 接下来只要拿着relative_position_index去table里找值就可以了,显然相同的值所获取的table值也是一样的。找到的值用来和attn相加,由于这个值也是需要训练的,所以其需要初始化(代码中使用高斯分布)以及参数更新。
  2. 值不重要,重要的是值与值之间的关系。
  3. RPE的一个缺陷是其table制约于token的个数,这样就会造成训练和测试的时候因为分辨率不同而无法使用,比如超分任务。
  4. DAT的相对位置嵌入进行了略微的修改,主要利用offset来表示相对位置关系。同样都不是自注意力,VRT无法找到DAT这种利用offset来表示相对位置的方法,而退而求其次使用余弦位置编码。
  5. 这里额外补充下shift_mask的做法(参考Swin-T、VRT):参照swin-T的做法,需要对shift之后超出图像部分区域增加mask,我这里取名为shifting-mask感觉更为贴切,源码如下:
img_mask = torch.zeros(*self.fmap_size)  # H W
h_slices = (slice(0, -self.window_size[0]),
            slice(-self.window_size[0], -self.shift_size),
            slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size[1]),
            slice(-self.window_size[1], -self.shift_size),
            slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
   for w in w_slices:
       img_mask[h, w] = cnt
       cnt += 1
mask_windows = einops.rearrange(img_mask, '(r1 h1) (r2 w1) -> (r1 r2) (h1 w1)',
                                              h1=self.window_size[0],w1=self.window_size[1])
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # nW ww ww
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
self.register_buffer("attn_mask", attn_mask)

### 然和和attn进行相加
if mask is not None:
    # attn : (b * nW) h w w
    # mask : nW ww ww
    nW, ww, _ = mask.size()
    attn = einops.rearrange(attn, '(b n) h w1 w2 -> b n h w1 w2', n=nW, h=self.heads, w1=ww, w2=ww) + mask.reshape(1, nW, 1, ww, ww)
    attn = einops.rearrange(attn, 'b n h w1 w2 -> (b n) h w1 w2')
attn = self.attn_drop(attn.softmax(dim=3))

2、绝对位置编码

关于位置编码,绝对位置主要是利用余弦函数来做,具体参考绝对位置编码位置编码综述
余弦位置编码源码如下(参考VRT):

def get_sine_position_encoding(self, HW=(6, 12), num_pos_feats=128/2, temperature=10000, normalize=False, scale=None):
    # 128、2,之所以除以2是因为有x,y两个方向
    # 余弦位置编码公式中的d_model=64
    if scale is not None and normalize is False:
        raise ValueError("normalize should be True if scale is passed")
    if scale is None:
        scale = 2 * math.pi
    not_mask = torch.ones([1, HW[0], HW[1]])
    # 符合图像坐标
    y_embed = not_mask.cumsum(1, dtype=torch.float32)  # 1 6 12
    x_embed = not_mask.cumsum(2, dtype=torch.float32)
    
    if normalize:
        eps = 1e-6
        y_embed = y_embed / (y_embed[:, -1:, :] + eps) * scale
        x_embed = x_embed / (x_embed[:, :, -1:] + eps) * scale
        
    dim_t = torch.arange(num_pos_feats, dtype=torch.float32)  # 64
    dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats)  # //是关键
    
    pos_x = x_embed[:, :, :, None] / dim_t  # 1 6 12 64
    pos_y = y_embed[:, :, :, None] / dim_t
    # sin、cos各自要计算32(64/2)次
    pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
    pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)  # 1 6 12 64
    pos_embed = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)  # 1 128 6 12
    return pos_embed.flatten(2).permute(0, 2, 1).contiguous()  # B N C = (1, 72, 128)

3.3 Model Architectures

在这里插入图片描述

  1. 上图所示就是DAT的pipeline,它整体呈现一个金字塔结构——特征图的分辨率从浅到深依次从大到小,但特征维度逐渐增大。这种下采样的方式有点类似于VRT、Swin-T。
  2. DAT的下采样都是通过卷积来做的,大致分为2个版本:①(Baseline)一个是非重叠的类似于ViT;②另一个是重叠的patch embedding。后者可以增加局部性从而提升性能。
  3. 下图展示了DAT的3种不同参数量的模型:
    在这里插入图片描述
    我们分析下Baseline:①首先 224 × 224 224\times 224 224×224的图像经过 4 × 4 , s = 4 4\times 4,s=4 4×4,s=4的非重叠卷积产生 56 × 56 × 128 56\times 56\times 128 56×56×128的图像,接下去通过1个Swin-T模块;②接下去使用 2 × 2 , s = 2 2\times 2,s=2 2×2,s=2的非重叠卷积继续下采样产生 28 × 28 × 256 28\times 28\times 256 28×28×256的图像,并通过1个Swin-T模块;③然后同样使用 2 × 2 , s = 2 2\times 2,s=2 2×2,s=2的非重叠卷积继续下采样产生 14 × 14 × 512 14\times 14\times 512 14×14×512的图像,并通过9个不含shift的Swin-T以及DA模块;④最后使用 2 × 2 , s = 2 2\times 2,s=2 2×2,s=2的非重叠卷积继续下采样产生 14 × 14 × 512 14\times 14\times 512 14×14×512的图像,并通过1个不含shift的Swin-T以及DA模块;⑤根据不同类型任务做出最后几层的设计,比如分类的话就做个FC和softmax等。可以看出DAT的设计基于金字塔结构,体现出空间尺寸下降,特征维度上升的特性
  4. 这种设计的好处在于:①使用Swin-T聚集局部信息,然后通过可以建模全局感受野的DAT来捕捉具有局部增强特性token之间的相关性;②前面使用Swin-T除了可以先捕捉局部相关性以外,还可以减轻计算量,因为DA属于全局范围内的注意力机制,如果将其放在分辨率较大的前面几层那么势必带来较大的计算复杂度。

4 Experiments

略。
(详细可参考CVPR 2022 | 清华开源DAT:具有可变形注意力的视觉Transformer一文)

5 Conclusion

  1. 文章推出了一种新的多层级Vision-Transformer模型——由Swin-T、DA组成的金字塔结构DAT。其中Swin-T负责收集空间上的局部信息,DA负责在增强的局部token上建立全局的相关性
  2. DA是DAT的核心,其利用可变形机制将搜索区域集聚在重要关键这种相关区域;此外通过对offset进行下采样而非feature map本身来降低计算复杂度并不丢失细节信息——这是通过多个 Q u e r y Query Query共享同一个 K e y Key Key获取的。
  3. 和一般的Vision Transformer不同,DA的 K e y 、 V a l u e Key、Value KeyValue通过采样获取的,即sample-Ksampled-V
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值