创建文件夹命名为 oneNet ,然后在文件夹下创建一个 main.py 文件,One night in beijing, 当然不是这个意思,也不是 one night 写出来的,这里 one net 是简单简单的一个网络,呵呵,就这么简单吗?
如下代码
if __name__ == "__main__":
print("hello One Net")
设计一个好的应用我们需要考虑的事情很多,我们需要加载数据集、定义损失函数、选择优化器以及设置必要超参数。所以我们需要管家帮助我们管理这件事。我们希望这些做的灵活一些,同时有希望他们都遵守一定规则便于后期横向对比和调试。所以可是定义接口
创建个 src 文件夹,在 src 下创建一个 engine
src
---engine
---utils
这个类主要协助我们更流畅地完成训练,是训练哦,因为不同网络存在差异所以这个定义为一个抽象类。
from abc import ABCMeta, abstractmethod
class BaseEngine(metaclass=ABCMeta):
def __init__(self,
model,
optimizer=None,
work_dir=None):
pass
train(),val() 和 save_checkpoint() 这些方法需要具体类去实现。
@abstractmethod
def train(self):
pass
@abstractmethod
def val(self):
pass
@abstractmethod
def save_checkpoint(self,out_dir,save_optimizer=True,meta=None):
pass
有了大概想法,我们还是先具体然后再抽象吧,我们去弄一个 mobileNetv3 ,现在网络越来越复杂,需要配置内容越来越多,所以设计一个配置样式显得格外重要,我们这里参考 open-mmlab 方式来配置
configs
---base
------datasets
------models
------schedules
---deeplabv3plus
有的时候想的太多,反而无法推进,我们还是以一种倒叙方式来,先把最后想要效果写出来然后一个一个实现。
if __name__ == "__main__":
print("hello One Net")
# 这些参数用户在调用 main.py 时传入的
args = parse_args()
import os
import os.path as osp
import argparse
import copy
def parse_args():
parser = argparse.ArgumentParser(description='Train a segmentor')
parser.add_argument('--config', default='configs/deeplabv3plus/deeplabv3plus.py',help='train config file path')
args = parser.parse_args()
return args
if __name__ == "__main__":
print("hello One Net")
# 这些参数用户在调用 main.py 时传入的
args = parse_args()
这里我们加载一个配置文件,这个配置文件放置 onfigs/deeplabv3plus 目录下
configs
---base
------datasets
------models
---------deeplabv3plus.py
------schedules
---deeplabv3plus
------deeplabv3plus.py
我们现在就去实现一个 Config 类,将这个类定义在 src.utils 下 config 文件,然后将其导入到 main.py 文件
from src.utils import Config
...
if __name__ == "__main__":
print("hello One Net")
# 这些参数用户在调用 main.py 时传入的
args = parse_args()
cfg = Config.fromfile(args.config)
import os
import os.path as osp
import tempfile
class Config:
@staticmethod
def fromfile(filename):
Config._file2dict(filename)
@staticmethod
def _file2dict(filename):
filename = osp.abspath(osp.expanduser(filename))
# print(osp.expanduser(filename))
# print(filename)
fileExtname = osp.splitext(filename)[1]
# 感觉现在支持类型也需要控制
if fileExtname not in ['.py']:
raise IOError('Only py type are supported now!')
filename = osp.abspath(osp.expanduser(filename))
- 首先获取配置文件绝对路径
- 然后获取扩展名,通过扩展名来判断配置文件类型是否正确
该模块创建临时文件和目录,同时支持 linux 和 windows 平台。提供支持自动清理的 TemporaryFile, NamedTemporaryFile、TemporaryDirectory 和 SpooledTemporaryFile 的高级接口,可以作为上下文管理器使用。
@staticmethod
def _file2dict(filename):
filename = osp.abspath(osp.expanduser(filename))
# print(osp.expanduser(filename))
# print(filename)
fileExtname = osp.splitext(filename)[1]
# 感觉现在支持类型也需要控制
if fileExtname not in ['.py']:
raise IOError('Only py type are supported now!')
with tempfile.TemporaryDirectory() as temp_config_dir:
temp_config_file = tempfile.NamedTemporaryFile(
dir=temp_config_dir, suffix=fileExtname)
if platform.system() == 'Windows':
temp_config_file.close()
temp_config_name = osp.basename(temp_config_file.name)
print(temp_config_name)
今天暂时写到这里,待续。