Context Autoencoder for Self-Supervised Representation Learning 用于自监督表示学习的上下文自动编码器
论文:https://arxiv.org/pdf/2202.03026.pdf
code: open-mmlab
摘要:
- 新的 masked image modeling (MIM) 方法:context autoencoder (CAE)
-
pretext task(前置任务):从图像中的可见 patches 估计mask 的 patches
- 与之前结合 encoding 和 pretext task 的MIM方法(例如,BEiT)相比,我们的方法有利于分离表征学习(encoding)和 pretext task
- 解释了为什么对比预训练和监督预训练表现相似,以及为什么MIM可能表现更好。
- 我们通过在下游任务(语义分割、对象检测和实例分割)中证明了CAE的有效性(优于监督预训练、对比预训练和其他MIM方法)
Introduction & Approach
CAE结构如图:
- 潜在的上下文回归(latent contextual regressor): 输入可见部编码、不可见部分queries,输出不可见部分编码(预测)
- 对齐约束(alignment):输入不可见部分编码(预测)、不可见部分编码,输出Loss
- 图中 top stream 中的 encoder 对 visible patches 进行操作,仅专注于学习语义表示。分离!
- 结构特点:在编码表示空间增加了新的约束
- 与BEiT和ViT论文中的方法相比,我们的CAE编码器利用了更大的学习表示的能力,从而提高了表示质量。

不同编码解码结构比较,图2:
(a)上下文自动编码器(CAE):
- 编码器 F 接收 visible patches: Xv 并输出 潜在表征:Zv。
- 潜在上下文回归器 :H 从Zv 预测 mask patches 的 潜在表征: Zm。
- 解码器从 Zm 预测 mask patches 的 目标 Ym。
- Lz 和 Ly是损失函数。
(b)BEiT(Bao等人,2021):
- 输入: visible patches: Xv 和 mask 查询: Qm,
- 并且在 函数 R 内更新它们的表示。
(c)去噪自动编码器(DAE)的计算图 :
- 噪声函数:N 从输入X 生成 噪声版本~X
- F 和 G 分别是普通编码器和解码器。
- 为了简单起见,positional embeddings 不包括在计算图中
比较:CAE 和 DAE显式地分别执行编码和解码;BEiT隐式地同时执行编码和译码

CAE详细解释:
回顾:
- Xv--->编码器 F--->Zv。
- Zv--->潜在上下文回归器 H--->Zm。
- Zm---> 解码器 G --->Ym。
编码器( Xv--->编码器 F--->Zv )
编码器 F 将 visible patches: Xv 映射到 潜在表征:Zv。它只处理 visible patches 。
编码器 F结构:
我们使用ViT来形成我们的 编码器。
它首先通过 linear projection 将 visible patches 编码为 patch embeddings ,并添加 positional embeddings(Pv)。
然后,它将组合的 embeddings 发送到基于 self-attention 的 transformer blocks 序列中,生成 Zv。
潜在的上下文回归(Zv--->潜在上下文回归器 H--->Zm)
潜在上下文回归器H从 Zv 预测 mask patches 的潜在表征:Zm。
潜在上下文回归器H结构:
我们使用一系列基于cross-attention 的 transformer blocks 来形成潜在的上下文回归器H。
被称为 mask queries 的初始查询Qm 是作为模型参数学习的 mask tokens ,并且对于所有 masked patches 都是相同的。
键和值是相同的,并且由visible patches 表示 Zv 和 cross-attention 的输出(第一个 cross-attention 层的Qm)组成。
在计算 queries and keys之间的 cross-attention 权重时,考虑了相应的positional embeddings。
在该过程中,不更新Zv。
对齐约束(Alignment constraint)
潜在表示对齐约束 被施加 Zm上(在由潜在的上下文回归器 预测的 Zm)
我们将 masked patches :Xm 输入编码器 F(该 编码器 与用于编码visible patches 的编码器相同),并生成¯ Zm。
然后,我们将masked patches 的两个潜在表示¯Zm和Zm对齐。
解码器:(Zm---> 解码器 G --->Ym)
解码器G将 Zm 映射到 masked patches 的某些形式Ym,
解码器结构:
其中离散 tokens 作为目标,如在BEiT中所做的。
解码器与编码器类似,是一堆基于self-attention 的 transformer blocks,后面是预测目标的线性层。
解码器仅接 Zm 和 positional embeddings of the masked patches 入作为输入,而不直接使用 visible patches 的信息。
目标函数
Masking and targets
- 继BEiT之后(Bao等人,2021),我们采用了随机分块 masking 策略(如图3所示),将输入图像分割为两组 patches,visible and masked patches。
- 对于每个图像,196个(14×14)patches 中的 98个 masked

