SVD解决词分布式表示稀疏性

1,352 阅读3分钟

我正在参加「掘金·启航计划」


这篇文章是接着一文拿捏点互信息(PMI)解决词分布式表示稀疏性问题写的。解决分布式表示稀疏性问题另一个方法是使用奇异值分解(Singular Value Decomposition,SVD)

我把例子搬过来了。还是原来的三个句子及其共现矩阵M

  • 我 喜欢 自然 语言 处理。
  • 我 爱 深度 学习。
  • 我 喜欢 机器 学习。
 我  喜欢  自然  语言  处理  爱  深度  学习  机器  我 0211111213 喜欢 2011100112 自然 1101100001 语言 1110100001 处理 1111000001 爱 1000001101 深度 1000010101 学习 2100011011 机器 1100000101 。 3211111210\begin{array}{ccccccccccc} \hline & \text { 我 } & \text { 喜欢 } & \text { 自然 } & \text { 语言 } & \text { 处理 } & \text { 爱 } & \text { 深度 } & \text { 学习 } & \text { 机器 } & \circ \\ \hline \text { 我 } & 0 & 2 & 1 & 1 & 1 & 1 & 1 & 2 & 1 & 3 \\ \text { 喜欢 } & 2 & 0 & 1 & 1 & 1 & 0 & 0 & 1 & 1 & 2 \\ \text { 自然 } & 1 & 1 & 0 & 1 & 1 & 0 & 0 & 0 & 0 & 1 \\ \text { 语言 } & 1 & 1 & 1 & 0 & 1 & 0 & 0 & 0 & 0 & 1 \\ \text { 处理 } & 1 & 1 & 1 & 1 & 0 & 0 & 0 & 0 & 0 & 1 \\ \text { 爱 } & 1 & 0 & 0 & 0 & 0 & 0 & 1 & 1 & 0 & 1 \\ \text { 深度 } & 1 & 0 & 0 & 0 & 0 & 1 & 0 & 1 & 0 & 1 \\ \text { 学习 } & 2 & 1 & 0 & 0 & 0 & 1 & 1 & 0 & 1 & 1 \\ \text { 机器 } & 1 & 1 & 0 & 0 & 0 & 0 & 0 & 1 & 0 & 1 \\ \text { 。 } & 3 & 2 & 1 & 1 & 1 & 1 & 1 & 2 & 1 & 0 \\ \hline \end{array}

SVD奇异值分解

从矩阵角度来看

image.png

公式为:

M=UΣVTM = U\Sigma V^T

怎么分解的涉及到数学知识,我等就不必深究了。总之简单来讲就是将MM矩阵分解成一个UU,一个Σ\Sigma,一个VTV^T三个矩阵相乘。

  • VVUU都是正交矩阵

    • 正交矩阵: 如果A×AT=E(单位矩阵)A\times A^T = E(单位矩阵),那A就是正交矩阵。
    • 正交矩阵都是方阵。
  • Σ\Sigma是一个半正定m×nm×n阶对角矩阵,其对角线上的值就是MM矩阵分解的奇异值。

  • MM矩阵的形状是m×nm×n,那它的特征值最多为min(m,n)\min(m,n)个。

也就是说奇异值分解最终得到的奇异值只有这个小方阵里的对角线元素。

image.png

若在 Σ{\Sigma} 中仅保留 dd(d<min(m,n))(d<\min(m,n)) 最大的奇异值(UUV{V} 也只保留相应的维度),则被保留的奇异值组成的对角矩阵被称为截断奇异值分解 (Truncated Singular Value Decomposition,TSVD)

从向量角度看

M=σ1u1v1T+σ2u2v2T++σrurvrT其中r=min(m,n)M=\sigma_1 u_1 v_1^{\mathrm{T}}+\sigma_2 u_2 v_2^{\mathrm{T}}+\ldots+\sigma_r u_r v_r^{\mathrm{T}} \quad 其中r = \min(m,n)

其中等式右边每一项前的系数 σ\sigma 就是奇异值, uuvv 分别表示列向量,每一项 uvTu v^{T} 都是秩为 1 的矩阵。奇异值满足 σ1σ2σr>0\sigma_1 \geq \sigma_2 \geq \ldots \geq \sigma_r>0

这样就可以和前边的截断奇异值对上了。前边我们提到,我们可以选择保留多少奇异值。一个矩阵MM分解后最多有min(m,n)\min(m,n)个奇异值。

看一下下图,是借用知乎上的图,从左到右依次是原图、奇异值选1、5、50时候的样子。

当截断奇异值矩阵选择r=1r = 1时,

