PyTorch deep learning project made easy.
Requirements
- Python 3.x
- PyTorch
Features
- Clear folder structure which is suitable for many deep learning projects.
.jsonconfig file support for more convenient parameter tuning.- Checkpoint saving and resuming.
- Abstract base classes for faster development:
BaseTrainerhandles checkpoint saving/resuming, training process logging, and more.BaseDataLoaderhandles batch generation, data shuffling, and validation data splitting.BaseModelprovides basic model summary.
Folder Structure
pytorch-template/
│
├── train.py - example main
├── config.json - example config file
│
├── base/ - abstract base classes
│ ├── base_data_loader.py - abstract base class for data loaders
│ ├── base_model.py - abstract base class for models
│ └── base_trainer.py - abstract base class for trainers
│
├── data_loader/ - anything about data loading goes here
│ └── data_loaders.py
│
├── datasets/ - default datasets folder
│
├── logger/ - for training process logging
│ └── logger.py
│
├── model/ - models, losses, and metrics
│ ├── modules/ - submodules of your model
│ ├── loss.py
│ ├── metric.py
│ └── model.py
│
├── saved/ - default checkpoints folder
│
├── trainer/ - trainers
│ └── trainer.py
│
└── utils/
├── util.py
└── ...
Usage
The code in this repo is an MNIST example of the template.
Config file format
Config files are in .json format:
{
"name": "Mnist_LeNet", // training session name
"cuda": true, // use cuda
"data_loader": {
"data_dir": "datasets/", // dataset path
"batch_size": 32, // batch size
"shuffle": true // shuffle data each time calling __iter__()
},
"validation": {
"validation_split": 0.1, // validation data ratio
"shuffle": true // shuffle training data before splitting
},
"optimizer_type": "Adam",
"optimizer": {
"lr": 0.001, // (optional) learning rate
"weight_decay": 0 // (optional) weight decay
},
"loss": "my_loss", // loss
"metrics": [ // metrics
"my_metric",
"my_metric2"
],
"trainer": {
"epochs": 1000, // number of training epochs
"save_dir": "saved/", // checkpoints are saved in save_dir/name
"save_freq": 1, // save checkpoints every save_freq epochs
"verbosity": 2, // 0: quiet, 1: per epoch, 2: full
"monitor": "val_loss", // monitor value for best model
"monitor_mode": "min" // "min" if monitor value the lower the better, otherwise "max"
},
"arch": "MnistModel", // model architecture
"model": {} // model configs
}
Add addional configurations if you need.
Using config files
Modify the configurations in .json config files, then run:
python train.py --config config.json
Resuming from checkpoints
You can resume from a previously saved checkpoint by:
python train.py --resume path/to/checkpoint
Customization
Data Loader
- Writing your own data loader
-
Inherit
BaseDataLoaderBaseDataLoaderis similar totorch.utils.data.DataLoader, you can use either of them.BaseDataLoaderhandles:- Generating next batch
- Data shuffling
- Generating validation data loader by calling
BaseDataLoader.split_validation()
-
Implementing abstract methods
You need to implement these abstract methods:
_pack_data(): pack data members into a list of tuples_unpack_data(): unpack packed data_update_data(): updata data members_n_samples(): total number of samples
-
DataLoader Usage
BaseDataLoaderis an iterator, to iterate through batches:for batch_idx, (x_batch, y_batch) in data_loader: pass -
Example
Please refer to
data_loader/data_loaders.pyfor an MNIST data loading example.
Trainer
- Writing your own trainer
-
Inherit
BaseTrainerBaseTrainerhandles:- Training process logging
- Checkpoint saving
- Checkpoint resuming
- Reconfigurable monitored value for saving current best
- Controlled by the configs
monitorandmonitor_mode, ifmonitor_mode == 'min'then the trainer will save a checkpointmodel_best.pth.tarwhenmonitoris a current minimum
- Controlled by the configs
-
Implementing abstract methods
You need to implement
_train_epoch()for your training process, if you need validation then you can implement_valid_epoch()as intrainer/trainer.py
-
Example
Please refer to
trainer/trainer.pyfor MNIST training.
Model
- Writing your own model
-
Inherit
BaseModelBaseModelhandles:- Inherited from
torch.nn.Module summary(): Model summary
- Inherited from
-
Implementing abstract methods
Implement the foward pass method
forward()
-
Example
Please refer to
model/model.pyfor a LeNet example.
Loss and metrics
If you need to change the loss function or metrics, first import those function in train.py, then modify "loss" and "metrics" in .json config files
Multiple metrics
You can add multiple metrics in your config files:
"metrics": ["my_metric", "my_metric2"],
Additional logging
If you have additional information to be logged, in _train_epoch() of your trainer class, merge them with log as shown below before returning:
additional_log = {"gradient_norm": g, "sensitivity": s}
log = {**log, **additional_log}
return log
Validation data
To split validation data from a data loader, call BaseDataLoader.split_validation(), it will return a validation data loader, with the number of samples according to the specified ratio in your config file.
Note: the split_validation() method will modify the original data loader Note: split_validation() will return None if "validation_split" is set to 0
Checkpoints
You can specify the name of the training session in config files:
"name": "MNIST_LeNet",
The checkpoints will be saved in save_dir/name.
The config file is saved in the same folder.
Note: checkpoints contain:
{
'arch': arch,
'epoch': epoch,
'logger': self.train_logger,
'state_dict': self.model.state_dict(),
'optimizer': self.optimizer.state_dict(),
'monitor_best': self.monitor_best,
'config': self.config
}
Contributing
Feel free to contribute any kind of function or enhancement, here the coding style follows PEP8
TODOs
- Multi-GPU support
-
TensorboardXorvisdomsupport - Support iteration-based training (instead of epoch)
- Configurable logging layout, checkpoint naming
- Load settings from
configfiles
License
This project is licensed under the MIT License. See LICENSE for more details
Acknowledgments
This project is inspired by the project Tensorflow-Project-Template by Mahmoud Gemy