深度学习项目架构

1,671 阅读2分钟

深度学习也有架构模板噢!
github项目:github.com/MrGemy95/Te…

项目结构


文件结构

┃━━ base
┃     ┃━━ base_model.py          - model的抽象类
┃     ┗━━ base_train.py          - trainer的抽象类
┃
┃━━ model                        - 存放你所有model类(网络结构)
┃     ┗━━ your_model.py          
┃
┃━━ trainer                      - 存放你项目的训练程序
┃     ┗━━ your_trainer.py
┃
┃━━ mains                        - 存放各种入口程序
┃     ┗━━ your_main.py           
┃
┃━━ data_loader                  - 存放数据加载和预处理程序
┃     ┗━━ data_generator.py
┃
┃━━ configs                      - 存放配置信息文件     
┃     ┗━━ config.yml
┃
┃━━ checkpoints                  - 存放检查点模型和参数文件
┃     
┃
┃━━ prediction                   - 存放预测程序     
┃
┗━━ util                         - 存放所有工具类,例如日志(tensorboard)、加载配置参数类等
      ┃━━ logger.py
      ┃━━ config.py
      ┗━━ utils.py

主要组件

Models

  • Base model

    Base是一个抽象类,你需要创建一个类继承它。之所以有一个Base model抽象类是因为所有模型之间有许多相同的特性(方法)。base model包括以下特性(方法):

    • Save - 保存检查点(checkpoints)。
    • Load - 加载检查点。
    • Init_Saver - 抽象方法,用于初始化保存和加载检查点。注意:需要重写(override)该方法。
    • Bulid_model - 抽象方法,用于定义模型。注意:需要重写(override)该方法。
  • Your model

    你在这里实现你的模型。

    • 创建你的model类并继承base_model类。
    • 重写bulid_model方法。
    • 重写init_save方法。
    • 构造方法里调用bulid_model和init_save。

Trainer

  • Base Trainer

    Base trainer是一个抽象类。

  • Your Trainer

    在这里实现你的训练过程。

    • 创建你的trainer类并继承base_trainer类
    • 重写base_trainer类的方法

Data Loader

这个类负责所有数据的加载和预处理,并提供简单接口给trainer类。


Logger

这个类用于tensorboard summary,在你的trainer类中将创建一个所有你想summarize的tensorflow变量的目录,并将这个路径给logger.summarize()


Configuration

编写所有配置文件,然后使用utils中的config.py进行解析。


Main

在这里编写各种入口程序

训练

  • 传递配置信息
  • 创建tensorflow session
  • 创建Model、Data_Generator、logger实例,并将配置参数传递给它们。
  • 创建一个Trainer实例,将前面的所有对象传递给它。
  • 调用Trainer.train()开始训练。

预测

  • 传递配置信息
  • 加载Model、Data_Generator,将配置信息传递给它们。
  • 创建一个prediction实例,将前面的对象传递给它
  • 调用prediction.predict()开始预测。

参考

深度学习工程模板

github.com/MrGemy95/Te…