PyTorch 张量操作的“完整兵器库”

76 阅读5分钟

我们来完整、系统、通俗地讲清楚 PyTorch 张量的所有常用操作,包括但不限于:

unsqueeze / squeeze → 加维度 / 减维度
view / reshape → 变形(数据不变)
transpose / permute → 维度交换 / 任意重排
contiguous → 内存整理
✅ 切片 t[:, 0::2] → 隔着取
* / @ / matmul → 乘法三兄弟
expand / repeat → 扩展 / 重复
cat / stack → 拼接
chunk / split → 拆分
index_select / gather → 高级索引


🧩 总览图(建议收藏!)

张量操作
├── 形状变换
│   ├── unsqueeze(dim)   → 加1维
│   ├── squeeze(dim)     → 减1维(只减size=1的)
│   ├── view(*shape)     → 重构形状(数据不变)
│   ├── reshape(*shape)  → 同view,更安全
│   ├── transpose(d0,d1) → 交换两维
│   ├── permute(*dims)   → 任意重排维度
│   └── contiguous()     → 整理内存(为view服务)
│
├── 切片 & 索引
│   ├── t[start:end:step] → 基础切片
│   ├── t[:, 0::2]        → 隔一个取
│   ├── index_select      → 按索引选
│   └── gather            → 高级收集
│
├── 数学运算
│   ├── a * b             → 逐元素乘
│   ├── a @ b / matmul    → 矩阵乘
│   ├── add, sub, div, pow → 四则运算
│   └── sum, mean, max, min → 聚合
│
├── 扩展 & 重复
│   ├── expand(*sizes)    → 广播扩展(不复制数据)
│   └── repeat(*sizes)    → 重复复制(真复制)
│
├── 拼接 & 拆分
│   ├── cat(tensors, dim) → 沿dim拼接
│   ├── stack(tensors, dim) → 新建dim拼接
│   ├── chunk(chunks, dim) → 均分
│   └── split(sizes, dim)  → 按指定大小分
│
└── 其他
    ├── flatten           → 展平
    ├── numel()           → 总元素数
    └── is_contiguous()   → 是否内存连续

🧱 一、形状变换操作

1. unsqueeze(dim) —— “加壳”

x = torch.tensor([1, 2, 3])        # [3]
y = x.unsqueeze(0)                 # [1, 3]
z = x.unsqueeze(1)                 # [3, 1]

→ 常用于:增加 batch 维、通道维


2. squeeze(dim) —— “脱壳”

squeeze 含义是 挤压

x = torch.tensor([[[1, 2, 3]]])    # [1, 1, 3]
y = x.squeeze(0)                   # [1, 3] ← 脱第0维
z = x.squeeze()                    # [3] ← 脱所有size=1的维

→ 注意:只对 size=1 的维度有效!

x = torch.randn(2, 1, 3)
y = x.squeeze(1)  # ✅ [2, 3]
z = x.squeeze(0)  # ❌ 第0维size=2,squeeze无效 → 还是 [2, 1, 3]

3. view(*shape) / reshape(*shape) —— “重新装盒”

x = torch.arange(6)                # [6] → [0,1,2,3,4,5]
y = x.view(2, 3)                   # [2, 3]
# [[0, 1, 2],
#  [3, 4, 5]]

z = x.reshape(3, 2)                # [3, 2] ← reshape更安全

reshapeview 的安全版(自动处理不连续内存)


4. transpose(dim0, dim1) —— “交换两层书架”

transpose 含义是 转置

x = torch.randn(2, 3, 4)           # [batch, height, width]
y = x.transpose(1, 2)              # [batch, width, height]

→ 常用于:图像 [B,C,H,W] → [B,H,W,C]


5. permute(*dims) —— “任意重排书架层”

permute 含义是 置换

x = torch.randn(2, 3, 4, 5)        # [A, B, C, D]
y = x.permute(3, 0, 2, 1)          # [D, A, C, B]

→ 比 transpose 更灵活(可重排任意多维)


6. contiguous() —— “整理书架”

contiguous 含义是 相邻的

x = torch.randn(2, 3, 4)
y = x.transpose(1, 2)              # 内存不连续
z = y.contiguous().view(2, 12)     # ✅ 先整理,再变形

什么时候需要?

  • transpose / permute 后想用 view
  • 报错:view size is not compatible...

✂️ 二、切片 & 索引操作

1. 基础切片 t[start:end:step]

t = torch.tensor([0,1,2,3,4,5,6,7,8,9])

t[1:5]     # [1,2,3,4] ← 从1到5前
t[::2]     # [0,2,4,6,8] ← 从0开始,步长2
t[1::2]    # [1,3,5,7,9] ← 从1开始,步长2
t[::-1]    # [9,8,7,6,5,4,3,2,1,0] ← 反转

2. 多维切片 t[:, 0::2]

t = torch.tensor([[1,2,3,4],
                  [5,6,7,8]])

t[:, ::2]  # 所有行,列步长2 → [[1,3], [5,7]]
t[0, 1:3]  # 第0行,第1~3列 → [2,3]

3. index_select(dim, index) —— “按名单选人”

x = torch.tensor([[1,2,3],
                  [4,5,6],
                  [7,8,9]])

