一文理解LSTM

1,622 阅读6分钟

这是我参与8月更文挑战的第20天,活动详情查看:8月更文挑战

LSTM

前置知识——RNN

RNN即Recurrent Neural Network,循环神经网络。RNN是一种用于处理序列数据(像时序数据)的神经网络。

我们先来了解下普通的RNN。

不同于普通的神经网络,RNN多了一个Memory,即RNN会把隐藏层的值存储在内存中。这样使得RNN可以记忆上一层隐藏层的输出,具有记忆力。上图理解:

RNN.jpg

RNN的输入有两个值,一个是当前状态的时序数据 xx,一个是上一隐藏层的值 hh。也就是说,在下一状态或下一时刻输入 x2x^2 时,也要把上一时刻的Memory里的值也一起考虑,同时隐藏层输出新的值 h2h^2 会替换原先的 h1h^1 ,即同时更新Memory里的值。最后的输出 yy 是要通过一层softmax得到概率值。

简单来说,对于有序的数据集,RNN把之前分析数据 x1x^1 的结果存入Memory,分析 x2x^2 的时候调用之前 x1x^1 的结果一起分析。即累积记忆,一起分析。

值得注意的是,因为RNN处理的是时序的数据,为了更直观的展示时序之间的关系,这张图里,把不同时刻的RNN放在一起,所以图里不是有三个RNN,而是只有一个RNN。

除此之外,RNN还有很多的扩展,比如Elman Network和Jordan Network,两者的区别在于Memory存储的值不同。

  • Elman Network:Memory里存储的是隐藏层的值
  • Jordan Network:Memory里存储的最终的输出值

Elman_Network.jpg

此外,还有双向RNN(Bidirectional RNN)。双向的意思即同时正向训练RNN和逆向训练RNN,如下图

rnn-bi.png

什么是LSTM

LSTM全称Long Short-Term Memory,即长短期记忆网络。是RNN的进阶版。

前面我们了解了普通RNN,其对Memory的读取更新是没有限制的,也就是每一时刻神经网络输入新的数据时,Memory都会被更新。

LSTM则是在RNN的基础上,增加了三个 Gate,用于控制全局的Memory。

三个gate即:

  • input gate: 输入控制,只有打开才可以把值写进Memory里。
  • output gate: 输出控制, 只有打开才会可以从Memory里读取值。
  • forget gate: 遗忘控制,决定是要忘记清空,还是保存Memory里的值。

LSTM.png

可以看到,LSTM的一个Memory cell有4个input 、1个output。output的就是Memory里的值,input的就是想要写进Memory中的值(其实就是隐藏层的输出值)以及控制三个gate的门控信号。

结构

下面深入的理解下LSTM的结构,看下图

lstm2.png

从表达式的角度来看Memory cell,zz 是想存储进Memory的值,ziz_i 控制input gate,zoz_o 控制output gate,zfz_f 控制forget gate。函数 ff 表示激活函数,一般选用sigmoid,可以控制在 [0,1][0,1],表示gate打开的程度。

对于input gate,容易看到,f(zi)=1f(z_i)=1 表示门控打开,g(z)g(z) 可以写进去;反之 g(z)g(z) 就变为0。

对于forget gate,是 f(zf)=1f(z_f)=1,表示保存当前Memory中的值 cc。而 f(zf)=0f(z_f)=0,才是表示要遗忘当前Memory中的值 cc,也就是清空当前值 cccc' 就是Memory更新后的值。c=g(z)f(zi)+cf(zf)c'=g(z)f(z_i)+cf(z_f) 这个式子应该很好理解。

最后的输出 aa 就由output gate f(zo)f(z_o) 控制,显然 f(zo)=1f(z_o)=1 时才可以读取Memory的值 h(c)h(c')。这里也是加了个激活函数 hh

明白了Memory cell的具体结构,下面我们来看完整的LSTM的结构。下图是同一个LSTM在两个相邻时刻的情况。

LSTM_example.png

我们观察上图,tt 时刻输入样本数据 xtx^t,它会先乘上一个矩阵,转换成 z,zf,zi,zoz,z^f,z^i,z_o 这4个向量(因为我们从整个时序来考虑,所以是向量,实际上每个时刻输入的门控信号是一个值)。这4个向量就是控制信号。

此外在输入端还包括Memory cell里存储的值 ct1c^{t-1}。上一时刻隐藏层的输出值 ht1h^{t-1}。注意,这里有些朋友可能有些疑惑,输入为什么要考虑 ct1c^{t-1}ht1h^{t-1} 两个值呢?RNN只需要考虑 ht1h^{t-1}啊。其实,是因为有outgate,它决定了Memory是否可以读取,就算Memory cell的值更新了,output gate不给读取,那此时隐藏层输出的值肯定和cell的值不一样。

优势

LSTM有什么优势呢?

普通RNN存在的问题:

  • 梯度消失:误差反向传递时,每一步会乘以权重w,w<1,乘完越来越小
  • 梯度爆炸:w>1,乘完越来越大

LSTM的优势在于其可以解决梯度消失。不过LSTM不能解决梯度爆炸。

为什么LSTM可以解决梯度消失?

我们先回顾一下LSTM,在面对Memory cell时,cell里的值更新式子: c=g(z)f(zi)+cf(zf)c'=g(z)f(z_i)+cf(z_f)

从式子中我们发现,LSTM在每个时间点,是把原来Memory cell的值和input值相加起来,而RNN是每个时间点的Memory都会被更新覆盖,这样有个很大的区别就是,LSTM的Memory中的值是和过去cell里的值以及输入有关的,除非被forget gate控制遗忘掉原来的值。

也就是说,RNN里权值weight对Memory的影响在每个时间点都会被清除,而LSTM里面,除非forget gate控制遗忘掉原来Memory的值,否则权值weight对Memory的影响不会被清除,而是会一直累加保留,因此它不会有梯度消失的问题。

此外,因为LSTM可以解决梯度消失的问题,所以在训练的时候我们可以把学习率设置小一点。

小结

现在我们回过头来理解下Long short-term Memory,其实这个名字的意思是比较长的短期记忆,因为普通RNN的时候在每次新的input进来时会清除Memory,是很short-term的,而LSTM增加了forget gate,是把这个short-term变长一些。

普通RNN因为权值对Memory的影响每次都会被清除,所以容易出现梯度消失或梯度爆炸,而LSTM可以解决RNN梯度消失的问题。但不能解决梯度爆炸。因为LSTM里Memory中的值是和过去cell里的值以及输入有关的,除非被forget gate控制遗忘掉原来的值,否则权值weight对Memory的影响不会被清除。

但是LSTM也有个明显的缺点,就是参数太多,训练难度加大,且容易过拟合。所以有人提出了一种参数更少但效果和LSTM相当的RNN网络——GRU。简单来说,GRU会把input gate和forget gate连接起来,input gate打开,forget gate会控制遗忘掉Memory的值,也就是说你要存储新的值时要把旧的值先遗忘掉。

参考

  1. 台大李宏毅教授的机器学习课程 B站视频