我们来完整、系统、通俗地讲清楚 PyTorch 张量的所有常用操作,包括但不限于:
✅
unsqueeze/squeeze→ 加维度 / 减维度
✅view/reshape→ 变形(数据不变)
✅transpose/permute→ 维度交换 / 任意重排
✅contiguous→ 内存整理
✅ 切片t[:, 0::2]→ 隔着取
✅*/@/matmul→ 乘法三兄弟
✅expand/repeat→ 扩展 / 重复
✅cat/stack→ 拼接
✅chunk/split→ 拆分
✅index_select/gather→ 高级索引
🧩 总览图(建议收藏!)
张量操作
├── 形状变换
│ ├── unsqueeze(dim) → 加1维
│ ├── squeeze(dim) → 减1维(只减size=1的)
│ ├── view(*shape) → 重构形状(数据不变)
│ ├── reshape(*shape) → 同view,更安全
│ ├── transpose(d0,d1) → 交换两维
│ ├── permute(*dims) → 任意重排维度
│ └── contiguous() → 整理内存(为view服务)
│
├── 切片 & 索引
│ ├── t[start:end:step] → 基础切片
│ ├── t[:, 0::2] → 隔一个取
│ ├── index_select → 按索引选
│ └── gather → 高级收集
│
├── 数学运算
│ ├── a * b → 逐元素乘
│ ├── a @ b / matmul → 矩阵乘
│ ├── add, sub, div, pow → 四则运算
│ └── sum, mean, max, min → 聚合
│
├── 扩展 & 重复
│ ├── expand(*sizes) → 广播扩展(不复制数据)
│ └── repeat(*sizes) → 重复复制(真复制)
│
├── 拼接 & 拆分
│ ├── cat(tensors, dim) → 沿dim拼接
│ ├── stack(tensors, dim) → 新建dim拼接
│ ├── chunk(chunks, dim) → 均分
│ └── split(sizes, dim) → 按指定大小分
│
└── 其他
├── flatten → 展平
├── numel() → 总元素数
└── is_contiguous() → 是否内存连续
🧱 一、形状变换操作
1. unsqueeze(dim) —— “加壳”
x = torch.tensor([1, 2, 3]) # [3]
y = x.unsqueeze(0) # [1, 3]
z = x.unsqueeze(1) # [3, 1]
→ 常用于:增加 batch 维、通道维
2. squeeze(dim) —— “脱壳”
squeeze 含义是 挤压
x = torch.tensor([[[1, 2, 3]]]) # [1, 1, 3]
y = x.squeeze(0) # [1, 3] ← 脱第0维
z = x.squeeze() # [3] ← 脱所有size=1的维
→ 注意:只对 size=1 的维度有效!
x = torch.randn(2, 1, 3)
y = x.squeeze(1) # ✅ [2, 3]
z = x.squeeze(0) # ❌ 第0维size=2,squeeze无效 → 还是 [2, 1, 3]
3. view(*shape) / reshape(*shape) —— “重新装盒”
x = torch.arange(6) # [6] → [0,1,2,3,4,5]
y = x.view(2, 3) # [2, 3]
# [[0, 1, 2],
# [3, 4, 5]]
z = x.reshape(3, 2) # [3, 2] ← reshape更安全
→ reshape 是 view 的安全版(自动处理不连续内存)
4. transpose(dim0, dim1) —— “交换两层书架”
transpose 含义是 转置
x = torch.randn(2, 3, 4) # [batch, height, width]
y = x.transpose(1, 2) # [batch, width, height]
→ 常用于:图像 [B,C,H,W] → [B,H,W,C]
5. permute(*dims) —— “任意重排书架层”
permute 含义是 置换
x = torch.randn(2, 3, 4, 5) # [A, B, C, D]
y = x.permute(3, 0, 2, 1) # [D, A, C, B]
→ 比 transpose 更灵活(可重排任意多维)
6. contiguous() —— “整理书架”
contiguous 含义是 相邻的
x = torch.randn(2, 3, 4)
y = x.transpose(1, 2) # 内存不连续
z = y.contiguous().view(2, 12) # ✅ 先整理,再变形
→ 什么时候需要?
transpose/permute后想用view- 报错:
view size is not compatible...
✂️ 二、切片 & 索引操作
1. 基础切片 t[start:end:step]
t = torch.tensor([0,1,2,3,4,5,6,7,8,9])
t[1:5] # [1,2,3,4] ← 从1到5前
t[::2] # [0,2,4,6,8] ← 从0开始,步长2
t[1::2] # [1,3,5,7,9] ← 从1开始,步长2
t[::-1] # [9,8,7,6,5,4,3,2,1,0] ← 反转
2. 多维切片 t[:, 0::2]
t = torch.tensor([[1,2,3,4],
[5,6,7,8]])
t[:, ::2] # 所有行,列步长2 → [[1,3], [5,7]]
t[0, 1:3] # 第0行,第1~3列 → [2,3]
3. index_select(dim, index) —— “按名单选人”
x = torch.tensor([[1,2,3],
[4,5,6],
[7,8,9]])
index = torch.tensor([0, 2])
y = torch.index_select(x, 0, index) # 按第0维(行)选第0和第2行
# [[1,2,3],
# [7,8,9]]
4. gather(dim, index) —— “按位置收集”
x = torch.tensor([[1,2,3],
[4,5,6]])
index = torch.tensor([[0, 1],
[2, 0]])
y = torch.gather(x, 1, index) # 按第1维(列)收集
# 第0行:取第0列(1)、第1列(2) → [1,2]
# 第1行:取第2列(6)、第0列(4) → [6,4]
# 结果:[[1,2],
# [6,4]]
→ 常用于:根据预测索引收集概率
🧮 三、数学运算
1. a * b —— 逐元素乘
a = torch.tensor([1,2,3])
b = torch.tensor([4,5,6])
c = a * b # [4,10,18]
2. a @ b / torch.matmul(a, b) —— 矩阵乘
a = torch.randn(2, 3)
b = torch.randn(3, 4)
c = a @ b # [2, 4]
3. 四则运算
a.add(b) # a + b
a.sub(b) # a - b
a.mul(b) # a * b
a.div(b) # a / b
a.pow(2) # a²
4. 聚合操作
x = torch.tensor([[1,2,3],
[4,5,6]])
x.sum() # 21
x.sum(dim=0) # [5,7,9] ← 按行求和
x.mean(dim=1) # [2,5] ← 按列求平均
x.max(dim=1) # (values=[3,6], indices=[2,2])
📦 四、扩展 & 重复
1. expand(*sizes) —— “广播扩展”(不复制数据)
x = torch.tensor([1, 2, 3]) # [3]
y = x.expand(2, 3) # [2, 3] ← 广播
# [[1,2,3],
# [1,2,3]] ← 数据没复制,只是“假装”有两行
→ 内存高效,但不能 view
2. repeat(*sizes) —— “真重复”(复制数据)
x = torch.tensor([1, 2, 3]) # [3]
y = x.repeat(2, 1) # [2, 3] ← 真复制
# [[1,2,3],
# [1,2,3]] ← 数据真被复制了
→ 可以 view,但占内存
🧩 五、拼接 & 拆分
1. torch.cat(tensors, dim) —— “沿现有维度拼接”
a = torch.tensor([[1,2]])
b = torch.tensor([[3,4]])
c = torch.cat([a, b], dim=0) # [2, 2]
# [[1,2],
# [3,4]]
2. torch.stack(tensors, dim) —— “新建维度拼接”
a = torch.tensor([1,2])
b = torch.tensor([3,4])
c = torch.stack([a, b], dim=0) # [2, 2]
# [[1,2],
# [3,4]]
d = torch.stack([a, b], dim=1) # [2, 2]
# [[1,3],
# [2,4]]
3. chunk(chunks, dim) —— “均分”
x = torch.arange(10) # [10]
y = torch.chunk(x, 3, dim=0) # 分3块 → [4], [3], [3]
4. split(sizes, dim) —— “按指定大小分”
x = torch.arange(10)
y = torch.split(x, [2,3,5], dim=0) # 分三块:大小2,3,5
🧰 六、其他实用操作
1. flatten(start_dim, end_dim) —— “展平”
x = torch.randn(2, 3, 4)
y = x.flatten(1) # [2, 12] ← 从第1维开始展平
2. numel() —— “总元素数”
x = torch.randn(2, 3, 4)
print(x.numel()) # 24
3. is_contiguous() —— “是否内存连续”
x = torch.randn(2, 3)
y = x.transpose(0, 1)
print(y.is_contiguous()) # False
4. clone() —— “深拷贝”
x = torch.tensor([1,2,3])
y = x.clone() # y 是 x 的副本,修改y不影响x
5. detach() —— “脱离计算图”
x = torch.tensor([1.0], requires_grad=True)
y = x.detach() # y 不参与梯度计算
✅ 总结卡片:
| 操作类型 | 核心函数 | 用途 | 示例 |
|---|---|---|---|
| 加/减维度 | unsqueeze, squeeze | 调整维度数 | x.unsqueeze(0) |
| 重构形状 | view, reshape | 改变形状 | x.view(2,3) |
| 维度重排 | transpose, permute | 交换/重排维度 | x.permute(2,0,1) |
| 切片 | t[start:end:step] | 提取子集 | t[:, ::2] |
| 索引 | index_select, gather | 高级选择 | torch.gather(...) |
| 数学 | *, @, sum, mean | 计算 | a @ b |
| 扩展 | expand, repeat | 广播/复制 | x.repeat(2,1) |
| 拼接 | cat, stack | 合并张量 | torch.cat([a,b]) |
| 拆分 | chunk, split | 分割张量 | torch.split(x, [2,3]) |
| 工具 | contiguous, flatten, numel | 辅助操作 | x.contiguous().view(...) |
🧠 记忆口诀:
“加壳unsqueeze,脱壳squeeze;
view改形状,permute任重排;
切片步长取,gather按位收;
cat沿维拼,stack新建轴;
expand假扩展,repeat真复制;
不连续要整理,contiguous来救急!”
现在你掌握了 PyTorch 张量操作的“完整兵器库”!
这些是构建任何深度学习模型的基石 —— 熟练使用,你就能像搭乐高一样组合出任何复杂结构 🧱🚀