持续创作,加速成长!这是我参与「掘金日新计划 · 10 月更文挑战」的第6天,点击查看活动详情
Pytorch
PyTorch是一个开源的Python机器学习库,基于Torch,可用于计算机视觉、自然语言处理等任务。
PyTorch是相当简洁且高效快速的框架,设计追求最少的封装,符合人类思维,让用户尽可能地专注于实现自己的想法。
在初次阅读论文的源代码的时候,常常出现了很多基础的函数却不了解它们的功能,本文记录了一些代码阅读中的遇到的函数的功能和参数等,以及一些概念,以便后续阅读。
基本函数
*args和**kwargs
- *args(arguments):位置参数,不定长元组tuple
- **kwargs(keyword arguments):关键字参数,不定长字典dict
- 位置参数必须放在关键字参数的前面,即*args 放在 **kwargs 前面
列表前面加星号(*)
功能:将list解开成几个独立的参数,传入函数
# 示例
nn.Sequential(*layers)
add(a, b)
list = [1, 2]
c = add(*list)
# ->返回3。即list[0], list[1]分别作为add的两个参数
functools.partial()
功能:基于函数构造新函数;基于第一个参数“函数1”和第二个参数来构造新函数,使用变量func即可使用新函数
# 示例
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
func = functools.partial(函数1, 函数1的参数)
torch.cat()
功能:将两个tensor拼接在一起;0表示按行竖着拼接,A上B下;1表示按行横着拼接,A左B右
# 示例
x2 = torch.cat((x, x1), 1)
torch.cat((tensor1, tensor2), 拼接方式)
import_module(str)
功能:将str指向的文件导入为模块
# 示例
importlib.import_module(str)
getattr()
功能:获取object中的attr属性,未获取到则返回default 注:module中的属性为class等,class中的属性为类和变量等
getattr(object, attr[, default])
文件路径处理
import os.path
# 获取当前文件的绝对路径
realpath(__file__)
abspath(__file__)
# 获取str文件的文件目录:最后一个/之前的所有内容
dirname(str)
# 获取最后一个/之后的所有内容
basename(str)
# 分离文件名与扩展名
list = splitext(str)
基本概念
分布式(distributed)并行
- node:物理节点,即机器,内部可有多个GPU
- rank:整个分布式任务中的进程序号
- local-rank:一个node(机器)上进程的相对序号
nn.Module
nn.Module:所有神经网络模块的基类,在定义自已的网络的时候,需要继承于此类,并重新实现构造函数__init__()和模型功能forward()两个方法。
python中的可调用对象
- 调用运算符:() 使用object()可以调用此对象
- 类:调用时运行类的 new ()创建一个实例,然后运行 init (),初始化实例,最后把实例返回给调用方。
- 类的实例:如果类定义了 call(),那么它的实例可作为函数调用,调用的就是类中的__call__()
# 调用Net类,创建类的实例net
net = Net()
# 调用类的实例net,本质为调用Net类中的__call__()
net()
forword()如何被调用?
- Net类继承于nn.Module类
- net() = 调用net的__call__(),若net的__call__()没有显式定义,则使用它的父类nn.Module的__call__()
- nn.Module的__call__()中调用了forward(),又因net类中定义了forward(),所以使用net类中重写的forward方法,即调用net中的forward()。