数字人原理详解--从头开始实现数字人(二)

454 阅读11分钟

“本文为稀土掘金技术社区首发签约文章,30天内禁止转载,30天后未获授权禁止转载,侵权必究!”

一、前言

在上一篇,我们实现了数字人人脸识别、音频特征提取、数据预处理、数据集类等代码。如果还没有阅读可以先阅读《数字人原理详解--从头开始实现数字人(一)》

这里我们简单回顾一下数据集类的返回内容。

数据集会返回三个内容,分别是x,mel_window和gt_window。其中X是masked_image和reference_image的组合,里面包含了真实人脸的上半部分和参考人脸的整体。x的形状为(batch_size,sync_t,6,96,96)。然后是mel_window,mel_window是与gt_window对应的音频特征序列,其形状为(batch_size,sync_t,1,80,16)。最后是gt_window,gt_window是目标人脸的图像序列,其形状为(batch_size,sync_t,3,96,96)。

下面根据上述情况,实现我们的模型。

二、模型结构

2.1 Unet结构

这里我们选择和Wav2Lip一致的网络结构。网络包括face_encoder、audio_encoder和face_decoder三个部分,整体为一个典型的UNet结构。

先看一下Unet的原始结果,如图所示:

image.png

上图网络包含encoder和decoder两个部分。其中encoder分为3个Block(每个Block做连续卷积),在encoder执行时会将每个Block的输出保存到feats。

encoder最终将图像编码成1024的向量,然后交给decoder。decoder包含与encoder相同数量的Block,decoder部分会把上一个输出和feats中与之对应的部分合并,然后最终输出图像。

2.2 Wav2Lip结构

Wav2Lip做了一些修改,具体如下:

  1. 增加了Block的数量
  2. decoder首次不输入encoder的输出,而是输入audio_encoder的输出

其余则和Unet基本一致。下面我们动手实现Wav2Lip网络。

(1)face_encoder

face_encoder部分就是连续多次卷积操作。在数据集中,我们x的形状为(batch_size,sync_t,6,96,96),是推理时,我们会将其reshape成(batch_size*sync_t,6,96,96)。最后输出一个(batch_size*sync_t,512,1,1)的特征。我们按照这个要求来设计face_encoder:

from torch imporrt nn


class Conv2d(nn.Module):
    def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.conv_block = nn.Sequential(
            nn.Conv2d(cin, cout, kernel_size, stride, padding),
            nn.BatchNorm2d(cout)
        )
        self.act = nn.ReLU()
        self.residual = residual

    def forward(self, x):
        out = self.conv_block(x)
        if self.residual:
            out += x
        return self.act(out)


self.face_encoder_blocks = nn.ModuleList([
    nn.Sequential(
        Conv2d(6, 16, kernel_size=7, stride=1, padding=3)
    ),  # 96,96
    nn.Sequential(
        Conv2d(16, 32, kernel_size=3, stride=2, padding=1),  # 48,48
        Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
        Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True)
    ),
    nn.Sequential(
        Conv2d(32, 64, kernel_size=3, stride=2, padding=1),  # 24,24
        Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
        Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
        Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True)
    ),
    nn.Sequential(
        Conv2d(64, 128, kernel_size=3, stride=2, padding=1),  # 12,12
        Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
        Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True)
    ),
    nn.Sequential(
        Conv2d(128, 256, kernel_size=3, stride=2, padding=1),  # 6,6
        Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
        Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True)
    ),
    nn.Sequential(
        Conv2d(256, 512, kernel_size=3, stride=2, padding=1),  # 3,3
        Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True)
    ),
    nn.Sequential(
        Conv2d(512, 512, kernel_size=3, stride=1, padding=0),  # 1, 1
        nn.Conv2d(512, 512, kernel_size=1, stride=1, padding=0)
    )
])

这里我们定义了一个Conv2D层,其作用就是在原有的nn.Conv2D中加入了残差连接和BatchNormlization。然后face_encoder定义了7个Block,每个Block输出的形状为:

torch.Size([1, 16, 96, 96])
torch.Size([1, 32, 48, 48])
torch.Size([1, 64, 24, 24])
torch.Size([1, 128, 12, 12])
torch.Size([1, 256, 6, 6])
torch.Size([1, 512, 3, 3])
torch.Size([1, 512, 1, 1])

我们会将上述每个输出放入feats列表,供解码时使用。

(2)audio_encoder

audio_encoder输入mel_window,输出和face_encoder形状一致的特征向量,其结构如下:


self.audio_encoder = nn.Sequential(
    Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
    Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
    Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),

    Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
    Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
    Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),

    Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
    Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
    Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),

    Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
    Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),

    nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
    nn.Conv2d(512, 512, kernel_size=1, stride=1, padding=0)
)

