人工神经网络算法入门

avatar
前端工程师 @字节跳动

本文作者:阳羡

背景

本文的目的是希望大家对“人工神经网络算法”有个初步的认知,同时希望有能力的同学可以借助本文和有关示例代码完成一套最简单的人工神经网络算法。

科学就是一个找规律的旅程,很多规律并不像 F=ma 那样一目了然,而是一种隐式的、复杂的、难以量化的规律。

比如投篮,如何把篮球投进篮筐?通过篮球的重量、出手的角度和力气求出还是凭借“手感”?我相信大部分普通人还是凭借“手感”来打篮球的。

大概如下图所示:

图片

把上图再次抽象,使其数学化,方便分析:

图片

我们有些输入值x1,x2,... xn,通过某种计算得到y1, y2, ... yn,具体如何计算我们不知道,事实上我们也不关心,只要它的输入输出符合预期就行,因此我们叫它“隐藏层”。

这样让计算机“找规律”的算法就叫做“人工神经网络算法”(Neural Network)(NN)

前向传播

把上文使用数学语言表示一下,即:

image.png

这个函数,我们叫它“前向传播算法”。

那么,前向传播算法到底是个啥?

我们希望这个函数的定义域与值域能够尽可能的覆盖实数轴上的绝大部分,或者说,希望覆盖我们希望它能够覆盖的绝大部分,从而保障它的通用性。

那么具体什么函数能满足这个条件呢?其实有很多啊,比如:

三角函数:

image.png

多项式函数:

image.png

可惜的是,今天的主角不是这两个,而是我们小学二年级就学过的:

image.png

为什么?因为简单啊,因为简单,所以方便求导,这个后文会提;而且在实际的过程中可以转换为矩阵运算,底层容易做优化

其中,w 的意思是 weight,即权重;b 的意思是 bias,即偏移

有同学说,那不对啊,这个函数太简单太线性了,现实问题这么复杂,仅使用它能够准确描述吗?

这个问题很好解决,那让这个函数再过一下另一个函数就行。这个函数叫做“激活函数”。

因为之前的函数是线性函数,所以激活函数通常选择非线性函数来互补。

对于萌新来说,通常选择 sigmoid 算法会比较合适

image.png

图片

它的图像如图所示,除了简单外,最大的优点是把 x 映射到 (0, 1) 的区间中,从而不会发生超上限的问题。

这个点特别重要,比如说

image.png

当 a > 1 的时候,其取值范围为负无穷到正无穷,在实际代码中,很容易超出上限。学术上把这种情况叫做“梯度爆炸”。当 0 < a < 1的时候,取值范围为 

image.png

只要保障 x 的取值范围受控,即可保障在上限之内。\

也因此,很多情况下会把输出限制在 0 ~ 1 中,那么这个数其实就是个概率,所以我开篇的例子也是使用概率这个概念。

Sigmoid 函数能够解决“梯度爆炸”问题,但是容易引发“梯度消失”的问题,即最后表达式运算容易变成:

image.png

导致每次迭代更新 weight 的速率缓慢,降低拟合速度。

image.png

为了加快拟合速度,从而演示方便,本次演示环节使用的是 Leaky ReLU 带泄露线性整流函数

它能够较好地解决梯度消失问题,但是需要人工对参数和结果进行干预才能解决梯度爆炸问题。

现在我们已经有了两个函数,我们分别称之为“矩阵乘法函数 Matrix”和“激活函数 Active”,通过这两个函数似乎可以表达我们想要表达的绝大多数情况了,但是,现实不是真空中的球形🐔,如果情况是这两个函数无法表达的,该怎么办?

受人脑的神经元启发,可以简单的多层多节点排布,从而在不提高运算复杂性的情况下,提高系统复杂性,也因此得名“人工神经网络算法”。

图片

其中,每一个节点都是对上一层内容做矩阵运算+激活函数,并且把结果传递给下一个节点,用代码表示一个节点则如下:

function forward(inputArray: number[]) {
this.value = this.active(inputArray.reduce(
(acc, input, index) => acc + this.weightArray[index] * input,        0      ) + this.bias)}

反向传播

正向传播中与 x 无关的变量都会在初始化时随机生成,(虽然也可以手动制定)

因此,如果不对这些变量(w、b)校准,那么最终计算出来的结果就是毫无意义的随机值。

校准变量的算法叫做“反向传播”

