一分钟读懂反向传播(代码解析)

136 阅读1分钟

torch.backward()

  • 该函数的目的是为了进行求导
  • backward函数能够对目标进行求导,并方向传播

一个例子

假设函数 y=2x**2关于列向量求导,首先创建一个变量并分配一个初始值

import torcch

x = torch.arange(4.0)
x

tensor([0.,1.,2.,3.,])

在计算y关于x的梯度之前,需要一个地方来存储梯度

x.requires_grad_(True) #等价于x = torch.arange(4.,requires_grad=True)
x.grad #默认值是None

现在计算y值

y = 2 *torch.dot(x,x)
y

tensor(28., grad_fn=<MulBackward0>)

x是一个长度为4的向量,计算x和x的点击,得到了我们赋值给y的标量输出。接下来,通过调用反向传播函数来自动计算y关于x每个分量的梯度,并打印这些梯度

y.backward()
x.grad

backward()函数将梯度反向传递后并更新了x原有的梯度