PyTorch对抗生成网络模型及Android端的实现

223 阅读8分钟

前言

学习和了解使用 PyTorch 实现生成对抗网络(GAN),训练一个简单的图片生成模型,并尝试在 Android 端运行该模型,在手机上直接推理生成图片。

本文不深入具体原理,纯碎从应用角度出发,重点在于端侧实现 GAN 的推理

Pytorch 实现 GAN

什么是 GAN

GAN(Generative Adversarial Networks,生成对抗网络)是一种深度学习模型,由Ian Goodfellow等人于2014年提出。其核心思想是通过两个神经网络——生成器(Generator)和判别器(Discriminator)的对抗训练,生成逼真的数据样本。

核心原理

  • 生成器(Generator)​​:接收随机噪声输入,尝试生成与真实数据分布相似的样本(如图像、文本等),目标是“欺骗”判别器。
  • 判别器(Discriminator)​​:判断输入数据是来自真实数据集还是生成器,目标是准确区分真假。
  • 对抗训练​​:两者通过博弈优化(极小极大博弈),最终生成器能产生以假乱真的样本,判别器则难以区分(理想状态下判别器判断概率为0.5)

众所周知,神经网络的学习或者说训练,就是从训练数据中自动获取最优权重参数的过程。而为了获得这个最优参数,损失函数的选择和实现尤为关键。

GAN 的实现

生成器

class Generator(nn.Module):
    def __init__(self, nc=3, nz=100, ngf=64):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # 输入维度 100 x 1 x 1
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # 特征维度 (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # 特征维度 (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # 特征维度 (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # 特征维度 (ngf) x 32 x 32
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # 特征维度. (nc) x 64 x 64
        )
        self.apply(weights_init)

    def forward(self, input_):
        b, dim_z = input_.shape
        input_ = input_.view(b, dim_z, 1, 1)
        return self.main(input_)

生成器接收一个(1x100)的随机噪声输入,然后通过转置卷积层不断上采样,最终实现输出结果为 3x64x64 大小的图像。

判别器

class Discriminator(nn.Module):
    def __init__(self, nc=3, ndf=64):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # 输入维度 (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # 特征维度 (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # 特征维度 (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # 特征维度 (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # 特征维度 (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )
        self.apply(weights_init)

    def forward(self, input):
        return self.main(input)

判别器则恰好相反,输入为 3x64x64 的数据,通过卷积层下采样,输出结果为图像是否为真实图像的概率。

GAN 的训练

找到合适的训练集之后,我们就可以进行训练了。对于 64x64 这样一个小规模的图像生成模型,用消费级的显卡训练大半天的时间其实就可以输出一个相对不错的模型来。

BATCH_SIZE = 256
WORKER = 1
LR = 0.0002
NZ = 100
num_epochs = 500

dataset = AnimeDataset(dataset_path=DATA_DIR, image_size=IMAGE_SIZE)
data_loader = data.DataLoader(dataset, batch_size=BATCH_SIZE,
                              shuffle=True, num_workers=WORKER)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
netG = Generator().to(device)
netD = Discriminator().to(device)
criterion = nn.BCELoss()
real_label = 1.
fake_label = 0.
optimizerD = optim.Adam(netD.parameters(), lr=LR, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=LR, betas=(0.5, 0.999))

g_writer = LossWriter(save_path=LOG_G_PATH)
d_writer = LossWriter(save_path=LOG_D_PATH)

img_list = []
G_losses = []
D_losses = []
log = Logger(name="my_logger")


