多头注意力
2021-12-5
| 2023-8-6
0  |  阅读时长 0 分钟
type
status
date
slug
summary
tags
category
icon
password
Property
 
 
在实践中,当给定相同的查询、键和值的集合时, 我们希望模型可以基于相同的注意力机制学习到不同的行为, 然后将不同的行为作为知识组合起来, 捕获序列内各种范围的依赖关系 (例如,短距离依赖和长距离依赖关系)。 因此,允许注意力机制组合使用查询、键和值的不同 子空间表示(representation subspaces)可能是有益的。
为此,与其只使用单独一个注意力汇聚, 我们可以用独立学习得到的 组不同的线性投影(linear projections)来变换查询、键和值。 然后,这组变换后的查询、键和值将并行地送到注意力汇聚中。 最后,将这 个注意力汇聚的输出拼接在一起, 并且通过另一个可以学习的线性投影进行变换, 以产生最终输出。 这种设计被称为多头注意力(multihead attention)。 对于 个注意力汇聚输出,每一个注意力汇聚都被称作一个头(head)
notion image

模型

给定查询 、 键 和 值 , 每个注意力头 )的计算方法为:
其中,可学习的参数包括 , 以及代表注意力汇聚的函数可以是加性注意力和缩放点积注意力。 多头注意力的输出需要经过另一个线性转换, 它对应着 个头连结后的结果,因此其可学习参数是
基于这种设计,每个头都可能会关注输入的不同部分, 可以表示比简单加权平均值更复杂的函数。
 

实现

在实现过程中,选择缩放点积注意力作为每一个注意力头。 为了避免计算代价和参数代价的大幅增长, 设定 。 值得注意的是,如果将查询、键和值的线性变换的输出数量设置为 , 则可以并行计算 个头。 在下面的实现中, 是通过参数num_hiddens指定的。
为了能够使多个头并行计算, 上面的MultiHeadAttention类将使用下面定义的两个转置函数。 具体来说,transpose_output函数反转了transpose_qkv函数的操作
 
使用键和值相同的小例子来测试的MultiHeadAttention类。 多头注意力输出的形状是(batch_sizenum_queriesnum_hiddens
 
 

代码

  • PyTorch
  • Bahdanau 注意力自注意力和位置编码
    目录