贝叶斯神经网络BNN原理推导及python实现

8,072 阅读6分钟

1. 简介

贝叶斯神经网络不同于一般的神经网络,其权重参数是随机变量,而非确定的值。如下图所示:

在这里插入图片描述

也就是说,和传统的神经网络用交叉熵,mse等损失函数去拟合标签值相反,贝叶斯神经网络拟合后验分布。

这样做的好处,就是降低过拟合。

2. BNN模型

BNN 不同于 DNN,可以对预测分布进行学习,不仅可以给出预测值,而且可以给出预测的不确定性。这对于很多问题来说非常关键,比如:机器学习中著名的 Exploration & Exploitation (EE)的问题,在强化学习问题中,agent 是需要利用现有知识来做决策还是尝试一些未知的东西;实验设计问题中,用贝叶斯优化来调超参数,选择下一个点是根据当前模型的最优值还是利用探索一些不确定性较高的空间。比如:异常样本检测,对抗样本检测等任务,由于 BNN 具有不确定性量化能力,所以具有非常强的鲁棒性

概率建模: 在这里插入图片描述 在这里,选择似然分布的共轭分布,这样后验可以分析计算。 比如,beta分布的先验和伯努利分布的似然,会得到服从beta分布的后验。 在这里插入图片描述

由于共轭分布,需要对先验分布进行约束。因此,我们尝试使用采用和变分推断来近似后验分布。


神经网络: 使用全连接网络来拟合数据,相当于使用多个全连接网络。 但是神经网络容易过拟合,泛化性差;并且对预测的结果无法给出置信度。

BNN: 把概率建模和神经网络结合起来,并能够给出预测结果的置信度。

先验用来描述关键参数,并作为神经网络的输入。神经网络的输出用来描述特定的概率分布的似然。通过采样或者变分推断来计算后验分布。 同时,和神经网络不同,权重 W 不再是一个确定的值,而是一个概率分布。


BNN建模如下:

假设 NN 的网络参数为 WWp(W)p(W) 是参数的先验分布,给定观测数据 D=X,YD={X,Y},这里 XX 是输入数据,YY 是标签数据。BNN 希望给出以下的分布:

也就是我们预测值为:

P(YX,D)=P(YX,W)P(WD)dW1P\left(Y^{\star} | X^{\star}, D\right)=\int P\left(Y^{\star} | X^{\star}, W\right) P(W | D) d W (1)

由于,WW是随机变量,因此,我们的预测值也是个随机变量。

其中:

P(WD)=P(W)P(DW)P(D)2P(W | D)=\frac{P(W) P(D | W)}{P(D)} (2)

这里 P(WD)P(W|D) 是后验分布,P(DW)P(D|W) 是似然函数,P(D)P(D) 是边缘似然。

从公式(1)中可以看出,用 BNN 对数据进行概率建模并预测的核心在于做高效近似后验推断,而 变分推断 VI 或者采样是一个非常合适的方法。

如果采样的话: 我们通过采样后验分布P(WD)P(W \vert \mathcal{D}) 来评估 P(WD)P(W \vert \mathcal{D}) , 每个样本计算 f(Xw)f(X \vert w), 其中 f 是我们的神经网络。

正是我们的输出是一个分布,而不是一个值,我们可以估计我们预测的不确定度。

3. 基于变分推断的BNN训练

如果直接采样后验概率 p(WD)p(W|D) 来评估 p(YX,D)p(Y|X, D)的话,存在后验分布多维的问题,而变分推断的思想是使用简单分布去近似后验分布。

表示θ=(μ,σ)\theta = (\mu, \sigma), 每个权重 wiw_i 从正态分布(μi,σi)(\mu_i, \sigma_i) 中采样。

希望 q(wθ)q(w \vert \theta)P(wD)P(w \vert \mathcal{D}) 相近,并使用 KL 散度来度量这两个分布的距离。 也就是优化:

