OneNet 诞生记(1)

259 阅读3分钟

创建文件夹命名为 oneNet ,然后在文件夹下创建一个 main.py 文件,One night in beijing, 当然不是这个意思,也不是 one night 写出来的,这里 one net 是简单简单的一个网络,呵呵,就这么简单吗?

如下代码

if __name__ == "__main__":
    print("hello One Net")

设计一个好的应用我们需要考虑的事情很多,我们需要加载数据集、定义损失函数、选择优化器以及设置必要超参数。所以我们需要管家帮助我们管理这件事。我们希望这些做的灵活一些,同时有希望他们都遵守一定规则便于后期横向对比和调试。所以可是定义接口

创建个 src 文件夹,在 src 下创建一个 engine

src
---engine
---utils

这个类主要协助我们更流畅地完成训练,是训练哦,因为不同网络存在差异所以这个定义为一个抽象类。

from abc import ABCMeta, abstractmethod

class BaseEngine(metaclass=ABCMeta):

    def __init__(self,
        model,
        optimizer=None,
        work_dir=None):
        pass

train(),val() 和 save_checkpoint() 这些方法需要具体类去实现。

    @abstractmethod
    def train(self):
        pass

    @abstractmethod
    def val(self):
        pass

    @abstractmethod
    def save_checkpoint(self,out_dir,save_optimizer=True,meta=None):
        pass

有了大概想法,我们还是先具体然后再抽象吧,我们去弄一个 mobileNetv3 ,现在网络越来越复杂,需要配置内容越来越多,所以设计一个配置样式显得格外重要,我们这里参考 open-mmlab 方式来配置

configs
---base
------datasets
------models
------schedules
---deeplabv3plus

有的时候想的太多,反而无法推进,我们还是以一种倒叙方式来,先把最后想要效果写出来然后一个一个实现。

if __name__ == "__main__":
    print("hello One Net")
    # 这些参数用户在调用 main.py 时传入的
    args = parse_args()
import os
import os.path as osp
import argparse
import copy

def parse_args():

    parser = argparse.ArgumentParser(description='Train a segmentor')
    parser.add_argument('--config', default='configs/deeplabv3plus/deeplabv3plus.py',help='train config file path')

    args = parser.parse_args()

    return args


if __name__ == "__main__":
    print("hello One Net")
    # 这些参数用户在调用 main.py 时传入的
    args = parse_args()

这里我们加载一个配置文件,这个配置文件放置 onfigs/deeplabv3plus 目录下

configs
---base
------datasets
------models
---------deeplabv3plus.py
------schedules
---deeplabv3plus
------deeplabv3plus.py

我们现在就去实现一个 Config 类,将这个类定义在 src.utils 下 config 文件,然后将其导入到 main.py 文件


from src.utils import Config

...

if __name__ == "__main__":
    print("hello One Net")
    # 这些参数用户在调用 main.py 时传入的
    args = parse_args()
    cfg = Config.fromfile(args.config)
import os
import os.path as osp

import tempfile

class Config:
    
    @staticmethod
    def fromfile(filename):
        Config._file2dict(filename)

    @staticmethod
    def _file2dict(filename):
        filename = osp.abspath(osp.expanduser(filename))
        # print(osp.expanduser(filename))
        # print(filename)

        fileExtname = osp.splitext(filename)[1]

        # 感觉现在支持类型也需要控制
        if fileExtname not in ['.py']:
            raise IOError('Only py type are supported now!')

filename = osp.abspath(osp.expanduser(filename))
  • 首先获取配置文件绝对路径
  • 然后获取扩展名,通过扩展名来判断配置文件类型是否正确

该模块创建临时文件和目录,同时支持 linux 和 windows 平台。提供支持自动清理的 TemporaryFile, NamedTemporaryFile、TemporaryDirectory 和 SpooledTemporaryFile 的高级接口,可以作为上下文管理器使用。

    @staticmethod
    def _file2dict(filename):
        filename = osp.abspath(osp.expanduser(filename))
        # print(osp.expanduser(filename))
        # print(filename)

        fileExtname = osp.splitext(filename)[1]

        # 感觉现在支持类型也需要控制
        if fileExtname not in ['.py']:
            raise IOError('Only py type are supported now!')
        with tempfile.TemporaryDirectory() as temp_config_dir:
            temp_config_file = tempfile.NamedTemporaryFile(
                dir=temp_config_dir, suffix=fileExtname)
            if platform.system() == 'Windows':
                temp_config_file.close()

            temp_config_name = osp.basename(temp_config_file.name)
            print(temp_config_name)

今天暂时写到这里,待续。