PyTorch实战图形风格迁移

1,336 阅读5分钟

前言

什么是图像风格的迁移?其实现在很多的APP应用中已经普遍存在了,比如让我们选择一张自己的大头照,然后选择一种风格的图片,确认后我们的大头照变成了所选图片类似的风格。

图像风格迁移重点就是找出一张图片的特征,然后将其融合到需要改变的图片中去,如下图所展示的就是一种典型的风格迁移。

所以图像风格迁移实现的难点就在于如何提取一张图片的特征,这里说的特征也就是图像的风格。论文《A Neural Algorithm of Artistic Style》使用了CNN(卷积神经网络)来对图像的风格进行提取。因为我们都知道CNN本来就可以对特征图像进行提取,然后通过特征来实现图像的分类。当我们有了图像风格的提取方法后,只需要将新提取到的风格融入到新的图片中去,就实现了图像风格的迁移。

1、PyTorch核心代码实现

其实代码的核心思想并不复杂,就是利用CNN提取内容图片的内容和风格图片的风格,然后输入一张新的图像。对输入的图像提取出内容和风格与CNN提取的内容和风格进行Loss计算,Loss的度量可以使用MSE,然后逐步对Loss进行优化,使Loss值达到最理想,将被优化的参数进行输出,这样输出的图片就达到了风格迁移的目的。

(一)、计算内容损失

为什么使用卷积提取内容?

下图是我通过一个卷积提取到的其中一个特征映射,说明使用卷积作为内容提取的方法是完全可行的。

计算内容损失的代码如下:

class Content_loss(torch.nn.Module):
    def __init__(self, weight, target):
        super(Content_loss, self).__init__()
        self.weight = weight
        self.target = target.detach()*weight
        self.loss_fn = torch.nn.MSELoss()
        
    def forward(self, input):
        self.loss = self.loss_fn(input*self.weight, self.target)
        self.output = input
        return self.output
        
    def backward(self):
        self.loss.backward(retain_graph = True)
        return self.loss
        

这里的target就是CNN对内容图像提取得到的内容,weight是用来控制内容和风格对input图像的影响程度,这里的input就是我们输入图像,还有定义的backward主要目的其实是为了调用方向传播方法和返回我们计算得到的Loss。Loss计算使用的是MSE来度量。

(二)、计算风格损失

计算风格损失的代码如下:

class Style_loss(torch.nn.Module):
    def __init__(self, weight, target):
        super(Style_loss, self).__init__()
        self.weight = weight
        self.target = target.detach()*weight
        self.loss_fn = torch.nn.MSELoss()
        self.gram = gram_matrix()
        
    def forward(self, input):
        self.output = input.clone()
        self.G = self.gram(input)
        self.G.mul_(self.weight)
        self.loss = self.loss_fn(self.G, self.target)
        return self.output
    def backward(self):
        self.loss.backward(retain_graph = True)
        return self.loss

这里的target、weight、input、backward、Loss使用的意义和之前的内容计算类似,唯一不同的地方是引入了Gram矩阵,通过对CNN提取后的内容进行Gram矩阵运算来定义图像的风格。

为什么Gram矩阵能够定义图像的风格了?

因为CNN卷积过后提取了图像的特征图,每个数字就是原图像的特性大小,而Gram矩阵是矩阵的内积运算,运算过后特征图中越大的数字会变得更大,这就相当于对图像的特性进行了缩放,使得特征突出了,也就相当于提取到了图片的风格。

Gram矩阵的代码如下:

class gram_matrix(torch.nn.Module):
    def forward(self, input):
        a,b,c,d = input.size()
        feature = input.view(a*b, c*d)
        gram = torch.mm(feature, feature.t())
        return gram.div(a*b*c*d)

(三)、构建训练CNN

构建新的训练模型代码:

content_layer = ["Conv_5","Conv_6"]

style_layer = ["Conv_1", "Conv_2", "Conv_3", "Conv_4", "Conv_5"]



content_losses = []
style_losses = []

conten_weight = 1
style_weight = 1000

new_model = torch.nn.Sequential()

model = copy.deepcopy(cnn)

gram = gram_matrix()

if use_gpu:
    new_model = new_model.cuda()
    gram = gram.cuda()

index = 1
for layer in list(model):
    if isinstance(layer, torch.nn.Conv2d):
        name = "Conv_"+str(index)
        new_model.add_module(name, layer)
        if name in content_layer:
            target = new_model(content_img).clone()
            content_loss = Content_loss(conten_weight, target)
            new_model.add_module("content_loss_"+str(index), content_loss)
            content_losses.append(content_loss)
            
        if name in style_layer:
            target = new_model(style_img).clone()
            target = gram(target)
            style_loss = Style_loss(style_weight, target)
            new_model.add_module("style_loss_"+str(index), style_loss)
            style_losses.append(style_loss)
            
    if isinstance(layer, torch.nn.ReLU):
        name = "Relu_"+str(index)
        new_model.add_module(name, layer)
        index = index+1
            
    if isinstance(layer, torch.nn.MaxPool2d):
        name = "MaxPool_"+str(index)
        new_model.add_module(name, layer)

要完成风格迁移,我们还需要构建自己的CNN网络。首先迁移了vgg16的模型,剔除了全连接部分,之后就是根据vgg16模型架构重构训练模型,加入了内容和风格Loss的计算部分。这里内容的提取只是选择了5、6层卷积,风格的提取只选择了1、2、3、4、5层卷积。

(四)、定义优化

优化定义代码:

input_img = content_img.clone()

parameter = torch.nn.Parameter(input_img.data)
optimizer = torch.optim.LBFGS([parameter])

这里为什么使用LBFGS来进行优化?

原因是我们要优化的Loss其实是多个,而不是像处理分类问题中只是需要优化一个Loss值,LBFGS能够获得更好的效果。

(五)、训练新定义的CNN

训练代码如下:

n_epoch = 1000

run = [0]
while run[0] <= n_epoch:

    def closure():
        optimizer.zero_grad()
        style_score = 0
        content_score = 0
        parameter.data.clamp_(0,1)
        new_model(parameter)
        for sl in style_losses:
            style_score += sl.backward()
        
        for cl in content_losses:
            content_score += cl.backward()
        
        run[0] += 1 
        if run[0] % 50 == 0:
            print('{} Style Loss : {:4f} Content Loss: {:4f}'.format(run[0],
                 style_score.data[0], content_score.data[0])) 

        return style_score+content_score
    


    optimizer.step(closure)

n_epoch定义了训练次数为1000次,使用sl.backward()和cl.backward()实现了反向传播,对参数进行优化。

2、改进

本文的图像风格迁移的方法没次实现都要进行一轮训练,而且风格调节的方式需要通过weight权重来控制,在实际应用中并不理想,现实中我们需要更加高效智能的实现方式。改进方法已经出现,先放出两篇论文

Fast Patch-based Style Transfer of Arbitrary Style

Visual Attribute Transfer through Deep Image Analogy

代码还在实现中......

参考资料:1、Welcome to PyTorch Tutorials

2、图像风格迁移(Neural Style)简史

完整代码:JaimeTang/PyTorch-and-Neural-style-transfer

如果觉得还行,请点个赞哦......