近期在看vm-unet的代码,在环境配置上就花了很多时间,下面将主要步骤和遇到的问题做一总结。另外,mamba2已经出现,其实现采用了python环境下的triton包(用于cuda调用,但只能在linux使用,windows out),并且vmamba(鹏程实验室-王耀伟 华为-谢凌曦 中国科学院大学-叶齐祥)代码中已经有应用mamba2的更新,感叹人家大厂做事就是快,有章法。
首先看一下整体代码结构。engine_synapse.py和engine.py分别是train_synapse.py和train.py调用的用于模型训练的函数,这里不作过多介绍。utils.py中包含诸如loss函数、优化器、数据增强、指标计算等函数。configs中采用构建python类的方式保存模型配置参数。笔者原来都采用yaml格式保存配置参数,但这种方式无法增加if等判断语句,保存信息是死的。而采用python类保存可以用灵活的方式保存配置参数,这不失为一种好办法。data文件夹保存训练数据。dataset文件夹包含dataset.py文件,该文件包含pytorch的Dataset类实现,用于训练过程中的数据读取和预处理。在models/vmunet文件夹下包含 vmamba.py和vmunet.py两个文件。其中vmamba.py来自于vmamba模型的代码实现(github.com/MzeroMiko/V…
需要说明的是,由于vm-unet出现较早,那时mamba2还未问世,vmamba自然也没有更新其提出的SS2D算法的mamba2版本。因此,vm-unet采用以c++调用cuda绑定python实现的mamba,其配置需要综合考虑电脑系统、python版本、cuda版本、cudnn版本等问题,并且由于本人采用实验室服务器做实验,不敢直接改服务器的cuda版本,因此需要找非root用户cuda配置方法,这又无疑增加了巨大的试错成本,配置体验极差。
下面具体说一下环境配置过程。首先我们看一下vm-unet的README.md中对这一部分的描述。
## 0. Main Environments
```bash
conda create -n vmunet python=3.8
conda activate vmunet
pip install torch==1.13.0 torchvision==0.14.0 torchaudio==0.13.0 --extra-index-url https://download.pytorch.org/whl/cu117
pip install packaging
pip install timm==0.4.12
pip install pytest chardet yacs termcolor
pip install submitit tensorboardX
pip install triton==2.0.0
pip install causal_conv1d==1.0.0 # causal_conv1d-1.0.0+cu118torch1.13cxx11abiFALSE-cp38-cp38-linux_x86_64.whl
pip install mamba_ssm==1.0.1 # mmamba_ssm-1.0.1+cu118torch1.13cxx11abiFALSE-cp38-cp38-linux_x86_64.whl
pip install scikit-learn matplotlib thop h5py SimpleITK scikit-image medpy yacs
The .whl files of causal_conv1d and mamba_ssm could be found here. {Baidu}
其中主要步骤包括:1.创建并激活conda环境,python版本3.8。2.安装torch等必要库。3.安装triton(没有显式调用),证据如下图,全文没有显式调用triton相关函数。3.安装causal\_conv1d。5.安装mamba\_ssm。

注意,根据后来捋顺的关系,triton是mamba\_ssm的依赖。causal\_conv1d又是mamba\_ssm的依赖。因此顺序必须是3-4-5,所以说,作者虽然啥也没说,但胜在严谨。
一开始我打算用已有的conda环境安装,毕竟pytorch啥的都不用重装,但最终失败。原因是多方面的,例如python版本与其他包冲突、cuda版本不对等等问题。就遇到了cuda版本与mamba\_ssm不匹配导致无法编译的问题。解决办法是严格对应版本并采用非root方式安装cuda等。具体步骤如下:
1\. 确定电脑系统 型号
输入 uname -a 和 cat /proc/version 来查看系统信息。在如下网址查找
GPU与CUDA对应关系
2. cuda安装
cuda安装可以采用conda直接安装,本人配置选择pytorch 1.13.0 cuda 11.6,也可以直接在官网下载包本地安装。
conda install pytorch==1.13.0 torchvision==0.14.0 torchaudio==0.13.0 pytorch-cuda=11.6 -c pytorch -c nvidia
注意,这个命令来自于pytorch官网,但要去历史版本中找安装命令,网址如下:
接下来非root安装cuda和cudnn的步骤我找到一个较完整的教程,网址如下,但强烈建议**看完本文再去尝试,否则出错概率较大**:
zhuanlan.zhihu.com/p/696863541…
cuda与cudnn下载网址:
cuda下载 developer.nvidia.com/cuda-downlo… cuDNN deb 版本下载 developer.nvidia.com/cudnn-downl… cuDNN tar版本下载(本文操作必须用这个!) developer.download.nvidia.cn/compute/cud…

上图是cuDNN页面中 tar包 的位置。
在上述非root安装教程中,需要注意一个点,在 ~/.bashrc 中增加下面语句一定要注意,千万不要把语法搞错了,否则可能会导致你的 ll、ls、vim等命令全部消失不见(因为PATH被覆盖了),别问我怎么知道的。解决办法是用vscode打开 ~/.bashrc 进行修改。为什么不直接用vim改回来?因为它找不到啊。
下面有两种写法均可
写法1:
export CUDA_HOME=/home/你的文件路径/cuda-11.6 export PATH=LD_LIBRARY_PATH:/home/你的文件路径/cuda-11.6/lib64
写法2:
export CUDA_HOME=/home/你的文件路径/cuda-12.1 export PATH=PATH export LD_LIBRARY_PATH=home/你的文件路径/cuda-12.1/lib64:$LD_LIBRARY_PATH
下面总结:
1. 如果不是具有多个值的环境变量,可以直接采用 = 进行赋值。
2. 如果是有多个值的环境变量,在其前/后追加值,如PATH,需要用 $PATH 表示已有的值,用 : 表示要进行添加操作。A:b 表示在A后面追加b,b:A 则表示在A前面追加b。
走到这一步,再按照步骤安装triton、causal\_conv1d、mamba\_ssm即可。
接下来说一下,mamba2(Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality)已经出现,官方代码地址:
偷偷说一句,vmamba(https://github.com/MzeroMiko/VMamba)中更新了mamba2版本,其实就是在VMamba/classification/models/mamba2 目录下,将上述mamba\_ssm/ops/tritotriton文件夹的所有代码一股脑拷贝过来,并且在此基础上构建高级模块。
mamba感觉争议还是很大的,目前还没有重量级的结果见刊,都是挂在arxiv上面。和身边人讨论,也是唱衰的多,有的说慢,有的说结果不好,还是让子弹再飞一会吧。
> 本文使用 [文章同步助手](https://juejin.cn/post/6940875049587097631) 同步