Pytorch 基础之张量合并与分割

829 阅读3分钟

本次介绍一下 Tensor 张量合并与分割常用的一些方法:

1. torch.cat(tensorsdim=0,  *** , out=None) → [Tensor]

方法释义:对除了要合并维度之外,其它维度 shape 都一样的 tensor 序列(数组或列表)进行合并操作

参数介绍:

tensors:要合并的 tensor 列表,shape 除 dim 指定维度外,其它维度必须一样;

dim:具体要合并的维度索引,从 0 开始

示例:

a = torch.rand(3, 4)
b = torch.rand(3, 5)
c = torch.rand(2, 4)
# 对 a, b 两 tensor 在第二个维度进行合并,相当于 3 行 4 列的矩阵 与 3 行 5 列的矩阵,合并成了 3 行 9 列的矩阵
print(torch.cat([a, b], 1).size())  
# 会报错,因为在第一维度上 a, c size 不一致, 和输出结果描述的错误一样
print(torch.cat([a, c], 1).size())   

# 输出结果
torch.Size([3, 9])
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 3 but got size 2 for tensor number 1 in the list.

2. torch.stack(tensorsdim=0,  *** , out=None) → [Tensor]

方法释义:对具有相同 size 的 tensor 序列,在指定维度索引位置插入一个新维度,并合并。

参数介绍:

tensors:相同 size 的 tensor 序列

dim:要插入并合并的索引维度,大小在 [-(dim + 1), dim] 区间

示例:

a = torch.rand(2, 3, 5)
b = torch.rand(2, 3, 5)
# a, b  tensor 分别在维度索引 3 处插入一个维度,即
# a, b 变成 torch.Size([2, 3, 5, 1]),然后再合并成 torch.Size([2, 3, 5, 2])
print(torch.stack([a, b], 3).shape)
# 这个索引 -4 ,是相当于 0 , 负数是从后面 -1 数过来
print(torch.stack([a, b], -4).shape)
# 报错,索引越界了,需要在 [-4, 3] 范围
print(torch.stack([a, b], 4))

# 输出结果
torch.Size([2, 3, 5, 2])
torch.Size([2, 2, 3, 5])
IndexError: Dimension out of range (expected to be in range of [-4, 3], but got 4)

3. torch.split(tensorsplit_size_or_sectionsdim=0) → List[Tensor]

方法释义:将 tensor 拆分成块,每块都与原 tensor 维度保持一致

参数介绍:

tensor:要拆分的 tensor

split_size_or_sections(int or list(int)):单块的 size 或者 list 中指明的各块 size 。

  • 如果是 int 类型,将尽力按照单块 size 去拆分,如果不能整除 size,那么最后一块的 size 可能会小一些;

  • 如果是 list 类型,将拆分成 list 长度的块,每块按照 list 指定的 size 进行拆分。

dim:根据哪个维度索引进行拆分

示例:

a = torch.rand(2, 3, 5)
# 根据第一个维度,每块 size 为 2 进行拆分,因为维度一 size 等于 2,所以只拆分了一块出来
aa = torch.split(a, 2, 0)
print(aa[0].shape)
# 根据维度二,拆分成 2 块,一块维度二 size 等于 1,另一块维度二 size 等于 2
bb = torch.split(a, [1, 2], 1)
print(bb[0].shape)
print(bb[1].shape)
# 报错,因为拆分成 2 块,但这两块的 size 不等于维度二的 3
bb = torch.split(a, [1, 1], 1)

# 输出结果
torch.Size([2, 3, 5])
torch.Size([2, 1, 5])
torch.Size([2, 2, 5])
RuntimeError: split_with_sizes expects split_sizes to sum exactly to 3 (input tensor's size at dimension 1), but got split_sizes=[1, 1]

4. torch.chunk(inputchunksdim=0) → List of Tensors

方法释义:试图将 tensor 拆分为指定的块数,维度和原 tensor 保持一致,拆分出来的块数可能比指定的块数少

参数介绍:

input:要拆分的 tensor

chunks:返回要拆分的块数

dim:根据哪个维度进行拆分

示例:

a = torch.rand(2, 3, 5)
# 将第二维度的 size 拆分成 4 块,由于第二维度为 3,只拆分成了 3 块返回
aa = torch.chunk(a, 4, 1)
print(aa[0].shape)
print(aa[1].shape)
print(aa[2].shape)
# 报错,索引越界了
print(aa[3].shape)

# 输出结果
torch.Size([2, 1, 5])
torch.Size([2, 1, 5])
torch.Size([2, 1, 5])
IndexError: tuple index out of range

5. cat 与 stack 使用场景区分

cat: 更多是在原有的维度上进行合并,适合于一些不需要新增维度的场景,如要将 1 ~ 4 班学生的科目成绩,与 5 ~ 8 班的科目成绩合并统计(学生人数和科目一样)

stack:适用于要新增维度的场景,如上面并没有班级维度的话,在统计时就需要通过 stack 新增一个班级维度进行合并。

总结:这几个 api 是比较经典常用的,官网还有一些新增的,有兴趣的可以去官网查看。