大家好,我是你们的技术伙伴。👋
在2026年的今天,AI已经渗透到了我们生活的方方面面。但是,无论多么高大上的大模型(LLM),其最底层的基石依然是张量(Tensor) 。
很多初学者在学习PyTorch时,往往急于求成直接去跑模型,结果在数据预处理阶段就被各种reshape、view、dim搞得晕头转向。今天,我就带你从零开始,深度剖析PyTorch的核心——张量操作与自动微分。这不仅是面试必考题,更是你写出高效深度学习代码的关键。
🧱 第一章:张量的“变形记”——索引与形状操作
数据进入神经网络前,必须是特定形状的张量。如何像捏橡皮泥一样随意改变张量的形状?这是我们必须掌握的第一课。
1. 索引:不只是简单的切片
PyTorch的索引比NumPy更强大,但也更需要注意细节。
- 基础切片:和Python列表一样,
[start:end:step]。 - 高级索引:这是新手容易懵的地方。我们可以直接传入下标列表来获取特定位置的元素。
import torch
# 创建一个5x5的随机张量
torch.manual_seed(24)
t1 = torch.randint(1, 10, (5, 5))
print("原始张量:\n", t1)
# 需求:获取(0,1)和(1,2)位置的元素
print("特定位置元素:", t1[[0, 1], [1, 2]])
# 布尔索引:筛选数据的灵魂
# 需求:获取第3列大于5的所有行
mask = t1[:, 2] > 5
print("第3列大于5的行:\n", t1[mask])
2. 形状变换:Reshape vs View vs Permute
这三个函数经常让人混淆,我为你总结了一个避坑表格:
| 函数 | 作用 | 核心区别 | 适用场景 |
|---|---|---|---|
reshape() | 改变形状 | 最通用,PyTorch会自动处理连续性问题 | 日常开发首选,简单粗暴 |
view() | 改变形状 | 仅限连续张量,如果张量经过transpose等操作变得不连续,view会报错 | 内存连续时使用,速度极快 |
permute() | 交换维度 | 可以一次性交换多个维度,常用于图像通道变换 | 图像处理(CV)中 HWC转CHW |
实战演示:
# 1. Reshape: 随意变形
t_flat = torch.arange(6)
t_reshape = t_flat.reshape(2, 3) # 变成2行3列
# 2. Permute: 维度大挪移 (常用于图片)
# 假设我们有一个 BCHW 格式的张量 (Batch, Channel, Height, Width)
img = torch.randn(10, 3, 224, 224)
# 如果想转成 BHWC (TensorFlow格式)
img_tf = img.permute(0, 2, 3, 1)
print("TensorFlow格式形状:", img_tf.shape) # [10, 224, 224, 3]
# 3. View的陷阱与Contiguous修复
t_trans = t_reshape.transpose(0, 1) # 转置后,内存可能不连续
print("转置后是否连续:", t_trans.is_contiguous()) # False
# 错误写法: t_trans.view(-1) 会报错
# 正确写法: 先转连续,再view
t_fixed = t_trans.contiguous().view(-1)
print("修复后的扁平化张量:", t_fixed)
🔗 第二章:张量的“合体技”——拼接与拆分
在构建复杂网络(如ResNet、Transformer)时,我们经常需要把不同的特征图拼在一起。
1. Cat vs Stack:一字之差,天地之别
torch.cat():缝合怪。维度不变,只是把数据“接”起来。比如两个(2,3)的矩阵按行拼接,变成(4,3)。torch.stack():生维者。维度+1。它会在指定位置“创造”一个新的维度。比如两个(2,3)的矩阵stack,会变成(2,2,3)。
t1 = torch.randint(1, 10, (2, 3))
t2 = torch.randint(1, 10, (2, 3))
# Cat: 按行拼接 (dim=0)
cat_result = torch.cat([t1, t2], dim=0)
print("Cat结果形状:", cat_result.shape) # torch.Size([4, 3])
# Stack: 堆叠 (dim=0)
stack_result = torch.stack([t1, t2], dim=0)
print("Stack结果形状:", stack_result.shape) # torch.Size([2, 2, 3])
# 解读:现在有了2个“样本”,每个样本是2x3的矩阵
🧮 第三章:数学运算的“阴阳两面”
张量的运算分为点乘(Element-wise) 和矩阵乘(Matrix) ,以及基础的加减乘除。
1. 点乘(*) vs 矩阵乘(@ / matmul)
- 点乘 (
*) :夫妻双双把家还。要求形状完全一致(或可广播),对应位置相乘。 - 矩阵乘 (
@) :线性代数的灵魂。 A(m,n)×B(n,p)=C(m,p)A(m,n)×B(n,p)=C(m,p) 。这是神经网络全连接层的核心。
# 点乘:对应元素相乘
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[5, 6], [7, 8]])
print("点乘结果:\n", a * b)
# [[1*5, 2*6], [3*7, 4*8]]
# 矩阵乘:线性变换
mat_result = torch.matmul(a, b)
# 或者 a @ b
print("矩阵乘结果:\n", mat_result)
# [[1*5+2*7, 1*6+2*8], [3*5+4*7, 3*6+4*8]]
2. 聚合运算:Sum, Mean, Max
这些函数有一个非常重要的参数 dim,它决定了“压缩”哪个维度。
dim=0:跨行操作,即把行“压”没,按列计算。dim=1:跨列操作,即把列“压”没,按行计算。
data = torch.tensor([[1., 2., 3.], [4., 5., 6.]])
print("按列求和(dim=0):", data.sum(dim=0)) # [5., 7., 9.]
print("按行求和(dim=1):", data.sum(dim=1)) # [6., 15.]
⚡ 第四章:PyTorch的“杀手锏”——自动微分 (Autograd)
这是深度学习框架最核心的功能。为什么要学它? 因为如果不理解它,你永远不知道梯度为什么会消失,权重是如何更新的。
1. 核心概念:requires_grad
只有当你想优化一个参数时,才需要开启自动微分。
# 定义权重w,这是我们要优化的变量
w = torch.tensor(10.0, requires_grad=True)
# 注意:必须是浮点数,且 requires_grad=True
# 定义损失函数 Loss = w^2 + 20
# 我们的目标是找到w,让Loss最小(显然是w=0时)
loss = w ** 2 + 20
2. 反向传播:backward()
这是魔法发生的时刻。PyTorch会自动根据loss计算出w的梯度(导数)。
# 执行反向传播
# 注意:loss必须是标量(只有一个数值)
loss.backward()
# 查看梯度
print("w的梯度:", w.grad)
# 输出:20.0
# 为什么是20?因为 Loss = w^2 的导数是 2w, 当w=10时,梯度=2*10=20
3. 梯度下降:更新参数
拿到了梯度,我们就可以更新权重了。
# 学习率 (Learning Rate)
lr = 0.01
# 梯度下降公式:w_new = w_old - lr * gradient
# 注意:我们操作的是 w.data,而不是创建一个新的w
# 因为创建新w会断开计算图
w.data = w.data - lr * w.grad
print("更新后的w:", w) # 10 - 0.01*20 = 9.8
4. 完整的训练循环(避坑重点)
在实际训练中,梯度是会累加的!所以每次循环必须清零。
w = torch.tensor(10.0, requires_grad=True)
lr = 0.01
for i in range(100):
# 1. 前向传播 (计算Loss)
loss = w ** 2 + 20
# 2. 梯度清零 (关键!否则梯度会累加)
# 如果w.grad是None,第一次运行会报错,所以要判断
if w.grad is not None:
w.grad.zero_()
# 3. 反向传播
loss.backward()
# 4. 更新参数
with torch.no_grad(): # 暂时关闭梯度计算,提高效率并防止内存泄漏
w.data -= lr * w.grad
print(f"训练结束,w趋近于0,值为: {w.item():.4f}")
🔄 第五章:与外界的桥梁——NumPy互转
深度学习往往需要结合传统的数据处理(如OpenCV、Pandas),这就涉及Tensor和NumPy的转换。
核心警告:内存共享机制
torch.from_numpy():共享内存。修改NumPy,Tensor也会变。torch.tensor():拷贝数据。两者互不相干。
import numpy as np
np_data = np.array([1, 2, 3])
# 方式1: 共享内存 (高效,但要注意数据一致性)
t1 = torch.from_numpy(np_data)
# 方式2: 拷贝 (安全,但耗内存)
t2 = torch.tensor(np_data)
# 特殊情况:如何从Tensor取值?
scalar_tensor = torch.tensor(3.14)
# 错误:不能直接当float用
# 正确:使用 .item()
float_pi = scalar_tensor.item()
print(type(float_pi)) # <class 'float'>
📝 结语
恭喜你读到这里!🎉
今天我们进行了一场深度学习底层的“硬核”之旅。我们不仅学会了如何用reshape和permute玩转数据形状,掌握了cat和stack的拼接艺术,更重要的是,我们揭开了自动微分(Autograd) 的神秘面纱,理解了神经网络是如何通过梯度下降来“学习”的。
最后的叮嘱:
- 多动手:代码敲一遍胜过看十遍,特别是
dim参数和view的连续性问题。 - 理解原理:不要死记硬背
w.grad.zero_(),要明白为什么要清零。
如果你觉得这篇长文对你有帮助,请务必点赞、收藏,并关注我。在2026年的AI征途中,让我们一起进阶!有任何疑问,欢迎在评论区留言,我会一一解答。💬