构建RNN代码过程中遇到函数总结1

136 阅读2分钟

引言

  • 本文为个人记录学习使用pytorch来构建RNN代码过程中遇到的一些函数用法总结
  • 本文主要包含pytorch中torch.stack(),以及numpy中np.linspace()

torch.stack()

  • stack()函数在pytorch中的主要作用是进行张量拼接。

    • 拼接的张量形状需要都是一致的,且只允许是序列,函数作用就是会沿着一个新维度对输入的张量序列进行拼接。
    • 通俗来讲就是,把多个2维的凑成一个3维的,以此类推,不懂没关系,看下面我的测试结果即可
  • outputs = torch.stack(inputs, dim=0)

    • inputs : 待连接的张量序列
    • dim : 新的维度, 注意维度不要超过界限
  • 测试代码以及测试结果如下所示

import torch
T1 = torch.tensor([[1, 2, 3],
                 [4, 5, 6],
                 [7, 8, 9]])
T2 = torch.tensor([[10, 20, 30],
                 [40, 50, 60],
                 [70, 80, 90]])
T3 = torch.stack((T1,T2),dim=0)
print(T3,T3.shape)

np.linspace()

  • 该方法的功能是生成一个指定大小,指定数据区间的均匀分布序列
  • numpy.linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None)
    • start:生成序列中数据的上界。
    • end:生成序列中数据的下界。
    • num:生成序列包含元素的个数;默认为50个。
    • endpoint:取True时,序列包含最大值end;否则不包含;其值默认为True。
    • retstep:该值取True时,生成的序列中显示间距;反正不显示;其值默认为false。
    • dtype:数据类型,可以指定生成序列的数据类型;当为None时,根据其他输入推断数据类型。
    • 返回值:是一个数组。
  • 具体测试代码以及测试结果如下图所示
import numpy as np
# 参数num的值默认为50
T1 = np.linspace(1, 10)
print(T1, len(T1), type(T1))

# 生成[1,10]之间元素个数为10的序列,参数endpoint默认为true,而参数retstep默认为false
T2 = np.linspace(1, 10,10)
print(T2, len(T2), type(T2))

# 设置参数retstep为true
T3 = np.linspace(1, 10, 10, endpoint=False)
print(T3, len(T3), type(T3))

# 生成[1,10)之间元素个数为10的整数序列
T4 = np.linspace(1, 10, 10, dtype=int)
print(T4, len(T4), type(T4))
  • 本文只是对这两个函数的简单总结,仅供参考