为什么是多头自注意力机制

多头自注意力是 Transformer 模型的核心创新技术。相比于循环神经网络(Recurrent Neural Network, RNN)和卷积神经网络(Convolutional Neural Network, CNN)等传统神经网络,多头自注意力机制能够直接建模任意距离的词元之间的交互关系。
- 循环神经网络迭代地利用前一个时刻的状态更新当前时刻的状态,因此在处理较长序列的时候,常常会出现梯度爆炸或者梯度消失的问题。
- 卷积神经网络:,只有位于同一个卷积核的窗口中的词元可以直接进行交互,通过堆叠层数来实现远距离词元间信息的交换。
多头自注意力机制通常由多个自注意力模块组成。在每个自注意力模块中,对于输入的词元序列,将其映射为相应的查询(Query, 𝑸)、键(Key, 𝑲)和值(Value,𝑽)三个矩阵。然后,对于每个查询,将和所有没有被掩盖的键之间计算点积。这些点积值进一步除以 √𝐷 进行缩放(𝐷 是键对应的向量维度),被传入到 softmax函数中用于权重的计算。进一步,这些权重将作用于与键相关联的值,通过加权和的形式计算得到最终的输出。在数学上,上述过程可以表示为:
𝑸 = 𝑿𝑾𝑄,𝑲 = 𝑿𝑾𝐾,𝑽 = 𝑿𝑾𝑉
Attention(𝑸, 𝑲,𝑽) = softmax(𝑸𝑲⊺√𝐷)𝑽.
与单头注意力相比,多头注意力机制的主要区别在于它使用了 𝐻 组结构相同但映射参数不同的自注意力模块。输入序列首先通过不同的权重矩阵被映射为一组查询、键和值。每组查询、键和值的映射构成一个“头”,并独立地计算自注意力的输出。最后,不同头的输出被拼接在一起,并通过一个权重矩阵 𝑾𝑂 ∈ R 𝐻×𝐻进行映射,产生最终的输出。如下面的公式所示:
MHA = Concat(head1, . . . , headN)𝑾𝑂,
head𝑛 = Attention(𝑿𝑾𝑄𝑛 , 𝑿𝑾𝑛𝐾, 𝑿𝑾𝑉𝑛).
由上述内容可见,自注意力机制能够直接建模序列中任意两个位置之间的关系,进而有效捕获长程依赖关系,具有更强的序列建模能力。另一个主要的优势是,自注意力的计算过程对于基于硬件的并行优化(如 GPU、TPU 等)非常友好,因此能够支持大规模参数的高效优化。