pytorch像ts的model.summary()一样打印模型结构

599 阅读1分钟

pytorch的API风格和tensorflow有很大不同,因此原生没有提供像model.summary这样的打印模型结构的功能。在pytorch中,打印模型结构大概有四种方式:print、torchinfo、torchsummary、torchsummaryx

print

直接通过print打印模型对象,按顺序打印了在__init__中定义的所有层,但是没有按模型的执行顺序打印,复用的模块也只会打印一次。

torchinfo

最近更新时间为2023-05-15

当仅传入model参数时,比print方式打印出了训练的参数量,但是也是按模型组件的定义顺序打印,复用的模块只会打印一次。

当传入input_size等参数,打印就会按照实际的执行顺序了,并且给出了估算的参数数量和内存占用,速度是最慢的一个。

torchsummary

最近更新时间为2018-09-26

torchsummary打印模型参数时,模型的输入形状是必传的,因此打印是按照实际的执行顺序打印,也给出了估算的参数数量和内存占用(但是只能传单条数据的shape,无法估计整个batch的内存占用),速度比torchinfo带Input_size快很多,但没有不带的快。

torchsummaryX

最近更新时间为2019-07-07

似乎使用了旧版的pandas,已经不兼容最新版本的pandas了

总结

通过测试发现,简单的打印可以使用torchsummary,要想内存估计比较真实,并且对时间不敏感可以使用torchinfo,但一定要传入input_size参数。不传参数和原生的print因为打印出的和实际模型执行顺序不同,基本没有意义。