最快理解梯度下降算法!

219 阅读2分钟

关于梯度下降算法

现在我们用最简单的函数来使用梯度下降算法找到函数的最低点。只要上过高中,就能听懂🐶

这是 wiki 的解释: 要使用梯度下降法找到一个函数的局部极小值,必须向函数上当前点对应梯度(或者是近似梯度)的反方向的规定步长距离点进行迭代搜索。

我们用最简单的例子来解释此概念。

我们有一个函数: f(x) = x² - 4x + 1, 他的图像为:

image.png 看到这个图,熟悉的高中味道是不是就来了~

现在我想找到这个函数的最低点。那么我们可以看图,可以跟高中一样,求导函数,然后再去导函数为 0 的点,就是原函数最低点。

那么梯度下降法如何做?

从一个初始点开始搜索,每次我们都试图去寻找函数下降的下一个点。

初始点

我们取一个初始点,比如我们取(5 , 6)这个点。从这个点开始搜索更低的位置(就是 f(x)函数值最低的位置)

梯度

函数在某个点的斜率,或者说有一个山坡,山坡的某个点的坡度。我们称这个坡度为梯度。那么梯度为正的就是我们即将“上坡”,梯度为负就是我们即将“下坡”。

至于为什么梯度正就是增长?为什梯度为负就是减少?我问了群友~

image 1.png

算法流程

如何从当前点找到下一个相对较低的点呢?

image 2.png

so,公式如下:

xn+1=xnηf(xn)\begin{equation}x_{n+1} = x_n - \eta \nabla f(x_n)\end{equation}

η 是学习率,所谓学习率就是步长。如果步长太大,那么可能一直反复横跳,一直找不到最低点。步长太小,导致找的太慢,效率太低。

  1. 找到一个初始点,比如说 x=5
  2. 计算该点的梯度,我们的函数是 f(x)= x^2- 4*x + 1 , 他的导函数是 2x - 4,他在 x=5 的导数是 6
  3. 向梯度相反方向走,假设我们的学习率是 0.1。则下一个 x 坐标:5 - 0.1 * 6 = 4.4
  4. 重复点 2 和 点 3

那么假设我们设置了合适的学习率和迭代次数,那么运气好的我们就达到了函数的最低点。我们就知道了他的最低点。

image 3.png

上图就是我们起始点是 5,学习率是 0.1,使用梯度下降法的迭代过程。

接下来,我们将学习率设置为 0.8

image-20241229141316083

可以看到搜索过程左右横跳。

接下来再看一个函数的图。

image-20241229141454155

这张图如果寻找最低点,那么就有两个,所谓局部的最低点。通过这几个图,我们能直观的感受到初始点和学习率的不同,所造成的影响。