一、张量拼接与切分
1.1 torch.cat()
torch.cat(tensors,
dim=0,
out=None)
功能:将张量按维度dim进行拼接
- tensors:张量序列
- dim:要拼接的维度
a = torch.Tensor([[1, 2], [3, 4]])
b = torch.Tensor([[5, 6], [7, 8]])
c = torch.cat([a, b], 0)
d = torch.cat([a, b], 1)
print(a)
print(a.shape)
print(b)
print(b.shape)
print(c)
print(c.shape)
print(d)
print(d.shape)
tensor([[1., 2.],
[3., 4.]])
torch.Size([2, 2])
tensor([[5., 6.],
[7., 8.]])
torch.Size([2, 2])
tensor([[1., 2.],
[3., 4.],
[5., 6.],
[7., 8.]])
torch.Size([4, 2])
tensor([[1., 2., 5., 6.],
[3., 4., 7., 8.]])
torch.Size([2, 4])
1.2 torch.stack()
torch.stack(tensors,
dim=0,
out=None)
功能:在新创建的维度dim上进行拼接
- tensors:张量序列
- dim:要拼接的维度
a = torch.Tensor([[1, 2], [3, 4]])
b = torch.Tensor([[5, 6], [7, 8]])
e = torch.stack([a, b], 0)
f = torch.stack([a, b], 1)
g = torch.stack([a, b], 2)
print(a)
print(a.shape)
print(b)
print(b.shape)
print(e)
print(e.shape)
print(f)
print(f.shape)
print(g)
print(g.shape)
tensor([[1., 2.],
[3., 4.]])
torch.Size([2, 2])
tensor([[5., 6.],
[7., 8.]])
torch.Size([2, 2])
tensor([[[1., 2.],
[3., 4.]],
[[5., 6.],
[7., 8.]]])
torch.Size([2, 2, 2])
tensor([[[1., 2.],
[5., 6.]],
[[3., 4.],
[7., 8.]]])
torch.Size([2, 2, 2])
tensor([[[1., 5.],
[2., 6.]],
[[3., 7.],
[4., 8.]]])
torch.Size([2, 2, 2])
1.3 torch.chunk()
torch.chunk(input,
chunks,
dim=0)
功能:将张量按维度dim进行平均切分
返回值:张量列表
注意事项:若不能整除,最后一份张量小于其他张量
- input:要切分的张量
- chunks:要切分的份数
- dim:要切分的维度
a = torch.Tensor([[1,2,3],[4,5,6]])
print(a)
print(a.shape)
# 使用chunks,沿着第0维进行分块,一共分两块,因此分割成两个1x3的Tensor
b = torch.chunk(a, 2, 0)
print(b)
# 沿着第1维进行分块,因此分割成两个Tensor,当不能整除时,最后一个的维数会小于前面的
# 因此第一个Tensor为2x2,第二个为2x1
c = torch.chunk(a, 2, 1)
print(c)
tensor([[1., 2., 3.],
[4., 5., 6.]])
torch.Size([2, 3])
(tensor([[1., 2., 3.]]), tensor([[4., 5., 6.]]))
(tensor([[1., 2.],
[4., 5.]]), tensor([[3.],
[6.]]))
1.4 torch.split()
torch.split(tensor,
split_size_or_sections,
dim=0)
功能:将张量按维度dim进行切分
返回值:张量列表
- tensor:要切分的张量
- split_size_or_sections:为int时,表示每一份的长度;为list时,按list元素切分
- dim:要切分的维度
a = torch.Tensor([[1,2,3],[4,5,6]])
print(a)
print(a.shape)
# 使用split,沿着第0维分块,每一块维度为2,由于第一维维度总共为2,因此相当于没有分割
b = torch.split(a, 2, 0)
print(b)
# 沿着第1维分块,每一块维度为2,因此第一个Tensor为2x2,第二个为2x1
c = torch.split(a, 2, 1)
print(c)
# split也可以根据属于的list进行自动分块,list中的元素代表了每一块占的维度
d = torch.split(a, [1, 2], 1)
print(d)
tensor([[1., 2., 3.],
[4., 5., 6.]])
torch.Size([2, 3])
(tensor([[1., 2., 3.],
[4., 5., 6.]]),)
(tensor([[1., 2.],
[4., 5.]]), tensor([[3.],
[6.]]))
(tensor([[1.],
[4.]]), tensor([[2., 3.],
[5., 6.]]))
二、张量索引
2.1 torch.index_select()
torch.index_select(input,
dim,
index,
out=None)
功能:在维度dim上,按index索引数据
返回值:依index索引数据拼接的张量
- input:要索引的张量
- dim:要索引的维度
- 要索引数据的序号
t = torch.randint(0, 9, size=(3, 3))
idx = torch.tensor([0, 2], dtype=torch.long)
t_select = torch.index_select(t, dim=0, index=idx)
print(t)
print(t_select)
tensor([[1, 4, 0],
[4, 8, 0],
[6, 6, 8]])
tensor([[1, 4, 0],
[6, 6, 8]])
2.2 torch.masked_select()
torch.masked_select(input,
mask,
out=None)
功能:按mask中的True进行索引
返回值:一维张量
- input:要索引的张量
- mask:与input同形状的布尔类型张量
t = torch.randint(0, 9, size=(3, 3))
mask = t.ge(5)
t_select = torch.masked_select(t, mask)
print(t)
print(mask)
print(t_select)
tensor([[6, 5, 8],
[7, 6, 7],
[1, 3, 3]])
tensor([[ True, True, True],
[ True, True, True],
[False, False, False]])
tensor([6, 5, 8, 7, 6, 7])
三、张量变换
3.1 torch.reshape()
torch.reshape(input,
shape)
功能:变换张量形状
注意事项:当张量在内存中是连续时,新张量与input共享数据内存
- input:要变换的张量
- shape:新张量的形状
t = torch.randperm(8)
print(t.shape)
t_reshape = torch.reshape(t,(2,4))
print(t_reshape.shape)
torch.Size([8])
torch.Size([2, 4])
3.2 torch.transpose()
torch.transpose(input,
dim0,
dim1)
功能:交换张量的两个维度
- input:要变换的张量
- dim0:要交换的维度
- dim1:要交换的维度
t = torch.rand((2, 3, 4))
t_transpose = torch.transpose(t, dim0=1, dim1=2)
print(t.shape)
print(t_transpose.shape)
torch.Size([2, 3, 4])
torch.Size([2, 4, 3])
3.3 torch.t()
torch.t(input)
功能:2维张量转置,对矩阵而言,等价于torch.transpose(input, 0, 1)
3.4 torch.squeeze()
torch.squeeze(input,
dim=None,
out=None)
功能:压缩长度为1的维度(轴)
- dim:若为None,移除所有长度为1的轴;若指定维度,当且仅当该轴长度为1时,可以被移除;
t = torch.rand((1, 2, 3, 1))
t_sq = torch.squeeze(t)
t_0 = torch.squeeze(t, dim=0)
t_1 = torch.squeeze(t, dim=1)
print(t.shape)
print(t_sq.shape)
print(t_0.shape)
print(t_1.shape)
torch.Size([1, 2, 3, 1])
torch.Size([2, 3])
torch.Size([2, 3, 1])
torch.Size([1, 2, 3, 1])
3.5 torch.unsqueeze()
torch.unsqueeze(input,
dim,
out=None)
功能:依据dim扩展维度
- dim:扩展的维度