我正在参加「掘金·启航计划」
前言
其实这是一个拖了很久很久的坑,不知道多少人看过我之前的一篇博客关于torch.fx的使用,在这里面我用torch.fx
实现了一些很有趣的功能比如模型可视化.所以当时就有一个想法,把代码封装一下写成一个属于自己的三方库,正好今天有点时间就把这个坑给填上.
这个工具的主要功能很简单,直接指定某个py文件工具会自动寻找文件中所有的nn.Module
并进行解析可视化.
开始
关于模型的trace以及算子的解析在之前的博客中已经写的比较清楚了,这里就不过多赘述.今天的主要内容是封装+Poetry上传.
1 封装
将上次的代码直接拿来用,然后写个函数调用一下
def draw(model: torch.nn.Module, inputs: torch.Tensor, save_dir: str = './Save', save_name: str = 'model'):
graph = model_graph(model, inputs)
graph.render(outfile=save_dir + '/' + save_name, view=False)
然后解析py文件中所有的nn.Module
def parse_py(py_path: str) -> list[nn.Module]:
spec = importlib.util.spec_from_file_location('module_name', py_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
module_classes = [m for _, m in module.__dict__.items() if
isinstance(m, (nn.Module, type)) or issubclass(type(m),nn.Module)]
if len(module_classes) == 0:
raise ValueError('No nn.Module class found in the file.')
return module_classes
for index, model_name in enumerate(model_list):
try:
if not isinstance(model_name, nn.Module):
model = model_name()
else:
model = model_name
args.name=type(model).__name__+'.svg'
draw(model, inputs, save_dir=args.dir, save_name=args.name)
except:
print(f"{model_name} draw failed")
pass
这里为了防止某些子类算子不能实现所有输入需求,其实也是目前设计的不太灵活,所以直接用try简化这部分处理.
2 测试一下
输入的测试文件如下test.py
import torch
from torch import nn
from torchvision.models import resnet18,regnet_x_8gf
class TestConv2d(nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super(TestConv2d, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
class TestModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = TestConv2d(3, 32, kernel_size=3)
self.conv2 = TestConv2d(32, 64, kernel_size=3)
self.dropout = nn.Dropout(0.3)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.dropout(x)
return x
model= resnet18()
model2=regnet_x_8gf()
貌似还可以,接下来就是Poetry打包上传了
3 打包&&上传
首先下载Poetry,根据官网推荐利用curl -sSL https://install.python-poetry.org | python3 -
进行安装,命令细节根据自己情况更改
使用poetry new xxx
新建一个项目,这里项目的结构已经生成好了,只需要把之前的文件复制到对应位置.接下来就是比较重要的一步了,修改.toml文件
这里对这本地环境pip list
把包的相关依赖加进去,关于^与~
的区别可以自行google一下版本命名规则相关的.接下来就去pypi上注册一个帐号,得到用户名和密码
最后进行上传
poetry publish --build -u username -p password
这里我没有去细讲关于poetry lock/poetry build/poetry show --tree
等相关的指令,大家有兴趣可以去看相关内容.
经过一系列操作终于在pypi上看到了我们自己的库,仅仅写了简单的readme没有去写license.
4 下载测试
我们从pypi上把这个库给install下来,并且在本地重新写个测试脚本调用库进行测试
from pytorch_show import main
if __name__ == '__main__':
main()
残暴如此,直接调用main就好,再来写个model.py
from torchvision.models import vgg16
model= vgg16()
直接调用测试python test.py -f model.py
直筒形VGG一切正常
总结
今天也算是小小填了一下去年留下的坑,不过还是有很多遗憾的地方.本来这个工具的初衷除了能解析py文件之外还能对onnx进行绘制,但是今天实践下来有很多坑.一开始的想法是将onnx反转回pytorch模型,但是torch.fx的trace并不能支持动态流,因此很多算子包含if或者for loop的地方都会报错.然后尝试了直接利用onnx.tools.net_drawer
进行绘制,成功得到了图片但是太过于复杂,很多多余的输入都被展示出来严重干扰了核心算子的展示,同时必须借助运行时才能得到每一步的shape,所以这部分想优化还是要好好想想办法,这也算是为下一阶段改进再留一个坑吧.
另外留个彩蛋,下一期会讲讲最近看到Rust与DL相关的内容.