LSTM

1,352 阅读4分钟

引用 zhuanlan.zhihu.com/p/34203833

LSTM

现在,我们把lstm一步一步拆开来看,深入了解那些细节:

LSTM的第一步是决定我们需要从cell状态中扔掉什么样的信息。这个决策由一个称为“遗忘门(forget gate)”的sigmoid层决定。输入h_{t-1}x_t ,输出一个0和1之间的数。1代表“完全保留这个值”,而0代表“完全扔掉这个值”。

比如对于一个基于上文预测最后一个词的语言模型。cell的状态可能包含当前主题的信息,来预测下一个准确的词。而当我们得到一个新的语言主题的时候,我们会想要遗忘旧的主题的记忆,应用新的语言主题的信息来预测准确的词。

第二步是决定我们需要在cell state里存储什么样的信息。这个问题有两个部分。第一,一个sigmoid层调用“输入门(input gate)”以决定哪些数据是需要更新的。然后,一个tanh层为新的候选值创建一个向量 \tilde{C}_t ,这些值能够加入state中。下一步,我们要将这两个部分合并以创建对state的更新。

比如还是语言模型,可以表示为想要把新的语言主题的信息加入到cell state中,以替代我们要遗忘的旧的记忆信息。

在决定需要遗忘和需要加入的记忆之后,就可以更新旧的cell stateC_{t-1} 到新的cell state C_t 了。在这一步,我们把旧的state C_{t-1}f_t 相乘,遗忘我们先前决定遗忘的东西,然后我们加上i_t * \tilde{C}_t,这可以理解为新的记忆信息,当然,这里体现了对状态值的更新度是有限制的,我们可以把i_t 当成一个权重。

最后,我们需要决定要输出的东西。这个输出基于我们的cell state,但会是一个过滤后的值。首先,我们运行一个sigmoid层,这个也就是输出门(output gate),以决定cell state中的那个部分是我们将要输出的。然后我们把cell state放进tanh(将数值压到-1和1之间),最后将它与sigmoid门的输出相乘,这样我们就只输出了我们想要的部分了。

推导过程

zhuanlan.zhihu.com/p/30465140

百面机器学习

问题一:LSTM是如何实现长短期记忆的功能的?

与传统神经网络相比,LSTM虽然仍然是基于x_th_{t-1}来计算h_t,只不过对内部的结构进行了更加精心的设计,加入了输入门i_t、遗忘门f_t、以及输出门o_t和一个内部记忆单元。输入门控制当前计算的新状态以多大程度更新到记忆单元中。遗忘门控制前一步记忆单元中的信息有多大程度被遗忘掉。输出门控制当前的输出有多大程度取决于当前的记忆单元。

经典LSTM 公式为:

i_t=f(W_ix_t+U_ih_{t-1}+b_i)
f_t=f(W_fx_t+U_fh_{t-1}+b_f)
o_t=f(W_ox_t+U_oh_{t-1}+b_o)
c_{t-new} = Tanh(W_cx_t+U_ch_{t-1})
c_t=f_t\otimes{c_{t-1}}+i_t\otimes{c_{t-new}}
h_t=o_t\otimes{Tanh(c_t)}

在一个训练好的网络中,当输入的序列中没有重要的信息时,LSTM的遗忘门的值接近于1.输入门的值接近于0,此时过去的记忆会被保存,从而实现了长期记忆的功能。当输入的序列中出现重要信息时,LSTM会把它存入记忆中,此时输入门的值会接近1.遗忘门接近于0,这样旧记忆会被遗忘,新的重要信息被记忆。经过这样设计,整个网络更容易学习到序列之间的长期以来。

问题二:LSTM里面的各个模块分别使用什么激活函数?可以使用别的激活函数吗?

  1. 遗忘门和输入门、输出门使用Sigmod函数作为激活函数

c_t=f_t\otimes{c_{t-1}}+i_t\otimes{c_{t-new}}

i_t“输入门”啦,取值0.0~1.0。

为了实现这个取值范围,我们很容易想到使用sigmoid函数作为输入门的激活函数,毕竟sigmoid的输出范围一定是在0.0到1.0之间嘛。符合门控的定义。

  1. 生成候选记忆时,使用Tanh作为激活函数

c_{t-new} = Tanh(W_cx_t+U_ch_{t-1})

之所以此处使用Tanh函数因为,输出实在-1~1之间。这与大多数场景下特征分布是0中心吻合。 此外,Tanh函数在输入为0附件相比Sigmod有更大的梯度,模型收敛更快

这两个函数都是饱和函数,如果使用ReLU等非饱和函数就无法起到门控作用。

但是激活函数的选择也不是一成不变的。例如在原始LSTM中,使用了Sigmod函数的变种。h(x)=2sig(x)-1,g(x)=4sig(x)-2,范围分别是[-1,1]和[2,2] 因为在原始LSTM中没有遗忘门,输入经过输入门直接与记忆相加,所以输入门控g(x)是以0为中心的。

此外,在一些计算能力有限的设备上,使用Sigmod函数求指数需要一定计算能力,所以会用0/1门。