Mamba相关环境配置

1,247 阅读5分钟

近期在看vm-unet的代码,在环境配置上就花了很多时间,下面将主要步骤和遇到的问题做一总结。另外,mamba2已经出现,其实现采用了python环境下的triton包(用于cuda调用,但只能在linux使用,windows out),并且vmamba(鹏程实验室-王耀伟  华为-谢凌曦 中国科学院大学-叶齐祥)代码中已经有应用mamba2的更新,感叹人家大厂做事就是快,有章法。

首先看一下整体代码结构。engine_synapse.py和engine.py分别是train_synapse.py和train.py调用的用于模型训练的函数,这里不作过多介绍。utils.py中包含诸如loss函数、优化器、数据增强、指标计算等函数。configs中采用构建python类的方式保存模型配置参数。笔者原来都采用yaml格式保存配置参数,但这种方式无法增加if等判断语句,保存信息是死的。而采用python类保存可以用灵活的方式保存配置参数,这不失为一种好办法。data文件夹保存训练数据。dataset文件夹包含dataset.py文件,该文件包含pytorch的Dataset类实现,用于训练过程中的数据读取和预处理。在models/vmunet文件夹下包含 vmamba.py和vmunet.py两个文件。其中vmamba.py来自于vmamba模型的代码实现(github.com/MzeroMiko/V…

需要说明的是,由于vm-unet出现较早,那时mamba2还未问世,vmamba自然也没有更新其提出的SS2D算法的mamba2版本。因此,vm-unet采用以c++调用cuda绑定python实现的mamba,其配置需要综合考虑电脑系统、python版本、cuda版本、cudnn版本等问题,并且由于本人采用实验室服务器做实验,不敢直接改服务器的cuda版本,因此需要找非root用户cuda配置方法,这又无疑增加了巨大的试错成本,配置体验极差。

图片图片

下面具体说一下环境配置过程。首先我们看一下vm-unet的README.md中对这一部分的描述。

## 0. Main Environments
```bash
conda create -n vmunet python=3.8
conda activate vmunet
pip install torch==1.13.0 torchvision==0.14.0 torchaudio==0.13.0 --extra-index-url https://download.pytorch.org/whl/cu117
pip install packaging
pip install timm==0.4.12
pip install pytest chardet yacs termcolor
pip install submitit tensorboardX
pip install triton==2.0.0
pip install causal_conv1d==1.0.0  # causal_conv1d-1.0.0+cu118torch1.13cxx11abiFALSE-cp38-cp38-linux_x86_64.whl
pip install mamba_ssm==1.0.1  # mmamba_ssm-1.0.1+cu118torch1.13cxx11abiFALSE-cp38-cp38-linux_x86_64.whl
pip install scikit-learn matplotlib thop h5py SimpleITK scikit-image medpy yacs

The .whl files of causal_conv1d and mamba_ssm could be found here. {Baidu}


其中主要步骤包括:1.创建并激活conda环境,python版本3.8。2.安装torch等必要库。3.安装triton(没有显式调用),证据如下图,全文没有显式调用triton相关函数。3.安装causal\_conv1d。5.安装mamba\_ssm。

![图片](https://mmbiz.qpic.cn/mmbiz_png/hCeKovh1lGWE2PDFLeCVydpHY1kI60uxvFhJKI6ZoT403zeFWibo3tvIUr7NW9y0uKibJXyRYTq7kArgnbmcwDnA/640?wx_fmt=png&from=appmsg)

注意,根据后来捋顺的关系,triton是mamba\_ssm的依赖。causal\_conv1d又是mamba\_ssm的依赖。因此顺序必须是3-4-5,所以说,作者虽然啥也没说,但胜在严谨![图片](https://res.wx.qq.com/t/wx_fed/we-emoji/res/v1.3.10/assets/Expression/Expression_1@2x.png)。

一开始我打算用已有的conda环境安装,毕竟pytorch啥的都不用重装,但最终失败。原因是多方面的,例如python版本与其他包冲突、cuda版本不对等等问题。就遇到了cuda版本与mamba\_ssm不匹配导致无法编译的问题。解决办法是严格对应版本并采用非root方式安装cuda等。具体步骤如下:

1\. 确定电脑系统 型号

输入  uname -a  和   cat /proc/version 来查看系统信息。在如下网址查找

GPU与CUDA对应关系

docs.nvidia.com/cuda/cuda-t…


2. cuda安装

cuda安装可以采用conda直接安装,本人配置选择pytorch 1.13.0 cuda 11.6,也可以直接在官网下载包本地安装。

conda install pytorch==1.13.0 torchvision==0.14.0 torchaudio==0.13.0 pytorch-cuda=11.6 -c pytorch -c nvidia


注意,这个命令来自于pytorch官网,但要去历史版本中找安装命令,网址如下:  

pytorch.org/get-started…


接下来非root安装cuda和cudnn的步骤我找到一个较完整的教程,网址如下,但强烈建议**看完本文再去尝试,否则出错概率较大**

zhuanlan.zhihu.com/p/696863541…


  

cuda与cudnn下载网址:

cuda下载 developer.nvidia.com/cuda-downlo… cuDNN deb 版本下载 developer.nvidia.com/cudnn-downl… cuDNN tar版本下载(本文操作必须用这个!) developer.download.nvidia.cn/compute/cud…


![图片](https://mmbiz.qpic.cn/mmbiz_png/hCeKovh1lGWE2PDFLeCVydpHY1kI60uxv0biaEXxRXjmpFMdBYzmFBVxPTBthBiciayxOuHibB4RaYk90fOJtVT6gw/640?wx_fmt=png&from=appmsg)

上图是cuDNN页面中  tar包 的位置。

在上述非root安装教程中,需要注意一个点,在 ~/.bashrc 中增加下面语句一定要注意,千万不要把语法搞错了,否则可能会导致你的 ll、ls、vim等命令全部消失不见(因为PATH被覆盖了),别问我怎么知道的。解决办法是用vscode打开 ~/.bashrc 进行修改。为什么不直接用vim改回来?因为它找不到啊。

下面有两种写法均可  
写法1:

export CUDA_HOME=/home/你的文件路径/cuda-11.6 export PATH=PATH:/home/你的文件路径/cuda11.6/binexport LDLIBRARYPATH=PATH:/home/你的文件路径/cuda-11.6/bin export LD_LIBRARY_PATH=LD_LIBRARY_PATH:/home/你的文件路径/cuda-11.6/lib64


写法2:  

export CUDA_HOME=/home/你的文件路径/cuda-12.1 export PATH=CUDAHOME/bin:CUDA_HOME/bin:PATH export LD_LIBRARY_PATH=home/你的文件路径/cuda-12.1/lib64:$LD_LIBRARY_PATH


下面总结:

1.   如果不是具有多个值的环境变量,可以直接采用 = 进行赋值。  
    
2.  如果是有多个值的环境变量,在其前/后追加值,如PATH,需要用 $PATH 表示已有的值,用 : 表示要进行添加操作。A:b 表示在A后面追加bb:A 则表示在A前面追加b。
    

走到这一步,再按照步骤安装tritoncausal\_conv1dmamba\_ssm即可。

接下来说一下,mamba2(Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality)已经出现,官方代码地址:

github.com/state-space…


偷偷说一句,vmamba(https://github.com/MzeroMiko/VMamba)中更新了mamba2版本,其实就是在VMamba/classification/models/mamba2 目录下,将上述mamba\_ssm/ops/tritotriton文件夹的所有代码一股脑拷贝过来,并且在此基础上构建高级模块。

mamba感觉争议还是很大的,目前还没有重量级的结果见刊,都是挂在arxiv上面。和身边人讨论,也是唱衰的多,有的说慢,有的说结果不好,还是让子弹再飞一会吧。  

> 本文使用 [文章同步助手](https://juejin.cn/post/6940875049587097631) 同步