Pytorch 搭建神经网络使用 nn.Module 遇到的一些疑问和解答

704 阅读3分钟

步入新的学习阶段,Java转CV开始尝试使用Pytorch搭建神经网络过程中遇到的一些问题:

nn.Module (全:torch.nn.Module)

介绍:

torch.nn.Module 这个类的内部有多达 48 个函数,这个类是 PyTorch 中所有 neural network module的基类。nn.Module的nn模块提供了模型构造类,通过继承它来搭建自己的网络层。

以一个自定义的CNN为例:

import torch
from torch import nn

class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super(CNNBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.batchnorm = nn.BatchNorm2d(out_channels)
        self.leakyrelu = nn.LeakyReLU(0.1)

    def forward(self, x):
        return self.leakyrelu(self.batchnorm(self.conv(x)))

疑问 1 :为什么要写forward函数和__init__(),作用是什么?

  • 作用:init()主要作用是定义基础的网络层,forward()则是实现各层网络的连接,

疑问 2 :怎么实现调用forward函数的?

简单解释:我们编写的模型所继承的nn.Module类中,其__call__方法内便包含了某种形式的对forward方法的调用,从而使得我们不需要显式地调用forward方法。

详细例子: 在[PyTorch]源码的torch/nn/modules/module.py

[地址:github.com/pytorch/pyt…] 文件中包含:


def _forward_unimplemented(self, *input: Any) -> None:
    # Should be overridden by all subclasses.
    raise NotImplementedError(f"Module [{type(self).__name__}] is missing the required \"forward\" function")

def _call_impl(self, *input, **kwargs):
...
      
forward: Callable[..., Any] = _forward_unimplemented
...
...
...
__call__ : Callable[…, Any] = _call_impl
    

首先从上面的源码可以看到如果继承nn.Module的子类没有重写forward方法,那么就会报错NotImplementedError

具体是怎么执行的呢?(见下面这个例子)

from typing import Callable, Any, List

def _forward_unimplemented(self, *input: Any) -> None:
    "Should be overridden by all subclasses"
    print("_forward_unimplemented")
    raise NotImplementedError
    
Module中__call__后跟冒号表示类型注解;
Callable表示可调用类型,此处指的是_call_impl;
Any是一种特殊的类型,它与所有类型兼容;
Callable[…, Any]表示_call_impl可接受任意数量的参数并返回Any。
下面__call__实际指向了_call_impl函数   
    
class Module:

    def __init__(self):
        print("Module初始化")
    
    forward: Callable[..., Any] = _forward_unimplemented
        
    def _call_impl(self, *input, **kwargs):
        print("Module._call_impl")
        out = self.forward(*input, **kwargs)
        return out

    __call__: Callable[..., Any] = _call_impl

    
class NN(Module):

    def __init__(self):
        print("NN初始化")
        super(NN, self).__init__()

    def forward(self, x):
        print("调用NNforward")
        return x

model = NN()
x: List[int] = [1, 2, 3]
print("result:", model(x))
print("test finish")


Result:

NN初始化
Module初始化
Module._call_impl
调用NNforward
result: [1, 2, 3]
test finish

执行流程:

1.model = NN() 首先执行语句print("NN初始化")并输出NN初始化

2.super(NN, self).__init__()初始化父类Module,并执行print("Module初始化"),输出MOdule初始化

  1. 执行model(x),发现是先输出Module.__call__impl后输出调用NNforward。说明此时先进入Module中执行__call__: Callable[..., Any] = _call_impl为什么执行__call__?下一个疑问中解答),然后调用Module中的def _call_impl(self, *input, **kwargs),然后在_call_impl函数中调用子类NN中实现的forward函数,此时输出调用NNforward

Module中的forward的实现方式与__call__相同,但是_forward_unimplemented函数并没有实现体,调用它会触发Error即NotImplementedError。因此在子类NN中一定要给出forward的具体实现(那么就会调用子类中实现的forward函数),否则调用的将是_forward_unimplemented

  1. 返回x,再返回out并输出。

疑问 3:为什么会自动执行__call__函数?

Python中的函数是一级对象。这意味着Python中的函数的引用可以作为输入传递到其他的函数/方法中,并在其中被执行。
而Python中类的实例(对象)可以被当做函数对待。 为了将一个类实例当做函数调用,我们需要在类中实现__call__()方法。也就是我们要在类中实现如下方法:def __call__(self, *args)

举例:

class X(object):
	def __init__(self, a, b, range):
		self.a = a
		self.b = b
		self.range = range
                print("初始化")
	def __call__(self, a, b):
		self.a = a
		self.b = b
		print('__call__ with ({}, {})'.format(self.a, self.b))


xInstance = X(1, 2, 3)
xInstance(1,2) 

Result:

初始化
__call__ with1, 2

分析:xInstance = X(1, 2, 3)对类的实例化,此时并不会调用__call__函数。 但是当xInstance(1,2)执行时,实例化的对象作为函数对待,会调用自身的__call__函数。

参考文章:

1 Python中的__init__()和__call__()函数

2 PyTorch中nn.Module类中__call__方法介绍