PyTorch中的Tensor变形记:快速掌握切分、变形等方法

223 阅读5分钟

在深度学习中,Tensor作为数据的基本表示形式,其操作的灵活性直接决定了模型构建的便捷性和效率。本文将详细介绍PyTorch中Tensor的连接、切分、索引等高级操作方法,帮助读者快速掌握这些实用技巧。

一、Tensor的连接操作

在深度学习中,经常需要将来自不同来源的数据组合在一起,这种操作称为连接。PyTorch提供了多种连接操作方法,主要包括torch.cattorch.stack

1. torch.cat:沿着指定维度拼接Tensor

torch.cat函数用于将多个Tensor沿着指定维度拼接。其基本语法如下:

Python复制

torch.cat(tensors, dim=0, out=None)
  • tensors:需要拼接的Tensor列表。
  • dim:拼接的维度。

例如,对于两个3x3的矩阵AB,沿着第0维(行)拼接:

Python复制

A = torch.ones(3, 3)
B = 2 * torch.ones(3, 3)
C = torch.cat((A, B), dim=0)
print(C)

输出结果为:

复制

tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [2., 2., 2.],
        [2., 2., 2.],
        [2., 2., 2.]])

如果沿着第1维(列)拼接:

Python复制

D = torch.cat((A, B), dim=1)
print(D)

输出结果为:

复制

tensor([[1., 1., 1., 2., 2., 2.],
        [1., 1., 1., 2., 2., 2.],
        [1., 1., 1., 2., 2., 2.]])

2. torch.stack:增加维度进行拼接

torch.stack函数用于在指定维度上增加一个新的维度,并将多个Tensor堆叠在一起。其基本语法如下:

Python复制

torch.stack(inputs, dim=0)
  • inputs:需要堆叠的Tensor列表。
  • dim:新维度的方向。

例如,将两个一维TensorAB堆叠成一个二维Tensor:

Python复制

A = torch.arange(0, 4)
B = torch.arange(5, 9)
C = torch.stack((A, B), dim=0)
print(C)

输出结果为:

复制

tensor([[0, 1, 2, 3],
        [5, 6, 7, 8]])

如果在列方向堆叠:

Python复制

D = torch.stack((A, B), dim=1)
print(D)

输出结果为:

复制

tensor([[0, 5],
        [1, 6],
        [2, 7],
        [3, 8]])

二、Tensor的切分操作

与连接操作相对应,切分操作用于将Tensor拆分成多个子Tensor。PyTorch提供了torch.chunktorch.splittorch.unbind三种切分方法。

1. torch.chunk:按份数切分Tensor

torch.chunk函数将Tensor沿着指定维度尽可能平均地切分成若干份。其基本语法如下:

Python复制

torch.chunk(input, chunks, dim=0)
  • input:待切分的Tensor。
  • chunks:切分的份数。
  • dim:切分的维度。

例如,将一个长度为10的Tensor切分成2份:

Python复制

A = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
B = torch.chunk(A, 2, dim=0)
print(B)

输出结果为:

(tensor([1, 2, 3, 4, 5]), tensor([ 6, 7, 8, 9, 10]))

如果不能整除,则最后一份会包含剩余的元素:

Python复制

B = torch.chunk(A, 3, dim=0)
print(B)

输出结果为:

(tensor([1, 2, 3, 4]), tensor([5, 6, 7, 8]), tensor([ 9, 10]))

2. torch.split:按大小切分Tensor

torch.split函数将Tensor按照指定的大小进行切分。其基本语法如下:

Python复制

torch.split(tensor, split_size_or_sections, dim=0)
  • tensor:待切分的Tensor。
  • split_size_or_sections:每份的大小或一个列表,表示每份的具体大小。
  • dim:切分的维度。

例如,将一个4x4的Tensor按行切分,每份大小为2:

Python复制

A = torch.rand(4, 4)
B = torch.split(A, 2, dim=0)
print(B)

输出结果为两个2x4的Tensor。

如果使用列表指定每份大小:

Python复制

A = torch.rand(5, 4)
B = torch.split(A, (2, 3), dim=0)
print(B)

输出结果为两个Tensor,大小分别为2x4和3x4。

3. torch.unbind:按维度拆解Tensor

torch.unbind函数用于将Tensor沿着指定维度拆解成多个子Tensor。其基本语法如下:

Python复制

torch.unbind(input, dim=0)
  • input:待拆解的Tensor。
  • dim:拆解的维度。

例如,将一个4x4的Tensor按行拆解:

Python复制

A = torch.arange(0, 16).view(4, 4)
b = torch.unbind(A, dim=0)
print(b)

输出结果为4个长度为4的Tensor。

三、Tensor的索引操作

在某些情况下,我们只需要从Tensor中提取部分数据,而不是进行整体切分。PyTorch提供了torch.index_selecttorch.masked_select两种索引操作方法。

1. torch.index_select:基于索引选择数据

torch.index_select函数用于从Tensor中选择指定索引位置的数据。其基本语法如下:

Python复制

torch.index_select(tensor, dim, index)
  • tensor:待选择的Tensor。
  • dim:选择的维度。
  • index:一个Tensor,表示需要选择的索引位置。

例如,从一个4x4的Tensor中选择第1行和第3行:

Python复制

A = torch.arange(0, 16).view(4, 4)
B = torch.index_select(A, 0, torch.tensor([1, 3]))
print(B)

输出结果为:

复制

tensor([[ 4,  5,  6,  7],
        [12, 13, 14, 15]])

2. torch.masked_select:基于条件选择数据

torch.masked_select函数用于根据条件选择Tensor中的数据。其基本语法如下:

Python复制

torch.masked_select(input, mask)
  • input:待选择的Tensor。
  • mask:一个布尔Tensor,表示选择的条件。

例如,选择一个Tensor中所有大于0.3的元素:

Python复制

A = torch.rand(5)
C = torch.masked_select(A, A > 0.3)
print(C)

输出结果为满足条件的元素组成的Tensor。

四、总结

通过本文的介绍,我们详细学习了PyTorch中Tensor的连接、切分和索引操作。这些操作在深度学习中非常常见,掌握它们可以帮助我们更灵活地处理数据。以下是这些操作的总结:

  • 连接操作torch.cat用于沿着指定维度拼接Tensor,torch.stack用于增加维度进行堆叠。
  • 切分操作torch.chunk按份数切分Tensor,torch.split按大小切分Tensor,torch.unbind按维度拆解Tensor。
  • 索引操作torch.index_select基于索引选择数据,torch.masked_select基于条件选择数据。

在实际应用中,这些操作的参数(如维度和大小)需要根据具体需求仔细计算,以避免错误。希望本文能够帮助读者更好地理解和使用PyTorch中的Tensor操作。