位置编码源码分析

该文详细介绍了位置编码在自注意力机制中的作用,其目的是为了解决并行计算时丢失序列顺序信息的问题。通过正弦和余弦函数生成的固定位置编码,注入绝对或相对位置信息。文章还探讨了如何利用矩阵变换表示两个位置之间的相对位置,并提供了相应的代码实现。位置编码对于捕获序列结构至关重要,特别是在Transformer等模型中。

1. 来源

位置编码的提出是为了解决自注意力因为并行计算而放弃了顺序操作,为了使用序列的顺序信息,我们通过在输入表示中添加位置编码来注入绝对或相对的位置信息。位置编码可以通过学习得到也可以直接固定得到。以下是基于正弦函数和余弦函数的固定位置编码

2. 公式

假设输入表示 X ∈ R n × d X\in R^{n \times d} XRn×d包含了一个序列中的n个词元的d维嵌入表示。位置编码使用相同形状的位置嵌入矩阵P 输出 X+P,矩阵的第 i 行,第 2j 列和第 2j列上的元素为:
P i , 2 j = sin ⁡ ( i 1000 0 2 j / d ) (1) P_{i,2j}=\sin(\frac{i}{10000^{2j/d}})\tag1 Pi,2j=sin(100002j/di)(1)
P i , 2 j + 1 = cos ⁡ ( i 1000 0 2 j / d ) (2) P_{i,2j+1}=\cos(\frac{i}{10000^{2j/d}})\tag2 Pi,2j+1=cos(100002j/di)(2)

3. 相对位置求解

对于上式我们令 w j = 1 / 1000 0 2 j / d w_j=1/10000^{2j/d} wj=1/100002j/d,那么可以得到如下:
P i , 2 j = sin ⁡ i w j (3) P_{i,2j}=\sin iw_j\tag3 Pi,2j=siniwj(3)
P i , 2 j + 1 = cos ⁡ i w j (4) P_{i,2j+1}=\cos iw_j\tag4 Pi,2j+1=cosiwj(4)
假设有一个点坐标为 A点 ( P i , 2 j , P i , 2 j + 1 ) (P_{i,2j},P_{i,2j+1}) (Pi,2j,Pi,2j+1),距离A点距离为 σ \sigma σ有B点 ( P i + σ , 2 j , P i + σ , 2 j + 1 ) (P_{i+\sigma,2j},P_{i+\sigma,2j+1}) (Pi+σ,2j,Pi+σ,2j+1),他们之间的转换如下
[ cos ⁡ ( σ w j ) sin ⁡ ( σ w j ) − sin ⁡ ( σ w j ) cos ⁡ ( σ w j ) ] [ P i , 2 j P i , 2 j + 1 ] (5) \begin{bmatrix}\cos(\sigma{w_j})&\sin(\sigma w_j)\\-\sin(\sigma w_j)&\cos(\sigma w_j)\end{bmatrix}\begin{bmatrix}P_{i,2j}\\P_{i,2j+1}\end{bmatrix} \tag5 [cos(σwj)sin(σwj)sin(σwj)cos(σwj)][Pi,2jPi,2j+1](5)
= [ cos ⁡ ( σ w j ) P i , 2 j + sin ⁡ ( σ w j ) P i , 2 j + 1 − sin ⁡ ( σ w j ) P i , 2 j + cos ⁡ ( σ w j ) P i , 2 j + 1 ] (6) =\begin{bmatrix}\cos(\sigma{w_j})P_{i,2j}+\sin(\sigma w_j)P_{i,2j+1}\\-\sin(\sigma w_j)P_{i,2j}+\cos(\sigma w_j)P_{i,2j+1}\end{bmatrix}\tag 6 =[cos(σwj)Pi,2j+sin(σwj)Pi,2j+1sin(σwj)Pi,2j+cos(σwj)Pi,2j+1](6)
= [ cos ⁡ ( σ w j ) sin ⁡ i w j + sin ⁡ ( σ w j ) cos ⁡ i w j − sin ⁡ ( σ w j ) sin ⁡ i w j + cos ⁡ ( σ w j ) cos ⁡ i w j ] (7) =\begin{bmatrix}\cos(\sigma{w_j})\sin iw_j+\sin(\sigma w_j)\cos iw_j\\-\sin(\sigma w_j)\sin iw_j+\cos(\sigma w_j)\cos iw_j\end{bmatrix} \tag 7 =[cos(σwj)siniwj+sin(σwj)cosiwjsin(σwj)siniwj+cos(σwj)cosiwj](7)
= [ sin ⁡ ( i w j + σ w j ) cos ⁡ ( i w j + σ w j ) ] (8) =\begin{bmatrix}\sin(iw_j+\sigma w_j)\\\cos(iw_j+\sigma w_j)\end{bmatrix} \tag 8 =[sin(iwj+σwj)cos(iwj+σwj)](8)
= [ sin ⁡ [ ( i + σ ) w j ) ] cos ⁡ [ ( i + σ ) w j ) ] ] (9) =\begin{bmatrix}\sin[(i+\sigma )w_j)]\\\\\cos[(i+\sigma) w_j)]\end{bmatrix}\tag 9 =sin[(i+σ)wj)]cos[(i+σ)wj)](9)
= [ P i + σ , 2 j P i + σ , 2 j + 1 ] (10) =\begin{bmatrix}P_{i+\sigma,2j}\\\\P_{i+\sigma,2j+1}\end{bmatrix} \tag{10} =Pi+σ,2jPi+σ,2j+1(10)

