张量操作

401 阅读3分钟

一、张量拼接与切分

1.1 torch.cat()

torch.cat(tensors,
        dim=0,
        out=None)

功能:将张量按维度dim进行拼接

  • tensors:张量序列
  • dim:要拼接的维度
a = torch.Tensor([[1, 2], [3, 4]])
b = torch.Tensor([[5, 6], [7, 8]])
c = torch.cat([a, b], 0)
d = torch.cat([a, b], 1)
print(a)
print(a.shape)
print(b)
print(b.shape)
print(c)
print(c.shape)
print(d)
print(d.shape)
tensor([[1., 2.],
        [3., 4.]])
torch.Size([2, 2])
tensor([[5., 6.],
        [7., 8.]])
torch.Size([2, 2])
tensor([[1., 2.],
        [3., 4.],
        [5., 6.],
        [7., 8.]])
torch.Size([4, 2])
tensor([[1., 2., 5., 6.],
        [3., 4., 7., 8.]])
torch.Size([2, 4])

1.2 torch.stack()

torch.stack(tensors,
        dim=0,
        out=None)

功能:在新创建的维度dim上进行拼接

  • tensors:张量序列
  • dim:要拼接的维度
a = torch.Tensor([[1, 2], [3, 4]])
b = torch.Tensor([[5, 6], [7, 8]])
e = torch.stack([a, b], 0)
f = torch.stack([a, b], 1)
g = torch.stack([a, b], 2)
print(a)
print(a.shape)
print(b)
print(b.shape)
print(e)
print(e.shape)
print(f)
print(f.shape)
print(g)
print(g.shape)
tensor([[1., 2.],
        [3., 4.]])
torch.Size([2, 2])
tensor([[5., 6.],
        [7., 8.]])
torch.Size([2, 2])
tensor([[[1., 2.],
         [3., 4.]],

        [[5., 6.],
         [7., 8.]]])
torch.Size([2, 2, 2])
tensor([[[1., 2.],
         [5., 6.]],

        [[3., 4.],
         [7., 8.]]])
torch.Size([2, 2, 2])
tensor([[[1., 5.],
         [2., 6.]],

        [[3., 7.],
         [4., 8.]]])
torch.Size([2, 2, 2])

1.3 torch.chunk()

torch.chunk(input,
            chunks,
            dim=0)

功能:将张量按维度dim进行平均切分
返回值:张量列表
注意事项:若不能整除,最后一份张量小于其他张量

  • input:要切分的张量
  • chunks:要切分的份数
  • dim:要切分的维度
a = torch.Tensor([[1,2,3],[4,5,6]])
print(a)
print(a.shape)
# 使用chunks,沿着第0维进行分块,一共分两块,因此分割成两个1x3的Tensor
b = torch.chunk(a, 2, 0)
print(b)
# 沿着第1维进行分块,因此分割成两个Tensor,当不能整除时,最后一个的维数会小于前面的
# 因此第一个Tensor为2x2,第二个为2x1
c = torch.chunk(a, 2, 1)
print(c)
tensor([[1., 2., 3.],
        [4., 5., 6.]])
torch.Size([2, 3])
(tensor([[1., 2., 3.]]), tensor([[4., 5., 6.]]))
(tensor([[1., 2.],
        [4., 5.]]), tensor([[3.],
        [6.]]))

1.4 torch.split()

torch.split(tensor,
            split_size_or_sections,
            dim=0)

功能:将张量按维度dim进行切分
返回值:张量列表

  • tensor:要切分的张量
  • split_size_or_sections:为int时,表示每一份的长度;为list时,按list元素切分
  • dim:要切分的维度
