注意力机制中的Nadaraya-Watson核回归

343 阅读8分钟

1. 引言

在机器学习中,注意力机制是一种强大的工具,能够帮助模型在处理数据时更加关注重要的部分。本文将介绍一种经典的注意力机制——Nadaraya-Watson核回归。通过这篇文章,你将了解到什么是注意力机制,以及如何通过Nadaraya-Watson核回归来实现它。

2. 什么是注意力机制?

注意力机制的核心思想是:在处理数据时,模型应该更加关注那些对当前任务更重要的部分。举个例子,假设你在阅读一篇文章时,你会更加关注那些与文章主题相关的句子,而忽略那些不相关的部分。注意力机制就是让机器学习模型也能做到这一点。

在数学上,注意力机制可以表示为对输入数据的加权平均。具体来说,给定一组输入数据,模型会根据每个数据的重要性分配一个权重,然后根据这些权重对数据进行加权平均,得到最终的输出。

3. Nadaraya-Watson核回归

3.1 生成数据集

为了更好地理解Nadaraya-Watson核回归,我们首先需要生成一个简单的数据集。假设我们有一个非线性函数:

y=2sin(x)+x0.8+ϵy = 2 \sin(x) + x^{0.8} + \epsilon

我们在这个函数的基础上加入一些噪声,生成一组训练数据和测试数据。

# -*- coding: utf-8 -*-
import torch
import d2l

n_train = 50  # 训练样本数
x_train, _ = torch.sort(torch.rand(n_train) * 5)  # 排序后的训练样本


def f(x):
    return 2 * torch.sin(x) + x ** 0.8


y_train = f(x_train) + torch.normal(0.0, 0.5, (n_train,))  # 训练样本的输出

x_test = torch.arange(0, 5, 0.1)  # 测试样本
y_truth = f(x_test)  # 测试样本的真实输出
n_test = len(x_test)  # 测试样本数

下面的函数将绘制所有的训练样本(样本由圆圈表示), 不带噪声项的真实数据生成函数 (标记为“Truth”), 以及学习得到的预测函数(标记为“Pred”)。

def plot_kernel_reg(y_hat):
    """
    绘制 Nadaraya-Watson 核回归的预测结果与真实值的对比图。

    参数:
    y_hat (Tensor 或 ndarray): 预测值,与测试输入 x_test 对应。

    说明:
    - 该函数首先绘制测试数据点 (x_test) 的真实值 (y_truth) 与模型预测值 (y_hat)。
    - 然后,在图中以散点的形式标注训练数据点 (x_train, y_train)。
    - `d2l.plot` 用于绘制曲线,`d2l.plt.plot` 用于绘制散点。
    """
    d2l.plot(x_test, [y_truth, y_hat], 'x', 'y',
             legend=['Truth', 'Pred'], xlim=[0, 5], ylim=[-1, 5])
    # 'o' 是 matplotlib.pyplot.plot() 的 marker 参数,用于指定数据点的标记样式。
    # 具体来说,'o' 表示用圆形(circle)标记数据点。
    d2l.plt.plot(x_train, y_train, 'o', alpha=0.5)

3.2 平均汇聚

在介绍Nadaraya-Watson核回归之前,我们先来看一个简单的估计器——平均汇聚。平均汇聚的思想很简单:对所有训练样本的输出值取平均,作为预测值。

y^=1ni=1nyi\hat{y} = \frac{1}{n} \sum_{i=1}^n y_i
y_hat = torch.repeat_interleave(y_train.mean(), n_test)
plot_kernel_reg(y_hat)
  • torch.repeat_interleave(input, repeats, dim=None) 是 PyTorch 中的一个函数,用于重复张量中的元素。它的作用是将张量的每个元素重复指定的次数。
    • input:输入的张量。
    • repeats:每个元素需要重复的次数。可以是一个整数(所有元素重复相同次数)或一个张量(每个元素重复不同的次数)。
    • dim:沿着哪个维度进行重复。如果为 None,则会将输入张量展平后重复。

10_2_nwmean.png

从图中可以看出,平均汇聚的预测结果与真实函数相差较大,这说明平均汇聚忽略了输入数据的分布。

3.3 非参数注意力汇聚

为了改进平均汇聚的不足,Nadaraya和Watson提出了一种基于核函数的回归方法,称为Nadaraya-Watson核回归。

3.3.1 什么是非参数注意力汇聚?

非参数注意力汇聚是一种基于加权平均的预测方法。它的核心思想是:对于一个新的输入(查询),我们根据它与训练数据(键)的相似性,给每个训练数据的输出(值)分配一个权重,然后用这些权重对输出进行加权平均,得到最终的预测结果。