- 我们使用pre-trained DALL-E(Ramesh等人,2021)tokenizer 生成离散 tokens 以形成目标。
- 输入图像被送到 DALL-E tokenizer ,为每个 patch 分配一个离散的token 。masked patches 的 target tokens 表示为“¯Ym”
Loss function
损失函数(如图2(a)所示,蓝色部分)包括 decoding loss:Ly(Ym, ¯Ym) 和 alignment loss:Lz(Zm, ¯Zm)。
整个损失是一个加权和:

其中:
1) Lz(Zm,'Zm)的 MSE 损失
2) Ly(Ym,'Ym)的交叉熵损失。
3) sg[·]代表停止梯度。在我们的实验中λ是2。
CODE部分
类似VisionTransformer的编码器部分

backbone
class CAEViT(VisionTransformer):
def forward(self, img: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
""" 生成mask图像的特征。
此函数生成 mask 图像并获取可见 patch的特征。
Args:
x (torch.Tensor): Input images, which is of shape B x C x H x W.
mask (torch.Tensor): Mask for input, which is of shape B x L.
Returns:
torch.Tensor: hidden features.
"""
# step1 计算 img 切成几个 patch,假设输入是224*224 切成16*16 共有14*14个
# step2 conv(img)输入是224*224 output channels: 768 [Batch, 14*14, 768]
x, _ = self.patch_embed(img)
batch_size, _, dim = x.size() # dim = 14*14
# 创建可学习的类别标签 [1,1,768]
# self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# 创建可学习的位置编码 [1,14*14 +1,768]
# self.pos_embed = nn.Parameter(torch.zeros(1, num_patches+self.num_tokens, embed_dim))
# 权重以正态分布初始化:
# nn.init.trunc_normal_(self.pos_embed, std=0.02)
# nn.init.trunc_normal_(self.cls_token, std=0.02)
# 扩维度 XXX.expand(, , )
# 类别标签 ****************************************
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # batch_size 扩展维度
x_unmasked = x[~mask].reshape(batch_size, -1, dim)
x_unmasked = torch.cat((cls_tokens, x_unmasked), dim=1) # [cls] + 可见部分原图 --->shape:[1,14*14 +1,768]
# 位置编码 ****************************************
pos_embed = self.pos_embed.expand(batch_size, self.num_patches + 1,dim) # batch_size 扩展维度
pos_embed_unmasked = pos_embed[:,1:][~mask].reshape(batch_size, -1, dim) # 可见patch位置编码
pos_embed_unmasked = torch.cat((pos_embed[:, :1], pos_embed_unmasked),dim=1) # [0]位置 + 可见patch位置编码
x_unmasked = x_unmasked + pos_embed_unmasked # 对应相加 : [cls] + [0]位置, 可见部分原图 + 可见patch位置编码
x_unmasked = self.drop_after_pos(x_unmasked) #--------------nn.Dropout()
# 特征提取 ***********************************************
# layer 是按配置生成的 list ,由TransformerEncoderLayer(MultiheadAttention,norm,ffn)组成,
for i, layer in enumerate(self.layers):
x_unmasked = layer(x_unmasked)
if i == len(self.layers) - 1 and self.final_norm: # 最后一层norm
x_unmasked = self.norm1(x_unmasked)
return x_unmasked
neck部分,潜在上下文回归器+解码器
class CAETransformerRegressorLayer(BaseModule):
"""
该模块不同于传统的 transformer 编码器层
它的查询Q是不可见图的tokens,
keys:不可见图 tokens
values:可见图 tokens
"""
def __init__( ) -> None:
# NOTE: cross attention
_, self.norm1_q_cross = build_norm_layer(norm_cfg, embed_dims, postfix=2)
_, self.norm1_k_cross = build_norm_layer(norm_cfg, embed_dims, postfix=2)
_, self.norm1_v_cross = build_norm_layer(norm_cfg, embed_dims, postfix=2)
_, self.norm2_cross = build_norm_layer(norm_cfg, embed_dims, postfix=2)
self.cross_attn = CrossMultiheadAttention(embed_dims,num_heads=num_heads,...)
self.ffn = FFN(embed_dims=embed_dims,feedforward_channels=feedforward_channels,...)
...
def forward(x_q, x_kv, pos_q, pos_k) -> torch.Tensor:
# x_q 覆盖token
# x_kv 所有token(key value)
# pos_q 覆盖位置编码
# pos_k 所有位置编码
x = x_q + self.drop_path(self.gamma_1_cross *
self.cross_attn(self.norm1_q_cross(x_q + pos_q),
k=self.norm1_k_cross(x_kv + pos_k),
v=self.norm1_v_cross(x_kv)))
x = self.norm2_cross(x) # norm
x = x + self.drop_path(self.gamma_2_cross * self.ffn(x))
return x class
CrossMultiheadAttention(BaseModule):
""" queries 与 [keys ,values] 之间交叉注意力
Attention 是由 queries + [keys ,values] 计算出来的
"""
def __init__(...)
self.q = nn.Linear(embed_dims, embed_dims, bias=False) # 768 ;768
self.k = nn.Linear(embed_dims, embed_dims, bias=False)
self.v = nn.Linear(embed_dims, embed_dims, bias=False)
def forward(self,
x: 覆盖位置编码+覆盖token,
k: 所有位置编码+所有token,
v: 所有token) -> None:
"""Forward function."""
B, N, _ = x.shape
N_k = k.shape[1]
N_v = v.shape[1]
# 准备q,k,v
q = F.linear(input=x, weight=self.q.weight, bias=q_bias) # (B, N_q, dim)
k = F.linear(input=k, weight=self.k.weight, bias=k_bias) # (B, N_k, dim)
v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
# 多头
q = q.reshape(B, N , 1, self.num_heads,-1).permute(2, 0, 3, 1, 4).squeeze(0) # (B, num_heads, N_q, dim)
k = k.reshape(B, N_k, 1, self.num_heads,-1).permute(2, 0, 3, 1, 4).squeeze(0) # (B, num_heads, N_k, dim)
v = v.reshape(B, N_v, 1, self.num_heads,-1).permute(2, 0, 3, 1, 4).squeeze(0) # (B, num_heads, N_v, dim)
q = q * self.scale
attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class CAENeck(BaseModule): # cae_neck.py
def forward(self, x_unmasked, pos_embed_masked, pos_embed_unmasked) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get the latent prediction and final prediction.
Args:
x_unmasked (torch.Tensor): 可见图 tokens.
pos_embed_masked (torch.Tensor): 位置编码 masked tokens.
pos_embed_unmasked (torch.Tensor): 位置编码 unmasked tokens.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Final prediction and latent prediction.
"""
# 权重以正态分布初始化:
# trunc_normal_(self.mask_token, std=0.02)
# trunc_normal_(self.head.weight, std=0.02)
x_masked = self.mask_token.expand(x_unmasked.shape[0],self.mask_token_num, -1) # shape [batchsize, 75, 768] 批量生成token
# regressor 是modellist,由CAETransformerRegressorLayer组成
for regressor in self.regressors:
x_masked = regressor(x_masked, torch.cat([x_unmasked, x_masked], dim=1),pos_embed_masked,torch.cat([pos_embed_unmasked, pos_embed_masked], dim=1))
x_masked = self.norm_regressor(x_masked) # 得到mask部分的预测token
latent_pred = x_masked
# 解码器
x_masked = x_masked + pos_embed_masked # mask预测的token + mask位置编码
for decoder in self.decoders: # TransformerEncoderLayer 的list
x_masked = decoder(x_masked)
logits= self.norm_decoder(x_masked)
return logits, latent_pred
整体流程
class cae # algorithom
def loss()
"""forward 函数 in training.
Args:
inputs (List[torch.Tensor]): The input images.inputs[0] 原图 images.inputs[1] 目标图
data_samples (List[SelfSupDataSample]): All elements required during the forward function.
Returns:
Dict[str, torch.Tensor]: A dictionary of loss components.
"""
# 数据集 config
# { type='RandomResizedCropAndInterpolationWithTwoPic',
# size=224,
# second_size=112 }
# >>> inputs[0].shape = [64, 3, 224, 224]
# >>> inputs[1].shape = [64, 3, 112, 112] 从inputs[0] 直接 cv2.resize 到 second_size
mask = torch.stack([data_sample.mask.value for data_sample in data_samples])
mask = mask.flatten(1).to(torch.bool)
# 编码可见部分
unmasked = self.backbone(inputs[0], mask) # 编码
# get the latent prediction for the masked patches
# 获得不可见部分标签,此标签是经过backbone生成的
with torch.no_grad():
# inputs[0] is the prediction image
latent_target = self.teacher(inputs[0], ~mask)
latent_target = latent_target[:, 1:, :]
self.momentum_update()
# 获取 mask部分 和 可见部分的位置编码
pos_embed = self.backbone.pos_embed.expand(inputs[0].shape[0], -1, -1)
pos_embed_masked = pos_embed[:,1:][mask].reshape(inputs[0].shape[0], -1,pos_embed.shape[-1])
pos_embed_unmasked = pos_embed[:, 1:][~mask].reshape(inputs[0].shape[0], -1, pos_embed.shape[-1])
# 上下文预测器+解码器
# input:unmasked tokens 、 所有的位置编码
# output: masked_预测分类、masked_预测tokens
logits, latent_pred = self.neck(unmasked[:, 1:], pos_embed_masked,pos_embed_unmasked)
logits = logits.view(-1, logits.shape[-1])
# logits_target.shape = [batchsize, 8192, 14, 14]
logits_target = self.target_generator(inputs[1]) # Dall-e 做分类标签
# LOSS*************************
# loss_main = self.loss_cross_entropy(logits, target)
# loss_align = self.loss_mse(latent_pred,latent_target.detach()) * self.lambd
loss_main, loss_align = self.head(logits, logits_target, latent_pred,latent_target, mask)
losses = dict()
losses['loss'] = loss_main + loss_align
losses['main'] = loss_main
losses['align'] = loss_align
return losses
2.Analysis and Connection
2.1 Analysis
观点:CAE编码器 关心 patch representations 。
理由:CAE根据 visible patches 对 masked patches 预测。这要求CAE编码器关心 patches 的表示,而不仅仅是全局表示,以便CAE探索 patches 之间的关系,以进行预测。
实验:有无约束
现象:上:原图 中:有约束重建图 下:无约束重建图(无意义)
结论:对齐约束对于确保在 encoded representation space 中进行预测至关重要。

- 图4提供了从ImageNet-1K验证集中随机采样的几个示例的重建结果。
- 可以看出,我们的方法可以成功地重建图像,这意味着潜在上下文回归器的输入和输出表示在同一空间中。
- 相反,在没有对齐约束的情况下,重建的图像是有噪声的,这表明潜在上下文回归器的输入和输出表示在不同的空间中。
- 结果表明,对齐约束对于确保在 encoded representation space 中进行预测至关重要。
概率公式
MIM问题可以用概率形式表示:在给定条件、visible patches Xv、visible patches 的位置Pv 和 masked patches 的位置Pm的情况下,最大化 masked patches 的预测Ym的概率:P(Ym|Xv,Pv,Pm)。它可以通过引入潜在表示Zm 和 Zv 来解决,假设 Zv 和 Pm(Ym 和 Pv)是条件独立的: ( 潜在表示对齐约束可以写成条件概率 P(Zm| Zm))

有利于将 表征学习 与 pretext task 完全分离
- CAE编码器处理 visible patches ,以提取其表示,而无需预测 masked patches 。
- 潜在上下文回归器不会更新 visible patches 的表示:回归器中 visible patches 的表达是 cross-attention 的 values and keys ;
- 对齐约束期望潜在上下文回归器的输出与编码器输出 在相同的表示空间中。解码器仅处理 masked patches 的预测表示。因此,编码器承担表示学习的责任,并且仅用于表示学习
直观的解释
类比人类:人类能够对遮盖区域中出现的东西以及它们如何根据可见区域出现产生幻觉。我们推测,人类这样做的方式可能与下面的例子类似:假设只有狗头的区域部分可见,而其余部分缺失,则可以

(a)识别出可见区域是关于狗的,
(b)预测狗的其他部分出现的区域,
(c)猜测其他部分是什么样子。
我们的CAE编码器 在某种意义上类似于人类识别步骤(a)
它通过将 visible patches 映射到位于与类别dog1 对应的子空间中的潜在表示来理解内容。
如图:
图象是 t-SNE 可视化(从 ADE20K 中的图像中提取的潜在表示)投影到2D空间
一个类别一种颜色
左:预训练为ViT的CAE(预训练 ImageNet-1K) 右:具有随机权重的ViT
图5所示的2D投影表明,对于不同的类别,潜在的表示在一定程度上是聚集的

2.2 Connection
与自动编码器的关系
>>> 最初的自动编码器(LeCun,1987;Gallinari等人,1987;Hinton&Zemel,1994)
- 由编码器和解码器组成。
- 编码器将输入映射到潜在表示中,解码器从潜在表示重建输入。
>>> 去噪自动编码器(DAE)(Vincent等人,2010)(如图2(c)所示)
- 是自动编码器的一种变体,它通过添加噪声来破坏输入,并且仍然重建未损坏的输入。
>>> 我们的CAE编码器(如图2(a)所示)
- 类似于原始的自动编码器,还包含一个编码器和一个解码器。
- 与编码器和解码器处理整个图像的自动编码器不同,我们的编码器将 patches 的一部分作为输入,而我们的解码器将 patches 的另一部分的估计潜在呈现作为输入。
- 重要的是,CAE引入了一个潜在的上下文回归因子,可以在潜在空间中预测从 visible patches 到 masked patches 的变化。
与BEiT的关系
- BEiT 将 visible patches 和 masked patches(由masked tokens 表示)都输入到基于 self-attention 的ViT中,
- 然后预测离散 patches tokens,其中只有masked patches 的 tokens 计入损失函数。
- BEiT中的ViT同时理解图像内容,并产生masked patches 的假设。
- 没有显式和单独的表示提取模块,这表明ViT网络使用部分能力进行表示学习。相比之下,CAE编码器只用于理解内容,而不用于预测 masked patches 。CAE中的其他部分不会更新 visible patches 的表示。
- 潜在上下文回归器的输出与编码器计算的masked patches 表示 对齐,从而限制表示提取角色仅由编码器承担。这意味着我们的CAE编码器利用了表示学习的全部功能。
对比学习
典型的对比学习方法,例如:
- SimCLR(Chen et al,2020b)
- MoCo(He et al,2020;Chen et all,2021。
相关实验结论:研究表明,在(Chen et al,2020b)中,随机裁剪在对比学习的视图增强中发挥着重要作用。

通过分析随机裁剪(如图3所示),我们观察到原始图像空间中的中心像素有很大的机会属于随机裁剪。我们怀疑,通过对比学习全局表示往往主要集中在原始图像中的中心像素上,因此来自同一图像的不同crops 的表示可能是相似的。图6(第二行)显示,对于典型的对比学习方法MoCo v3,原始图像的中心区域受到了高度关注。
图6:说明了在ImageNet-1K上预训练的ViT编码器的最后一层中,类令牌和补丁令牌之间的12个注意力头的平均注意力图。蓝色轮廓内的区域是通过阈值化注意力权重来获得的,以保持50%的质量。顶部:输入图像,中部:MoCo v3,这是一种典型的对比学习方法,底部:我们的CAE。可以看出,MoCo v3倾向于主要关注中心区域,而很少关注其他补丁,我们的CAE倾向于考虑几乎所有的补丁。

相反,我们的MIM方法CAE从增强视图中随机采样patches,以形成visible patches 和 masked patches。
对于增强视图和相应的原始图像,所有 patches 都可以被 mask。因此,CAE编码器需要学习所有patches的良好表示
图6(第三行)说明了在我们的CAE 编码器中几乎考虑了原始图像中的所有patches 。
考虑到ImageNet-1K中1000个类别的实例主要位于原始图像的中心附近,典型的对比学习方法,例如MoCo v3,主要学习关于1000个类别,这类似于监督预训练。
但我们的CAE和其他MIM方法能够从非中心图像区域学习超过1000个类别的更多知识。这表明CAE有潜力更好地执行下游任务。
Experiments
训练设置
- 我们研究了标准的ViT small、 base、large 架构,即ViT-S(12个 dim 384的 transformer blocks)、ViT-B(12个 dim 768的 transformer blocks ,以及ViTL(24个dim为1024的transformer blocks)
- 潜在上下文回归器 由4个基于 cross-attention 的 transformer blocks 组成
- 解码器由4个 基于self-attention 的 transformer blocks 和一个额外的 linear projection 组成。
- 在ImageNet-1K上培训CAE。
- 我们将224×224的图像划分为14×14个patches ,patches 大小为16×16。
- 我们使用标准的随机裁剪和水平翻转来增加数据。预训练设置与BEiT几乎相同
linear projection : linear projection 被广泛用作自监督表示学习的预训练质量评估。它通过使用图像的标签,在预训练编码器输出的图像级表示上学习线性分类器
cross-attention:图7:cross-attention 单元示意图。attention map(底部)是在额外 class token 和 patches 之间的12个 heads 上的 cross-attention maps 的平均值。可以看出,关注区域主要位于对象,这有助于图像分类。

Results.
表1显示了三种方案的结果,
- 典型对比预处理(MoCo v3和DINO)
- MIM(BEiT和MAE)方法的(linear probing(LIN)、attentive probing(ATT)和微调(FT))
- 我们的方法CAE。
我们使用官方实现对具有300个 epochs 的MAE和BEiT的模型进行了预训练,其他模型是官方发布的模型。
对同一数据集(ImageNet-1K)进行微调对于预训练评估来说不是一个好的方案。
表格解释:
在 linear probing 方面,对比学习方法MoCov3和DINO的得分高于MIM方法。这是意料之中的,因为对比学习主要集中在学习1000个类别的表示(见第3节的讨论)
对于MIM方法,attentive probing 得分为 linear probing 大得多。这验证了我们的分析:MIM方法提取了所有 patches 的表示,并且分类任务需要关注 patches 的相应部分。

消融实验
- decoder and the alignment constraint in CAE.
- pretrained on ImageNet-1K with 300 epochs.
- 下游任务: semantic segmentation on ADE20K and object detection on COCO
- 结论:当仅添加解码器时,下游任务性能几乎不变,并且当同时添加解码器和对准约束时,性能增加。
- 理论:这也验证了对准约束对于确保masked patch 的预测表示位于编码表示空间中, 并且因此在编码表示空间内进行预测是重要的,以及相应地提高表示质量是重要的。

下游任务
1)Semantic segmentation on ADE20K
- ADE20K上的语义分割。
- 使用multi-crop 预训练增强
- †:这些结果来自(He等人,2021)。

2)Object detection and instance segmentation on COCO
- 基于COCO的对象检测和实例分割。
- 采用Mask R-CNN,并按照1×时间表进行训练。
- ImageNet-1K上预训练
- *:使用 multi-crop 训练增强(见表1)。

本文介绍用于自监督表示学习的上下文自动编码器(CAE),它是新的masked image modeling方法,能分离表征学习和pretext task。文中详细阐述CAE结构、目标函数等,通过实验证明其在下游任务(语义分割、对象检测等)中优于监督预训练、对比预训练和其他MIM方法。

918

被折叠的 条评论
为什么被折叠?