那么在校准前的第一步,我们要对结果进行打分,来判断结果与预期的差距,这个打分函数叫做“代价函数 Cost”或者“损失函数”,通常可以选用:

image.png

其中 x 是指计算值,t 是指期望值。

这个值越低越好,当它是 0 的时候,代表人工神经网络已经完成了它的学习

为什么不是单纯的减法?因为线性+不收敛,求导没意义

为什么不是减法后取绝对值?因为求导复杂

为什么要平方后乘1/2?因为求导可以约掉

三个问题的答案都是求导,感觉这个函数就是为求导而生的hhh。

事实上也的确如此,这个函数就如同考试成绩,成绩其实不重要,重要的是通过这次考试能够了解到学会了哪些内容,哪些内容需要进一步提高,而获取到这些信息的数学表达就是求导。

如果把 cost 看作某个 w 的函数,则它的图像可能如下所示:

图片

我们希望 cost 越低越好,那么在上面任取一点,那要想达到最低点它应该往导数的相反方向走。

而且只能走一点,不能走多了,不然可能会越过最低点,下次迭代再越过来,这种现象叫做“震荡”,形容走的那一点点的那个参数我们叫它“学习率”。

image.png

有同学会问了,那会不会落入极小值而非最小值呢?在这个例子中不会,因为除了 cost 函数有最小值,其他函数都是单调的。因为简单,所以方便计算和理解。如果使用了其他复杂函数,因为本身节点多,系统复杂度高,所以还是容易陷入一个比较好的结果,通常 cost 在 

image.png

 这个量级就非常不错了。如果的确陷入了不能接受的极小值,那么就需要进行“调参”了。

把上文的导数展开,得到一下这个数学表达式:

image.png

代价函数对某个 w 求偏导,等于代价函数对激活函数求导 * 激活函数对矩阵乘法求导 * 矩阵乘法对某个 w 求导。

其中,矩阵乘法对某个 w 求导就是与该 w 相乘的 x,这个最好理解

激活函数对矩阵乘法求导也很简单,比如 sigmoid 算法的导数就是:

image.png

代价函数对激活函数求偏导就需要分情况讨论了,如果是最后一层,即输出层,那么结果就是代价函数的导数

image.png 如果是隐藏层,则公式为:

image.png

其中L为右边一层,l是L中的节点

该层的代价函数对激活函数的偏导等于,右边每一层的:

• 代价函数对激活函数的偏导 —— 右边那层已经算过了

• 激活函数对矩阵乘法的偏导 —— 右边那层已经算过了

• 矩阵乘法对j层激活函数的偏导 —— 说这么玄乎,其实就是链接这两层的 w 权重

这上述东西的积 的 和

然后就可以计算新的 w:

image.png

对于一个节点的反向传播算法,代码如下:

backward(expect: number) {
  /**
   * 激活函数的偏导
   */
  this.activeValue = this.activeDerivative(this.getValue());

  if (this.nextLayer) {
    /**
     * 隐藏层。
     * 公式是:【右边一层的代价函数偏导结果 乘 右边一层的激活函数偏导结果 乘 这两个节点之前的权重】的和
     */
    this.costValue = this.nextLayer.nodeArray.reduce<number>((acc, node) => {
      return (
        acc +
        node.costValue! * node.activeValue! * node.weightArray[this.index]
      );
    }, 0);
  } else {
    /**
     * 最后一层,即输出层。
     * 公式是代价函数的偏导
     */
    this.costValue = this.costDerivative({ expect, actual: this.getValue() });
  }

  this.weightDeltaArray = this.weightArray.map((_, index) => {
    /**
     * 左一层的值
     */
    const prevValue = this.prevLayer.getValue()[index];

    return (
      -1 * this.learnRate * this.costValue! * this.activeValue! * prevValue
    );
  });
}

这里留一个小练习:请推导出应该如何更新 bias?

现在我们只要不断运行前向传播算法和反向传播算法,w和b就会越来越接近我们想要的值,cost就会越来越低,最后可能变成0,此时这个人工神经网络就训练完成了,虽然我们不知道其中的w和b具体是啥意思,但是对于我们给定的输入,能够给出符合预期的输出就够了。

展示环节

地址:github.com/wu-yu-xuan/…

大家提出想要学习的函数然后阿特自己,由我来当场演示。规则:

• 只是个 demo,所以不要太复杂。

• 定义域、值域 最好保持在 0 ~ 1 内,上文说过原因

加减法

image.png

编写方程组可以参考这个图:

图片

图片