torch.backward()
- 该函数的目的是为了进行求导
- backward函数能够对目标进行求导,并方向传播
一个例子
假设函数 y=2x**2关于列向量求导,首先创建一个变量并分配一个初始值
import torcch
x = torch.arange(4.0)
x
tensor([0.,1.,2.,3.,])
在计算y关于x的梯度之前,需要一个地方来存储梯度
x.requires_grad_(True) #等价于x = torch.arange(4.,requires_grad=True)
x.grad #默认值是None
现在计算y值
y = 2 *torch.dot(x,x)
y
tensor(28., grad_fn=<MulBackward0>)
x是一个长度为4的向量,计算x和x的点击,得到了我们赋值给y的标量输出。接下来,通过调用反向传播函数来自动计算y关于x每个分量的梯度,并打印这些梯度
y.backward()
x.grad
backward()函数将梯度反向传递后并更新了x原有的梯度