持续创作,加速成长!这是我参与「掘金日新计划 · 10 月更文挑战」的第16天,点击查看活动详情
LSTM(Long Short-term Memory),如字面意思一样,被称为长短记忆网络,比较擅长处理序列数据,是人工神经网络中较为流行的模块,在LSTM中主要包含输入门、输出门、遗忘门,激活函数、全连接层等主要结构。下面通过一张图直观的了解LSTM的主要结构。
既然提到了LSTM,那么下面的公式我们需要对其有一定的了解,
其中其中,ht为t时刻的隐藏状态,Ct为t时刻的单元状态,xt为t时刻的输入,h(t-1)为t-1时刻层的隐藏状态或o时刻的初始隐藏状态,it、ft、gt、Ot分别为输入门、忘记门、单元门、输出门。sigma是sigmoid函数。
LSTM模块中,当前时刻的输入不仅与当前时间有关,而且还依赖于上一时刻的输出、输入以及隐含层的输出,所以从此处来看,LSTM模块的输入等同于普通的全连接层的四倍,这也导致计算量大大增加。
LSTM中提到了许多的门结构,在数学层面的理解就是将输入的计算值经过sigmoid函数后得到一个概率值,然后由这个生成的概率决定当前输入值的一个强弱程度。然后这个概率会和输入进行矩阵乘法然后才得到经过门处理后的实际值。
目前pytorch对LSTM已经进行了集成,但是具体如何实现的,还要在下面的代码中进行详细的理解,也就是对上述提到的公式进行计算的一个过程
import torch
import numpy as np
from torch import nn
def lstm(inputs, state, params):
[W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,b_c, W_hq, b_q] = params
(H, C) = state
outputs = []
for X in inputs:
# 输入门
I = torch.sigmoid(torch.matmul(X, W_xi) + torch.matmul(H, W_hi) + b_i)
# 遗忘门
F = torch.sigmoid(torch.matmul(X, W_xf) + torch.matmul(H, W_hf) + b_f)
# 输出门
O = torch.sigmoid(torch.matmul(X, W_xo) + torch.matmul(H, W_ho) + b_o)
# 候选记忆细胞
C_tilda = torch.tanh(torch.matmul(X, W_xc) + torch.matmul(H, W_hc) + b_c)
# 计算公式中的c
C = F * C + I * C_tilda
# 公式中的h
H = O * C.tanh()
# 输出层
Y = torch.matmul(H, W_hq) + b_q
outputs.append(Y)
return outputs
对LSTM进行测试,可以得到下面的输出。LSTM在自然语言处理时应用场景较多,在其它的地方也有应用,其思想逻辑非常的严谨,适合进行思虑的借鉴。