In Sections 3.2.4 and 3.2.5 we considered a Gaussian p(x) in which we parti
tioned the vector x into two subvectors x = (xa, xb) and then found expressions
for the conditional distribution p(xa|xb) and the marginal distribution p(xa). We
noted that the mean of the conditional distribution p(xa|xb) was a linear function of
xb. Here we will suppose that we are given a Gaussian marginal distribution p(x)
and a Gaussian conditional distribution p(y|x) in which p(y|x) has a mean that is a
linear function of x and a covariance that is independent of x. This is an example
of a linear-Gaussian model (Roweis and Ghahramani, 1999). We wish to find the
marginal distribution p(y) and the conditional distribution p(x|y). This is a struc
ture that arises in several types of generative model and it will prove convenient to
derive the general results here.
在3.2.4和3.2.5节中,我们考虑了一个高斯分布 ,其中我们将向量 划分为两个子向量 ,然后找到了条件分布 和边际分布 的表达式。我们注意到条件分布 的均值是 的线性函数。在这里,我们假设给定一个高斯边际分布 和一个高斯条件分布 ,其中 的均值是 的线性函数,协方差与 无关。这是一个线性高斯模型的例子(Roweis 和 Ghahramani, 1999)。我们希望找到边际分布 和条件分布 。
我们将边际分布和条件分布表示为:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal
# 设置均值和协方差矩阵
mean = [0.5, 0.5]
cov = [[0.01, 0.007], [0.007, 0.01]]
# 创建网格
x, y = np.meshgrid(np.linspace(0, 1, 100), np.linspace(0, 1, 100))
pos = np.dstack((x, y))
# 创建多元高斯分布
rv = multivariate_normal(mean, cov)
# 绘制高斯分布轮廓
plt.figure(figsize=(10, 5))
# 子图 (a): 高斯分布轮廓
plt.subplot(1, 2, 1)
plt.contourf(x, y, rv.pdf(pos), levels=10, cmap='viridis')
plt.colorbar()
plt.title('(a) Contours of a Gaussian distribution p(xa, xb)')
plt.xlabel('xa')
plt.ylabel('xb')
# 子图 (b): 边际分布和条件分布
plt.subplot(1, 2, 2)
# 边际分布 p(xa)
xa = np.linspace(0, 1, 100)
marginal_p_xa = multivariate_normal(mean[0], cov[0][0]).pdf(xa)
plt.plot(xa, marginal_p_xa, 'b-', label='p(xa)')
# 条件分布 p(xa | xb = 0.7)
xb = 0.7
conditional_mean = mean[0] + cov[0][1] / cov[1][1] * (xb - mean[1])
conditional_cov = cov[0][0] - cov[0][1] ** 2 / cov[1][1]
conditional_p_xa_given_xb = multivariate_normal(conditional_mean, conditional_cov).pdf(xa)
plt.plot(xa, conditional_p_xa_given_xb, 'r-', label='p(xa | xb = 0.7)')
plt.title('(b) Marginal and Conditional Distributions')
plt.xlabel('xa')
plt.ylabel('Density')
plt.legend()
plt.tight_layout()
plt.show()
书上的图
上述代码结果 colab 运行