Pytorch常见函数

134 阅读3分钟

持续创作,加速成长!这是我参与「掘金日新计划 · 10 月更文挑战」的第6天,点击查看活动详情

Pytorch

PyTorch是一个开源的Python机器学习库,基于Torch,可用于计算机视觉、自然语言处理等任务。
PyTorch是相当简洁且高效快速的框架,设计追求最少的封装,符合人类思维,让用户尽可能地专注于实现自己的想法。
在初次阅读论文的源代码的时候,常常出现了很多基础的函数却不了解它们的功能,本文记录了一些代码阅读中的遇到的函数的功能和参数等,以及一些概念,以便后续阅读。

基本函数

*args和**kwargs

  • *args(arguments):位置参数,不定长元组tuple
  • **kwargs(keyword arguments):关键字参数,不定长字典dict
  • 位置参数必须放在关键字参数的前面,即*args 放在 **kwargs 前面

列表前面加星号(*)

功能:将list解开成几个独立的参数,传入函数

# 示例
nn.Sequential(*layers)

add(a, b)
list = [1, 2]
c = add(*list)
# ->返回3。即list[0], list[1]分别作为add的两个参数

functools.partial()

功能:基于函数构造新函数;基于第一个参数“函数1”和第二个参数来构造新函数,使用变量func即可使用新函数

# 示例
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
func = functools.partial(函数1, 函数1的参数)

torch.cat()

功能:将两个tensor拼接在一起;0表示按行竖着拼接,A上B下;1表示按行横着拼接,A左B右

# 示例
x2 = torch.cat((x, x1), 1)

torch.cat((tensor1, tensor2), 拼接方式)

import_module(str)

功能:将str指向的文件导入为模块

# 示例
importlib.import_module(str)

getattr()

功能:获取object中的attr属性,未获取到则返回default 注:module中的属性为class等,class中的属性为类和变量等

getattr(object, attr[, default])

文件路径处理

import os.path
# 获取当前文件的绝对路径
realpath(__file__)
abspath(__file__)
# 获取str文件的文件目录:最后一个/之前的所有内容
dirname(str)
# 获取最后一个/之后的所有内容
basename(str)
# 分离文件名与扩展名
list = splitext(str)

基本概念

分布式(distributed)并行

  • node:物理节点,即机器,内部可有多个GPU
  • rank:整个分布式任务中的进程序号
  • local-rank:一个node(机器)上进程的相对序号

nn.Module

nn.Module:所有神经网络模块的基类,在定义自已的网络的时候,需要继承于此类,并重新实现构造函数__init__()和模型功能forward()两个方法。

python中的可调用对象

  • 调用运算符:() 使用object()可以调用此对象
  • 类:调用时运行类的 new ()创建一个实例,然后运行 init (),初始化实例,最后把实例返回给调用方。
  • 类的实例:如果类定义了 call(),那么它的实例可作为函数调用,调用的就是类中的__call__()
# 调用Net类,创建类的实例net
net = Net() 
# 调用类的实例net,本质为调用Net类中的__call__()
net() 

forword()如何被调用?

  • Net类继承于nn.Module类
  • net() = 调用net的__call__(),若net的__call__()没有显式定义,则使用它的父类nn.Module的__call__()
  • nn.Module的__call__()中调用了forward(),又因net类中定义了forward(),所以使用net类中重写的forward方法,即调用net中的forward()。