承接上文CycleGAN
构建缓存区域存储fake数据
输入A数据,经过G AB网络,生成假的B',将这个假的B'保存到缓存区域,等后面使用的时候(将假的B'输入G BA网络进行还原),再从缓存区域拿出来。
定义损失函数
不同的项目选择不同的损失函数,
损失函数就是计算一下,输入数据和经过GAB合成、GBA还原之后输出数据的相似性
定义优化器
接下来就是训练的过程了
从datasets中一个batch一个batch的取数据,需要把所有数据全部都读进来,如果显存比较低的情况下,执行这一步就error memory了。
计算当前迭代了多少次
指定数据集
把实际的数据指定好,
这一步是真正做损失计算的过程了,
forword 前向传播
前向传播就是把数据输入到生成网络中,生成一个虚假的结果,再把虚假的结果进行还原的过程。
- 第一个是将实际的A传入GAB网络得到假的B
- 第二个是将假的B传入GBA网络得到还原的A(真实A和还原A做loss计算相似性)
- 第三个是将真实的B传入GBA网络得到假的A
- 第四个是将假的A传入GAB网络得到还原的B(真实B和还原B做loss计算相似性)
当训练生成器的时候,需要做一个限制:只训练生成器,判决器不工作,把DA和DB的梯度设置为false即现在不需要判决器更新了,只更新生成器的结果。
接下来是计算反向传播
加了一些权重参数,主要是让数值展示起来更直接,把数值进行放大,相当于把实际的损失值放大了10倍。
将真实的B传入GAB网络生成假的A,计算真实B和假的A之间的损失值,两者的差异越小越好。希望GAB一方面能生成假的B,另一方面也能识别出真实的B。
把生成器生成的假的B传入判决器让它去瞒过判决器
判决器得到N x N的矩阵,那得到的标签也是N x N的
预测结果和真实标签计算损失值
计算真实值和还原值之间的差异
逐个去算它们之间的差异有多大,然后进行累加的到一个损失值
这是把判决器的梯度值设置为true,计算判决器的损失,
对于判决器来说把真实值预测成真实的,
把生成数据预测成假的
计算真实损失值和假的损失值的平均值
安装visdom
修改server.py
安装visdom需要从外网下载资源包,这里注释掉下载的过程
随着迭代的进行,不断的画图