用一个具体到数字的例子理解transformer——“The dog runs fast”这句话在transformer中如何翻译为“狗跑得快”?

假设我们要将英文句子 “The dog runs fast” 翻译成中文,设定词表大小 Vsrc=10000V_{src}=10000Vsrc=10000(源语言词表大小),Vtgt=8000V_{tgt}=8000Vtgt=8000(目标语言词表大小),模型维度 dmodel=512d_{model}=512dmodel=512,序列最大长度 max_lensrc=20max\_len_{src}=20max_lensrc=20(源语言序列最大长度),max_lentgt=20max\_len_{tgt}=20max_lentgt=20(目标语言序列最大长度),头的数量 h=8h = 8h=8,头维度 dk=64d_k = 64dk=64,前馈网络中间维度 dff=2048d_ff = 2048dff=2048

已知的矩阵

  • 词嵌入矩阵
    • 对于源语言和目标语言的词嵌入矩阵(如 Esrc∈R10000×512E_{src} \in \mathbb{R}^{10000\times512}EsrcR10000×512Etgt∈R8000×512E_{tgt} \in \mathbb{R}^{8000\times512}EtgtR8000×512),维度依据词表大小和模型维度确定。这里的512是词向量(词嵌入)维度,不一定是512,比如glove词嵌入就有很多版本。词表维度也是方便计算假设的。
  • 位置编码矩阵
    • 位置编码矩阵(如 PEsrc∈R20×512PE_{src} \in \mathbb{R}^{20\times512}PEsrcR20×512PEtgt∈R20×512PE_{tgt} \in \mathbb{R}^{20\times512}PEtgtR20×512)依据位置编码公式(PE(pos,2i)=sin⁡(pos/100002i/512)PE_{(pos, 2i)} = \sin(pos / 10000^{2i / 512})PE(pos,2i)=sin(pos/100002i/512)PE(pos,2i+1)=cos⁡(pos/100002i/512)PE_{(pos, 2i + 1)} = \cos(pos / 10000^{2i / 512})PE(pos,2i+1)=cos(pos/100002i/512))预先计算得出。

初始化之后慢慢训练得到的矩阵

  • 多头注意力机制中的权重矩阵
    • 在多头注意力机制中,WQ,WK,WV∈R512×64W_Q, W_K, W_V \in \mathbb{R}^{512\times64}WQ,WK,WVR512×64WO∈R512×512W_O \in \mathbb{R}^{512\times512}WOR512×512 这些权重矩阵需初始化,然后在训练过程中借助反向传播不断更新。其初始化通常采用随机初始化方式(如从正态分布或均匀分布中采样),训练时用损失函数计算梯度,用优化器(如 Adam)更新矩阵参数。
  • 前馈神经网络中的权重矩阵
    • 前馈神经网络中的 W1∈R512×2048W_1 \in \mathbb{R}^{512\times2048}W1R512×2048W2∈R2048×512W_2 \in \mathbb{R}^{2048\times512}W2R2048×512 以及偏置向量 b1b_1b1b2b_2b2 同样初始化后经训练得到。在翻译任务重前馈神经网络可以使模型能学习到数据中的语义和语法等复杂模式。
  • 层归一化中的参数
    • 层归一化中的 γ\gammaγβ\betaβ 参数可学习。初始化后在训练过程中调整,以适应数据分布特点,助力模型更好地训练和收敛。在 LN(X)=γX−μσ2+ϵ+βLN(X) = \gamma\frac{X - \mu}{\sqrt{\sigma^2+\epsilon}}+\betaLN(X)=γσ2+ϵXμ+β 中,γ\gammaγβ\betaβ 会依据训练数据的统计特性(均值 μ\muμ 和方差 σ2\sigma^2σ2)以及模型训练目标进行更新。

    开始计算!!!!!!!!!!

