torch.nn.RNN参数解释

299 阅读2分钟

torch.nn.RNN

1.数学推导

结构类似下图,是循环展开的一种结构

graph LR
	X -->|U| H;
	H -->|V| O;
	H -->|W| H;
H(t)=ϕ(UX(t)+WH(t1)+b0)H^{(t)} = \phi(UX^{(t)} + WH^{(t-1)} + b_0)
O(t)=ϕ(VH(t)+b1)O^{(t)} = \phi(VH^{(t)} + b_1)
graph LR
	H_prev --> mul0(Mul);
	W_h --> mul0;
	
	X --> mul1;
	W_x --> mul1(Mul);
	
	mul0 --> Add;
	mul1 --> Add;
	b --> Add;
	
	Add --> phi0(Phi);
	phi0(Phi) --> H_next;
	
	H_next -.-> H_prev;
	H_next --> mul2(Mul);
	V --> mul2(Mul);
	
	mul2 --> phi1(Phi) --> O

设第tt 轮次损失函数为 L(t)(O(t),y(t))L^{(t)}(O^{(t)}, y^{(t)}) , 则 L=t=1nL(t)L = \sum\limits_{t=1}^{n} L^{(t)}

LV=t=1nL(t)O(t)O(t)V\frac{\partial L}{\partial V} = \sum\limits_{t=1}^{n} \frac{\partial L^{(t)}}{\partial O^{(t)}} \frac{O^{(t)}}{\partial V}
L(t)W=i=1tL(t)O(t)O(t)H(t)(j=i+1tH(j)H(j1))H(t)W\frac{\partial L^{(t)}}{\partial W} = \sum\limits_{i=1}^{t} \frac{\partial L^{(t)}}{ \partial O^{(t)}} \frac{\partial O^{(t)}}{\partial H^{(t)}} (\prod\limits_{j=i+1}^{t}\frac{\partial H^{(j)}}{\partial H^{(j-1)}})\frac{\partial H^{(t)}}{\partial W}
L(t)U=i=1tL(t)O(t)O(t)H(t)(j=i+1tH(j)H(j1))H(t)U\frac{\partial L^{(t)}}{\partial U} = \sum\limits_{i=1}^{t} \frac{\partial L^{(t)}}{ \partial O^{(t)}} \frac{\partial O^{(t)}}{\partial H^{(t)}} (\prod\limits_{j=i+1}^{t}\frac{\partial H^{(j)}}{\partial H^{(j-1)}})\frac{\partial H^{(t)}}{\partial U}

也就是累加到各个地方的权重路径, 在实现的时候往往会增加一项H0H_0(全部为0的元素矩阵) 便于计算,其中

H(t)H(t1)=ϕ(t)W\frac{\partial H^{(t)}}{\partial H^{(t-1)}} = \phi'^{(t)} W

连乘的地方十分容易导致梯度爆炸和梯度消失, 所以一般采用Truncated BPTTTruncated \space BPTT 方法, 也就是在反向传播的时候选定长度截断

2.torch.nn.RNN 参数

  1. input_size: 每个输入项的特征数量。
  2. hidden_size: 每个隐藏层的特征数量。
  3. num_layers (默认值:1): RNN的层数。多层RNN是通过堆叠单层RNN来实现的。
  4. nonlinearity (默认值:'tanh'): 可以选择的激活函数。可以是 'tanh' 或 'relu'。
  5. bias (默认值:True): 如果为 True,则在 RNN 层中加入偏置。
  6. batch_first (默认值:False): 如果为 True,则输入和输出张量的形状为 (batch, seq, feature)。默认情况下,形状为 (seq, batch, feature)
  7. dropout (默认值:0): 如果非零,则引入一个 Dropout 层在每个RNN层之间,但最后一个RNN层除外。dropout 的值是丢弃概率。
  8. bidirectional (默认值:False): 如果为 True,则将构建一个双向RNN。

3. 手动实现

一般来说,RNN 的输入矩阵的形状为 (sequence_length, batch_size, input_size) 实际上常见的矩阵形状限制为sequence_length > batch_size > channel_size 本文采用的是 (sequence_length, batch_size, input_size) 的形式

import torch
import torch.nn as nn


class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNN, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.i2h = nn.Linear(input_size, hidden_size)
        self.h2h = nn.Linear(hidden_size, hidden_size)
        self.h2o = nn.Linear(hidden_size, output_size)
        self.activate = nn.Tanh()
        self.h2o_list = list()
        self.h2h_list = list()

    def forward(self, x):
        hp = self.i2h(x)
        self.h2h_list.append(torch.zeros(self.hidden_size,))
        # for broadcast the tensor

        # hp:size (seq_len, batch_size, input_size)
        # hp[i]:size (batch_size, input_size)

        seq_len = hp.size(0)
        for i in range(seq_len):
            h2h = self.h2h_list[-1].squeeze(0)
            h2h = self.h2h(h2h)
            hc = self.activate(hp[i] + h2h)
            self.h2h_list.append(hc.unsqueeze(0))

            oc = self.h2o(hp[i])
            oc = self.activate(oc)
            self.h2o_list.append(oc.unsqueeze(0))

        h2o_out = torch.concat(self.h2o_list, dim=0)
        h2h_hid = torch.concat(self.h2h_list[1:], dim=0)
        self.h2o_list.clear()
        self.h2h_list.clear()
        return h2o_out, h2h_hid


x = torch.randn(5, 2, 4)
tag = torch.sigmoid(torch.randn(5, 2, 2))

model = RNN(4, 3, 2)
optim = torch.optim.Adam(model.parameters())
criterion = nn.MSELoss()

for i in range(1, 200):
    model.zero_grad()
    output, _ = model(x)
    loss = criterion(output, tag)
    loss.backward()
    optim.step()
    if i % 10 == 0:
        print(f"{loss.item():.3f}")