Pytorch实现量化感知训练QAT(一)
本文由林大佬原创,转载请注明出处,我们有来自腾讯、阿里等一线AI算法工程师组成微信交流群, 如果你想交流欢迎添加微信: jintianandmerry 拉群, 添加请备注 "交流群"
很久没更新文章了,小伙伴是不是有点想我呢? 小林最近在搞训练感知量化的东西, 这玩意比那些后处理量化方案靠谱多了, 比你把float32的模型转到tensorrt,再在tensorrt上做量化靠谱. 有人可能会问了, 你说的靠谱到底体现在哪儿? 简单来说有这么几点:
- 众所周知, 英伟达tensorrt默认的量化方式like a shit, 在一些小的检测模型上几乎没法用,而且受限于你的校准数据分布,很难获得一个全面的量化结果,换句话说,在你的校准数据上校准之后的模型,即使在evalset上AP差不多,在陌生数据上呢可能依旧GG. 而训练感知量化可以充分的学习你的训练集的分布;
- 还有一个好处是, 你拿到的int8的pytorch模型, 可以无缝的部署到任何支持的框架上, 而不需要再其他框架上再进行量化.
- 最后就是量化的精度问题. 很多人说量化精度取决于量化算法, 我不否认,但更重要的还是看你的数据呀, 这就好比你拿100张图片训练的resnet50去pk拿imagenet训练的resnet18, 不能相提并论. 而这一切的核心, 其实就在于你能否让同样的长度去表征你的所有数据分布, 而这只有可能在训练的时候去考虑, 才有可能做得到.
那既然这么多好处, 我们为什么还要用tensort那自带切不开源很难hack的量化算法呢, 其实原因是 .... 懒, 不想做.
话已至此, 我也没啥好说的了, 我直接给你一个实例代码, 你超过去用可以吧? 先展示一下QAT训练感知量化的一个大概指标 (res18的模型, cifar):
FP32 evaluation accuracy: 0.869
INT8 evaluation accuracy: 0.869
FP32 CPU Inference Latency: 5.12 ms / sample
FP32 CUDA Inference Latency: 3.01 ms / sample
INT8 CPU Inference Latency: 1.30 ms / sample
INT8 JIT CPU Inference Latency: 0.51 ms / sample
简单来说, 就是不掉点, 速度提高了...... 5倍!.
这么好的技能点, 怎么能不get呢? 快点点赞收藏本文, 有时间我们一起来学习学习! 说明, 这个实验你用笔记本也可以做.
什么是Quantization Aware Training
上面我说了post quantization 不是终极的道理, 因为信息损失在哪一个阶段已经无法避免,无论你的优化函数多么厉害,最终都会造成不可避免的信息丢失. 而QAT在训练的时候就给你实现了一套low precision的inference module, 而这个module在同步的和float32进行推理校验, 你可以类比为tesla的 影子模式. 在这种模式下,每一个layer的信息差都可以直接被捕捉到. 如同tensorFlow在blog中提到的一样:
This introduces the quantization error as noise during the training and as part of the overall loss, which the optimization algorithm tries to minimize. Hence, the model learns parameters that are more robust to quantization.
这样操作下来, 把量化的误差当做事了一种训练的噪声, 当你在QAT的同时,你已经让模型在学习适应这种噪声了. 换句话说, 你在训练量化模型的同时,float32模型的参数同时也在学习如何去minize这个误差. 这张图可以很好的展示 QAT和post-training quantization 的差别:
[FBI WARNING] 有读者问为什么代码跑起来和文章不一致, 这里忘了说了resnet是经过修改的, torchvision里面有一些不兼容quantization的地方,这也是一个坑. 后面我会说为什么需要修改, 修改哪些地方.
Pytorch实现QAT
接下来教大家如何实现在pytorch下进行量化感知的训练. 事实上现在这个feature已经变得很简单, 只不过当你在应用一些复杂的模型的时候,过程可能会比较繁琐,不管那么多,我们先从最简单的开始吧.
整个过程的步骤大概是:
- 训练一个float32的模型;
- 测试float32分别在CPU和GPU上的时间;
- 假如bn层融合;
- 测试融合前和融合后精度和结果比对 (看看融合对于结果的影响);
- 加入torch的量化感知API;
- 重新训练一个量化感知的模型;
- 把int8的模型保存, 并测试速度;
- 测试int8模型精度;
- 对比float32与int8模型的精度;
- 测试融合bn下和不融合bn的int8模型的结果;
- 保存int8模型.
ok, 这些步骤基本上包含了我们要探寻的所有内容. 闲话不多说, 直接上代码:
首先我们假设你有了一个model:
# 修改修改之后的resnet
model = resnet18(num_classes=num_classes, pretrained=False)
这个你可以从torchvision 直接获取, 接下来我们要做的事情就是把这个模型加上量化的版本,其实也非常的简单, 我们做一个wrapper:
class QuantizedResNet18(nn.Module):
def __init__(self, model_fp32):
super(QuantizedResNet18, self).__init__()
# QuantStub converts tensors from floating point to quantized.
# This will only be used for inputs.
self.quant = torch.quantization.QuantStub()
# DeQuantStub converts tensors from quantized to floating point.
# This will only be used for outputs.
self.dequant = torch.quantization.DeQuantStub()
# FP32 model
self.model_fp32 = model_fp32
def forward(self, x):
# manually specify where tensors will be converted from floating
# point to quantized in the quantized model
x = self.quant(x)
x = self.model_fp32(x)
# manually specify where tensors will be converted from quantized
# to floating point in the quantized model
x = self.dequant(x)
return x
我们对刚才的fp32的模型进行了一个简单的封装, 在这里面实现的步骤也很简单, 就是加一个quant和dequant, 分别在模型的最开始和最后的部分. 如果你要说这个步骤到底是在做什么事情, 我想大概是把刚才说的量化的误差加入到反向传播中, 让你的fakequant_model的输出误差,能够被正常的fp32的模型感知到, 从而你的fp32的模型实际上也在学习让quant的输出误差最小化.
我在后面会贴出全部代码,现在主要是讲解一下,希望大家跟着一步一步的理解. 现在我们的quant打包的model有了,接下来就需要开始训练一波了.
请注意, 接下来要训练了, 训练的时候,有一个标准的流程:
quantization_config = torch.quantization.get_default_qconfig("fbgemm")
# Custom quantization configurations
# quantization_config = torch.quantization.default_qconfig
# quantization_config = torch.quantization.QConfig(activation=torch.quantization.MinMaxObserver.with_args(dtype=torch.quint8), weight=torch.quantization.MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric))
quantized_model.qconfig = quantization_config
# Print quantization configurations
print(quantized_model.qconfig)
# https://pytorch.org/docs/stable/_modules/torch/quantization/quantize.html#prepare_qat
torch.quantization.prepare_qat(quantized_model, inplace=True)
简单来说注意两点:
- 你的模型要设置合理的quatization_config, 具体如何设置可以看看我上面铁的网址;
- 你需要手动的必须要设置model的模式为train(), 每次设置的时候都需要手动重置一下.
然后我们train一下:
def train_model(model, train_loader, test_loader, device, learning_rate=1e-1, num_epochs=200):
# The training configurations were not carefully selected.
criterion = nn.CrossEntropyLoss()
model.to(device)
# It seems that SGD optimizer is better than Adam optimizer for ResNet18 training on CIFAR10.
optimizer = optim.SGD(model.parameters(), lr=learning_rate,
momentum=0.9, weight_decay=1e-4)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=500)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, milestones=[100, 150], gamma=0.1, last_epoch=-1)
# optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
# Evaluation
model.eval()
eval_loss, eval_accuracy = evaluate_model(
model=model, test_loader=test_loader, device=device, criterion=criterion)
print("Epoch: {:02d} Eval Loss: {:.3f} Eval Acc: {:.3f}".format(-1,
eval_loss, eval_accuracy))
for epoch in range(num_epochs):
# Training
model.train()
running_loss = 0
running_corrects = 0
for inputs, labels in train_loader:
inputs = inputs.to(device)
labels = labels.to(device)
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# statistics
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
train_loss = running_loss / len(train_loader.dataset)
train_accuracy = running_corrects / len(train_loader.dataset)
# Evaluation
model.eval()
eval_loss, eval_accuracy = evaluate_model(
model=model, test_loader=test_loader, device=device, criterion=criterion)
# Set learning rate scheduler
scheduler.step()
print("Epoch: {:03d} Train Loss: {:.3f} Train Acc: {:.3f} Eval Loss: {:.3f} Eval Acc: {:.3f}".format(
epoch, train_loss, train_accuracy, eval_loss, eval_accuracy))
return model
在这之前,你需要准备一个dataloader:
def prepare_dataloader(num_workers=8, train_batch_size=128, eval_batch_size=256):
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
# transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
transforms.Normalize(mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225))
])
test_transform = transforms.Compose([
transforms.ToTensor(),
# transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
transforms.Normalize(mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225))
])
train_set = torchvision.datasets.CIFAR10(
root="data", train=True, download=True, transform=train_transform)
# We will use test set for validation and test in this project.
# Do not use test set for validation in practice!
test_set = torchvision.datasets.CIFAR10(
root="data", train=False, download=True, transform=test_transform)
train_sampler = torch.utils.data.RandomSampler(train_set)
test_sampler = torch.utils.data.SequentialSampler(test_set)
train_loader = torch.utils.data.DataLoader(
dataset=train_set, batch_size=train_batch_size,
sampler=train_sampler, num_workers=num_workers)
test_loader = torch.utils.data.DataLoader(
dataset=test_set, batch_size=eval_batch_size,
sampler=test_sampler, num_workers=num_workers)
return train_loader, test_loader
train_loader, test_loader = prepare_dataloader(
num_workers=8, train_batch_size=128, eval_batch_size=256)
可以看到, QAT 模型的训练实际上就是在fp32的结果上微调. 最终会得到一个收敛值, 这个过程的epoch大概只要10-20就足够了. 然后你就可以测试一下这个模型的精度.
插曲
由于pytorch quantization之后的JIT不支持直接的tensor相加, 需要使用 tensor.add()或者FloatFunctional, 因此在resnet跳线链接的地方需要修改这个操作:
self.skip_add = nn.quantized.FloatFunctional()
具体为什么这么说,是为了兼容JIT, 因为我们要把int8模型的trace出来就要通过JIT 才能得到最大的加速. 我猜测如果要导出到ONNX亦或者其他格式,也需要类似的操作.
未完待续
这篇文章大概就讲述了如何去训练量化模型, 我们后面的更新将会进一步阐述这么做的优势, 以及int8的模型加速效果. 当然最终我们也会在一些较大的模型比如Yolov5上,实现类似的操作. 尽请期待.
本篇文章所有代码:
代码持续更新中...
更多
如果你想学习人工智能,对前沿的AI技术比较感兴趣,可以加入我们的知识星球,获取第一时间资讯,前沿学术动态,业界新闻等等!你的支持将会鼓励我们更频繁的创作,我们也会帮助你开启更深入的深度学习之旅!
往期文章
zhuanlan.zhihu.com/p/165009477
zhuanlan.zhihu.com/p/149398749