简单来说,就是越相似的输入,对应的输出对预测结果的贡献越大

3.3.2 非参数注意力汇聚的公式
y^(x)=i=1nK(xxi)yij=1nK(xxj)\hat{y}(x) = \frac{\sum_{i=1}^n K(x - x_i) y_i}{\sum_{j=1}^n K(x - x_j)}

其中:

  • xx 是新的输入(查询)。
  • xix_i 是训练数据中的输入(键)。
  • yiy_i 是训练数据中的输出(值)。
  • K(xxi)K(x - x_i) 是核函数,用来衡量 xxxix_i 的相似性。
3.3.3 公式的通俗解释
(1)核函数 K(xxi)K(x−x_i)

核函数的作用是计算输入 xx 和训练数据 xix_i 的相似性。常用的核函数是高斯核:

K(u)=12πexp(u22)K(u) = \frac{1}{\sqrt{2\pi}} \exp\left(-\frac{u^2}{2}\right)
  • xxxix_i 越接近时,K(xxi)K(x − x_i) 的值越大(相似性越高)。
  • xxxix_i 越远时,K(xxi)K(x − x_i) 的值越小(相似性越低)。

将高斯核代入Nadaraya-Watson核回归公式,可以得到:

y^(x)=i=1nexp(12(xxi)2)yij=1nexp(12(xxj)2)=i=1nexp(12(xxi)2)j=1nexp(12(xxj)2)yi=i=1nsoftmax(12(xxi)2)yi\begin{align} \hat{y}(x) &= \frac{\sum_{i=1}^n \exp\left(-\frac{1}{2} (x - x_i)^2\right) y_i}{\sum_{j=1}^n \exp\left(-\frac{1}{2} (x - x_j)^2\right)} \\ &= \sum_{i=1}^n \frac{\exp\left(-\frac{1}{2} (x - x_i)^2\right)}{\sum_{j=1}^n \exp\left(-\frac{1}{2} (x - x_j)^2\right)} y_i \\ &= \sum_{i=1}^n \text{softmax}(-\frac{1}{2} (x - x_i)^2) y_i \end{align}
(2)加权平均
  • 分子部分 i=1nK(xxi)yi\sum_{i=1}^n K(x - x_i) y_i:对每个训练数据的输出 yiy_i 进行加权,权重是 K(xxi)K(x − x_i)
  • 分母部分 j=1nK(xxj)\sum_{j=1}^n K(x - x_j):对所有权重进行归一化,确保权重之和为 1。

最终的结果 y^(x)\hat{y}(x) 就是所有训练数据输出的加权平均,权重由输入 xx 和训练数据 xix_i 的相似性决定。

3.3.4 举个例子

假设我们有以下训练数据:

  • 输入 xix_i[1,2,3,4][1,2,3,4]
  • 输出 yiy_i[2,4,6,8][2,4,6,8]

现在有一个新的输入 x=2.5x=2.5,我们想预测它的输出 y^(x)\hat{y}(x)

步骤 1:计算相似性(核函数)

使用高斯核计算 xx 和每个 xix_i 的相似性:

  • K(2.51)=exp(12(2.51)2)0.3247K(2.5−1)=\exp\left(-\frac{1}{2} (2.5 - 1)^2\right) \approx 0.3247
  • K(2.52)=exp(12(2.52)2)0.8825K(2.5−2)=\exp\left(-\frac{1}{2} (2.5 - 2)^2\right) \approx 0.8825
  • K(2.53)=exp(12(2.53)2)0.8825K(2.5−3)=\exp\left(-\frac{1}{2} (2.5 - 3)^2\right) \approx 0.8825
  • K(2.54)=exp(12(2.54)2)0.3247K(2.5−4)=\exp\left(-\frac{1}{2} (2.5 - 4)^2\right) \approx 0.3247
步骤 2:计算加权平均
  • 分子:0.3247×2+0.8825×4+0.8825×6+0.3247×812.07200.3247×2+0.8825×4+0.8825×6+0.3247×8≈12.0720
  • 分母:0.3247+0.8825+0.8825+0.32472.41440.3247+0.8825+0.8825+0.3247≈2.4144
  • 预测值:y^(2.5)=12.07202.41445.00\hat{y}(2.5)=\frac{12.0720}{2.4144} \approx 5.00
3.3.5 为什么叫“非参数”?

“非参数”意味着模型没有可学习的参数。核函数是固定的(比如高斯核),模型直接根据输入和训练数据的相似性进行计算,而不需要通过训练来调整参数。

from torch import nn