1. 词嵌入

  • 源语言词嵌入:源语言词嵌入矩阵 Esrc∈R10000×512E_{src} \in \mathbb{R}^{10000\times512}EsrcR10000×512。假设 “The”、“dog”、“runs”、“fast” 在源语言词表中的索引分别为 100100100200200200300300300400400400。对于单词 “The”,其嵌入向量 x100=Esrc[100,:]∈R512x_{100}=E_{src}[100,:] \in \mathbb{R}^{512}x100=Esrc[100,:]R512,同理可得其他单词的嵌入向量。整个源语言输入序列 Xsrc∈R1×4×512X_{src} \in \mathbb{R}^{1\times4\times512}XsrcR1×4×512(这里假设 batch_size 为 1 )。
  • 目标语言词嵌入:在训练解码器时,目标语言输入序列同样需要词嵌入。目标语言词嵌入矩阵 Etgt∈R8000×512E_{tgt} \in \mathbb{R}^{8000\times512}EtgtR8000×512。假设目标语言序列 “狗 跑 得 快” 的词在词表中索引分别为 505050606060707070808080,其嵌入向量组成目标语言输入序列 Xtgt∈R1×4×512X_{tgt} \in \mathbb{R}^{1\times4\times512}XtgtR1×4×512

2. 位置编码

  • 源语言位置编码:位置编码矩阵 PEsrc∈R20×512PE_{src} \in \mathbb{R}^{20\times512}PEsrcR20×512。位置编码公式为:
    PE(pos,2i)=sin⁡(pos/100002i/512)PE_{(pos, 2i)} = \sin(pos / 10000^{2i / 512})PE(pos,2i)=sin(pos/100002i/512)
    PE(pos,2i+1)=cos⁡(pos/100002i/512)PE_{(pos, 2i + 1)} = \cos(pos / 10000^{2i / 512})PE(pos,2i+1)=cos(pos/100002i/512)
    例如,第一个位置 pos=0pos = 0pos=0 时,PE(0,0)=sin⁡(0)PE_{(0,0)}=\sin(0)PE(0,0)=sin(0)PE(0,1)=cos⁡(0)PE_{(0,1)}=\cos(0)PE(0,1)=cos(0) 。源语言输入序列 XsrcX_{src}Xsrc 与位置编码相加:Xsrc′=Xsrc+PEsrc[:,:4,:]X_{src}' = X_{src} + PE_{src}[:, :4, :]Xsrc=Xsrc+PEsrc[:,:4,:] ,得到 Xsrc′∈R1×4×512X_{src}' \in \mathbb{R}^{1\times4\times512}XsrcR1×4×512
  • 目标语言位置编码:同理,目标语言位置编码矩阵 PEtgt∈R20×512PE_{tgt} \in \mathbb{R}^{20\times512}PEtgtR20×512,目标语言输入序列 XtgtX_{tgt}Xtgt 与位置编码相加:Xtgt′=Xtgt+PEtgt[:,:4,:]X_{tgt}' = X_{tgt} + PE_{tgt}[:, :4, :]Xtgt=Xtgt+PEtgt[:,:4,:] ,得到 Xtgt′∈R1×4×512X_{tgt}' \in \mathbb{R}^{1\times4\times512}XtgtR1×4×512

3. 编码器

