白话神经网络-长短期记忆网络LSTM

306 阅读9分钟

一 背景

既然有了RNN,为何又需要LSTM呢?

循环神经网络RNN的网络结构使得它可以使用历史信息来帮助当前的决策。例如使用之前出现的单词来加强对当前文字的理解。可以解决传统神经网络模型不能充分利用上下文的信息增益的问题,但是同时,这也带来了更大的技术挑战 -- 长期依赖问题(long-term dependencies)

  • 在某些场景中,模型需要短期内的信息来进行增强,比如模型尝试去预测短语“天空非常蓝”,并不需要记忆这个短语之前更长的上下文,因为相关的信息与待预测的词的空间非常接近。

  • 在某些场景中,模型需要长期的信息来进行增强,比如当模型尝试去预测“xxx市建设了大量的化工厂,水污染非常严重,已经严重的影响了当地的生态环境,同时空气污染十分严重 ....... 天空都是灰色的”的最后一个单词灰色的时候,仅仅依赖短期的信息无法进行准确的预测,模型需要记忆更长的信息上下文来进行判断。

在这种长期上下文依赖的情况下,RNN的网路结构会丧失学习到如此长距离信息的能力

那么如何解决呢?

解决方法我们可以联想我们的记忆能力,其实我们的记忆是有选择的,有些重要的事情印象特别深刻,有些非重要的事情就已经模糊,这样才可以在有限的脑力的情况下记忆尽量重要的事情,总结下来就是记忆+忘记,只有忘记才能记忆

将这个思想运用到模型里面,就引出了本章要讲述的内容LSTM(long short term memory,LSTM)。

二 LSTM

长短时记忆网络(Long Short Term Memory Network)LSTM,是一种改进之后的循环神经网络,通过门控机制有选择的记忆重要的内容,可以解决RNN无法处理长距离的依赖的问题,目前比较流行。LSTM结构(图右)和普通RNN的主要输入输出区别如下所示。

相比RNN只有一个传输状态hth_t,LSTM有两个传输状态:

  • 一个状态是ctc_t(cell state)
  • 一个状态是hth_t(hidden state)。

其中对于传输状态的ctc_t改变很慢,而hth_t则在不同时刻下往往会有很大的区别。

LSTM是有三个gate,当外界某个neural的output想要被写到memory cell里面的时候,必须通过一个input Gate,那这个input Gate要被打开的时候,你才能把值写到memory cell里面去,如果把这个关起来的话,就没有办法把值写进去。至于input Gate是打开还是关起来,这个是neural network自己学的(它可以自己学说,它什么时候要把input Gate打开,什么时候要把input Gate关起来)。那么输出的地方也有一个output Gate,这个output Gate会决定说,外界其他的neural可不可以从这个memory里面把值读出来(把output Gate关闭的时候是没有办法把值读出来,output Gate打开的时候,才可以把值读出来)。那跟input Gate一样,output Gate什么时候打开什么时候关闭,network是自己学到的。那第三个gate叫做forget Gate,forget Gate决定说:什么时候memory cell要把过去记得的东西忘掉。这个forget Gate什么时候会把存在memory的值忘掉,什么时候会把存在memory里面的值继续保留下来),这也是network自己学到的。

同时相对于RNN的neuron,LSTM的neuron要做的事情其实就是将原来简单的neuron换成LSTM。input[ht,xt]input[h_t, x_t ]会乘以不同的weight当做LSTM不同的输入(假设我们这个hidden layer只有两个neuron,但实际上是有很多的neuron)。

  • input[ht,xt]input[h_t, x_t ]乘以不同的weight当做forget gate。
  • input[ht,xt]input[h_t, x_t ]乘以不同的weight操控input gate。
  • input[ht,xt]input[h_t, x_t ]乘以不同的weight会去操控output gate。
  • input[ht,xt]input[h_t, x_t ]乘以不同的weight当做底下的input。

第二个LSTM也是一样的。所以LSTM是有四个input(一个是想要被存在memory cell的值(但它不一定存的进去)还有操控input Gate的讯号,操控output Gate的讯号,操控forget Gate的讯号,有着四个input)跟一个output,所以LSTM需要的参数量(假设你现在用的neural的数目跟LSTM是一样的)是一般neural network的四倍。这个跟Recurrent Neural Network 的关系是什么,这个看起来好像不一样,所以我们要画另外一张图来表示。

2.1 Sigmoid函数

Sigmoid函数是深度学习神经网络常用的激活函数,他的特点值在0和1之间。用于LSTM的门槛,那么离0越近表示忘记,离1越近表示保留。

2.2 Forget Gate

遗忘门主要是对上一个时刻传进来的输入进行选择性忘记。简单来说就是会 “忘记不重要的,记住重要的”,是针对长期依赖的不重要的信息的遗忘。

算法流程

  1. 时刻t的输入XtX_t与上一个时刻t1t-1的隐藏层状态ht1h_{t-1}进行向量拼接。
  2. 拼接后的向量经过SigmodSigmod函数(门控),进行忘记门控的计算。
  3. 该门控后续会和上个时刻的状态Ct1C_{t-1}进行权重控制(矩阵星乘),决定哪些信息信息需要忘记,哪些信息需要保留。

