理论推导见VAE理论推导
1. 数据准备
使用MNIST数据集
class MNISTDataset(Dataset):
def __init__(self, data, label, transform=None):
self.data = data
self.label = label
self.transform = transform
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.transform(self.data[idx]), self.label[idx]
2. 模型实现
VAE模型分为encoder和decoder两部分,encoder使用卷积和MLP将输入图像编码为隐变量即一个n维的高斯分布的均值和方差,decoder使用转置卷积将隐变量转换为图像。
class ConvVAE(nn.Module):
def __init__(self, latent_dim=100):
super(ConvVAE, self).__init__()
# Encoder: Convolutions to extract features
self.encoder = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1), # (B, 64, 16, 16)
nn.LeakyReLU(0.1),
nn.BatchNorm2d(64),
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), # (B, 128, 8, 8)
nn.LeakyReLU(0.1),
nn.BatchNorm2d(128),
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), # (B, 256, 4, 4)
nn.LeakyReLU(0.1),
nn.BatchNorm2d(256)
)
# Latent space
self.fc_mu = nn.Linear(256 * 4 * 4, latent_dim)
self.fc_logvar = nn.Linear(256 * 4 * 4, latent_dim)
# Decoder: Transposed convolutions for upsampling
self.decoder_input = nn.Linear(latent_dim, 256 * 4 * 4)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), # (B, 64, 8, 8)
nn.LeakyReLU(0.1),
nn.BatchNorm2d(128),
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), # (B, 32, 16, 16)
nn.LeakyReLU(0.1),
nn.BatchNorm2d(64),
nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1), # (B, 3, 32, 32)
nn.Sigmoid() # Output pixel values between 0 and 1
)
def encode(self, x):
x = self.encoder(x)
x = x.view(x.size(0), -1) # Flatten
mu, logvar = self.fc_mu(x), self.fc_logvar(x)
return mu, logvar
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z):
x = self.decoder_input(z)
x = x.view(x.size(0), 256, 4, 4) # Reshape to feature maps
x = self.decoder(x)
return x
def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
recon_x = self.decode(z)
return recon_x, mu, logvar
3. 训练过程
VAE的损失函数包括两项:MSE和KL散度,在KL散度前乘一个系数,用于平衡两项的贡献。
def loss_function(recon_x, x, mu, logvar, lamda=1):
rec_loss = nn.functional.mse_loss(recon_x, x, reduction='sum')
# rec_loss = nn.BCELoss(reduction='sum')(recon_x, x)
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return rec_loss + lamda*KLD, rec_loss.item(), KLD.item()
训练过程就是先用encoder将图像转换为高斯分布的方差和均值,之后利用decoder再复原图像,最后将复原的图像和原始图像做损失,有点像cycle-gan。
for i, (data, _) in enumerate(dataloader):
data = data.to(device)
optimizer.zero_grad()
data = nn.functional.pad(data, (2, 2, 2, 2), value=0)
recon_data, mu, logvar = model(data)
loss, rec_loss, kl_loss = loss_function(recon_data, data, mu, logvar, 0.1)
loss.backward()
optimizer.step()
epoch_loss.append(loss.item())
rec_epoch_loss.append(rec_loss)
kl_epoch_loss.append(kl_loss)
下面是loss曲线,还是挺平滑的。
4.生成过程
采样过程或者说生成过程就是在隐变量空间随机生成高斯样本,再利用decoder将高斯样本变为图像,生成过程encoder并不参与。
# 修改输入数据维度为16
test_data = torch.randn(16, 1, latent_dim).to(device) # 修改第一维度为16
with torch.no_grad():
model.eval()
output = model.decode(test_data)
show_images(output)
采样结果
有些字很好,有些字很奇怪。
5. 隐空间差值
从原理上看隐变量和样本是有关系的,任意一张图像都有其对应的隐变量,那么如果将两个图像的隐变量差值转换为图像,是不是可以看到两张图像的变化过程呢?
alphas = torch.linspace(0, 1, 8).to(device)
with torch.no_grad():
model.eval()
for i, (data, _) in enumerate(test_dataloader):
data = data.to(device)
data = nn.functional.pad(data, (2, 2, 2, 2), value=0)
_, mu, logvar = model(data)
latent_var = model.reparameterize(mu, logvar)
for i in range(8):
if i % 2 == 0:
continue
interpolated_z = torch.stack([(1 - a) * latent_var[i] + a * latent_var[i+1] for a in alphas])
# print(interpolated_z.shape)
output = model.decode(interpolated_z)
# print(data[i:i+1].shape)
grid = vutils.make_grid(output, nrow=8, padding=2, normalize=True)
plt.figure(figsize=(16,16))
plt.imshow(grid.permute(1, 2, 0).cpu(), cmap='gray') # 调整维度顺序
plt.axis('off')
plt.show()
plt.close()
break
下图是从5变到4的过程
这是从1变到3的过程
还是很有意思的,真的是有逐步变化的样子。