动手学深度学习9.2 LSTM

273 阅读3分钟

这是我参与2022首次更文挑战的第11天,活动详情查看:2022首次更文挑战

本系列更多文章可以看这里:草履虫都能看懂的 白话解析《动手学深度学习》专栏(juejin.cn)

还在更新中…………


在上一节里边,我们已经说到了如何使用门控单元对某些无关内容的忽略。也提到了其实首先是出现了LSTM,后来才出现了GRU,但是因为GRU更简单。所以现在很多都会先讲GRU。上一节已经讲完了GRU那这一节就来讲一讲LSTM。

温馨提示,本节内容会结合GRU进行讲解,所以请务必熟读上一篇的内容。动手学深度学习9.1 GRU - 掘金 (juejin.cn)


Long Short-Term Memory | MIT Press Journals & Magazine | IEEE Xplore

长短期存储器(long short-term memory, LSTM) 它有许多与门控循环单元一样的属性.

但是长短期记忆网络的设计比门控循环单元稍微复杂一些,并且它的诞生比GRU早了二十来年。

因为难易程度的问题,现在很多课程的讲课顺序都是先说GRU再说LSTM。

1 输入门、忘记门、输出门

image.png

和GRU不同,GRU有两个门,LSTM有三个门,它分别是输入门It\mathbf{I}_t、忘记门Ft\mathbf{F}_t和输出门Ot\mathbf{O}_t

假设有 hh 个隐藏单元,批量大小为 nn,输入数为 dd

公式如下:

It=σ(XtWxi+Ht1Whi+bi),Ft=σ(XtWxf+Ht1Whf+bf),Ot=σ(XtWxo+Ht1Who+bo),\begin{aligned} &\mathbf{I}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xi} + \mathbf{H}_{t-1} \mathbf{W}_{hi} + \mathbf{b}_i),\\ &\mathbf{F}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xf} + \mathbf{H}_{t-1} \mathbf{W}_{hf} + \mathbf{b}_f),\\ &\mathbf{O}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xo} + \mathbf{H}_{t-1} \mathbf{W}_{ho} + \mathbf{b}_o), \end{aligned}
  • 其中输入 XtRn×d\mathbf{X}_t \in \mathbb{R}^{n \times d}
  • 前一时间步的隐藏状态为 Ht1Rn×h\mathbf{H}_{t-1} \in \mathbb{R}^{n \times h}
  • tt时间步时, 输入门ItRn×h\mathbf{I}_t \in \mathbb{R}^{n \times h},遗忘门FtRn×h\mathbf{F}_t \in \mathbb{R}^{n \times h},输出门OtRn×h\mathbf{O}_t \in \mathbb{R}^{n \times h}
  • Wxi,Wxf,WxoRd×h\mathbf{W}_{xi}, \mathbf{W}_{xf}, \mathbf{W}_{xo} \in \mathbb{R}^{d \times h}Whi,Whf,WhoRh×h\mathbf{W}_{hi}, \mathbf{W}_{hf}, \mathbf{W}_{ho} \in \mathbb{R}^{h \times h} 是权重参数
  • bi,bf,boR1×h\mathbf{b}_i, \mathbf{b}_f, \mathbf{b}_o \in \mathbb{R}^{1 \times h} 是偏置参数。
  • 激活函数依旧使用sigmoid

当然也可以合并起来写:

It=σ([Xt,Ht1]Wi+bi),Ft=σ([Xt,Ht1]Wf+bf),Ot=σ([Xt,Ht1]Wo+bo),\begin{aligned} &\mathbf{I}_t = \sigma([\mathbf{X}_t ,\mathbf{H}_{t-1}] \mathbf{W}_{i} + \mathbf{b}_i),\\ &\mathbf{F}_t = \sigma([\mathbf{X}_t ,\mathbf{H}_{t-1}] \mathbf{W}_{f} + \mathbf{b}_f),\\ &\mathbf{O}_t = \sigma([\mathbf{X}_t ,\mathbf{H}_{t-1}] \mathbf{W}_{o} + \mathbf{b}_o), \end{aligned}

再次强调,不懂这里为什么能合并的建议回去补RNN的知识。

2 候选记忆单元

长短期记忆网络引入了存储记忆单元(memory cell),或简称为单元(cell)。有些文献认为存储单元是隐藏状态的一种特殊类型。嗯。

image.png

然后是候选记忆单元C~t\tilde{\mathbf{C}}_t 的计算。LSTM中候选记忆单元是直接进行计算的。这一点和GRU不太相同。GRU这一步是结合遗忘门来进行候选隐藏状态的计算。

候选记忆单元就是将本步的输入和上一步的隐状态进行计算。

候选记忆单元公式如下:

C~t=tanh(XtWxc+Ht1Whc+bc)\tilde{\mathbf{C}}_t = \text{tanh}(\mathbf{X}_t \mathbf{W}_{xc} + \mathbf{H}_{t-1} \mathbf{W}_{hc} + \mathbf{b}_c)
  • WxcRd×h\mathbf{W}_{xc} \in \mathbb{R}^{d \times h}WhcRh×h\mathbf{W}_{hc} \in \mathbb{R}^{h \times h} 是权重参数。
  • bcR1×h\mathbf{b}_c \in \mathbb{R}^{1 \times h} 是偏置参数。
  • 候选记忆单元使用的激活函数是tanh。

3 记忆单元

先来回顾一下在GRU当中,我们使用重置门来决定是否忽略上一步的隐藏状态。使用更新门来计算新的隐藏状态。而更新门的作用是决定使用多少本步的候选隐藏状态和上一步的隐藏状态。

类似地,在长短期记忆网络中,也有两个门用于这样的目的:输入门 It\mathbf{I}_t 控制采用多少来自 C~t\tilde{\mathbf{C}}_t 的新数据,而遗忘门 Ft\mathbf{F}_t 控制保留了多少旧记忆单元 Ct1Rn×h\mathbf{C}_{t-1} \in \mathbb{R}^{n \times h} 的内容。最后计算结果存储在记忆单元Ct\mathbf{C}_t 中。

image.png

公式如下:

Ct=FtCt1+ItC~t\mathbf{C}_t = \mathbf{F}_t \odot \mathbf{C}_{t-1} + \mathbf{I}_t \odot \tilde{\mathbf{C}}_t

因为输入门、忘记门他们都使用的sigmoid作为激活函数。因此它们两个的值都是趋近于0或者近于1的。

  • 如果遗忘门为 11 且输入门为 00,则过去的记忆单元 Ct1\mathbf{C}_{t-1} 将随时间被保存并传递到当前时间步。
  • 如果遗忘门为 00 且输入门为 11,则过去的记忆单元 Ct1\mathbf{C}_{t-1} 被丢弃掉,仅使用当前的候选记忆单元C~t\tilde{\mathbf{C}}_t

引入这种设计是为了缓解梯度消失问题,并更好地捕获序列中的长距离依赖关系。

4 隐藏单元

输入门遗忘门都介绍了,输出门的作用就在 隐藏单元Ht\mathbf{H}_t 计算这一步。

image.png

公式如下:

Ht=Ottanh(Ct)\mathbf{H}_t = \mathbf{O}_t \odot \tanh(\mathbf{C}_t)
  • 输出门接近 11,我们就能够把我们的记忆单元信息传递下去。
  • 输出门接近 00,我们只保留存储单元内的所有信息。