深度学习也有架构模板噢!
github项目:github.com/MrGemy95/Te…
项目结构
文件结构
┃━━ base
┃ ┃━━ base_model.py - model的抽象类
┃ ┗━━ base_train.py - trainer的抽象类
┃
┃━━ model - 存放你所有model类(网络结构)
┃ ┗━━ your_model.py
┃
┃━━ trainer - 存放你项目的训练程序
┃ ┗━━ your_trainer.py
┃
┃━━ mains - 存放各种入口程序
┃ ┗━━ your_main.py
┃
┃━━ data_loader - 存放数据加载和预处理程序
┃ ┗━━ data_generator.py
┃
┃━━ configs - 存放配置信息文件
┃ ┗━━ config.yml
┃
┃━━ checkpoints - 存放检查点模型和参数文件
┃
┃
┃━━ prediction - 存放预测程序
┃
┗━━ util - 存放所有工具类,例如日志(tensorboard)、加载配置参数类等
┃━━ logger.py
┃━━ config.py
┗━━ utils.py
主要组件
Models
-
Base model
Base是一个抽象类,你需要创建一个类继承它。之所以有一个Base model抽象类是因为所有模型之间有许多相同的特性(方法)。base model包括以下特性(方法):
- Save - 保存检查点(checkpoints)。
- Load - 加载检查点。
- Init_Saver - 抽象方法,用于初始化保存和加载检查点。注意:需要重写(override)该方法。
- Bulid_model - 抽象方法,用于定义模型。注意:需要重写(override)该方法。
-
Your model
你在这里实现你的模型。
- 创建你的model类并继承base_model类。
- 重写bulid_model方法。
- 重写init_save方法。
- 构造方法里调用bulid_model和init_save。
Trainer
-
Base Trainer
Base trainer是一个抽象类。
-
Your Trainer
在这里实现你的训练过程。
- 创建你的trainer类并继承base_trainer类
- 重写base_trainer类的方法
Data Loader
这个类负责所有数据的加载和预处理,并提供简单接口给trainer类。
Logger
这个类用于tensorboard summary,在你的trainer类中将创建一个所有你想summarize的tensorflow变量的目录,并将这个路径给logger.summarize()
Configuration
编写所有配置文件,然后使用utils中的config.py进行解析。
Main
在这里编写各种入口程序
训练
- 传递配置信息
- 创建tensorflow session
- 创建Model、Data_Generator、logger实例,并将配置参数传递给它们。
- 创建一个Trainer实例,将前面的所有对象传递给它。
- 调用Trainer.train()开始训练。
预测
- 传递配置信息
- 加载Model、Data_Generator,将配置信息传递给它们。
- 创建一个prediction实例,将前面的对象传递给它
- 调用prediction.predict()开始预测。
参考