def train():
    log.info(f"开始训练 {dataset.__len__()} {module_name}")

    iters = 0
    for epoch in range(num_epochs):
        for data in data_loader:
            #################################################
            # 1. 更新判别器D: 最大化 log(D(x)) + log(1 - D(G(z)))
            # 等同于最小化 - log(D(x)) - log(1 - D(G(z)))
            #################################################
            netD.zero_grad()
            # 1.1 来自数据集的样本
            real_imgs = data.to(device)
            b_size = real_imgs.size(0)
            label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
            # 使用鉴别器对数据集样本做判断
            output = netD(real_imgs).view(-1)
            # 计算交叉熵损失 -log(D(x))
            errD_real = criterion(output, label)
            # 对判别器进行梯度回传
            errD_real.backward()
            D_x = output.mean().item()

            # 1.2 生成随机向量
            noise = torch.randn(b_size, NZ, device=device)
            # 来自生成器生成的样本
            fake = netG(noise)
            label.fill_(fake_label)
            # 使用鉴别器对生成器生成样本做判断
            output = netD(fake.detach()).view(-1)
            # 计算交叉熵损失 -log(1 - D(G(z)))
            errD_fake = criterion(output, label)
            # 对判别器进行梯度回传
            errD_fake.backward()
            D_G_z1 = output.mean().item()

            # 对判别器计算总梯度,-log(D(x))-log(1 - D(G(z)))
            errD = errD_real + errD_fake
            # 更新判别器
            optimizerD.step()

            #################################################
            # 2. 更新判别器G: 最小化 log(D(x)) + log(1 - D(G(z))),
            # 等同于最小化log(1 - D(G(z))),即最小化-log(D(G(z)))
            # 也就等同于最小化-(log(D(G(z)))*1+log(1-D(G(z)))*0)
            # 令生成器样本标签值为1,上式就满足了交叉熵的定义
            #################################################
            netG.zero_grad()
            # 对于生成器训练,令生成器生成的样本为真,
            label.fill_(real_label)
            # 输入生成器的生成的假样本
            output = netD(fake).view(-1)
            # 对生成器计算损失
            errG = criterion(output, label)
            # 对生成器进行梯度回传
            errG.backward()
            D_G_z2 = output.mean().item()
            # 更新生成器
            optimizerG.step()

            # 输出损失状态
            if iters % 100 == 0:
                log.info('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                      % (epoch, num_epochs, iters, len(data_loader),
                         errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
                d_writer.add(loss=errD.item(), i=iters)
                g_writer.add(loss=errG.item(), i=iters)
                save_image(fake.data[:25], f"build/{epoch}_{iters}.png",
                           nrow=5, normalize=True)

            iters += 1
        if epoch % 50 == 0:
            save_filename = 'net_%s.pth' % epoch
            save_path = f"./build/{save_filename}"
            torch.save(netG.state_dict(), save_path)
    torch.save(netG.state_dict(), MODEL_G_PATH)

训练完成之后,可以简单测试一下训练的效果

def gen_multi():
    net = Generator().eval()
    x = torch.randn(25, 100)
    data = net(x).detach()
    images = recover_image(data)
    full_image = np.full((5 * 64, 5 * 64, 3), 0, dtype="uint8")
    for i in range(25):
        row = i // 5
        col = i % 5
        full_image[row * 64:(row + 1) * 64, col * 64:(col + 1) * 64, :] = images[i]

    plt.imshow(full_image)
    plt.show()

train_result.png

可以看到训练生成的结果已经有点那个意思了,有些图片其实还是能看的。Pytroch 训练生成的模型无法直接在移动端使用,还需要做一下转换。

转换为移动端可用的模型

import torch
from model1 import Generator
from torch.utils.mobile_optimizer import optimize_for_mobile

MODEL_PATH = "./build/Net_G.pth"
net = Generator().eval()
net.load_state_dict(torch.load(MODEL_PATH, map_location='cpu'))

x = torch.randn(1, 100)
traced_script_module = torch.jit.trace(func=net, example_inputs=x)
traced_script_module = optimize_for_mobile(traced_script_module)
traced_script_module._save_for_lite_interpreter("./build/dcgan64.pt")、

GAN 在移动端推理实现

PyTorch 有移动端的 SDK 可以用来加载模型并进行推理。因此,只要有了模型文件,我们唯一需要做的事情就是把确保推理模型获取正确的输入,对推理的结果可以正确的进行处理。

模型初始化和加载

首先我们添加 PyTorch Android 端的依赖

    implementation 'org.pytorch:pytorch_android_lite:2.1.0'
    implementation 'org.pytorch:pytorch_android_torchvision_lite:2.1.0'

一般情况下,将模型文件放在 assets 目录下即可,通过 PyTorch 提供的接口加载推理模型

    private fun initModel() {
        module = LiteModuleLoader.load(AndroidAssetsFileUtil.assetFilePath(this, "dcgan.pt"))
    }

输入数据进行推理

    private fun genImage(): Bitmap {
        val zDim = intArrayOf(1, 100)
        val outDims = intArrayOf(64, 64, 3)
        val z = FloatArray(zDim[0] * zDim[1])
        val rand = Random()
        // 生成高斯随机数
        for (c in 0 until zDim[0] * zDim[1]) {
            z[c] = rand.nextGaussian().toFloat()
        }

        val shape = longArrayOf(1, 100)
        val tensor = Tensor.fromBlob(z, shape)
        // 用模型进行推理
        val resultT = module.forward(IValue.from(tensor)).toTensor()
        val resultArray = resultT.dataAsFloatArray
        val resultImg = Array(outDims[0]) { Array(outDims[1]) { FloatArray(outDims[2]) { 0.0f } } }
      
        // 根据输出的一维数组,解析生成的卡通图像
        for (j in 0 until outDims[2]) {
            for (k in 0 until outDims[0]) {
                for (m in 0 until outDims[1]) {
                    resultImg[k][m][j] = resultArray[index] * 127.5f + 127.5f
                }
            }
        }
        val bitmap = Utils.getBitmap(resultImg, outDims)
        return bitmap
    }

这里我们按照模型训练时的约束,现定义好输入参数和输出参数的格式。输入是一个 1x100 的随机数组,输出是 64x64x3 的数组。然后通过调用 module.forward(IValue.from(tensor)).toTensor() 进行推理,同时将返回的结果转换为 Tensor 类型。这里对输出结果的处理需要注意两点

  1. 首先输出结果 dataAsFloatArray 是 [-1,1] 之间的小数,而 Bitmap 需要的数据是 0~255 之间的整数值,因此需要将模型推理的结果进行归一化处理,转换为 Bitmap 可以接收的数据。
  2. 其次,返回结果 dataAsFloatArray 是一维数组,同时在 PyTorch 中处理图像时,通常使用 CHW(Channel-Height-Width) 格式的 Tensor,而 Android 的 Bitmap 使用 HWC(Height-Width-Channel) 格式 。因此,数据需要进行一次转换。
    public static Bitmap getBitmap(float[][][] image_array, int[] dim_info) {
        int count = 0;
        int[] color_info = new int[dim_info[0] * dim_info[1]];
        // 遍历图像,获取颜色信息
        for (int i = 0; i < dim_info[0]; i++) {
            for (int j = 0; j < dim_info[1]; j++) {
                float[] arr = image_array[i][j];
                int alpha = 255;
                int red = (int) arr[0];
                int green = (int) arr[1];
                int blue = (int) arr[2];
                int tempARGB = (alpha << 24) | (red << 16) | (green << 8) | blue;
                color_info[count++] = tempARGB;
            }
        }
        // 创建bitmap对象
        return Bitmap.createBitmap(color_info, dim_info[0], dim_info[1], Bitmap.Config.ARGB_8888);
    }

我们批量生成一些图片看看实际效果如何

            AsyncExecutor.fromIO().execute {
                bitmapList.clear()
                for (i in 0 until 120) {
                    val bitmap = genImage()
                    bitmapList.add(bitmap)
                }
                Log.i(TAG, "done")
                runOnUiThread {
                    Log.i(TAG, "notify")
                    adapter.notifyDataSetChanged()
                }
            }

output.png

在手机上,生成 120 张这种规模的图片耗时不到 1 秒,可以说是非常快。同时可以看到生成图片有那么几张还是很不错的。

小结

通过生成对抗网络进行简单的搭建和训练,其实也能训练出效果相对来说比较稳定的图像生成模型。同时这类模型可以用于在手机上进行推理生图,比如生成表情包、emoji 的场景,完全可以通过数据集训练一个特定的模型来实现。这样即便在弱网或者离线场景,也能使用模型推理功能,而且数据完全都在本地,对用户隐私来说也更加的全安。