常见运算的计算图

549 阅读4分钟

本文介绍常见运算的计算图。

计算图直观地表示了计算过程。通过观察反向传播的梯度流动,可以帮助我们理解反向传播的推导过程。

我们会利用计算图来实现自动求导工具。首先我们看一下常见运算操作的计算图。

加法

image-20211215222515421

z=x+yz = x + y求这个运算的梯度比较简单,易得zx=1,zy=1\frac{\partial z}{\partial x}=1,\frac{\partial z}{\partial y}=1

Lz\frac{\partial L}{\partial z}为经过反向传播传递到zz节点上的梯度。

减法

image-20211215223356415

z=xyz = x - y 可得zx=1,zy=1\frac{\partial z}{\partial x}=1,\frac{\partial z}{\partial y}=-1

乘法

image-20211215222555765

z=x×yz = x \times y 的梯度也比较简单,易得zx=y,zy=x\frac{\partial z}{\partial x}=y,\frac{\partial z}{\partial y}=x

此时,反向传播时会将上游传来的梯度乘以当前路径上计算出来的梯度。

除法

image-20211215224225069

z=xyz=\frac{x}{y}的梯度稍微有点复杂,zx=1y,zy=xy2\frac{\partial z}{\partial x} = \frac{1}{y}, \frac{\partial z}{\partial y} =-\frac{x}{y^2}

我们现在看到的都是单变量,其实也可以是多变量(向量、张量或矩阵)。在多变量时,只需要独立计算向量中各个元素,即,向量的各个元素独立于其他元素进行对应元素的计算。在下文的矩阵乘法时会详细介绍。

分支

严格来说,分支并不是我们常见运算的一种。但是有些情况下很有用,比如进行广播操作时。

image-20211215222854231

分支是最简单的复制形式,它的反向传播是上游传来的梯度之和。

Repeat

上面的分支操作有两个副本(或者分支),也可以扩展为NN个副本,此时称为复制(Repeat)。

image-20211215234024522

如上图,将长度为DD的数组复制了NN份,这个复制操作可以看成是NN个分支操作,所以它的反向传播可以通过NN个梯度的总和。

image-20211215234537203

如果通过Numpy实现的化:

import numpy as np

D, N = 8, 7
x = np.random.randn(1,D)
y = np.repeat(x, N, axis=0) # axis=0 沿着行的方向复制N份,变成了(N,D)
# 上面是正向传播
# 下面是梯度
dy = np.random.randn(N,D) # y的梯度一定和y的维度保持一致
dx = np.sum(dy, axis=0, keepdims=True) # 同理,x的梯度也和x保持一致,这里变成了(1,D)

出自predictivehacks.com

上图是简单介绍一下Numpy中axis的概念。当数组是1D的时候,只有一个轴,所以0轴的方向和2D的不同,要注意一下。

Numpy中的广播会复制数组的元素,可以通过这里的复制操作来表示。

Sum

Sum(求和)也是我们在深度学习中常用的运算。加法操作可以看成是求和的特殊形式。

考虑对一个N×DN \times D对数组沿着第行的方向求和,此时正向传播和反向传播如下所示。

image-20211216165318882

和加法一样,反向传播时将梯度(拷贝)分配到所有的箭头上,Sum操作是上面介绍的复制操作的逆向操作。即Sum的正向传播相当于复制操作的反向传播;Sum的反向传播相当于复制操作的正向传播。

我们也看一下通过Numpy实现的例子。

D, N = 8, 7
# 正向传播
x = np.random.randn(N, D)
y = np.sum(x, axis=0, keepdims=True) # 变成了(1,D)
# 反向传播
dy = np.random.randn(1, D) # 维度和y保持一致
dx = np.repeat(dy, N, axis=0) # 复制成了(N,D)

Matmul

Matmul是矩阵乘法(Matrix Multiply),比如,考虑y=xWy=xW这个运算。x,W,yx,W,y的形状分别是1×D1 \times DD×HD \times H1×H1 \times H

Matmul前向传播

它的反向传播稍微有点复杂。我们先来了解下雅可比矩阵(Jacobian matrix)。

雅可比矩阵

用每个yy对每个xx计算偏微分,计算得到的矩阵高度是yy的个数,宽度是xx的个数。

y=xWy=xW展开得:

[y1,y2,,yH]=[x1,x2,,xD][W11W12W1HW21W22W2HWD1WD2WDH]\left [y_1,y_2,\cdots,y_H \right] = \left[x_1,x_2, \cdots,x_D \right] \begin{bmatrix} W_{11} & W_{12}&\cdots &W_{1H} \\ W_{21}&W_{22}&\cdots&W_{2H} \\ \vdots & \vdots&\ddots& \vdots \\ W_{D1}&W_{D2}&\cdots&W_{DH} \end{bmatrix}

