持续创作,加速成长!这是我参与「掘金日新计划 · 10 月更文挑战」的第17天,点击查看活动详情
Semantic-Spatial Aware GAN是2021年10月发表的语义空间感知GAN(SSA-GAN)框架,主要提出:
- 一种语义空间感知卷积网络(SSACN)模块,通过基于当前生成的图像特征预测掩码映射草图,这种掩码图不仅可以决定在何处添加文本信息,还起到了权重作用即决定要在某个部分上加强多少文本信息。
- 一种新的仿射参数计算方法,将掩码图添加到SCBN中作为空间条件,然后从编码的文本向量中学习仿射参数,对语义空间条件进行批量归一化。
文章精读报告:blog.csdn.net/air__Heaven…
本篇文章将使用九天深度学习平台复现SSA-GAN。
一、算力领取
九天深度学习平台是中国移动旗下的一款机器学习平台,提供CPU、V100、T4等高性能计算资源的调度管理,集成主流人工智能开源算法框架、Jupyter lab开发工具、主流的公开数据集,提供参考源代码和预训练模型等,为模型训练、服务部署和在线推理提供一站式服务。
二、复现SSA-GAN
2.1、创建实例
首先我们点击进入中国移动云-九天深度学习平台的控制台页面:
点击左侧notebook建模,然后点击新建实例,创建我们的ssagan模型。
因为我们已经有申请过算力试用,故不用担心费用的问题,直接选择vGPU或者V100套餐创建实例。
创建好的实例如上图所示,点击运行,进入熟悉的juypter界面:
2.2、下载代码和数据集
我们点击最下方的terminal,打开终端,然后使用git克隆代码:git clone https://github.com/wtliao/text2image.git 下载代码成功。
然后下载元数据包,使用命令行cd 进入text2image目录,下载为鸟类准备好的预处理元数据,元数据的谷歌链接打不开可以通过CSDN链接1或者链接2下载,然后将原数据并上传到data目录并使用unzip命令解压成文件夹:
最后下载数据集,我们下载鸟数据集,下载链接:
www.vision.caltech.edu/visipedia/C…
然后使用命令tar zxvf CUB_200_2011.tgz解压保存在data/birds/中:
另外还要使用
unzip text.zip终端命令解压text.zip文件:
2.3、下载预训练的 DAMSM 模型
下载预处理的DAMSM模型,打不开可以访问CSDN链接下载。
然后将其上传到 DAMSMencoders目录下:
同样我们使用
unzip bird.zip命令将其解压。
2.4、环境配置
至此我们需要的资源基本上就已经具备了,下一步我们安装所需的虚拟环境:
首先我们conda create -n ssagan 创建新的虚拟环境,环境名为ssagan(也可以任意取名)
然后可以通过nvcc命令看到cuda版本为10.1,
所以我们首先激活虚拟环境:conda activate ssagan
然后安装pytorch,首先我们使用conda search pytorch,找到可以安装的版本:
因为cuda版本是10.1的所以我们优先找到cuda101的:
然后终端输入:
conda install pytorch=1. 7.1=cuda101py36h42dc283_ 1安装pytorch
以同样的方式安装torchvision
然后根据提示安装其他环境: conda install tensorboardX conda install python-dateutil conda install tqdm conda install matplotlib pip install scikit-image pip install easydict pip install nltk pip install pandas pip install pyyaml
2.5、训练
将bird.yml中的B_VALIDATION改为 False
cd进入text2image目录,输入终端命名:python main.py开始训练
可能出现的报错1:load() missing 1 required positional argument: ‘Loader‘ 解决方案:这是因为.yaml文件在load()时缺少必填的loader参数,只需将 pyyaml 版本降级或者将config.py的
yaml_cfg = edict(yaml.load(f))改为safeload
可能出现的报错2:module 'torchvision.transforms' has no attribute 'Resize' 解决方案:
pip install --upgrade torchvision
可能出现的报错3:TypeError: init() got an unexpected keyword argument 'serialized_options' 解决方案:终端上的 protoc 版本 与python库内的protobuf版本不一样。我们只需要
pip install -U protobuf,如果还是报错建议卸载删除低版本protobuf,再重新安装
可能出现的报错4:urllib.error.HTTPError: HTTP Error 403: Forbidden 问题原因:网站设置了白名单,大部分网站不让访问,故Downloading: "download.pytorch.org/models/ince…" to /root/.cache/torch/hub/checkpoints/inception_v3_google-0cc3c7bd.pth时被拒绝。 解决方案:打开download.pytorch.org/models/ince…
运行成功如下:
也可以下载已经训练好的SSA-GAN模型进行采用生成。
可以看到nf=64时,SSA-GAN的消耗为每轮epoch要14分钟左右,共600轮epoch
三、SSA-GAN原理
SSA-GAN的框架如下:
整体来看,和DF-GAN很像,也是单级主干结构,但是把UPBlocks改成了 SSACN Blocks。SSA-GAN包括一个文本编码器,一个生成器,一个鉴别器,首先由一个随机整体噪声输入,经过FC层和一次Reshape后,连接七个SSACN层,生成图片后输入鉴别器进行鉴别,需要注意的是,在SSA-GAN中,文本编码器不固定参数,其也是生成器的一部分。
论文提出了一种新的用于T2I生成的语义空间感知GAN(SSA-GAN)框架,主要是在生成器上做的工作,创新如下:
- 一种语义空间感知卷积网络(SSACN)模块,通过基于当前生成的图像特征预测掩码映射草图,这种掩码图不仅可以决定在何处添加文本信息,还起到了权重作用即决定要在某个部分上加强多少文本信息。
- 一种新的仿射参数计算方法,将掩码图添加到SCBN中作为空间条件,然后从编码的文本向量中学习仿射参数,对语义空间条件进行批量归一化。
最后
💖 个人简介:人工智能领域研究生,目前主攻文本生成图像(text to image)方向
🔥 限时免费订阅:文本生成图像T2I专栏
🎉 支持我:点赞👍+收藏⭐️+留言📝
如果这篇文章帮助到你很多,希望能点击下方打赏我一杯可乐!多加冰哦