注意力汇聚:Nadaraya-Watson 核回归
2021-12-5
| 2023-8-6
0  |  阅读时长 0 分钟
type
status
date
slug
summary
tags
category
icon
password
Property
 
 
1964年提出的Nadaraya-Watson核回归模型来演示具有注意力机制的机器学习

生成数据集

考虑这个回归问题: 给定的成对的“输入-输出”数据集 , 如何学习来预测任意新输入的输出
根据下面的非线性函数生成一个人工数据集, 其中加入的噪声项为
服从均值为和标准差为的正态分布。
生成 个训练样本和 个测试样本,为了更好地可视化之后的注意力模式,将训练样本进行排序。
下面函数绘制所有的训练样本(样本由圆圈表示), 不带噪声项的真实数据生成函数(标记为“Truth”), 以及学习得到的预测函数(标记为“Pred”)
 

平均汇聚

先使用最简单的估计器来解决回归问题: 基于平均汇聚来计算所有训练样本输出值的平均值:
如下图所示,这个估计器确实不够聪明: 真实函数 (“Truth”)和预测函数(“Pred”)相差很大
notion image
 

非参数注意力汇聚

显然,平均汇聚忽略了输入 。 于是Nadaraya和Watson提出了一个更好的想法, 根据输入的位置对输出进行加权:
是核,公式描述的估计器被称为 Nadaraya-Watson核回归。 这里不深入讨论核函数的细节, 但受此启发, 可以从注意力机制框架的角度重写此公式, 成为一个更加通用的注意力汇聚(attention pooling)公式:
其中 是查询, 是键值对。 比较这两个公式, 注意力汇聚是的加权平均。 将查询 和键 之间的关系建模为注意力权重(attention weight) , 这个权重将被分配给每一个对应值。 对于任何查询,模型在所有键值对注意力权重都是一个有效的概率分布: 它们是非负的,并且总和为1。
 
考虑一个高斯核(Gaussian kernel),其定义为:
将高斯核代入可以得到:
如果一个键越是接近给定的查询,那么分配给这个键对应值的注意力权重就会越大,也就“获得了更多的注意力”。
Nadaraya-Watson核回归是一个非参数模型,上式是非参数的注意力汇聚模型。
将基于这个非参数的注意力汇聚模型来绘制预测结果。 你会发现新的模型预测线是平滑的,并且比平均汇聚的预测更接近真实。
notion image
观察注意力的权重, 这里测试数据的输入相当于查询,而训练数据的输入相当于键。 因为两个输入都是经过排序的,因此由观察可知“查询-键”对越接近, 注意力汇聚的注意力权重就越高。
notion image
 
 

带参数注意力汇聚

非参数的Nadaraya-Watson核回归具有一致性(consistency)的优点: 如果有足够的数据,此模型会收敛到最优结果。 尽管如此,我们还是可以轻松地将可学习的参数集成到注意力汇聚中。
下面的查询 和键 之间的距离乘以可学习参数

批量矩阵乘法

为了更有效地计算小批量数据的注意力, 我们可以利用深度学习开发框架中提供的批量矩阵乘法。
假设第一个小批量数据包含 个矩阵 , 形状为 , 第二个小批量包含 个矩阵 , 形状为 。 它们的批量矩阵乘法得到 个矩阵 , 形状为 。 因此,假定两个张量的形状分别是 , 它们的批量矩阵乘法输出的形状为
注意力机制的背景中,可以使用小批量矩阵乘法来计算小批量数据中的加权平均值

定义模型

基于带参数的注意力汇聚,使用小批量矩阵乘法, 定义Nadaraya-Watson核回归的带参数版本为:
 

训练

将训练数据集变换为键和值用于训练注意力模型。 在带参数的注意力汇聚模型中, 任何一个训练样本的输入都会和除自己以外的所有训练样本的“键-值”对进行计算, 从而得到其对应的预测输出。
使用平方损失函数和随机梯度下降
notion image
训练完带参数的注意力汇聚模型后,发现: 在尝试拟合带噪声的训练数据时, 预测结果绘制的线不如之前非参数模型的平滑
notion image
为什么新的模型更不平滑了呢? 看一下输出结果的绘制图: 与非参数的注意力汇聚模型相比, 带参数的模型加入可学习的参数后, 曲线在注意力权重较大的区域变得更不平滑
notion image
 

代码

  • PyTorch
  • 注意力提示注意力评分函数
    目录