M=σ1u1v1TM' = \sigma_1 u_1 v_1^{\mathrm{T}}

当截断奇异值矩阵选择r=5r = 5时,

M=σ1u1v1T+σ2u2v2T+σ3u3v3T+σ4u4v4T+σ5u5v5TM' = \sigma_1 u_1 v_1^{\mathrm{T}} + \sigma_2 u_2 v_2^{\mathrm{T}} + \sigma_3 u_3 v_3^{\mathrm{T}} +\sigma_4 u_4 v_4^{\mathrm{T}} +\sigma_5 u_5 v_5^{\mathrm{T}}

当截断奇异值矩阵选择r=50r = 50时,

M=σ1u1v1T+σ2u2v2T+...+σ50u50v50TM' = \sigma_1 u_1 v_1^{\mathrm{T}}+ \sigma_2 u_2 v_2^{\mathrm{T}} + ... + \sigma_{50} u_{50} v_{50}^{\mathrm{T}}

随着项数逐渐增大,MM'逐渐还原MM,就像泰勒展开式一样,项数越多越接近原图。

image.png

截断奇异值分解实际上是对矩阵 MM 的低秩近似。通过截断奇异值分解所得到的矩阵UU中的每一行,则为相应词的dd维向量表示, 该向量一般认为其具有连续、低维和稠密的性质。由于UU的各列相互正交,因此可以认为词表示的每一维表达了该词的一种独立的“潜在语义”,所以这种方法也被称作潜在语义分析(Latent Semantic Analysis,LSA)。另外,ΣVTΣV^T的每一列也可以作为相应上下文的向量表示。

注意UUΣVT\Sigma V^T是不相等的,相当于两套表示,我们在这选择UU作为MM的稠密表示。

代码

不管是NumPy还是PyTorch 中都自带了SVD分解。

直接使用.linalg.svd()方法即可。

import torch

M = torch.Tensor([[0, 2, 1, 1, 1, 1, 1, 2, 1, 3],
                  [2, 0, 1, 1, 1, 0, 0, 1, 1, 2],
                  [1, 1, 0, 1, 1, 0, 0, 0, 0, 1],
                  [1, 1, 1, 0, 1, 0, 0, 0, 0, 1],
                  [1, 1, 1, 1, 0, 0, 0, 0, 0, 1],
                  [1, 0, 0, 0, 0, 0, 1, 1, 0, 1],
                  [1, 0, 0, 0, 0, 1, 0, 1, 0, 1],
                  [2, 1, 0, 0, 0, 1, 1, 0, 1, 1],
                  [1, 1, 0, 0, 0, 0, 0, 1, 0, 1],
                  [3, 2, 1, 1, 1, 1, 1, 2, 1, 0]])

u, s, v = torch.linalg.svd(M)

print((u @ torch.diag(s) @ v).int()) # 乘起来

torch.set_printoptions(precision=3, sci_mode=False)
print(u)  # M 的 稠密表示

结果:

乘起来可以看到SVD之后的结果还能再拼回去,不是在骗你。 image.png

tensor([[ -0.500, 0.724, 0.351, 0.253, -0.025, 0.193,  0.000, -0.000, -0.000, 0.017],  

\quad\quad\quad[ -0.384, 0.052, -0.463, -0.519, 0.394, 0.363,  0.000, -0.000, -0.000, 0.282],  

\quad\quad\quad[ -0.218, 0.036, -0.398, 0.156, -0.168, -0.182,  0.072, -0.138, -0.802, -0.200],  

\quad\quad\quad[ -0.218, 0.036, -0.398, 0.156, -0.168, -0.182,  -0.501, 0.586, 0.270, -0.200],  

\quad\quad\quad[ -0.218, 0.036, -0.398, 0.156, -0.168, -0.182,  0.429, -0.448, 0.531, -0.200],  

\quad\quad\quad[ -0.183, -0.010, 0.228, -0.419, 0.070, -0.258,  0.529, 0.468, -0.033, -0.409],  

\quad\quad\quad[ -0.183, -0.010, 0.228, -0.419, 0.070, -0.258,  -0.529, -0.468, 0.033, -0.409],  

\quad\quad\quad[ -0.293, -0.300, 0.152, -0.252, -0.789, 0.302,  0.000, -0.000, 0.000, 0.155],  

\quad\quad\quad[ -0.208, 0.015, 0.087, -0.079, -0.027, -0.708,  0.000, 0.000, 0.000, 0.664],  

\quad\quad\quad[ -0.515, -0.615, 0.226, 0.416, 0.356, 0.070,  -0.000, 0.000, -0.000, -0.038]])