这里我们只有在face_decoder最开始输入需要用到音频特征向量,所以直接一次性输出。输出形状如下:

torch.Size([1, 512, 1, 1])

(3)face_decoder_blocks

face_decoder_blocks第一个Block输入音频特征向量,输出向量o1,然后将o1和feats[-1]合并,作为第二个Block的输入,然后依次往后。

decoder的操作是将图像从小放大,因此我们选择使用转置卷积,同样我们编写一个Conv2dTranspose供后面使用。代码如下:

class Conv2dTranspose(nn.Module):
    def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.conv_block = nn.Sequential(
            nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding),
            nn.BatchNorm2d(cout)
        )
        self.act = nn.ReLU()

    def forward(self, x):
        out = self.conv_block(x)
        return self.act(out)
        
self.face_decoder_blocks = nn.ModuleList([
    nn.Sequential(
        nn.Conv2d(512, 512, kernel_size=1, stride=1, padding=0)
    ),
    nn.Sequential(
        Conv2dTranspose(1024, 512, kernel_size=3, stride=1, padding=0),  # 3,3
        Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True)
    ),
    nn.Sequential(
        Conv2dTranspose(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1),
        Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
        Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True)
    ),  # 6, 6
    nn.Sequential(
        Conv2dTranspose(768, 384, kernel_size=3, stride=2, padding=1, output_padding=1),
        Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True),
        Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True)
    ),  # 12, 12
    nn.Sequential(
        Conv2dTranspose(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
        Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
        Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True)
    ),  # 24, 24
    nn.Sequential(
        Conv2dTranspose(320, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
        Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
        Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True)
    ),  # 48, 48
    nn.Sequential(
        Conv2dTranspose(160, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
        Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
        Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True)
    )  # 96,96
])

decoder的最后输出形状为(1,64,96,96),但是我们会将它与face_encoder的第一个Block的输出(形状为1,16,96,96)合并,最终得到(1,80,96,96)的特征图。最后我们用一个输出层将其转换成图像的形状,代码如下:

self.output_block = nn.Sequential(
    Conv2d(80, 32, kernel_size=3, stride=1, padding=1),
    nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0),
    nn.Sigmoid()
)

因为我们输入的图像是经过归一化的,其取值在0-1之间。所以这里选择用sigmoid作为最终输出的激活函数。

(4)正向传播

最后我们实现正向传播的代码,代码如下:

def forward(self, audio_sequences, face_sequences):
    # 数据集返回的形状是(batch_size,sync_t,c,h,w),因此需要将其reshape
    if face_sequences.ndim > 4:
        face_sequences = face_sequences.view(-1, 6, 96, 96)
    if audio_sequences.ndim > 4:
        audio_sequences = audio_sequences.view(-1, 1, 80, 16)
    
    # 获取音频特征向量
    audio_embedding = self.audio_encoder(audio_sequences)  # B, 512, 1, 1
    
    # 获取face_encoder的各个特征图
    feats = []
    x = face_sequences
    for f in self.face_encoder_blocks:
        x = f(x)
        feats.append(x)
    
    # 输入音频特征和feats,解码人脸
    x = audio_embedding
    for f in self.face_decoder_blocks:
        x = f(x)
        x = torch.cat((x, feats[-1]), dim=1)
        feats.pop()
    # 得到输出
    x = self.output_block(x)
    return x

下面我们测试一下我们的网络:

model = Wav2Lip()
f = torch.randn(1, 6, 96, 96)
a = torch.randn(1, 1, 80, 16)

o = model(a, f)
print(o.shape)

输出结果如下:

torch.Size([1, 3, 96, 96])

三、训练模型

现在我们有了数据集和模型本身,接下来就可以开始训练模型了。首先我们需要加载模型、数据集、优化器、损失函数:

from pathlib import Path
from os.path import join

import cv2
import torch
import numpy as np
from tqdm import tqdm
from torch import optim, nn
from torch.utils.data import DataLoader
from torchvision.utils import make_grid

# 前面实现的模型和数据集
from wav2lip.models import Wav2Lip
from wav2lip.datasets import Wav2LipDataset



def main():
    global global_step
    # 获取路径
    base_dir = Path(__file__).parent.parent
    dataset_dir = base_dir / 'processed'
    checkpoint_dir = base_dir / 'outputs/wav2lip_checkpoints'

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    train_loader = DataLoader(Wav2LipDataset(dataset_dir, 'train'), batch_size=16)
    test_loader = DataLoader(Wav2LipDataset(dataset_dir, 'val'), batch_size=16)
    model = Wav2Lip().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    train(
        model, train_loader, test_loader, optimizer, n_epochs=10, loss_fn=nn.L1Loss(),
        checkpoint_dir=checkpoint_dir, checkpoint_interval=500, device=device
    )


if __name__ == '__main__':
    main()