index = torch.tensor([0, 2])
y = torch.index_select(x, 0, index)  # 按第0维(行)选第0和第2行
# [[1,2,3],
#  [7,8,9]]

4. gather(dim, index) —— “按位置收集”

x = torch.tensor([[1,2,3],
                  [4,5,6]])

index = torch.tensor([[0, 1],
                      [2, 0]])

y = torch.gather(x, 1, index)  # 按第1维(列)收集
# 第0行:取第0列(1)、第1列(2) → [1,2]
# 第1行:取第2列(6)、第0列(4) → [6,4]
# 结果:[[1,2],
#        [6,4]]

→ 常用于:根据预测索引收集概率


🧮 三、数学运算

1. a * b —— 逐元素乘

a = torch.tensor([1,2,3])
b = torch.tensor([4,5,6])
c = a * b  # [4,10,18]

2. a @ b / torch.matmul(a, b) —— 矩阵乘

a = torch.randn(2, 3)
b = torch.randn(3, 4)
c = a @ b  # [2, 4]

3. 四则运算

a.add(b)   # a + b
a.sub(b)   # a - b
a.mul(b)   # a * b
a.div(b)   # a / b
a.pow(2)   # a²

4. 聚合操作

x = torch.tensor([[1,2,3],
                  [4,5,6]])

x.sum()        # 21
x.sum(dim=0)   # [5,7,9] ← 按行求和
x.mean(dim=1)  # [2,5] ← 按列求平均
x.max(dim=1)   # (values=[3,6], indices=[2,2])

📦 四、扩展 & 重复

1. expand(*sizes) —— “广播扩展”(不复制数据)

x = torch.tensor([1, 2, 3])        # [3]
y = x.expand(2, 3)                 # [2, 3] ← 广播
# [[1,2,3],
#  [1,2,3]] ← 数据没复制,只是“假装”有两行

→ 内存高效,但不能 view

2. repeat(*sizes) —— “真重复”(复制数据)

x = torch.tensor([1, 2, 3])        # [3]
y = x.repeat(2, 1)                 # [2, 3] ← 真复制
# [[1,2,3],
#  [1,2,3]] ← 数据真被复制了

→ 可以 view,但占内存


🧩 五、拼接 & 拆分

1. torch.cat(tensors, dim) —— “沿现有维度拼接”

a = torch.tensor([[1,2]])
b = torch.tensor([[3,4]])
c = torch.cat([a, b], dim=0)       # [2, 2]
# [[1,2],
#  [3,4]]

2. torch.stack(tensors, dim) —— “新建维度拼接”

a = torch.tensor([1,2])
b = torch.tensor([3,4])
c = torch.stack([a, b], dim=0)     # [2, 2]
# [[1,2],
#  [3,4]]

d = torch.stack([a, b], dim=1)     # [2, 2]
# [[1,3],
#  [2,4]]

3. chunk(chunks, dim) —— “均分”

x = torch.arange(10)               # [10]
y = torch.chunk(x, 3, dim=0)       # 分3块 → [4], [3], [3]

4. split(sizes, dim) —— “按指定大小分”

x = torch.arange(10)
y = torch.split(x, [2,3,5], dim=0) # 分三块:大小2,3,5

🧰 六、其他实用操作

1. flatten(start_dim, end_dim) —— “展平”

x = torch.randn(2, 3, 4)
y = x.flatten(1)                   # [2, 12] ← 从第1维开始展平

2. numel() —— “总元素数”

x = torch.randn(2, 3, 4)
print(x.numel())                   # 24

3. is_contiguous() —— “是否内存连续”

x = torch.randn(2, 3)
y = x.transpose(0, 1)
print(y.is_contiguous())           # False

4. clone() —— “深拷贝”

x = torch.tensor([1,2,3])
y = x.clone()                      # y 是 x 的副本,修改y不影响x

5. detach() —— “脱离计算图”

x = torch.tensor([1.0], requires_grad=True)
y = x.detach()                     # y 不参与梯度计算

✅ 总结卡片:

操作类型核心函数用途示例
加/减维度unsqueeze, squeeze调整维度数x.unsqueeze(0)
重构形状view, reshape改变形状x.view(2,3)
维度重排transpose, permute交换/重排维度x.permute(2,0,1)
切片t[start:end:step]提取子集t[:, ::2]
索引index_select, gather高级选择torch.gather(...)
数学*, @, sum, mean计算a @ b
扩展expand, repeat广播/复制x.repeat(2,1)
拼接cat, stack合并张量torch.cat([a,b])
拆分chunk, split分割张量torch.split(x, [2,3])
工具contiguous, flatten, numel辅助操作x.contiguous().view(...)

🧠 记忆口诀:

“加壳unsqueeze,脱壳squeeze;
view改形状,permute任重排;
切片步长取,gather按位收;
cat沿维拼,stack新建轴;
expand假扩展,repeat真复制;
不连续要整理,contiguous来救急!”


现在你掌握了 PyTorch 张量操作的“完整兵器库”!

这些是构建任何深度学习模型的基石 —— 熟练使用,你就能像搭乐高一样组合出任何复杂结构 🧱🚀