步入新的学习阶段,Java转CV开始尝试使用Pytorch搭建神经网络过程中遇到的一些问题:
nn.Module (全:torch.nn.Module)
介绍:
torch.nn.Module 这个类的内部有多达 48 个函数,这个类是 PyTorch 中所有 neural network module的基类。nn.Module的nn模块提供了模型构造类,通过继承它来搭建自己的网络层。
以一个自定义的CNN为例:
import torch
from torch import nn
class CNNBlock(nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super(CNNBlock, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
self.batchnorm = nn.BatchNorm2d(out_channels)
self.leakyrelu = nn.LeakyReLU(0.1)
def forward(self, x):
return self.leakyrelu(self.batchnorm(self.conv(x)))
疑问 1 :为什么要写forward函数和__init__(),作用是什么?
- 作用:init()主要作用是定义基础的网络层,forward()则是实现各层网络的连接,
疑问 2 :怎么实现调用forward函数的?
简单解释:我们编写的模型所继承的nn.Module类中,其__call__方法内便包含了某种形式的对forward方法的调用,从而使得我们不需要显式地调用forward方法。
详细例子: 在[PyTorch]源码的torch/nn/modules/module.py
[地址:github.com/pytorch/pyt…] 文件中包含:
def _forward_unimplemented(self, *input: Any) -> None:
# Should be overridden by all subclasses.
raise NotImplementedError(f"Module [{type(self).__name__}] is missing the required \"forward\" function")
def _call_impl(self, *input, **kwargs):
...
forward: Callable[..., Any] = _forward_unimplemented
...
...
...
__call__ : Callable[…, Any] = _call_impl
首先从上面的源码可以看到如果继承nn.Module的子类没有重写forward方法,那么就会报错NotImplementedError
具体是怎么执行的呢?(见下面这个例子)
from typing import Callable, Any, List
def _forward_unimplemented(self, *input: Any) -> None:
"Should be overridden by all subclasses"
print("_forward_unimplemented")
raise NotImplementedError
Module中__call__后跟冒号表示类型注解;
Callable表示可调用类型,此处指的是_call_impl;
Any是一种特殊的类型,它与所有类型兼容;
Callable[…, Any]表示_call_impl可接受任意数量的参数并返回Any。
下面__call__实际指向了_call_impl函数
class Module:
def __init__(self):
print("Module初始化")
forward: Callable[..., Any] = _forward_unimplemented
def _call_impl(self, *input, **kwargs):
print("Module._call_impl")
out = self.forward(*input, **kwargs)
return out
__call__: Callable[..., Any] = _call_impl
class NN(Module):
def __init__(self):
print("NN初始化")
super(NN, self).__init__()
def forward(self, x):
print("调用NNforward")
return x
model = NN()
x: List[int] = [1, 2, 3]
print("result:", model(x))
print("test finish")
Result:
NN初始化
Module初始化
Module._call_impl
调用NNforward
result: [1, 2, 3]
test finish
执行流程:
1.model = NN() 首先执行语句print("NN初始化")并输出NN初始化
2.super(NN, self).__init__()初始化父类Module,并执行print("Module初始化"),输出MOdule初始化
- 执行
model(x),发现是先输出Module.__call__impl后输出调用NNforward。说明此时先进入Module中执行__call__: Callable[..., Any] = _call_impl(为什么执行__call__?下一个疑问中解答),然后调用Module中的def _call_impl(self, *input, **kwargs),然后在_call_impl函数中调用子类NN中实现的forward函数,此时输出调用NNforward
(Module中的forward的实现方式与__call__相同,但是_forward_unimplemented函数并没有实现体,调用它会触发Error即NotImplementedError。因此在子类NN中一定要给出forward的具体实现(那么就会调用子类中实现的forward函数),否则调用的将是_forward_unimplemented)
- 返回x,再返回out并输出。
疑问 3:为什么会自动执行__call__函数?
Python中的函数是一级对象。这意味着Python中的函数的引用可以作为输入传递到其他的函数/方法中,并在其中被执行。
而Python中类的实例(对象)可以被当做函数对待。
为了将一个类实例当做函数调用,我们需要在类中实现__call__()方法。也就是我们要在类中实现如下方法:def __call__(self, *args)。
举例:
class X(object):
def __init__(self, a, b, range):
self.a = a
self.b = b
self.range = range
print("初始化")
def __call__(self, a, b):
self.a = a
self.b = b
print('__call__ with ({}, {})'.format(self.a, self.b))
xInstance = X(1, 2, 3)
xInstance(1,2)
Result:
初始化
__call__ with (1, 2)
分析:xInstance = X(1, 2, 3)对类的实例化,此时并不会调用__call__函数。
但是当xInstance(1,2)执行时,实例化的对象作为函数对待,会调用自身的__call__函数。