上面没有什么特别的代码,接下来我们来看看训练代码:

def train(
        model, train_loader, test_loader, optimizer, loss_fn,
        checkpoint_dir=None, checkpoint_interval=None, n_epochs=20, device=None
):
    global global_step
    for epoch in range(n_epochs):
        losses = []
        pbar = tqdm(enumerate(train_loader), total=len(train_loader))
        for step, (x, mel, y) in pbar:
            # 在数据集中,数据形状为(batch_size,sync_t,3,96,96),所以需要修改形状
            x = x.view(-1, 6, 96, 96)
            mel = mel.view(-1, 1, 80, 16)
            y = y.view(-1, 3, 96, 96)
            x, mel, y = x.to(device), mel.to(device), y.to(device)
            # 设置为训练模型
            model.train()
            # 清空梯度
            optimizer.zero_grad()

            # 前向传播
            preds = model(mel, x)
            loss = loss_fn(preds, y)
            # 反向传播
            loss.backward()
            # 更新参数
            optimizer.step()

            global_step += 1
            losses.append(loss.item())

            # 每迭代checkpoint_interval次,评估一次模型,并保存一个checkpoints
            if global_step == 1 or global_step % checkpoint_interval == 0:
                print(f"Train loss: {sum(losses) / len(losses)}")
                save_checkpoint(
                    model, optimizer, checkpoint_dir, epoch, global_step
                )
                save_sample_images(x, preds, y, global_step, checkpoint_dir)
                eval_model(model, test_loader, loss_fn, device=device)

在拿到数据后,我们对数据形状做了修改,这里我们把sync_t维度当做batch_size来看。其余部分则是常规的训练代码。另外在评估部分有三个函数,下面分别看看,首先是保存模型的代码。这里我们保存了模型本身、optimizer这些信息:

def save_checkpoint(model, optimizer, checkpoint_dir, epoch):
    global global_step
    checkpoint_path = join(
        checkpoint_dir, "checkpoint_step{:09d}.pth".format(global_step))
    optimizer_state = optimizer.state_dict()
    torch.save({
        "state_dict": model.state_dict(),
        "optimizer": optimizer_state,
        "global_step": global_step,
        "global_epoch": epoch,
    }, checkpoint_path)
    print("Saved checkpoint:", checkpoint_path)

然后是评估的函数,评估函数基本和训练一致,只需要注意将模型设置为评估模式即可:

@torch.no_grad()
def eval_model(model, test_loader, loss_fn, eval_steps=3000, device=None):
    # 将模型设置为评估模式
    model.eval()
    losses = []
    for step, (x, mel, y) in enumerate(test_loader):
        x = x.view(-1, 6, 96, 96)
        mel = mel.view(-1, 1, 80, 16)
        y = y.view(-1, 3, 96, 96)
        x, mel, y = x.to(device), mel.to(device), y.to(device)
        preds = model(mel, x)
        loss = loss_fn(preds, y)
        losses.append(loss.item())
        if step >= eval_steps:
            break
    averaged_loss = sum(losses) / len(losses)
    model.train()
    print(f"Evaluate loss: {averaged_loss}")
    return averaged_loss

这里基本和训练代码基本一致。最快看看保存样本的函数:

def save_sample_images(x, g, gt, step, checkpoint_dir):
    
    # 将Tensor转换成图像
    def to_image(pt, nrow=5):
        pt = make_grid(pt, nrow=nrow)
        return (pt.detach().cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)

    folder = Path(checkpoint_dir, f'samples_step{step:09d}')
    folder.mkdir(exist_ok=True)

    refs, inps = x[:, 3:, ::], x[:, :3, :, :]
    n_syncs = x.size(0) // 5
    for i in range(n_syncs):
        ref = refs[i * 5: (i + 1) * 5, ...]
        inp = inps[i * 5: (i + 1) * 5, ...]
        gi = g[i * 5: (i + 1) * 5, ...]
        gti = gt[i * 5: (i + 1) * 5, ...]
        image = to_image(torch.concat([ref, inp, gi, gti]), nrow=5)
        cv2.imwrite(str(folder / f'{i:02d}.jpg'), image)

x是我们的输入,前三个通道是masked_image,后三个通道是reference_image,gi为生成的图片,gti为真实图像。下面是一个样本:

12.jpg

其中第三行是预测结果。目前是已经完成了面部的重构,唇形匹配也比较一致,通过加大步数可以得到更好的结果。但是清晰度方面有所降低,在原模型中,提供了GAN版本的Wav2Lip模型可以得到更清晰的结果,这里不展开讨论。

这样我们就完成了wav2lip模型的训练,但是此时模型的推理速度在低显存机器上还比较慢。在后续的内容中我们会对此进行优化,并且编写完善的推理代码。