# X_repeat的形状:(n_test,n_train), 每一行都包含着相同的测试输入
X_repeat = x_test.repeat_interleave(n_train).reshape((-1, n_train))
print(X_repeat.shape)  # torch.Size([50, 50])
# attention_weights的形状:(n_test,n_train), 每一行都包含着要在给定的每个查询的值(y_train)之间分配的注意力权重
attention_weights = nn.functional.softmax(-(X_repeat - x_train) ** 2 / 2, dim=1)
print(attention_weights.shape)  # torch.Size([50, 50])
# y_hat的每个元素都是值的加权平均值,其中的权重是注意力权重
y_hat = torch.matmul(attention_weights, y_truth)
plot_kernel_reg(y_hat)

10_2_非参数注意力汇聚.png

从图中可以看出,Nadaraya-Watson核回归的预测结果比平均汇聚更加接近真实函数。

3.4 带参数注意力汇聚

在非参数注意力汇聚中,核函数是固定的(比如高斯核),模型没有可学习的参数。为了增强模型的表达能力,我们可以引入可学习的参数,让模型能够自动调整核函数的形状。这就是带参数注意力汇聚的核心思想。

3.4.1 公式

带参数注意力汇聚的公式如下:

y^(x)=i=1nexp(12((xxi)w)2)j=1nexp(12((xxj)w)2)yi=i=1nsoftmax(12((xxi)w)2)yi\begin{align} \hat{y}(x) &= \sum_{i=1}^n \frac{\exp\left(-\frac{1}{2} ((x - x_i) w)^2\right)}{\sum_{j=1}^n \exp\left(-\frac{1}{2} ((x - x_j) w)^2\right)} y_i \\ &= \sum_{i=1}^n \text{softmax}(-\frac{1}{2} ((x - x_i) w)^2) y_i \end{align}

其中:

  • xx 是新的输入(查询)。
  • xix_i 是训练数据中的输入(键)。
  • yiy_i 是训练数据中的输出(值)。
  • ww 是一个可学习的参数,用于调整核函数的形状。
3.4.2 代码实现

以下是带参数注意力汇聚的代码实现,并附上详细注释:

# 定义带参数的注意力汇聚模型
class NWKernelRegression(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        # 定义一个可学习的参数 w,初始值为随机数
        self.w = nn.Parameter(torch.rand((1,), requires_grad=True))

    def forward(self, queries, keys, values):
        """
        前向传播函数
        - queries: 查询(新输入)
        - keys: 键(训练数据的输入)
        - values: 值(训练数据的输出)
        """
        # 将 queries 重复,使其形状与 keys 匹配
        queries = queries.repeat_interleave(keys.shape[1]).reshape((-1, keys.shape[1]))
        # 计算注意力权重,使用带参数的高斯核
        attention_weights = nn.functional.softmax(-1 / 2 * ((queries - keys) * self.w) ** 2, dim=1)
        # 对 values 进行加权平均,得到预测值
        return torch.bmm(attention_weights.unsqueeze(1), values.unsqueeze(-1)).reshape(-1)

代码注释:

  1. self.w = nn.Parameter(torch.rand((1,), requires_grad=True))

    • 定义一个可学习的参数 ww,初始值为随机数。
    • requires_grad=True 表示这个参数会在训练过程中被优化。
  2. queries.repeat_interleave(keys.shape[1]).reshape((-1, keys.shape[1]))

    • 将查询(queries)重复,使其形状与键(keys)匹配。
    • 例如,如果 queries 的形状是 (n_test,)keys 的形状是 (n_test, n_train),那么 queries 会被重复 n_train 次。
  3. nn.functional.softmax(-((queries - keys) * self.w)**2 / 2, dim=1)

    • 计算注意力权重,使用带参数的高斯核。
    • (queries - keys) 计算查询和键之间的距离。
    • self.w 是可学习的参数,用于调整核函数的形状。
    • softmax 函数将权重归一化,使其和为 1。
  4. torch.bmm(self.attention_weights.unsqueeze(1), values.unsqueeze(-1)).reshape(-1)

    • 对值(values)进行加权平均,得到预测值。
    • torch.bmm 是批量矩阵乘法,用于高效计算加权平均。
    • unsqueeze(dim) 的主要作用是为张量增加一个维度,通常用于调整张量的形状,以满足某些操作的要求。
      • 在矩阵乘法中,需要对齐维度。
      • 在卷积操作中,需要增加通道维度。
      • 在注意力机制中,需要调整权重的形状。
3.4.3 训练模型

生成训练数据的代码:

# X_tile的形状:(n_train,n_train),每一行都包含着相同的训练输入
X_tile = x_train.repeat((n_train, 1))
# Y_tile的形状:(n_train,n_train),每一行都包含着相同的训练输出
Y_tile = y_train.repeat((n_train, 1))
# keys的形状:(n_train,n_train - 1)
keys = X_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))
# values的形状:(n_train,n_train - 1)
values = Y_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))

