VAE代码实现-使用MNIST数据集

557 阅读3分钟

理论推导见VAE理论推导

1. 数据准备

使用MNIST数据集

class MNISTDataset(Dataset):
    def __init__(self, data, label, transform=None):
        self.data = data
        self.label = label
        self.transform = transform
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.transform(self.data[idx]), self.label[idx]

2. 模型实现

VAE模型分为encoder和decoder两部分,encoder使用卷积和MLP将输入图像编码为隐变量即一个n维的高斯分布的均值和方差,decoder使用转置卷积将隐变量转换为图像。

class ConvVAE(nn.Module):
    def __init__(self, latent_dim=100):
        super(ConvVAE, self).__init__()
        
        # Encoder: Convolutions to extract features
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),  # (B, 64, 16, 16)
            nn.LeakyReLU(0.1),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), # (B, 128, 8, 8)
            nn.LeakyReLU(0.1),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), # (B, 256, 4, 4)
            nn.LeakyReLU(0.1),
            nn.BatchNorm2d(256)
        )
        
        # Latent space
        self.fc_mu = nn.Linear(256 * 4 * 4, latent_dim)
        self.fc_logvar = nn.Linear(256 * 4 * 4, latent_dim)
        
        # Decoder: Transposed convolutions for upsampling
        self.decoder_input = nn.Linear(latent_dim, 256 * 4 * 4)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), # (B, 64, 8, 8)
            nn.LeakyReLU(0.1),
            nn.BatchNorm2d(128),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), # (B, 32, 16, 16)
            nn.LeakyReLU(0.1),
            nn.BatchNorm2d(64),
            nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1), # (B, 3, 32, 32)
            nn.Sigmoid()  # Output pixel values between 0 and 1
        )

    def encode(self, x):
        x = self.encoder(x)
        x = x.view(x.size(0), -1)  # Flatten
        mu, logvar = self.fc_mu(x), self.fc_logvar(x)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        x = self.decoder_input(z)
        x = x.view(x.size(0), 256, 4, 4)  # Reshape to feature maps
        x = self.decoder(x)
        return x

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon_x = self.decode(z)
        return recon_x, mu, logvar
    

3. 训练过程

VAE的损失函数包括两项:MSE和KL散度,在KL散度前乘一个系数,用于平衡两项的贡献。

def loss_function(recon_x, x, mu, logvar, lamda=1):
    rec_loss = nn.functional.mse_loss(recon_x, x, reduction='sum')
    # rec_loss = nn.BCELoss(reduction='sum')(recon_x, x)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return rec_loss + lamda*KLD, rec_loss.item(), KLD.item()

训练过程就是先用encoder将图像转换为高斯分布的方差和均值,之后利用decoder再复原图像,最后将复原的图像和原始图像做损失,有点像cycle-gan。

for i, (data, _) in enumerate(dataloader):
    data = data.to(device)
    optimizer.zero_grad()
    data = nn.functional.pad(data, (2, 2, 2, 2), value=0)
    recon_data, mu, logvar = model(data)
    loss, rec_loss, kl_loss = loss_function(recon_data, data, mu, logvar, 0.1)
    loss.backward()
    optimizer.step()
    epoch_loss.append(loss.item())
    rec_epoch_loss.append(rec_loss)
    kl_epoch_loss.append(kl_loss)

下面是loss曲线,还是挺平滑的。 image.png

4.生成过程

采样过程或者说生成过程就是在隐变量空间随机生成高斯样本,再利用decoder将高斯样本变为图像,生成过程encoder并不参与。


# 修改输入数据维度为16
test_data = torch.randn(16, 1, latent_dim).to(device)  # 修改第一维度为16

with torch.no_grad():
    model.eval()
    output = model.decode(test_data)
    show_images(output)

采样结果

image.png

有些字很好,有些字很奇怪。

5. 隐空间差值

从原理上看隐变量和样本是有关系的,任意一张图像都有其对应的隐变量,那么如果将两个图像的隐变量差值转换为图像,是不是可以看到两张图像的变化过程呢?

alphas = torch.linspace(0, 1, 8).to(device)
with torch.no_grad():
    model.eval()
    for i, (data, _) in enumerate(test_dataloader):
        data = data.to(device)
        data = nn.functional.pad(data, (2, 2, 2, 2), value=0)
        _, mu, logvar = model(data)
        latent_var = model.reparameterize(mu, logvar)
        for i in range(8):
            if i % 2 == 0:
                continue
            interpolated_z = torch.stack([(1 - a) * latent_var[i] + a * latent_var[i+1] for a in alphas])
            # print(interpolated_z.shape)
            output = model.decode(interpolated_z)
            # print(data[i:i+1].shape)
            grid = vutils.make_grid(output, nrow=8, padding=2, normalize=True)
            plt.figure(figsize=(16,16))
            plt.imshow(grid.permute(1, 2, 0).cpu(), cmap='gray')  # 调整维度顺序
            plt.axis('off')
            plt.show()
            plt.close()
        break

下图是从5变到4的过程

image.png 这是从1变到3的过程 image.png 还是很有意思的,真的是有逐步变化的样子。

完整代码实现见这里