PyTorch 张量操作与梯度学习:从 view/reshape 到梯度下降

49 阅读2分钟
  • Sigmoid 函数

    • 优点:输出在 [0,1],天然适合二分类。
    • 缺点:导数在两端很小,容易出现梯度消失 → 训练慢甚至卡住。
  • 导数的意义

    • 高中里:dy/dx​ 表示 x 变一点,y 变多少。
    • 机器学习里:∂L/∂w​ 表示参数 w 变一点,损失 L 变多少。
    • 正负号:决定调 w 的方向(往哪走能让损失变小)。
    • 大小:表示对 w 的敏感程度(导数太小 → 调整不动,太大 → 调整容易震荡)。
  • 为什么是 ∂L/∂w而不是 ∂L/∂x

    • 高中:x 是变量,k 已知 → 研究 y 随 x 的变化。
    • 机器学习:x 是数据(固定的),w 是参数(需要训练) → 研究 L 随 w 的变化。
  • 梯度下降的核心逻辑

    • 损失函数 L 想要最小。
    • 计算梯度(∂L/∂w)。
    • 用梯度的正负号决定 w 增大还是减小。
    • 用梯度的大小 + 学习率 决定 w 调整的幅度。
import numpy as np

def tensor_shapes():
    # 基本张量操作
    data = torch.tensor([[10, 20, 30], [40, 50, 60]])
    print("原始 shape:", data.shape)
    new_data = data.reshape(6, 1, 1)
    print("reshape 后 shape:", new_data.shape)
    
    data3d = torch.tensor(np.random.randint(0, 10, [3, 4, 5]))
    print("原始 3D 张量 shape:", data3d.shape)
    transposed = torch.transpose(data3d, 1, 2)
    print("transpose 后 shape:", transposed.shape)
    permuted = torch.permute(data3d, [1, 2, 0])
    print("permute 后 shape:", permuted.shape)

def squeeze_unsqueeze():
    data = torch.tensor(np.random.randint(0, 10, [1, 3, 1, 5]))
    print("原始 shape:", data.shape)
    data = data.squeeze(0)
    print("squeeze 后 shape:", data.shape)
    data = data.unsqueeze(-1)
    print("unsqueeze 后 shape:", data.shape)

def tensor_statistics():
    data = torch.randint(0, 10, [2, 3], dtype=torch.float64)
    print("数据:\n", data)
    print("mean:", data.mean())
    print("sum:", data.sum())

def gradient_demo():
    x = torch.tensor([10, 20, 30, 40], requires_grad=True, dtype=torch.float64)
    for _ in range(3):
        f = (x ** 2 + 20).mean()
        f.backward()
        print("f =", f.item(), " | x.grad =", x.grad)
        x.grad.zero_()  # 清零梯度
        print('-'*30)

def gradient_descent_demo():
    x = torch.tensor(10., requires_grad=True, dtype=torch.float64)
    lr = 0.001
    for step in range(1, 5001):
        f = x ** 2
        x.grad = None
        f.backward()
        x.data -= lr * x.grad
        if step % 500 == 0:
            print(f"Step {step}: x = {x.item():.6f}")

if __name__ == "__main__":
    print("=== 张量形状操作 ===")
    tensor_shapes()
    print("\n=== squeeze & unsqueeze ===")
    squeeze_unsqueeze()
    print("\n=== 张量统计 ===")
    tensor_statistics()
    print("\n=== 梯度计算示例 ===")
    gradient_demo()
    print("\n=== 简单梯度下降 ===")
    gradient_descent_demo()