RNN和LSTM图解

这篇博客通过图解的方式介绍了循环神经网络(RNN)和长短期记忆网络(LSTM)。RNN用于处理序列数据,通过状态来保持上下文信息,但存在梯度消失或爆炸问题。LSTM引入了门控机制,包括输入门、遗忘门和输出门,解决了RNN的问题,更好地处理长期依赖。

        从网上看到一些对RNN和LSTM的讲解,说得实在不明白,我看了好久才理解一点。所以把自己的理解画成图总结了一下。都是从网上找的资料不一定权威,也不是专门搞这个的,只是入门的介绍,其他深入的问题也不会。权当做一个简介和入门吧。

        神经网络

        如果看神经网络的一个单层而且是全连接层的话,可以表示成下面这样。

         其中W就是权重矩阵,Y=WX,非常好理解。

        循环神经网络(RNN)

        首先,为什么要有这个循环神经网络呢?因为普通神经网络中,给定一个输入X_1 就对应一个输出Y_1,给定一个输入X_2就对应一个输出Y_2,这俩互相之间没有半毛钱关系。但我们实际应用中,很多时候要输入一个序列,比如做自然语言处理时,输入一个句子,其中每个单词对应一个编码。我要是分开输入每个单词,倒也可以,但单词之间的关联就被抹去了,它肯定效果不好。事实上,这些单词之间是有关联的。那么我们怎么解决这个问题呢?有人就想了,利用前面三个单词,输入进去得到一个状态,然后利用这个状态和第四个单词一起,对第四个单词进行预测。这个想法就很好。

        所以得到了如下的结构,其中S就代表状态。

         图里为了简洁把激活函数没画出来,实际上一般经过全连接后得到新的结果之前都要加一层激活函数。另外,S_t连过来怎么就成了S_{t-1}了呢?这就是用上一个时刻的状态和该时刻的输入得到该时刻的状态和输出。或者理解为自然语言处理中,输入上一个单词得到了一个状态S_t,然后要输入这个单词了,上面得到的S_t不就是 S_{t-1}了吗?然后把这个S_{t-1}和输入的这个单词X_t得到新的状态S_t,进而得到该单词对应的输出Y_t

        写成公式就是

Y_t=g(WS_t)\\ S_t=f(UX_t+VS_{t-1})

        LSTM

        上面的RNN,能对序列的输入,对每个输入得到一个预测输出,比起我将这个序列的每个输入拆开 分别扔进神经网络中,RNN能够取得更好的效果。

        然而,RNN如果序列很长,可能出现梯度消失或梯度爆炸的问题。我们需要做一些调整。这就有了LSTM,即长短期记忆。下面我们先看结构,再说原理。我觉得这样比起一些一上来就这门那门的教学要更好理解吧。

        我们还是从RNN出发来继续理解LSTM。上面说了,RNN为了处理一个序列输入,它要维护一个“状态”。这个状态能够“记忆”之前的输入,再结合现在的输入,综合给出现在输入对应的输出。而LSTM需要两个状态,即h_tc_t。下面先看一下它的结构。

         看起来仿佛很复杂,其实也很好理解。这里有两个状态,h_t(hidden\ state)c_t(cell\ state),我们可以分别叫它隐状态和单元状态。很明显地,也是隐状态h_t反馈回来,即h_{t-1},和输入X_t共同决定输出。但其中的运算方式和普通的RNN不一样了。原来的RNN是二者乘各自的权重矩阵然后加和得到现在的状态,然后乘输出权重矩阵得到输出。但在这里,把二者拼接,然后乘权重矩阵W得到新的输入Z。为了便于表述,我们把拼接结果叫做(x,h)吧。

        下面就是体现LSTM的优势和特点的几个步骤了。我们这个输入不全要,只选择一部分重要的记录下来。因此上面的拼接结果(x,h)乘“信息权重矩阵”W_i得到“信息门控信号”Z_i,再让Z_iZ逐元素相乘,不就对输入进行了一个选择操作吗? 

        所以上面图中的Ꙩ 即代表哈达玛积,也就是矩阵逐元素相乘,即相同位置元素算乘积。当然了,这个运算肯定两个乘数矩阵和积矩阵都是相同形状的啦。

        得到了选择后的输入之后。我们又遇到了一个类似的操作,还是状态返回来,和现在的输入要共同决定输出。你看,LSTM比起RNN,除了有一些门控选择之外,再就是套了两层状态反馈呢。

        不过这一层,反馈的状态就是单元状态c_t了。反馈回来的c_{t-1}和上面选择之后的状态共同决定新的状态c_t。不过同样地, 由拼接结果(x,h)乘“遗忘权重矩阵”W_f得到“信息门控信号”Z_f,再让Z_fc_{t-1}逐元素相乘,这就是对上一时刻的单元状态c_{t-1}进行了一个选择操作,选择一部分去忘掉。
        这样得到的单元状态c_t经过一个tanh激活函数之后,就应该得到了隐状态h_t了吧?不过这样还不行,就还需要由之前的拼接结果(x,h)乘“输出权重矩阵”W_o得到“输出门控信号”​​​​​​​Z_o,再让Z_o和激活后的c_t逐元素相乘,才得到隐状态h_t。也就是说,从单元状态c_t到隐状态h_t,还需要进行一步选择。

        这样就得到了隐状态h_t,再经过一个权重矩阵V自然就得到输出Y_t啦。

        附上公式,和图解结合食用更加舒适。

        首先记拼接结果为(x,h)

(x,h)=\binom{X_t}{h_{t-1}}

        然后得到处理后的输入信号和多个门控信号。

Z=tanh(W\cdot (x,h))\\ Z_i=\sigma(W_i\cdot (x,h))\\ Z_f=\sigma(W_f\cdot (x,h))\\ Z_o=\sigma(W_o\cdot (x,h))\\

        剩下就是这几个门控信号是如何发挥作用的了。

\\c_t=Z_f\odot c_{t-1}+Z_i\odot Z\\ h_t=Z_o\odot tanh(c_t)\\ Y_t=\sigma(Vh_t)

        总结一下。

        一方面,LSTM中嵌套了两层状态反馈,分别是隐状态h_t和单元状态c_t。和RNN相同地,当前上一个时刻的状态反馈回来,和当前时刻的输入共同决定当前时刻的输出。

        另一方面,LSTM中有三个阶段的选择,它们的实现方式都是由反馈的隐状态h_{t-1},和输入X_t的拼接结果即(x,h),再分别乘以三个权重矩阵W_iW_fW_o,再进行激活,得到三个选择信号Z_iZ_fZ_o。选择信号在对应的位置和响应信号进行逐元素相乘(即哈达玛积),就达到了选择的效果。

        好了,搞懂了LSTM的基本原理之后,我们就很容易明白为什么需要它了。在长序列输入的情况下,普通的RNN只有一种简单的记忆方式,不能对记忆的内容进行选择。而我们的长序列可能有些内容重要,有些内容不重要需要忘记,而LSTM经过上面描述的三个选择,就能够实现这样的需求。这也就是为什么它被称作“长短期记忆网络”(Long Short-Term Memory)吧。

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

        

 

 

 

        

 

 

 

 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值