θ=argminθ KL[q(wθ)P(wD)]  (3)\theta^* = \underset{\theta}{\mathrm{argmin}} \text{ KL}\left[q(w \vert \theta) \vert \vert P(w \vert \mathcal{D})\right] \; (3)

进一步推导:

θ=argminθ KL[q(wθ)P(wD)]=argminθ Eq(wθ)[log[q(wθ)P(wD)]](definition of KL divegence)=argminθ Eq(wθ)[log[q(wθ)P(D)P(Dw)P(w)]](Bayes Theorem)=argminθ Eq(wθ)[log[q(wθ)P(Dw)P(w)]](Drop P(D) because it doesn’t depend on θ)  4\begin{array}{l} \theta^* &= \underset{\theta}{\mathrm{argmin}} \text{ KL}\left[q(w \vert \theta) \vert \vert P(w \vert \mathcal{D})\right] & \\\\ &= \underset{\theta}{\mathrm{argmin}} \text{ }\mathbb{E}_{q(w \vert \theta)}\left[ \log\left[\frac{ q(w \vert \theta) }{P( w \vert \mathcal{D})}\right]\right] & \text{(definition of KL divegence)} \\\\ &= \underset{\theta}{\mathrm{argmin}} \text{ }\mathbb{E}_{q(w \vert \theta)}\left[ \log\left[\frac{ q(w \vert \theta)P(\mathcal{D}) }{P( \mathcal{D} \vert w)P(w)}\right]\right] & \text{(Bayes Theorem)} \\\\ &= \underset{\theta}{\mathrm{argmin}} \text{ }\mathbb{E}_{q(w \vert \theta)}\left[ \log\left[\frac{ q(w \vert \theta) }{P( \mathcal{D} \vert w)P(w)}\right]\right] & \text{(Drop }P(\mathcal{D})\text{ because it doesn't depend on } \theta) \end{array} \;(4)

公式中, q(wθ)q(w|\theta) 表示给定正态分布的参数后,权重参数的分布; P(Dw)P(D|w) 表示给定网络参数后,观测数据的似然; P(w)P(w) 表示权重的先验,这部分可以作为模型的正则化。

并且使用

L=Eq(wθ)[log[q(wθ)P(Dw)P(w)]]  (5)\mathcal{L} = - \mathbb{E}_{q(w \vert \theta)}\left[ \log\left[\frac{ q(w \vert \theta) }{P( \mathcal{D} \vert w)P(w)}\right]\right] \;(5)

来表示变分下界ELBO, 也就是公式(4)等价于最大化ELBO:

L=ilogq(wiθi)ilogP(wi)jlogP(yjw,xj)  (6) \mathcal{L} = \sum_i \log q(w_i \vert \theta_i) - \sum_i \log P(w_i) - \sum_j \log P(y_j \vert w, x_j) \;(6)

其中,D={(x,y)}D =\{ (x, y)\}

我们需要对公式(4)中的期望进行求导,但是,这里,我们使用对权重进行重参数的技巧:

wi=μi+σi×ϵi  (7)w_i = \mu_i + \sigma_i \times \epsilon_i \; (7)

其中, ϵiN(0,1)\epsilon_i \sim \mathcal{N}(0,1).

于是,用 ϵ\epsilon代 替 ww 后有:

θEq(ϵ)[log[q(wθ)P(Dw)P(w)]]=Eq(ϵ)[θlog[q(wθ)P(Dw)P(w)]]  (8)\frac{\partial}{\partial \theta}\mathbb{E}_{q(\epsilon)}\left[ \log\left[\frac{ q(w \vert \theta) }{P( \mathcal{D} \vert w)P(w)}\right]\right] =\mathbb{E}_{q(\epsilon)}\left[ \frac{\partial}{\partial \theta}\log\left[\frac{ q(w \vert \theta) }{P( \mathcal{D} \vert w)P(w)}\right]\right] \; (8)

也就是说,我们可以通过 多个不同的 ϵN(0,1)\epsilon \sim \mathcal{N}(0,1) ,求取θlog[q(wθ)P(Dw)P(w)]\frac{\partial}{\partial \theta}\log\left[\frac{ q(w \vert \theta) }{P( \mathcal{D} \vert w)P(w)}\right] 的平均值,来近似 KL 散度对 θ\theta 的求导。

此外,除了对 ww 进行重采样之外,为了保证 θ\theta 参数取值范围包含这个实轴,对 σ\sigma 进行重采样,可以令,

σ=log(1+eρ)      (9)\sigma = \log (1 + e^{\rho}) \;\;\; (9)

然后,θ=(μ,ρ)\theta = (\mu, \rho),这里的 θ\theta 已经和原来定义的θ=(μ,σ)\theta = (\mu, \sigma) 不一样了。

4. BNN实践

算法:

  1. N(μ,log(1+eρ))N(\mu, log(1+e^\rho)) 中采样,获得 ww
  2. 分别计算 logq(wθ)\log q(w|\theta)logp(w)\log p(w)logp(yw,x)\log p(y|w,x). 其中,计算 logp(yw,x)\log p(y|w,x) 实际计算 logp(yypred)\log p(y|y_{pred}), ypred=wxy_{pred} = w*x. 也就可以得到 L=ilogq(wiθi)ilogP(wi)jlogP(yjw,xj)\mathcal{L} = \sum_i \log q(w_i \vert \theta_i) - \sum_i \log P(w_i) - \sum_j \log P(y_j \vert w, x_j)
  3. 重复更新参数θ=θαθL\theta’ = \theta -\alpha \nabla_\theta \mathcal{L}.

Pytorch实现:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Normal
import numpy as np
from scipy.stats import norm
import matplotlib.pyplot as plt

class Linear_BBB(nn.Module):
    """
        Layer of our BNN.
    """
    def __init__(self, input_features, output_features, prior_var=1.):
        """
            Initialization of our layer : our prior is a normal distribution
            centered in 0 and of variance 20.
        """
        # initialize layers
        super().__init__()
        # set input and output dimensions
        self.input_features = input_features
        self.output_features = output_features

        # initialize mu and rho parameters for the weights of the layer
        self.w_mu = nn.Parameter(torch.zeros(output_features, input_features))
        self.w_rho = nn.Parameter(torch.zeros(output_features, input_features))

        #initialize mu and rho parameters for the layer's bias
        self.b_mu =  nn.Parameter(torch.zeros(output_features))
        self.b_rho = nn.Parameter(torch.zeros(output_features))        

        #initialize weight samples (these will be calculated whenever the layer makes a prediction)
        self.w = None
        self.b = None

        # initialize prior distribution for all of the weights and biases
        self.prior = torch.distributions.Normal(0,prior_var)

    def forward(self, input):
        """
          Optimization process
        """
        # sample weights
        w_epsilon = Normal(0,1).sample(self.w_mu.shape)
        self.w = self.w_mu + torch.log(1+torch.exp(self.w_rho)) * w_epsilon

        # sample bias
        b_epsilon = Normal(0,1).sample(self.b_mu.shape)
        self.b = self.b_mu + torch.log(1+torch.exp(self.b_rho)) * b_epsilon

        # record log prior by evaluating log pdf of prior at sampled weight and bias
        w_log_prior = self.prior.log_prob(self.w)
        b_log_prior = self.prior.log_prob(self.b)
        self.log_prior = torch.sum(w_log_prior) + torch.sum(b_log_prior)

        # record log variational posterior by evaluating log pdf of normal distribution defined by parameters with respect at the sampled values
        self.w_post = Normal(self.w_mu.data, torch.log(1+torch.exp(self.w_rho)))
        self.b_post = Normal(self.b_mu.data, torch.log(1+torch.exp(self.b_rho)))
        self.log_post = self.w_post.log_prob(self.w).sum() + self.b_post.log_prob(self.b).sum()

        return F.linear(input, self.w, self.b)

class MLP_BBB(nn.Module):
    def __init__(self, hidden_units, noise_tol=.1,  prior_var=1.):

        # initialize the network like you would with a standard multilayer perceptron, but using the BBB layer
        super().__init__()
        self.hidden = Linear_BBB(1,hidden_units, prior_var=prior_var)
        self.out = Linear_BBB(hidden_units, 1, prior_var=prior_var)
        self.noise_tol = noise_tol # we will use the noise tolerance to calculate our likelihood

    def forward(self, x):
        # again, this is equivalent to a standard multilayer perceptron
        x = torch.sigmoid(self.hidden(x))
        x = self.out(x)
        return x

    def log_prior(self):
        # calculate the log prior over all the layers
        return self.hidden.log_prior + self.out.log_prior

    def log_post(self):
        # calculate the log posterior over all the layers
        return self.hidden.log_post + self.out.log_post

    def sample_elbo(self, input, target, samples):
        # we calculate the negative elbo, which will be our loss function
        #initialize tensors
        outputs = torch.zeros(samples, target.shape[0])
        log_priors = torch.zeros(samples)
        log_posts = torch.zeros(samples)
        log_likes = torch.zeros(samples)
        # make predictions and calculate prior, posterior, and likelihood for a given number of samples
        for i in range(samples):
            outputs[i] = self(input).reshape(-1) # make predictions
            log_priors[i] = self.log_prior() # get log prior
            log_posts[i] = self.log_post() # get log variational posterior
            log_likes[i] = Normal(outputs[i], self.noise_tol).log_prob(target.reshape(-1)).sum() # calculate the log likelihood
        # calculate monte carlo estimate of prior posterior and likelihood
        log_prior = log_priors.mean()
        log_post = log_posts.mean()
        log_like = log_likes.mean()
        # calculate the negative elbo (which is our loss function)
        loss = log_post - log_prior - log_like
        return loss

def toy_function(x):
    return -x**4 + 3*x**2 + 1

# toy dataset we can start with
x = torch.tensor([-2, -1.8, -1, 1, 1.8, 2]).reshape(-1,1)
y = toy_function(x)

net = MLP_BBB(32, prior_var=10)
optimizer = optim.Adam(net.parameters(), lr=.1)
epochs = 2000
for epoch in range(epochs):  # loop over the dataset multiple times
    optimizer.zero_grad()
    # forward + backward + optimize
    loss = net.sample_elbo(x, y, 1)
    loss.backward()
    optimizer.step()
    if epoch % 10 == 0:
        print('epoch: {}/{}'.format(epoch+1,epochs))
        print('Loss:', loss.item())
print('Finished Training')


# samples is the number of "predictions" we make for 1 x-value.
samples = 100
x_tmp = torch.linspace(-5,5,100).reshape(-1,1)
y_samp = np.zeros((samples,100))
for s in range(samples):
    y_tmp = net(x_tmp).detach().numpy()
    y_samp[s] = y_tmp.reshape(-1)
plt.plot(x_tmp.numpy(), np.mean(y_samp, axis = 0), label='Mean Posterior Predictive')
plt.fill_between(x_tmp.numpy().reshape(-1), np.percentile(y_samp, 2.5, axis = 0), np.percentile(y_samp, 97.5, axis = 0), alpha = 0.25, label='95% Confidence')
plt.legend()
plt.scatter(x, toy_function(x))
plt.title('Posterior Predictive')
plt.show()

这里是重复计算100次的平均值和100次平均值的97.5%大和2.5%小的区域线图(即置信度95%)。 在这里插入图片描述


参考:

  1. 变分推断;
  2. Weight Uncertainty in Neural Networks Tutorial;
  3. Bayesian Neural Networks;