【UniAD】python装饰器和mmcv的Registry注册机制

178 阅读4分钟

一、python装饰器的功能

先举一个简单的python装饰器的例子。在python环境下输入:

def deco(cls):
    # 为函数开辟了空间,并保存了函数指针
    cls.new_method = lambda self: print("add a new method by decorator")
    # 但现在还不能打印 cls.n
    return cls
    
@deco
class cls1(object):
    def __init__(self):
        self.n = "class cls1"
    
obj = cls1()
obj.new_method()

执行代码,输出:add a new method by decorator

装饰器:@deco,可以作用在函数或类名之前。

装饰器的功能是,为函数F或类C外部套一个壳,在原有FC的外部增加一些新的功能,同时又不改变FC原有的功能。当调用FC时,调用的是原有功能+增量功能,但调用的名字不发生变化。

二、python装饰器的运行机制探讨

在上述例子里,第3行被输入时,就已经给new_method函数创建了空间。

当第7行被输入后,python解释器就已经识别到了deco其实是一个装饰器(也是一个函数),然后就去执行第1行的deco。同时,携带了第8行cls1类定义的指针。

然后,第5行返回时,把类定义指针重新赋值给第8行的cls1,使cls1先有了第一个new_method函数。

当第9行被输入完成,cls1类增加了第二个函数--初始化函数。

举一个Openmmlab官方的例子:

_module_dict = dict()

# 定义装饰器函数
def register_module(name):
    def _register(cls):
        _module_dict[name] = cls
        return cls

    return _register

# 装饰器用法
@register_module('one_class')
class OneTest(object):
    pass

@register_module('two_class')
class TwoTest(object):
    pass
    

if __name__ == '__main__':
    # 通过注册类名实现自动实例化功能
    one_test = _module_dict['one_class']()
    print(one_test)

# 输出
<__main__.OneTest object at 0x7f1d7c5acee0>

第12行:输入装饰器,将执行第4行,并把'one_class'作为参数赋值给name,装饰器返回的是,函数_register(OneTest)的返回值。

在函数_register(OneTest)中,执行了_module_dict['one_class'] = OneTest,给字典增加元素了。

现在_module_dict的内容是{'one_class': OneTest},其中OneTest是类定义的指针。装饰器返回的也是类定义的指针。

第16行的作用和第12行完全相同,给字典新增了元素。此时_module_dict的内容是{'one_class': OneTest,'two_class': TwoTest}

三、mmcv的注册机制Registry

mmcv注册机制的核心方法是使用了python装饰器。

mmcv的目的,是为了节省大模型或深度学习模型建模的时间,避免大家重复造轮子或修轮子的bug。科学家和工程师应当把更多的精力放在模型设计和验证的工作上。

使用了注册机制,就方便把mmcv的轮子管理起来。使模型设计和代码实现,过程尽可能解耦。

注册机制最核心的代码是Registry类定义。这个类的本质是维护了一个字典,这个字典里包含神经网络模块的名称和对应的类指针。

class Registry:
    def __init__(self, name):
        # 可实现注册类细分功能
        self._name = name 
        # 内部核心内容,维护所有的已经注册好的 class
        self._module_dict = dict()

    def _register_module(self, module_class, module_name=None, force=False):
        if not inspect.isclass(module_class):
            raise TypeError('module must be a class, '
                            f'but got {type(module_class)}')

        if module_name is None:
            module_name = module_class.__name__
        if not force and module_name in self._module_dict:
            raise KeyError(f'{module_name} is already registered '
                           f'in {self.name}')
        # 最核心代码
        self._module_dict[module_name] = module_class

    # 装饰器函数
    def register_module(self, name=None, force=False, module=None):
        if module is not None:
            # 如果已经是 module,那就知道 增加到字典中即可
            self._register_module(
                module_class=module, module_name=name, force=force)
            return module

        # 最标准用法
        # use it as a decorator: @x.register_module()
        def _register(cls):
            self._register_module(
                module_class=cls, module_name=name, force=force)
            return cls
        return _register

官方文档给了一个简单的例子:

CONVERTERS = Registry('converter')

@CONVERTERS.register_module()
class Converter1(object):
    def __init__(self, a, b):
        self.a = a
        self.b = b

converter_cfg = dict(type='Converter1', a=a_value, b=b_value)
converter = build_from_cfg(converter_cfg,CONVERTERS)

第1行:构造一个名叫CONVERTERS的字典,并增加一个名为'converter'的模块。name被赋值'converter'。

第3行:执行Rigistry类的第22行,带上cls=Converter1的参数。然后在CONVERTERS字典里增加元素{'converter': Converter1},其中,Converter1是类定义指针。

第4行:神经网络模块Converter1的定义。

第5行:神经网络的层次定义。(本例子里没有神经网络的结构定义、前向传播方法等)

第10行:根据配置文件和模块字典,构造神经网络。

因为通过字典可以索引到模块的指针,所以只需要修改配置文件,想用什么模块,就在配置文件里直接改。整个过程字典不变。模型需要调整结构的时候,也能快速调整而不用修bug,达到了即插即用的目的。