阅读 317

NumPy 用户的 PyTorch 速查表

原文链接: github.com

Build Status

PyTorch version of Torch for Numpy users.
We assume you use the latest PyTorch and Numpy.

Types

Numpy PyTorch
np.ndarray torch.Tensor or torch.tensor
np.float32 torch.FloatTensor or torch.tensor(x, dtype = torch.float)
np.float64 torch.DoubleTensor or torch.tensor(x, dtype = torch.double
np.float16 torch.HalfTensor or torch.tensor(x, dtype = torch.half)
np.int8 torch.CharTensor or torch.tensor(x, dtype = torch.int8)
np.uint8 torch.ByteTensor or torch.tensor(x, dtype = torch.uint8)
np.int16 torch.ShortTensor or torch.tensor(x, dtype = torch.short)
np.int32 torch.IntTensor or torch.tensor(x, dtype = torch.int)
np.int64 torch.LongTensor or torch.tensor(x, dtype = torch.long)

Constructors

Ones and zeros

Numpy PyTorch
np.empty((2, 3)) torch.empty((2, 3))
np.empty_like(x) torch.empty_like(x)
np.eye torch.eye
np.identity torch.eye
np.ones torch.ones
np.ones_like torch.ones_like
np.zeros torch.zeros
np.zeros_like torch.zeros_like

From existing data

Numpy PyTorch
np.array([[1, 2], [3, 4]]) torch.tensor([[1, 2], [3, 4])
x.copy() x.clone()
np.fromfile(file) torch.tensor(torch.Storage(file))
np.frombuffer
np.fromfunction
np.fromiter
np.fromstring
np.load torch.load
np.loadtxt
np.concatenate torch.cat

Numerical ranges

Numpy PyTorch
np.arange(10) torch.arange(10)
np.arange(2, 3, 0.1) torch.arange(2, 3, 0.1)
np.linspace torch.linspace
np.logspace torch.logspace

Building matrices

Numpy PyTorch
np.diag torch.diag
np.tril torch.tril
np.triu torch.triu

Attributes

Numpy PyTorch
x.shape x.shape
x.strides x.stride()
x.ndim x.dim()
x.data x.data
x.size x.nelement()
x.dtype x.dtype

Indexing

Numpy PyTorch
x[0] x[0]
x[:, 0] x[:, 0]
x[indices] x[indices]
np.take(x, indices) torch.take(x, torch.LongTensor(indices))
x[x != 0] x[x != 0]

Shape manipulation

Numpy PyTorch
x.reshape x.reshape or x.view
x.resize() x.resize_
x.resize_as_
x.transpose x.transpose or x.permute
x.flatten x.view(-1)
x.squeeze() x.squeeze()
x[:, np.newaxis] or np.expand_dims(x, 1) x.unsqueeze(1)

Item selection and manipulation

Numpy PyTorch
np.put
x.put x.put_
x.repeat x.repeat
np.tile
np.choose
np.sort sorted, indices = torch.sort(x, [dim])
np.argsort sorted, indices = torch.sort(x, [dim])
np.nonzero torch.nonzero
np.where torch.where
x[::-1] a workaround

Calculation

Numpy PyTorch
x.min x.min
x.argmin x.argmin
x.max x.max
x.argmax x.argmax
x.clip x.clamp
x.round x.round
np.floor(x) torch.floor(x) or x.floor()
np.ceil(x) torch.ceil() or x.ceil()
x.trace x.trace
x.sum x.sum
x.cumsum x.cumsum
x.mean x.mean
x.std x.std
x.prod x.prod
x.cumprod x.cumprod
x.all (x == 1).sum() == x.nelement()
x.any (x == 1).sum() > 0

Arithmetic and comparison operations

Numpy PyTorch
np.less x.lt
np.less_equal x.le
np.greater x.gt
np.greater_equal x.ge
np.equal x.eq
np.not_equal x.ne