多头注意力机制
  • 计算注意力分数(Attention Scores):将输入 Xsrc′X_{src}'Xsrc 线性变换为查询(Query)、键(Key)和值(Value)矩阵:
    Q=Xsrc′WQQ = X_{src}'W_QQ=XsrcWQ, K=Xsrc′WKK = X_{src}'W_KK=XsrcWK, V=Xsrc′WVV = X_{src}'W_VV=XsrcWV
    其中 WQ,WK,WV∈R512×64W_Q, W_K, W_V \in \mathbb{R}^{512\times64}WQ,WK,WVR512×64
    Q∈R1×4×64Q \in \mathbb{R}^{1\times4\times64}QR1×4×64K∈R1×4×64K \in \mathbb{R}^{1\times4\times64}KR1×4×64V∈R1×4×64V \in \mathbb{R}^{1\times4\times64}VR1×4×64
    注意力分数矩阵 AAA 计算如下:
    A=Softmax(QKT64)A = \text{Softmax}(\frac{QK^T}{\sqrt{64}})A=Softmax(64QKT)
    QKT∈R1×4×4QK^T \in \mathbb{R}^{1\times4\times4}QKTR1×4×4,经过Softmax后 A∈R1×4×4A \in \mathbb{R}^{1\times4\times4}AR1×4×4
  • 计算注意力输出(Attention Output):注意力输出 ZZZ 为:
    Z=AVZ = AVZ=AV
    Z∈R1×4×64Z \in \mathbb{R}^{1\times4\times64}ZR1×4×64
  • 多头拼接与线性变换:多头注意力有 h=8h = 8h=8 个头,每个头都进行上述计算,得到 Z1,Z2,⋯ ,Z8Z_1, Z_2,\cdots, Z_8Z1,Z2,,Z8,每个 Zi∈R1×4×64Z_i \in \mathbb{R}^{1\times4\times64}ZiR1×4×64
    拼接后 Zconcat=[Z1;Z2;⋯ ;Z8]∈R1×4×512Z_{concat}=[Z_1; Z_2;\cdots; Z_8] \in \mathbb{R}^{1\times4\times512}Zconcat=[Z1;Z2;;Z8]R1×4×512
    最终多头注意力输出:
    MultiHeadAttention(Xsrc′)=ZconcatWO\text{MultiHeadAttention}(X_{src}') = Z_{concat}W_OMultiHeadAttention(Xsrc)=ZconcatWO
    其中 WO∈R512×512W_O \in \mathbb{R}^{512\times512}WOR512×512,输出 MultiHeadAttention(Xsrc′)∈R1×4×512\text{MultiHeadAttention}(X_{src}') \in \mathbb{R}^{1\times4\times512}MultiHeadAttention(Xsrc)R1×4×512
前馈神经网络

FFN(Z)=ReLU(ZW1+b1)W2+b2FFN(Z) = \text{ReLU}(ZW_1 + b_1)W_2 + b_2FFN(Z)=ReLU(ZW1+b1)W2+b2
其中 W1∈R512×2048W_1 \in \mathbb{R}^{512\times2048}W1R512×2048W2∈R2048×512W_2 \in \mathbb{R}^{2048\times512}W2R2048×512
Z∈R1×4×512Z \in \mathbb{R}^{1\times4\times512}ZR1×4×512ZW1+b1∈R1×4×2048ZW_1 + b_1 \in \mathbb{R}^{1\times4\times2048}ZW1+b1R1×4×2048,经过ReLU后维度不变,再乘以 W2W_2W2 得到 FFN(Z)∈R1×4×512FFN(Z) \in \mathbb{R}^{1\times4\times512}FFN(Z)R1×4×512

层归一化

对多头注意力输出或前馈神经网络输出进行层归一化:
LN(X)=γX−μσ2+ϵ+βLN(X) = \gamma \frac{X - \mu}{\sqrt{\sigma^2 + \epsilon}} + \betaLN(X)=γσ2+ϵXμ+β
μ\muμσ2\sigma^2σ2 是沿着特征维度(这里是 512512512 维)计算的均值和方差,γ\gammaγβ\betaβ 是可学习参数,输出维度与输入 XXX 相同,假设输入 X∈R1×4×512X \in \mathbb{R}^{1\times4\times512}XR1×4×512,输出也为 R1×4×512\mathbb{R}^{1\times4\times512}R1×4×512

经过上述编码器处理后,最终编码器输出 Zencoder∈R1×4×512Z_{encoder} \in \mathbb{R}^{1\times4\times512}ZencoderR1×4×512

4. 解码器

掩码多头注意力
  • 计算注意力分数:先将输入 Xtgt′X_{tgt}'Xtgt 线性变换为查询(Query)、键(Key)和值(Value)矩阵:
    Qmask=Xtgt′WQmaskQ_{mask} = X_{tgt}'W_{Q_{mask}}Qmask=XtgtWQmask, Kmask=Xtgt′WKmaskK_{mask} = X_{tgt}'W_{K_{mask}}Kmask=XtgtWKmask, Vmask=Xtgt′WVmaskV_{mask} = X_{tgt}'W_{V_{mask}}Vmask=XtgtWVmask
    其中 WQmask,WKmask,WVmask∈R512×64W_{Q_{mask}}, W_{K_{mask}}, W_{V_{mask}} \in \mathbb{R}^{512\times64}WQmask,WKmask,WVmaskR512×64
    Qmask∈R1×4×64Q_{mask} \in \mathbb{R}^{1\times4\times64}QmaskR1×4×64Kmask∈R1×4×64K_{mask} \in \mathbb{R}^{1\times4\times64}KmaskR1×4×64Vmask∈R1×4×64V_{mask} \in \mathbb{R}^{1\times4\times64}VmaskR1×4×64

计算注意力分数矩阵 AmaskA_{mask}Amask 时,引入掩码矩阵 MMM
Amask=Softmax(QmaskKmaskT64+M)A_{mask} = \text{Softmax}(\frac{Q_{mask}K_{mask}^T}{\sqrt{64}} + M)Amask=Softmax(64QmaskKmaskT+M)
掩码矩阵 MMM 是一个下三角矩阵(对于长度为 LLL 的序列,Mij=−∞M_{ij} = -\inftyMij= 如果 i>ji > ji>jMij=0M_{ij} = 0Mij=0 如果 i≤ji \leq jij),这里 M∈R4×4M \in \mathbb{R}^{4\times4}MR4×4。这样可以确保在计算注意力分数时,当前位置不会关注到未来位置的信息。经过Softmax后 Amask∈R1×4×4A_{mask} \in \mathbb{R}^{1\times4\times4}AmaskR1×4×4

  • 计算注意力输出:注意力输出 ZmaskZ_{mask}Zmask 为:
    Zmask=AmaskVmaskZ_{mask} = A_{mask}V_{mask}Zmask=AmaskVmask
    Zmask∈R1×4×64Z_{mask} \in \mathbb{R}^{1\times4\times64}ZmaskR1×4×64
  • 多头拼接与线性变换:同样有 h=8h = 8h=8 个头,每个头都进行上述计算,得到 Zmask1,Zmask2,⋯ ,Zmask8Z_{mask1}, Z_{mask2},\cdots, Z_{mask8}Zmask1,Zmask2,,Zmask8,每个 Zmaski∈R1×4×64Z_{maski} \in \mathbb{R}^{1\times4\times64}ZmaskiR1×4×64
    拼接后 Zmask_concat=[Zmask1;Zmask2;⋯ ;Zmask8]∈R1×4×512Z_{mask\_concat}=[Z_{mask1}; Z_{mask2};\cdots; Z_{mask8}] \in \mathbb{R}^{1\times4\times512}Zmask_concat=[Zmask1;Zmask2;;Zmask8]R1×4×512
    经过线性变换:
    MaskedMultiHeadAttention(Xtgt′)=Zmask_concatWOmask\text{MaskedMultiHeadAttention}(X_{tgt}') = Z_{mask\_concat}W_{O_{mask}}MaskedMultiHeadAttention(Xtgt)=Zmask_concatWOmask
    其中 WOmask∈R512×512W_{O_{mask}} \in \mathbb{R}^{512\times512}WOmaskR512×512,输出 MaskedMultiHeadAttention(Xtgt′)∈R1×4×512\text{MaskedMultiHeadAttention}(X_{tgt}') \in \mathbb{R}^{1\times4\times512}MaskedMultiHeadAttention(Xtgt)R1×4×512
编码器 - 解码器注意力
  • 计算注意力分数:将掩码多头注意力的输出 Zmask_outZ_{mask\_out}Zmask_out 线性变换为查询矩阵 Qenc−decQ_{enc - dec}Qencdec,编码器输出 ZencoderZ_{encoder}Zencoder 线性变换为键矩阵 Kenc−decK_{enc - dec}Kencdec 和值矩阵 Venc−decV_{enc - dec}Vencdec
    Qenc−dec=Zmask_outWQenc−decQ_{enc - dec} = Z_{mask\_out}W_{Q_{enc - dec}}Qencdec=Zmask_outWQencdec, Kenc−dec=ZencoderWKenc−decK_{enc - dec} = Z_{encoder}W_{K_{enc - dec}}Kencdec=ZencoderWKencdec, Venc−dec=ZencoderWVenc−decV_{enc - dec} = Z_{encoder}W_{V_{enc - dec}}Vencdec=ZencoderWVencdec
    其中 WQenc−dec,WKenc−dec,WVenc−dec∈R512×64W_{Q_{enc - dec}}, W_{K_{enc - dec}}, W_{V_{enc - dec}} \in \mathbb{R}^{512\times64}WQencdec,WKencdec,WVencdecR512×64
    Qenc−dec∈R1×4×64Q_{enc - dec} \in \mathbb{R}^{1\times4\times64}QencdecR1×4×64Kenc−dec∈R1×4×64K_{enc - dec} \in \mathbb{R}^{1\times4\times64}KencdecR1×4×64Venc−dec∈R1×4×64V_{enc - dec} \in \mathbb{R}^{1\times4\times64}VencdecR1×4×64

计算注意力分数矩阵 Aenc−decA_{enc - dec}Aencdec
Aenc−dec=Softmax(Qenc−decKenc−decT64)A_{enc - dec} = \text{Softmax}(\frac{Q_{enc - dec}K_{enc - dec}^T}{\sqrt{64}})Aencdec=Softmax(64QencdecKencdecT)
Qenc−decKenc−decT∈R1×4×4Q_{enc - dec}K_{enc - dec}^T \in \mathbb{R}^{1\times4\times4}QencdecKencdecTR1×4×4,经过Softmax后 Aenc−dec∈R1×4×4A_{enc - dec} \in \mathbb{R}^{1\times4\times4}AencdecR1×4×4

  • 计算注意力输出(Attention Output):注意力输出 Zenc−decZ_{enc - dec}Zencdec 为:
    Zenc−dec=Aenc−decVenc−decZ_{enc - dec} = A_{enc - dec}V_{enc - dec}Zencdec=AencdecVencdec
    Zenc−dec∈R1×4×64Z_{enc - dec} \in \mathbb{R}^{1\times4\times64}ZencdecR1×4×64
  • 多头拼接与线性变换:多头操作与前面类似,得到 Zenc−dec_concat∈R1×4×512Z_{enc - dec\_concat} \in \mathbb{R}^{1\times4\times512}Zencdec_concatR1×4×512,再经过线性变换:
    EncoderDecoderAttention(Zmask_out,Zencoder)=Zenc−dec_concatWOenc−dec\text{EncoderDecoderAttention}(Z_{mask\_out}, Z_{encoder}) = Z_{enc - dec\_concat}W_{O_{enc - dec}}EncoderDecoderAttention(Zmask_out,Zencoder)=Zencdec_concatWOencdec
    其中 WOenc−dec∈R512×512W_{O_{enc - dec}} \in \mathbb{R}^{512\times512}WOencdecR512×512,输出 EncoderDecoderAttention(Zmask_out,Zencoder)∈R1×4×512\text{EncoderDecoderAttention}(Z_{mask\_out}, Z_{encoder}) \in \mathbb{R}^{1\times4\times512}EncoderDecoderAttention(Zmask_out,Zencoder)R1×4×512
前馈神经网络与层归一化

与编码器中的前馈神经网络和层归一化类似。
前馈神经网络:
FFN(Zenc−dec_out)=ReLU(Zenc−dec_outW1+b1)W2+b2FFN(Z_{enc - dec\_out}) = \text{ReLU}(Z_{enc - dec\_out}W_1 + b_1)W_2 + b_2FFN(Zencdec_out)=ReLU(Zencdec_outW1+b1)W2+b2
其中 W1∈R512×2048W_1 \in \mathbb{R}^{512\times2048}W1R512×2048W2∈R2048×512W_2 \in \mathbb{R}^{2048\times512}W2R2048×512,输出 FFN(Zenc−dec_out)∈R1×4×512FFN(Z_{enc - dec\_out}) \in \mathbb{R}^{1\times4\times512}FFN(Zencdec_out)R1×4×512

然后进行层归一化,得到最终的解码器输出 Outputdecoder∈R1×4×512Output_{decoder} \in \mathbb{R}^{1\times4\times512}OutputdecoderR1×4×512

5. 最终输出

将解码器输出通过线性变换矩阵 Wfinal∈R512×8000W_{final} \in \mathbb{R}^{512\times8000}WfinalR512×8000,再经过Softmax函数得到输出概率分布,用于生成翻译结果。例如,输出 logits∈R1×4×8000logits \in \mathbb{R}^{1\times4\times8000}logitsR1×4×8000,经过Softmax后得到每个位置对应目标语言词表中每个词的概率,从而选择概率最高的词作为翻译结果。

至此,完成了这句话的训练。好多句话经过这个流程后,需要训练的参数逐渐完善,然后输入一句新的话,经过上述流程就可以得到它的翻译了。

代码实现:

import tensorflow as tf
import numpy as np


def positional_encoding(position, d_model):
    angle_rates = 1 / np.power(10000, (2 * (np.arange(d_model)[np.newaxis, :] // 2)) / np.float32(d_model))
    angle_rads = np.arange(position)[:, np.newaxis] * angle_rates
    # 应用 sin 到偶数索引位置
    angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
    # 应用 cos 到奇数索引位置
    angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
    pos_encoding = angle_rads[np.newaxis,...]
    return tf.cast(pos_encoding, dtype=tf.float32)


def scaled_dot_product_attention(q, k, v, mask):
    matmul_qk = tf.matmul(q, k, transpose_b=True)
    dk = tf.cast(tf.shape(k)[-1], tf.float32)
    scaled_attention_logits = matmul_qk / tf.sqrt(dk)
    if mask is not None:
        scaled_attention_logits += (mask * -1e9)
    attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
    output = tf.matmul(attention_weights, v)
    return output, attention_weights


def point_wise_feed_forward_network(d_model, dff):
    return tf.keras.Sequential([
        tf.keras.layers.Dense(dff, activation='relu'),
        tf.keras.layers.Dense(d_model)
    ])


class MultiHeadAttention(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        assert d_model % num_heads == 0
        self.depth = d_model // num_heads
        self.wq = tf.keras.layers.Dense(d_model)
        self.wk = tf.keras.layers.Dense(d_model)
        self.wv = tf.keras.layers.Dense(d_model)
        self.dense = tf.keras.layers.Dense(d_model)

    def split_heads(self, x, batch_size):
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
        return tf.transpose(x, perm=[0, 2, 1, 3])

    def call(self, v, k, q, mask):
        batch_size = tf.shape(q)[0]
        q = self.wq(q)
        k = self.wk(k)
        v = self.wv(v)
        q = self.split_heads(q, batch_size)
        k = self.split_heads(k, batch_size)
        v = self.split_heads(v, batch_size)
        scaled_attention, attention_weights = scaled_dot_product_attention(q, k, v, mask)
        scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])
        concat_attention = tf.reshape(scaled_attention, (batch_size, -1, self.d_model))
        output = self.dense(concat_attention)
        return output, attention_weights


class EncoderLayer(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads, dff, rate=0.1):
        super(EncoderLayer, self).__init__()
        self.mha = MultiHeadAttention(d_model, num_heads)
        self.ffn = point_wise_feed_forward_network(d_model, dff)
        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = tf.keras.layers.Dropout(rate)
        self.dropout2 = tf.keras.layers.Dropout(rate)

    def call(self, x, training, mask):
        attn_output, _ = self.mha(x, x, x, mask)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(x + attn_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        out2 = self.layernorm2(out1 + ffn_output)
        return out2


class DecoderLayer(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads, dff, rate=0.1):
        super(DecoderLayer, self).__init__()
        self.mha1 = MultiHeadAttention(d_model, num_heads)
        self.mha2 = MultiHeadAttention(d_model, num_heads)
        self.ffn = point_wise_feed_forward_network(d_model, dff)
        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = tf.keras.layers.Dropout(rate)
        self.dropout2 = tf.keras.layers.Dropout(rate)
        self.dropout3 = tf.keras.layers.Dropout(rate)

    def call(self, x, enc_output, training, look_ahead_mask, padding_mask):
        attn1, attn_weights_block1 = self.mha1(x, x, x, look_ahead_mask)
        attn1 = self.dropout1(attn1, training=training)
        out1 = self.layernorm1(x + attn1)
        attn2, attn_weights_block2 = self.mha2(enc_output, enc_output, out1, padding_mask)
        attn2 = self.dropout2(attn2, training=training)
        out2 = self.layernorm2(out1 + attn2)
        ffn_output = self.ffn(out2)
        ffn_output = self.dropout3(ffn_output, training=training)
        out3 = self.layernorm3(out2 + ffn_output)
        return out3, attn_weights_block1, attn_weights_block2


class Encoder(tf.keras.layers.Layer):
    def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size, maximum_position_encoding, rate=0.1):
        super(Encoder, self).__init__()
        self.d_model = d_model
        self.num_layers = num_layers
        self.embedding = tf.keras.layers.Embedding(input_vocab_size, d_model)
        self.pos_encoding = positional_encoding(maximum_position_encoding, self.d_model)
        self.enc_layers = [EncoderLayer(d_model, num_heads, dff, rate) for _ in range(num_layers)]
        self.dropout = tf.keras.layers.Dropout(rate)

    def call(self, x, training, mask):
        seq_len = tf.shape(x)[1]
        x = self.embedding(x)
        x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
        x += self.pos_encoding[:, :seq_len, :]
        x = self.dropout(x, training=training)
        for i in range(self.num_layers):
            x = self.enc_layers[i](x, training, mask)
        return x


class Decoder(tf.keras.layers.Layer):
    def __init__(self, num_layers, d_model, num_heads, dff, target_vocab_size, maximum_position_encoding, rate=0.1):
        super(Decoder, self).__init__()
        self.d_model = d_model
        self.num_layers = num_layers
        self.embedding = tf.keras.layers.Embedding(target_vocab_size, d_model)
        self.pos_encoding = positional_encoding(maximum_position_encoding, self.d_model)
        self.dec_layers = [DecoderLayer(d_model, num_heads, dff, rate) for _ in range(num_layers)]
        self.dropout = tf.keras.layers.Dropout(rate)

    def call(self, x, enc_output, training, look_ahead_mask, padding_mask):
        seq_len = tf.shape(x)[1]
        x = self.embedding(x)
        x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
        x += self.pos_encoding[:, :seq_len, :]
        x = self.dropout(x, training=training)
        for i in range(self.num_layers):
            x, block1, block2 = self.dec_layers[i](x, enc_output, training, look_ahead_mask, padding_mask)
        return x


class Transformer(tf.keras.Model):
    def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size, target_vocab_size, pe_input, pe_target, rate=0.1):
        super(Transformer, self).__init__()
        self.encoder = Encoder(num_layers, d_model, num_heads, dff, input_vocab_size, pe_input, rate)
        self.decoder = Decoder(num_layers, d_model, num_heads, dff, target_vocab_size, pe_target, rate)
        self.final_layer = tf.keras.layers.Dense(target_vocab_size)

    def call(self, inp, tar, training, enc_padding_mask, look_ahead_mask, dec_padding_mask):
        enc_output = self.encoder(inp, training, enc_padding_mask)
        dec_output = self.decoder(tar, enc_output, training, look_ahead_mask, dec_padding_mask)
        final_output = self.final_layer(dec_output)
        return final_output


# 示例使用
num_layers = 4
d_model = 128
dff = 512
num_heads = 8
input_vocab_size = 5000
target_vocab_size = 5000
dropout_rate = 0.1
input_sequence_length = 20
target_sequence_length = 20

transformer = Transformer(num_layers, d_model, num_heads, dff, input_vocab_size, target_vocab_size, input_sequence_length, target_sequence_length, dropout_rate)

input_tensor = tf.placeholder(tf.int32, shape=(None, input_sequence_length))
target_tensor = tf.placeholder(tf.int32, shape=(None, target_sequence_length))
enc_padding_mask = None  # 根据需要创建掩码
look_ahead_mask = None  # 根据需要创建掩码
dec_padding_mask = None  # 根据需要创建掩码

outputs = transformer(input_tensor, target_tensor, True, enc_padding_mask, look_ahead_mask, dec_padding_mask)
print(outputs)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值