2.3 Input Gate

输入门主要是将这个时刻的输入有选择性地进行记忆。通过输入门控,对于当前时刻的输入信息进行控制,更多的记录下重要的信息,对于不重要的信息则少记一些。

算法流程

  1. 时刻t的输入xtx_t与上一个时刻t1t-1的隐藏层状态ht1h_{t-1}进行向量拼接。
  2. 拼接后的向量分别进入两个全连接网络
    1. 门控网络:单层全连接,通过激活函数SigmoidSigmoid计算门控值得到iti_t
    2. 输入网络:单层全连接,通过激活函数tanhtanh计算得到值candidate。
  3. 经过SigmodSigmod函数(门控),进行忘记门控的计算。
  4. 该门控会和当前时刻的状态CondidateCondidate进行权重控制(矩阵星乘),首先,我们将之前的隐藏状态ht1h_{t-1}和输入xtx_t,经过全连接网络使用激活函数sigmoid。它通过将值转换为0到1之间来决定哪些值将被更新。0表示不重要,1表示重要。同时还将之前的隐藏状态ht1h_{t-1}和输入xtx_t,经过全连接网络传使用激活函数tanh,以压缩-1和1之间的值,以帮助调节网络。然后将tanh输出和sigmoid输出相乘。sigmoid输出将决定哪些信息是重要的,而不是tanh输出。生成向量 itcii_t * c_i,如下图

2.4 Cell State

下面我们看下当前时刻的ctc_t是如何计算的,状态ctc_t作为LSTM循环体中传输的状态之一,它的状态的改变对于整个LSTM的中长期记忆与遗忘起到关键的作用。

算法流程

  1. 首先,时刻t的遗忘门值ftf_t与上一个时刻t1t-1的状态ct1c_{t-1}进行向量星乘,生成ftct1f_t * c_{t-1}
  2. 然后,将上面输入门计算的结果 itcti_t * c_t与上一步生成的结果ftct1f_t * c_{t-1}进行矩阵加法。
  3. 最后,更新当前时刻t的状态ct=itct+ftct1c_t = i_t * c_t + f_t * c_{t-1},同时作为下一个时刻的状态输入。

2.5 Output Gate

最后介绍输出门,输出门有两个作用,一个是LSTM循环体的输出(也可以理解为RNN循环体的y值),另外一个就是下一个时刻的输入(同状态ctc_t一样,hth_tcthtc_t和h_t作为LSTM循环体的双输入)。下面我们看下当前时刻的oto_t是如何计算的,输出门oto_t作为LSTM循环体中传输的状态之一,它的状态的改变对于整个LSTM的中短期记忆与遗忘起到关键的作用。

算法流程:

  1. 时刻t的输入xtx_t与上一个时刻t1t-1的隐藏层状态ht1h_{t-1}进行向量拼接。

  2. 拼接后的向量分别进入单层全连接,通过激活函数SigmoidSigmoid计算门控值得到oto_t

  3. 当前时刻的状态ctc_t通过tanh函数后与当前的门控oto_t进行矩阵乘法,生成最后的输出值hth_t

2.6 多层LSTM

在实践中经常会使用多层LSTM,所以这里也简单介绍下,主要看下面的图就好。

三 伪码

通过上面的讲解与分析,相信大家对LSTM应该已经有了全面的理解,下面附上代码,大家可以通过代码再进行下加深理解。

四 参考链接

五 番外篇

个人介绍:杜宝坤,隐私计算行业从业者,从0到1带领团队构建了京东的联邦学习解决方案9N-FL,同时主导了联邦学习框架与联邦开门红业务。 框架层面:实现了电商营销领域支持超大规模的工业化联邦学习解决方案,支持超大规模样本PSI隐私对齐、安全的树模型与神经网络模型等众多模型支持。 业务层面:实现了业务侧的开门红业务落地,开创了新的业务增长点,产生了显著的业务经济效益。 个人比较喜欢学习新东西,乐于钻研技术。基于从全链路思考与决策技术规划的考量,研究的领域比较多,从工程架构、大数据到机器学习算法与算法框架均有涉及。欢迎喜欢技术的同学和我交流,邮箱:baokun06@163.com

六 公众号导读

自己撰写博客已经很长一段时间了,由于个人涉猎的技术领域比较多,所以对高并发与高性能、分布式、传统机器学习算法与框架、深度学习算法与框架、密码安全、隐私计算、联邦学习、大数据等都有涉及。主导过多个大项目包括零售的联邦学习,社区做过多次分享,另外自己坚持写原创博客,多篇文章有过万的阅读。公众号秃顶的码农大家可以按照话题进行连续阅读,里面的章节我都做过按照学习路线的排序,话题就是公众号里面下面的标红的这个,大家点击去就可以看本话题下的多篇文章了,比如下图(话题分为:一、隐私计算 二、联邦学习 三、机器学习框架 四、机器学习算法 五、高性能计算 六、广告算法 七、程序人生),知乎号同理关注专利即可。

一切有为法,如梦幻泡影,如露亦如电,应作如是观。