这里假设我们要计算LLxx的导数Lx\frac{\partial L}{\partial x}

我们先计算Ly=[Ly1,Ly2,,LyH]\frac{\partial L}{\partial y}=\left[\frac{\partial L}{\partial y_1},\frac{\partial L}{\partial y_2},\cdots,\frac{\partial L}{\partial y_H}\right]

接着计算yyxx的导数yx\frac{\partial y}{\partial x},根据雅克比矩阵,有

yx=[y1x1y1x2y1xDy2x1y2x2y2xDyHx1yHx2yHxD]\frac{\partial y}{\partial x}= \begin{bmatrix} \frac{\partial y_1}{\partial x_1 } & \frac{\partial y_1}{\partial x_2} & \cdots & \frac{\partial y_1}{\partial x_D} \\ \frac{\partial y_2}{\partial x_1 } & \frac{\partial y_2}{\partial x_2} & \cdots & \frac{\partial y_2}{\partial x_D} \\ \vdots & \vdots & \ddots & \vdots \\ \frac{\partial y_H}{\partial x_1 } & \frac{\partial y_H}{\partial x_2} & \cdots & \frac{\partial y_H}{\partial x_D} \end{bmatrix}

看起来挺复杂,但是如果我们先把yy中第jj个元素yjy_j的等式写出来,就会很简单,如:

yj=x1W1j+x2W2j++xiWij++xDWDjy_j = x_1 \cdot W_{1j} + x_2 \cdot W_{2j} + \cdots +x_i \cdot W_{ij}+ \cdots + x_D \cdot W_{Dj}

所以 yjxi=Wij\frac{\partial y_j}{\partial x_i} = W_{ij},把yx\frac{\partial y}{\partial x}完整的写出来,有

yx=[W11W21WD1W12W22WD2W1HW2HWDH]=WT\frac{\partial y}{\partial x} =\begin{bmatrix} W_{11} & W_{21} & \cdots &W_{D1} \\ W_{12} & W_{22} & \cdots &W_{D2} \\ \vdots & \vdots & \ddots & \vdots \\ W_{1H} & W_{2H} & \cdots & W_{DH} \end{bmatrix} = W^T

所以yx=WT\frac{\partial y}{\partial x} = W^T​ ,这就解释了为什么计算矩阵乘法的反向传播时,有个参数需要转置的。

Lx=Lyyx=LyWT\frac{\partial L}{\partial x}= \frac{\partial L}{\partial y} \frac{\partial y}{\partial x} = \frac{\partial L}{\partial y} W^T

xx的形状是1×D1 \times DLx\frac{\partial L}{\partial x}的形状和它保持一致,也是1×D1 \times D

Ly\frac{\partial L}{\partial y}的形状和yy一样,是1×H1 \times H

WTW^T的形状是H×DH \times D

在推导上面的公式时,不要被写法的复杂所迷惑了,只要我们展开把等式写出来,或者用一个简单的比如2×32 \times 3的矩阵自己去推,就可以知道规律。

上面把yj=x1W1j+x2W2j++xiWij++xDWDjy_j = x_1 \cdot W_{1j} + x_2 \cdot W_{2j} + \cdots +x_i \cdot W_{ij}+ \cdots + x_D \cdot W_{Dj}写出来后,计算yjxi\frac{\partial y_j}{\partial x_i}就很简单了,因为此时只与xix_i有关,对于xx剩下的元素的导数都是0,变成了yjxi=0+0++Wij++0\frac{y_j}{x_i} = 0 + 0 + \cdots + W_{ij} + \cdots + 0

下面介绍几个简单的一元操作。

Pow

计算y=xcy= x^c,我们把xx看成是变量,cc看成是常数。只有一个变量,因此定义为一元操作。yx=cx(c1)\frac{\partial y}{\partial x}=c\cdot x^{(c-1)},一元操作比较简单,因此正向传播和反向传播画到一张图里面。

image-20211218092406811

Log

取对数(Log),一般指的是以指数ee为底。y=logxy = \log x,那么yx=1x\frac{\partial y}{\partial x}=\frac{1}{x}

image-20211218093159223

Exp

指数函数最简单了y=exy = e^xyx=ex\frac{\partial y}{\partial x} = e^x,原样返回。

image-20211218093353738

Neg

Neg是取负数的意思,y=xy=-xyx=1\frac{\partial y}{\partial x} = -1,可以理解为y=1xy = -1 \cdot x

image-20211218093713387