pytorch计算模型的参数量以及FLOPs

782 阅读1分钟

对于FlOPs,先解释一下其概念:

  • FLOPS:注意全大写,是floating point operations per second的缩写,意指每秒浮点运算次数,理解为计算速度。是一个衡量硬件性能的指标。
  • FLOPs:注意s小写,是floating point operations的缩写(s表复数),意指浮点运算数,理解为计算量。可以用来衡量算法/模型的复杂度。
  • 注意,模型的参数量少不代表FLOPs低,理论的FLOPs低也不代表实际的推理速度快
model = MyModel()







# 使用 thop 库 
from thop import profile
input = torch.randn(1,3,400,400)
flops, params = profile(model, inputs=(input, ))
print("参数量params: %.2fM           计算量flops: %.2fG"    % (params / (1000 ** 2), flops / (1000 ** 3)))        







# 使用 fvcore 库
from fvcore.nn import FlopCountAnalysis, parameter_count_table
tensor = (torch.rand(1, 3, 400, 400),)
flops = FlopCountAnalysis(model, tensor)
print("FLOPs: ", flops.total())      # 分析FLOPs
print(parameter_count_table(model))  # 分析parameters






# 使用 torchstat 库
from torchstat import stat
stat(model,(3,400,400)) 






# 使用 torchsummary 库
from torchsummary import summary
input = torch.randn(3,800,800)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
summary(model.to(device), input_size=(3, 400, 400), batch_size=1)







# 手动计算 parameters
total = sum([param.nelement() for param in model.parameters()])
print(' Number of params: %.2fM' % (total / 1e6))







# 测试推理时间(1000次取平均)
import time
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
input = torch.randn(1,3,400,400).to(device)    # 如果模型和数据都不加cuda,时间会慢很多
model = model.to(device)
sum=0.0
for i in range(1001):
    _,_, _ = model(input)
    start = time.time()
    _,_, _ = model(input)
    end = time.time()
    if i!=0:  # 第一次加载模型时间会很长,避开这一次
        sum+=end-start
print(sum/1000)