PyTorch 中常用的数据基础操作

174 阅读1分钟

1.torch.arange()生成一个一维张量

x=torch.arange(12,dtype=torch.float32)

tensor([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11.])

2.shape 可以返回一个tensor的维度

x.shape

torch.Size([12])

3.reshape 改变一个tensor的维度

x=x.reshape((2,6))

tensor([[ 0.,  1.,  2.,  3.,  4.,  5.],
        [ 6.,  7.,  8.,  9., 10., 11.]])

4.sum()对张量中的所有元素进行求和,返回一个具有一个元素的张量

sum=x.sum()
tensor(66.)

5.torch.zeros() 生成一个全为0的tensor

6.torch.cat(x,y,dim=0) 将两个张量的维度叠加

x=torch.zeros((2,3))
y=torch.ones((2,4))
out=torch.cat((x,y),dim=1)

tensor([[0., 0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1., 1.]])

#dim=1是在列上的维度进行叠加,dim=0实在行上的维度进行叠加

6.广播机制 就比如一一个维度为(1,2)的张量与(2,1)的张量相加 会变成(2,2)维度的

x=torch.arange(2)
y=torch.arange(2).reshape(2,1)
z=x+y

tensor([[0, 1],
        [1, 2]])