即:可以得到如下关系:
[ cos ⁡ ( σ w j ) sin ⁡ ( σ w j ) − sin ⁡ ( σ w j ) cos ⁡ ( σ w j ) ] [ P i , 2 j P i , 2 j + 1 ] = [ P i + σ , 2 j P i + σ , 2 j + 1 ] (11) \begin{bmatrix}\cos(\sigma{w_j})&\sin(\sigma w_j)\\\\-\sin(\sigma w_j)&\cos(\sigma w_j)\end{bmatrix}\begin{bmatrix}P_{i,2j}\\\\P_{i,2j+1}\end{bmatrix}=\begin{bmatrix}P_{i+\sigma,2j}\\\\P_{i+\sigma,2j+1}\end{bmatrix}\tag{11} cos(σwj)sin(σwj)sin(σwj)cos(σwj)Pi,2jPi,2j+1=Pi+σ,2jPi+σ,2j+1(11)
我们惊奇发现矩阵M是不依赖于i的,矩阵M存储了从点A到点B的相对位置信息;
M = [ cos ⁡ ( σ w j ) sin ⁡ ( σ w j ) − sin ⁡ ( σ w j ) cos ⁡ ( σ w j ) ] (12) M=\begin{bmatrix}\cos(\sigma{w_j})&\sin(\sigma w_j)\\\\-\sin(\sigma w_j)&\cos(\sigma w_j)\end{bmatrix}\tag{12} M=cos(σwj)sin(σwj)sin(σwj)cos(σwj)(12)

4. 代码

固定位置编码

# -*- coding: utf-8 -*-
# @Project: zc
# @Author: zc
# @File name: positional_test
# @Create time: 2022/2/27 11:26
#@save
import torch
from torch import nn
from d2l import torch as d2l
import matplotlib.pyplot as plt

class PositionalEncoding(nn.Module):
    """位置编码"""
    def __init__(self, num_hiddens, dropout, max_len=1000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dropout)
        # 创建一个足够长的P=(1,1000,32)
        self.P = torch.zeros((1, max_len, num_hiddens))
        # X=(1000,16)
        X = torch.arange(max_len, dtype=torch.float32).reshape(
            -1, 1) / torch.pow(10000, torch.arange(
            0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
        # self.P = (1,1000,32)
        self.P[:, :, 0::2] = torch.sin(X)
        self.P[:, :, 1::2] = torch.cos(X)

    def forward(self, X):
        X = X + self.P[:, :X.shape[1], :].to(X.device)
        return self.dropout(X)

encoding_dim, num_steps = 32, 60
pos_encoding = PositionalEncoding(encoding_dim, 0)
pos_encoding.eval()
# X=(1,60,32); P=(1,60,32)
X = pos_encoding(torch.zeros((1, num_steps, encoding_dim)))
P = pos_encoding.P[:, :X.shape[1], :]
d2l.plot(torch.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',
         figsize=(6, 2.5), legend=["Col %d" % d for d in torch.arange(6, 10)])

plt.show()

5. 小结

  • 在自注意力中,查询queries,键keys,值values都来自同一输入
  • 为了使用序列的顺序信息,我们可以通过在输入表示中添加位置编码,来注入绝对的或相对的位置信息
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值