代码详解:

  • X_tile = x_train.repeat((n_train, 1))

    • x_train.repeat((n_train, 1))x_train 沿着第一维(行)复制 n_train 次,使 X_tile 具有形状 n_train, n_train)
    • 这样 X_tile 的每一行都包含 完整的训练输入 x_train
  • Y_tile = y_train.repeat((n_train, 1))

    • y_train.repeat((n_train, 1)) 作用同上,将 y_train 复制 n_train 次,使 Y_tile 形状为 (n_train, n_train)
    • 这样 Y_tile 的每一行都包含 完整的训练输出 y_train
  • keys = X_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))

    • torch.eye(n_train) 生成 单位矩阵,形状为 (n_train, n_train),对角线全是 1,其余元素是 0。
    • 1 - torch.eye(n_train) 生成 反对角矩阵,对角线为 0,其余元素是 1。
    • .type(torch.bool) 将其转换为 布尔索引,对角线上的 0 变为 False,其余 1 变为 True
    • X_tile[...] 使用布尔索引删除每行的对角线元素,保留其他元素。
    • .reshape((n_train, -1)) 使 keys 变成 (n_train, n_train - 1),即:
      • 每行都是 x_train,但 去除了当前行对应的 x_train 值。
      • keys[i] 表示 去掉 x_train[i] 后,剩下的 x_train 值。
  • values = Y_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))

    • 作用和 keys 类似,只是这次操作的是 Y_tile
    • values[i] 表示 去掉 y_train[i] 后,剩下的 y_train 值。
    • 形状也是 (n_train, n_train - 1),即:
      • 每行存储 除去当前样本 y_train[i] 之外的所有 y_train 值。
  • 整体思路:

    • 这段代码的目标是 构造不包含自身的键值对,用于 带参数注意力汇聚 计算:
      1. 构造重复x_trainy_train,用于便捷索引 (X_tileY_tile)。
      2. 去除自身样本,即 keys[i]values[i] 不包含 x_train[i]y_train[i]
        • 这样,每个 x_train[i] 只与 其他 n_train - 1 个样本 计算注意力权重,而不会使用自身的值。
      3. 用于注意力计算keysvalues 作为输入,与 x_train 计算注意力权重,完成带参数注意力汇聚。

这段代码用于 构造不包含自身的数据集,在注意力机制中,这样的 keysvalues 能确保模型在计算注意力时不会直接依赖自身的信息,避免信息泄漏

训练带参数注意力汇聚模型的代码:

# 初始化模型
net = NWKernelRegression()
# 定义损失函数(均方误差)
loss = nn.MSELoss(reduction='none')
# 定义优化器(随机梯度下降)
trainer = torch.optim.SGD(net.parameters(), lr=0.05)
animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[1, 5])
# 训练模型
for epoch in range(5):
    trainer.zero_grad()
    l = loss(net(x_train, keys, values), y_train)
    l.sum().backward()
    trainer.step()
    print(f'epoch {epoch + 1}, loss {float(l.sum()):.6f}')
    animator.add(epoch + 1, float(l.sum()))
animator.show()
epoch 1, loss 59.251663
epoch 2, loss 18.949461
epoch 3, loss 18.708895
epoch 4, loss 18.533611
epoch 5, loss 18.400131

10_2_带参数注意力汇聚loss.png

# keys的形状:(n_test,n_train),每一行包含着相同的训练输入(例如,相同的键)
keys = x_train.repeat((n_test, 1))
# value的形状:(n_test,n_train)
values = y_train.repeat((n_test, 1))
y_hat = net(x_test, keys, values).unsqueeze(1).detach()
plot_kernel_reg(y_hat)
d2l.plt.show()
  • x_train.repeat((n_test, 1)) 使 x_train 沿着第一个维度(行)复制 n_test 次,得到形状为 (n_test, n_train) 的矩阵。

10_2_带参数注意力汇聚loss.png

d2l.show_heatmaps(net.attention_weights.unsqueeze(0).unsqueeze(0),
                  xlabel='Sorted training inputs',
                  ylabel='Sorted testing inputs')
d2l.plt.show()

10_2_带参数注意力汇聚_热图.png

4. 小结

Nadaraya-Watson核回归是一种经典的注意力机制,它通过对输入数据进行加权平均来实现预测。通过引入核函数和可学习的参数,我们可以进一步增强模型的表达能力。希望这篇文章能帮助你更好地理解注意力机制及其在机器学习中的应用。