a = torch.Tensor([[1,2,3],[4,5,6]])
print(a)
print(a.shape)
# 使用split,沿着第0维分块,每一块维度为2,由于第一维维度总共为2,因此相当于没有分割
b = torch.split(a, 2, 0)
print(b)
# 沿着第1维分块,每一块维度为2,因此第一个Tensor为2x2,第二个为2x1
c = torch.split(a, 2, 1)
print(c)
# split也可以根据属于的list进行自动分块,list中的元素代表了每一块占的维度
d = torch.split(a, [1, 2], 1)
print(d)
tensor([[1., 2., 3.],
        [4., 5., 6.]])
torch.Size([2, 3])
(tensor([[1., 2., 3.],
        [4., 5., 6.]]),)
(tensor([[1., 2.],
        [4., 5.]]), tensor([[3.],
        [6.]]))
(tensor([[1.],
        [4.]]), tensor([[2., 3.],
        [5., 6.]]))

二、张量索引

2.1 torch.index_select()

torch.index_select(input,
                    dim,
                    index,
                    out=None)

功能:在维度dim上,按index索引数据
返回值:依index索引数据拼接的张量

  • input:要索引的张量
  • dim:要索引的维度
  • 要索引数据的序号
t = torch.randint(0, 9, size=(3, 3))
idx = torch.tensor([0, 2], dtype=torch.long)
t_select = torch.index_select(t, dim=0, index=idx)
print(t)
print(t_select)
tensor([[1, 4, 0],
        [4, 8, 0],
        [6, 6, 8]])
tensor([[1, 4, 0],
        [6, 6, 8]])

2.2 torch.masked_select()

torch.masked_select(input,
                    mask,
                    out=None)

功能:按mask中的True进行索引
返回值:一维张量

  • input:要索引的张量
  • mask:与input同形状的布尔类型张量
t = torch.randint(0, 9, size=(3, 3))
mask = t.ge(5)
t_select = torch.masked_select(t, mask)
print(t)
print(mask)
print(t_select)
tensor([[6, 5, 8],
        [7, 6, 7],
        [1, 3, 3]])
tensor([[ True,  True,  True],
        [ True,  True,  True],
        [False, False, False]])
tensor([6, 5, 8, 7, 6, 7])

三、张量变换

3.1 torch.reshape()

torch.reshape(input,
                shape)

功能:变换张量形状
注意事项:当张量在内存中是连续时,新张量与input共享数据内存

  • input:要变换的张量
  • shape:新张量的形状
t = torch.randperm(8)
print(t.shape)
t_reshape = torch.reshape(t,(2,4))
print(t_reshape.shape)
torch.Size([8])
torch.Size([2, 4])

3.2 torch.transpose()

torch.transpose(input,
                dim0,
                dim1)

功能:交换张量的两个维度

  • input:要变换的张量
  • dim0:要交换的维度
  • dim1:要交换的维度
t = torch.rand((2, 3, 4))
t_transpose = torch.transpose(t, dim0=1, dim1=2)
print(t.shape)
print(t_transpose.shape)
torch.Size([2, 3, 4])
torch.Size([2, 4, 3])

3.3 torch.t()

torch.t(input)

功能:2维张量转置,对矩阵而言,等价于torch.transpose(input, 0, 1)

3.4 torch.squeeze()

torch.squeeze(input,
                dim=None,
                out=None)

功能:压缩长度为1的维度(轴)

  • dim:若为None,移除所有长度为1的轴;若指定维度,当且仅当该轴长度为1时,可以被移除;
t = torch.rand((1, 2, 3, 1))
t_sq = torch.squeeze(t)
t_0 = torch.squeeze(t, dim=0)
t_1 = torch.squeeze(t, dim=1)
print(t.shape)
print(t_sq.shape)
print(t_0.shape)
print(t_1.shape)
torch.Size([1, 2, 3, 1])
torch.Size([2, 3])
torch.Size([2, 3, 1])
torch.Size([1, 2, 3, 1])

3.5 torch.unsqueeze()

torch.unsqueeze(input,
                dim,
                out=None)

功能:依据dim扩展维度

  • dim:扩展的维度