编码器解码器架构(模板)
简单来说就是:编码器负责特征抽取,解码器负责输出,解码器也有自己的输入

编码器
from torch import nn
#@save
class Encoder(nn.Module):
"""编码器-解码器架构的基本编码器接口"""
def __init__(self, **kwargs):
super(Encoder, self).__init__(**kwargs)
def forward(self, X, *args):
raise NotImplementedError
和正常的模型一样
解码器
#@save
class Decoder(nn.Module):
"""编码器-解码器架构的基本解码器接口"""
def __init__(self, **kwargs):
super(Decoder, self).__init__(**kwargs)
def init_state(self, enc_outputs, *args):
raise NotImplementedError
def forward(self, X, state):
raise NotImplementedError
init_state(self, enc_outputs, *args):用来接收编码器的输出,转换成state
forward(self, X, state):且在forward中解码器也有自己的输入
合并编码器和解码器
#@save
class EncoderDecoder(nn.Module):
"""编码器-解码器架构的基类"""
def __init__(self, encoder, decoder, **kwargs):
super(EncoderDecoder, self).__init__(**kwargs)
self.encoder = encoder
self.decoder = decoder
def forward(self, enc_X, dec_X, *args):
enc_outputs = self.encoder(enc_X, *args)
dec_state = self.decoder.init_state(enc_outputs, *args)
return self.decoder(dec_X, dec_state)
forward(self, enc_X, dec_X, *args):enc_X:编码器输入
dec_X:解码器输入
enc_outputs = self.encoder(enc_X, *args)enc_outputs:经过编码器得到编码器输出
self.decoder.init_state(enc_outputs, *args)通过解码器的
init_state方法将编码器的输出变成一个状态供解码器使用
self.decoder(dec_X, dec_state)根据刚刚得到的状态和解码器自己的输入,得到解码器最终的输出
接下来几节使用的模型都基于这个结构。
本文详细介绍了编码器-解码器架构的实现原理,包括编码器的特征抽取、解码器的输入处理和状态转换,以及如何整合两者进行模型构建。适合理解深度学习中序列模型的基础构建方法。
&spm=1001.2101.3001.5002&articleId=124113024&d=1&t=3&u=98d995e0c6e546679850bf551cbb3009)
1425

被折叠的 条评论
为什么被折叠?



