一张图片胜过千言万语。但是您知道一张照片可以挽救一千条生命吗?全世界每天有数百万流浪动物在街头受苦或在收容所被安乐死。您可能会期望拥有迷人照片的宠物会引起更多兴趣并更快地被收养。但什么是好照片呢?在数据科学的帮助下,或许可以准确确定宠物照片的吸引力,自动提高照片质量并推荐构图改进。流浪狗和猫可以更快地找到它们的“毛茸茸”的家。
1 环境准备
1.1 安装依赖库
依赖库如下:
- python-box:这是一个Python库,用于创建不可变的记录类型,类似于其他语言中的结构体或类。
- timm:这是一个图像模型库,提供了大量的预训练模型,可以用于图像分类、分割等任务。
- pytorch-lightning==1.4.0:这是一个PyTorch的高级封装库,旨在简化PyTorch模型的训练过程,使其更加简洁和易于管理。指定了版本1.4.0。
- grad-cam:这是一种可视化技术,用于解释深度学习模型的决策过程,特别是卷积神经网络。
!pip install python-box timm pytorch-lightning==1.4.0 grad-cam ttach
1.2 构建python环境
设置一个用于深度学习的Python环境,并且导入了多个库来构建和训练模型。库和模型如下:
- os:用于与操作系统交互,例如文件路径操作。
- warnings:用于控制警告消息的显示。
- pprint:用于美化打印复杂的数据结构。
- glob:用于文件路径模式匹配,方便地获取文件列表。
- tqdm:一个快速,可扩展的Python进度条库。
- torch:PyTorch库,用于深度学习。
- torch.optim:PyTorch的优化器模块。
- torch.nn:PyTorch的神经网络模块。
- torch.nn.functional:PyTorch的函数式接口,提供许多用于构建神经网络的函数。
- numpy:用于科学计算的库。
- pandas:用于数据分析和操作的库。
- matplotlib.pyplot:用于数据可视化的库。
- torchvision.transforms:用于图像预处理的库。
- Box:用于创建不可变的数据结构。
- timm:图像模型库,提供预训练模型。
- StratifiedKFold:用于数据集的分层K折交叉验证。
- read_image:用于读取图像文件。
- DataLoader:PyTorch的数据处理工具,用于加载数据集。
- Dataset:PyTorch的自定义数据集基类。
- GradCAMPlusPlus:用于可视化深度学习模型的决策过程。
- show_cam_on_image:用于在图像上显示CAM(Class Activation Mapping)。
- pytorch_lightning:一个简化PyTorch模型训练的库。
- seed_everything:用于设置随机种子,确保实验的可重复性。
- callbacks:PyTorch Lightning的回调函数模块。
- ProgressBarBase:自定义进度条的基类。
- EarlyStopping:提前停止训练的回调函数。
- TensorBoardLogger:用于记录训练过程,方便在TensorBoard中查看。
- LightningDataModule 和 LightningModule:PyTorch Lightning的数据模块和模型模块,用于组织数据加载和模型结构。
import os
import warnings
from pprint import pprint
from glob import glob
from tqdm import tqdm
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torchvision.transforms as T
from box import Box
from timm import create_model
from sklearn.model_selection import StratifiedKFold
from torchvision.io import read_image
from torch.utils.data import DataLoader, Dataset
from pytorch_grad_cam import GradCAMPlusPlus
from pytorch_grad_cam.utils.image import show_cam_on_image
import pytorch_lightning as pl
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning import callbacks
from pytorch_lightning.callbacks.progress import ProgressBarBase
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning import LightningDataModule, LightningModule
warnings.filterwarnings("ignore")
2 字典配置
定义了一个名为 config
的字典,它包含了用于配置深度学习训练的各种参数。然后,使用 Box
类将这个字典转换成了一个不可变的配置对象。这样做的好处是,可以在代码中安全地访问配置参数,而不必担心意外修改它们。
下面是配置参数的简要说明:
seed
:随机种子,用于确保实验的可重复性。root
:数据集的根目录。n_splits
:交叉验证的分割数。epoch
:训练的总轮数。trainer
:PyTorch Lightning训练器的配置,包括GPU使用、梯度累积批次、进度条刷新率等。transform
:数据预处理的配置,包括预处理函数的名称和图像大小。train_loader
和val_loader
:训练和验证数据加载器的配置,包括批次大小、是否打乱数据、工作线程数等。model
:模型的配置,包括模型名称和输出维度。optimizer
:优化器的配置,包括优化器名称和参数。scheduler
:学习率调度器的配置,包括调度器名称和参数。loss
:损失函数的配置。
使用 Box
类将字典转换为不可变对象后,可以通过属性的方式访问配置参数,例如 config.seed
或 config.model.name
。
config = {'seed': 2021,
'root': '/input/petfinder-pawpularity-score/',
'n_splits': 5,
'epoch': 20,
'trainer': {
'gpus': 1,
'accumulate_grad_batches': 1,
'progress_bar_refresh_rate': 1,
'fast_dev_run': False,
'num_sanity_val_steps': 0,
'resume_from_checkpoint': None,
},
'transform':{
'name': 'get_default_transforms',
'image_size': 224
},
'train_loader':{
'batch_size': 64,
'shuffle': True,
'num_workers': 4,
'pin_memory': False,
'drop_last': True,
},
'val_loader': {
'batch_size': 64,
'shuffle': False,
'num_workers': 4,
'pin_memory': False,
'drop_last': False
},
'model':{
'name': 'swin_tiny_patch4_window7_224',
'output_dim': 1
},
'optimizer':{
'name': 'optim.AdamW',
'params':{
'lr': 1e-5
},
},
'scheduler':{
'name': 'optim.lr_scheduler.CosineAnnealingWarmRestarts',
'params':{
'T_0': 20,
'eta_min': 1e-4,
}
},
'loss': 'nn.BCEWithLogitsLoss',
}
config = Box(config)
pprint(config)
3 自定义数据集
定义了两个类,PetfinderDataset
和 PetfinderDataModule
,它们是用于PyTorch Lightning的自定义数据集和数据模块。
PetfinderDataset 类
这个类继承自 PyTorch 的 Dataset
类,用于加载和处理 Petfinder 数据集的图像。它的主要功能包括:
- 在初始化方法
__init__
中,它接收一个包含图像ID的DataFrame (df
) 和一个图像大小(默认为224)。它将图像ID存储在self._X
中,如果DataFrame中包含 "Pawpularity" 列,则将标签存储在self._y
中。同时,它定义了一个图像转换操作,将图像大小调整为指定的尺寸。 - 方法
__len__
返回数据集中的图像数量。 - 方法
__getitem__
根据索引idx
加载图像,应用转换,并返回图像及其对应的标签(如果存在)。
PetfinderDataModule 类
这个类继承自 PyTorch Lightning 的 LightningDataModule
类,用于管理数据加载和拆分过程。它的主要功能包括:
- 在初始化方法
__init__
中,它接收训练集和验证集的DataFrame (train_df
和val_df
),以及配置对象cfg
。 - 私有方法
__create_dataset
根据传入的布尔值train
决定创建训练集还是验证集的数据集实例。 - 方法
train_dataloader
创建并返回训练数据的数据加载器 (DataLoader
),使用配置中的训练加载器参数。 - 方法
val_dataloader
创建并返回验证数据的数据加载器,使用配置中的验证加载器参数。
这些类的设计使得数据加载和预处理过程与模型训练过程解耦,便于管理和维护。可以使用 PetfinderDataModule
来初始化数据加载器,并将其传递给 PyTorch Lightning 的训练器 (Trainer
)。
class PetfinderDataset(Dataset):
def __init__(self, df, image_size=224):
self._X = df["Id"].values
self._y = None
if "Pawpularity" in df.keys():
self._y = df["Pawpularity"].values
self._transform = T.Resize([image_size, image_size])
def __len__(self):
return len(self._X)
def __getitem__(self, idx):
image_path = self._X[idx]
image = read_image(image_path)
image = self._transform(image)
if self._y is not None:
label = self._y[idx]
return image, label
return image
class PetfinderDataModule(LightningDataModule):
def __init__(
self,
train_df,
val_df,
cfg,
):
super().__init__()
self._train_df = train_df
self._val_df = val_df
self._cfg = cfg
def __create_dataset(self, train=True):
return (
PetfinderDataset(self._train_df, self._cfg.transform.image_size)
if train
else PetfinderDataset(self._val_df, self._cfg.transform.image_size)
)
def train_dataloader(self):
dataset = self.__create_dataset(True)
return DataLoader(dataset, **self._cfg.train_loader)
def val_dataloader(self):
dataset = self.__create_dataset(False)
return DataLoader(dataset, **self._cfg.val_loader)
4 数据可视化
- 设置异常检测:通过
torch.autograd.set_detect_anomaly(True)
开启了PyTorch的异常检测,这有助于在训练过程中发现潜在的梯度计算错误。 - 设置随机种子:使用
seed_everything(config.seed)
来确保实验的可重复性,通过设置随机种子,确保每次运行代码时都能得到相同的结果。 - 读取数据:使用
pandas
读取了位于config.root
路径下的train.csv
文件,并将其存储在df
DataFrame中。 - 处理图像路径:将
df
中的 "Id" 列中的每个ID转换为完整的图像路径,并将这些路径存储回 "Id" 列。 - 创建数据模块和数据加载器:实例化了
PetfinderDataModule
类,传入训练数据集df
、验证数据集df
和配置对象config
。然后,调用val_dataloader
方法来获取验证数据的数据加载器。 - 获取样本数据:通过迭代数据加载器并调用
next()
函数来获取一批图像和标签。 - 可视化图像:使用
matplotlib
库来创建一个4x4的子图网格,显示前16张图像,并在每个子图的标题中显示对应的 "Pawpularity" 标签。
torch.autograd.set_detect_anomaly(True)
seed_everything(config.seed)
df = pd.read_csv(os.path.join(config.root, "train.csv"))
df["Id"] = df["Id"].apply(lambda x: os.path.join(config.root, "train", x + ".jpg"))
sample_dataloader = PetfinderDataModule(df, df, config).val_dataloader()
images, labels = iter(sample_dataloader).next()
plt.figure(figsize=(12, 12))
for it, (image, label) in enumerate(zip(images[:16], labels[:16])):
plt.subplot(4, 4, it+1)
plt.imshow(image.permute(1, 2, 0))
plt.axis('off')
plt.title(f'Pawpularity: {int(label)}')
5 数据增强
定义了一个函数 get_default_transforms
,它返回一个包含训练和验证数据集的图像预处理转换操作的字典。这些转换操作是深度学习中常用的数据增强和标准化技术。下面是每个转换操作的简要说明:
- T.RandomHorizontalFlip() :随机水平翻转图像,概率为0.5。
- T.RandomVerticalFlip() :随机垂直翻转图像,概率为0.5。
- T.RandomAffine(15, translate=(0.1, 0.1), scale=(0.9, 1.1)) :对图像进行随机仿射变换,包括旋转(最大15度)、平移(图像宽度和高度的10%)和缩放(0.9到1.1倍)。
- T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1) :随机调整图像的亮度、对比度和饱和度。
- T.ConvertImageDtype(torch.float) :将图像转换为浮点类型,这是标准化步骤之前的必要步骤。
- T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) :将图像标准化到ImageNet数据集的均值和标准差,这是预训练模型的常见做法。
训练转换("train"
)包括了所有的数据增强操作,而验证转换("val"
)只包括了类型转换和标准化,这是因为在验证阶段,我们通常不需要进行数据增强。
IMAGENET_MEAN = [0.485, 0.456, 0.406] # RGB
IMAGENET_STD = [0.229, 0.224, 0.225] # RGB
def get_default_transforms():
transform = {
"train": T.Compose(
[
T.RandomHorizontalFlip(),
T.RandomVerticalFlip(),
T.RandomAffine(15, translate=(0.1, 0.1), scale=(0.9, 1.1)),
T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
T.ConvertImageDtype(torch.float),
T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
]
),
"val": T.Compose(
[
T.ConvertImageDtype(torch.float),
T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
]
),
}
return transform
6 模型定义
定义了一个 mixup
函数和一个 Model
类,它们是深度学习训练过程中常用的技术。
mixup 函数
mixup
函数实现了数据增强技术 "mixup",它通过在训练过程中随机混合两个样本及其标签来生成新的训练样本。这种方法可以提高模型的泛化能力。函数的参数包括:
x
:输入的图像张量。y
:对应的标签张量。alpha
:Beta分布的参数,用于控制混合的强度,默认为1.0。
函数的主要步骤包括:
- 随机生成一个混合比例
lam
。 - 随机打乱索引
rand_index
。 - 根据
lam
混合图像x
和x[rand_index, :]
。 - 返回混合后的图像、原始标签、混合标签和混合比例。
Model 类
这个类继承自 PyTorch Lightning 的 LightningModule
类,用于构建和训练深度学习模型。它的主要功能包括:
- 在
__init__
方法中,初始化配置、构建模型、定义损失函数和优化器,并保存超参数。 - 在
__build_model
方法中,使用create_model
函数创建预训练的模型,并添加一个全连接层以适应特定的输出维度。 - 在
forward
方法中,定义模型的前向传播过程。 - 在
training_step
和validation_step
方法中,定义训练和验证步骤,包括使用mixup
函数进行数据增强。 - 在
__share_step
方法中,共享训练和验证步骤的代码。 - 在
training_epoch_end
和validation_epoch_end
方法中,定义每个训练和验证周期结束时的操作,包括计算损失。 - 在
__share_epoch_end
方法中,共享训练和验证周期结束时的操作。 - 在
check_gradcam
方法中,使用 GradCAM++ 技术可视化模型的决策过程。 - 在
configure_optimizers
方法中,定义优化器和学习率调度器。
这个类的设计使得模型训练过程更加模块化和可配置,同时支持数据增强和可视化技术。
def mixup(x: torch.Tensor, y: torch.Tensor, alpha: float = 1.0):
assert alpha > 0, "alpha should be larger than 0"
assert x.size(0) > 1, "Mixup cannot be applied to a single instance."
lam = np.random.beta(alpha, alpha)
rand_index = torch.randperm(x.size()[0])
mixed_x = lam * x + (1 - lam) * x[rand_index, :]
target_a, target_b = y, y[rand_index]
return mixed_x, target_a, target_b, lam
class Model(pl.LightningModule):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.__build_model()
self._criterion = eval(self.cfg.loss)()
self.transform = get_default_transforms()
self.save_hyperparameters(cfg)
def __build_model(self):
self.backbone = create_model(
self.cfg.model.name, pretrained=True, num_classes=0, in_chans=3
)
num_features = self.backbone.num_features
self.fc = nn.Sequential(
nn.Dropout(0.5), nn.Linear(num_features, self.cfg.model.output_dim)
)
def forward(self, x):
f = self.backbone(x)
out = self.fc(f)
return out
def training_step(self, batch, batch_idx):
loss, pred, labels = self.__share_step(batch, 'train')
return {'loss': loss, 'pred': pred, 'labels': labels}
def validation_step(self, batch, batch_idx):
loss, pred, labels = self.__share_step(batch, 'val')
return {'pred': pred, 'labels': labels}
def __share_step(self, batch, mode):
images, labels = batch
labels = labels.float() / 100.0
images = self.transform[mode](images)
if torch.rand(1)[0] < 0.5 and mode == 'train':
mix_images, target_a, target_b, lam = mixup(images, labels, alpha=0.5)
logits = self.forward(mix_images).squeeze(1)
loss = self._criterion(logits, target_a) * lam + \
(1 - lam) * self._criterion(logits, target_b)
else:
logits = self.forward(images).squeeze(1)
loss = self._criterion(logits, labels)
pred = logits.sigmoid().detach().cpu() * 100.
labels = labels.detach().cpu() * 100.
return loss, pred, labels
def training_epoch_end(self, outputs):
self.__share_epoch_end(outputs, 'train')
def validation_epoch_end(self, outputs):
self.__share_epoch_end(outputs, 'val')
def __share_epoch_end(self, outputs, mode):
preds = []
labels = []
for out in outputs:
pred, label = out['pred'], out['labels']
preds.append(pred)
labels.append(label)
preds = torch.cat(preds)
labels = torch.cat(labels)
metrics = torch.sqrt(((labels - preds) ** 2).mean())
self.log(f'{mode}_loss', metrics)
def check_gradcam(self, dataloader, target_layer, target_category, reshape_transform=None):
cam = GradCAMPlusPlus(
model=self,
target_layer=target_layer,
use_cuda=self.cfg.trainer.gpus,
reshape_transform=reshape_transform)
org_images, labels = iter(dataloader).next()
cam.batch_size = len(org_images)
images = self.transform['val'](org_images)
images = images.to(self.device)
logits = self.forward(images).squeeze(1)
pred = logits.sigmoid().detach().cpu().numpy() * 100
labels = labels.cpu().numpy()
grayscale_cam = cam(input_tensor=images, target_category=target_category, eigen_smooth=True)
org_images = org_images.detach().cpu().numpy().transpose(0, 2, 3, 1) / 255.
return org_images, grayscale_cam, pred, labels
def configure_optimizers(self):
optimizer = eval(self.cfg.optimizer.name)(
self.parameters(), **self.cfg.optimizer.params
)
scheduler = eval(self.cfg.scheduler.name)(
optimizer,
**self.cfg.scheduler.params
)
return [optimizer], [scheduler]
7 模型训练
PyTorch Lightning 和 StratifiedKFold 进行分层 K 折交叉验证。
- 初始化 StratifiedKFold:使用
StratifiedKFold
创建一个分层 K 折交叉验证对象,其中n_splits
为分割数,shuffle
为是否打乱数据,random_state
为随机种子。 - 遍历 K 折:使用
enumerate
遍历每一折(fold
),并获取训练和验证索引(train_idx
和val_idx
)。 - 创建训练和验证 DataFrame:根据索引创建训练和验证 DataFrame,并重置索引。
- 初始化数据模块:使用
PetfinderDataModule
初始化数据模块,传入训练和验证 DataFrame 以及配置对象。 - 初始化模型:使用
Model
类初始化模型,并传入配置对象。 - 初始化回调函数:
-
EarlyStopping
:提前停止训练,如果验证损失在一定周期内没有改善,则停止训练。LearningRateMonitor
:监控学习率的变化。ModelCheckpoint
:保存验证损失最低的模型。
- 初始化 TensorBoardLogger:用于记录训练过程,方便在 TensorBoard 中查看。
- 初始化 Trainer:使用 PyTorch Lightning 的
Trainer
类初始化训练器,传入日志记录器、最大周期数、回调函数和其他配置参数。 - 训练模型:使用
trainer.fit
方法训练模型,传入模型和数据模块。
这段代码展示了如何使用 PyTorch Lightning 和 StratifiedKFold 进行分层 K 折交叉验证,以评估模型的性能。通过这种方式,可以更准确地评估模型的泛化能力。
skf = StratifiedKFold(
n_splits=config.n_splits, shuffle=True, random_state=config.seed
)
for fold, (train_idx, val_idx) in enumerate(skf.split(df["Id"], df["Pawpularity"])):
train_df = df.loc[train_idx].reset_index(drop=True)
val_df = df.loc[val_idx].reset_index(drop=True)
datamodule = PetfinderDataModule(train_df, val_df, config)
model = Model(config)
earystopping = EarlyStopping(monitor="val_loss")
lr_monitor = callbacks.LearningRateMonitor()
loss_checkpoint = callbacks.ModelCheckpoint(
filename="best_loss",
monitor="val_loss",
save_top_k=1,
mode="min",
save_last=False,
)
logger = TensorBoardLogger(config.model.name)
trainer = pl.Trainer(
logger=logger,
max_epochs=config.epoch,
callbacks=[lr_monitor, loss_checkpoint, earystopping],
**config.trainer,
)
trainer.fit(model, datamodule=datamodule)
8 类激活图
类激活图是一种可视化卷积神经网络(CNN)决策过程的技术。它可以显示模型在训练过程中,权重或重心在何处、如何转移,分类模型是根据哪一部分的特征进行判别的
- 使用 Grad-CAM 技术对 Vision Transformer (ViT) 模型进行可视化。:
- 定义 reshape_transform 函数:这个函数用于将 Grad-CAM 生成的热图重塑为与原始图像相同的形状。由于 ViT 输出的特征图是一个序列,因此需要将其重塑为二维空间分布,以便与原始图像对齐。
- 加载模型:使用
Model
类创建模型实例,并加载之前保存的最佳模型权重。这里假设模型权重保存在best_loss.ckpt
文件中。 - 将模型转移到 GPU:使用
cuda()
方法将模型转移到 GPU 上,并设置为评估模式(eval()
)。 - 设置验证数据加载器:更新配置中的验证数据加载器的批次大小,并初始化
PetfinderDataModule
。 - 执行 Grad-CAM:调用模型的
check_gradcam
方法,传入验证数据加载器、目标层(这里选择模型的最后一个块的归一化层)、目标类别(这里设置为None
,表示对所有类别进行可视化)和reshape_transform
函数。 - 可视化 Grad-CAM 结果:使用
matplotlib
库创建一个子图网格,显示 Grad-CAM 生成的热图。每个子图显示原始图像、热图和预测标签与真实标签。
# gradcam reshape_transform for vit
def reshape_transform(tensor, height=7, width=7):
result = tensor.reshape(tensor.size(0),
height, width, tensor.size(2))
# like in CNNs.
result = result.permute(0, 3, 1, 2)
return result
model = Model(config)
model.load_state_dict(torch.load(f'{config.model.name}/default/version_0/checkpoints/best_loss.ckpt')['state_dict'])
model = model.cuda().eval()
config.val_loader.batch_size = 16
datamodule = PetfinderDataModule(train_df, val_df, config)
images, grayscale_cams, preds, labels = model.check_gradcam(
datamodule.val_dataloader(),
target_layer=model.backbone.layers[-1].blocks[-1].norm1,
target_category=None,
reshape_transform=reshape_transform)
plt.figure(figsize=(12, 12))
for it, (image, grayscale_cam, pred, label) in enumerate(zip(images, grayscale_cams, preds, labels)):
plt.subplot(4, 4, it + 1)
visualization = show_cam_on_image(image, grayscale_cam)
plt.imshow(visualization)
plt.title(f'pred: {pred:.1f} label: {label}')
plt.axis('off')
9 可视化结果
从 TensorBoard 的事件文件中读取训练过程中记录的标量数据,并使用 seaborn
和 matplotlib
进行可视化。以下是代码的详细解释:
- 读取事件文件:使用
glob
函数获取 TensorBoard 事件文件的路径,并创建一个EventAccumulator
对象来读取事件文件。这里假设事件文件的命名模式为events*
,并且位于当前目录下。 - 加载事件数据:调用
Reload
方法加载事件文件中的数据,并初始化一个空字典scalars
来存储标量数据。 - 提取标量数据:遍历事件文件中的所有标量标签(如学习率、损失等),并使用
Scalars
方法提取每个标签对应的事件数据。将这些数据存储在scalars
字典中。 - 设置 seaborn:设置
seaborn
的样式,使其图表更加美观。 - 绘制学习率变化图:创建一个 16x6 英寸的图表,并在一个子图中绘制学习率随训练周期(epoch)的变化。这里假设学习率的标签为
'lr-AdamW'
。 - 绘制训练和验证损失图:在另一个子图中绘制训练损失和验证损失随训练周期的变化。这里假设损失的标签分别为
'train_loss'
和'val_loss'
。使用图例区分训练损失和验证损失。 - 显示图表:调用
plt.show()
显示图表。
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
path = glob(f'./{config.model.name}/default/version_0/events*')[0]
event_acc = EventAccumulator(path, size_guidance={'scalars': 0})
event_acc.Reload()
scalars = {}
for tag in event_acc.Tags()['scalars']:
events = event_acc.Scalars(tag)
scalars[tag] = [event.value for event in events]
import seaborn as sns
sns.set()
plt.figure(figsize=(16, 6))
plt.subplot(1, 2, 1)
plt.plot(range(len(scalars['lr-AdamW'])), scalars['lr-AdamW'])
plt.xlabel('epoch')
plt.ylabel('lr')
plt.title('adamw lr')
plt.subplot(1, 2, 2)
plt.plot(range(len(scalars['train_loss'])), scalars['train_loss'], label='train_loss')
plt.plot(range(len(scalars['val_loss'])), scalars['val_loss'], label='val_loss')
plt.legend()
plt.ylabel('rmse')
plt.xlabel('epoch')
plt.title('train/val rmse')
plt.show()
print('best_val_loss', min(scalars['val_loss']))