Transformer中的多头自注意力机制

在自然语言处理和其他序列数据建模任务中,Transformer模型凭借其优异的性能成为了研究的热点。其中,多头自注意力(Multi-Head Self-Attention)是Transformer的核心组件之一,接下来将深入了解它的运作原理和实现细节。

什么是自注意力

自注意力机制允许模型在编码序列时,对于每个位置的表示,都可以获取其他位置的信息,并基于这些信息进行更新。不同于RNN/LSTM等循环神经网络捕获序列依赖的方式,自注意力通过注意力分数直接建模不同位置间的依赖关系。这种长程依赖的直接建模方式大大提高了模型的表现力。

多头注意力的计算过程

多头自注意力将注意力机制进一步推广,使用多个独立的"注意力头"来从不同的表示子空间捕获序列信息。具体计算过程如下:

  1. 线性投影

    输入序列\(X\)的维度为\((N, T, d)\),其中\(N\)为批次大小, \(T\)为序列长度, \(d\)为特征维度。我们首先通过一个线性变换\(W_{QKV}\)\(X\)投影到查询(Query)、键(Key)和值(Value)空间,得到\((N, T, 3d)\)的表示,然后将其沿最后一个维度拆分成\(Q\)\(K\)\(V\),每个形状为\((N, T, d)\)

  2. 分头(Multi-Head)

    \(Q\)\(K\)\(V\)在最后一个维度上分割成\(h\)个head,每个head的维度为\(d/h\)。这样\(Q\)\(K\)\(V\)的形状变为\((N, T, h, d/h)\)。接着,我们交换\(Q\)\(K\)\(V\)的第1和第2个维度,形状变为\((N, h, T, d/h)\),方便后续计算。

  3. 缩放点积注意力

    对于每个头\(i\),我们计算\(Q_i\)\(K_i\)的点积,除以\(\sqrt{d/h}\)进行缩放,加上掩码矩阵\(M\)处理填充项,然后对最后两个维度应用Softmax函数获得注意力分数矩阵:

    \[\text{Attention}_i = \text{softmax}\left(\frac{Q_iK_i^T}{\sqrt{d/h}} + M\right)\]

    分数矩阵的形状为\((N, h, T, T)\)

  4. 加权求和

    将注意力分数与\(V_i\)相乘,在最后两个维度上执行矩阵乘法,获得每个头的加权表示:

    \[\text{Head}_i = \text{Attention}_i V_i\]

    \((N, h, T, d/h)\)的张量。

  5. 合并头(Concat Heads)

    \(h\)个头的加权表示在最后一个维度上拼接,形状变为\((N, h, T, d)\)

  6. 线性变换

    通过一个额外的线性层\(W_o\)将合并后的表示映射回\(d\)维空间,得到最终的多头自注意力输出\((N, T, d)\)

上述过程中牵涉到的几个关键矩阵运算,可以用形状来解释其合理性:

  • \(Q_i\)\(K_i^T\)相乘: \((N, h, T, d/h) \times (N, h, d/h, T) \rightarrow (N, h, T, T)\)
  • \(\text{Attention}_i\)\(V_i\)相乘: \((N, h, T, T) \times (N, h, T, d/h) \rightarrow (N, h, T, d/h)\)
  • 多头合并: \((N, h, T, d/h) \rightarrow (N, T, h, d/h) \rightarrow (N, T, d)\)

通过分解成多个并行子空间,多头注意力能够从不同的表示子空间获取序列信息,并将这些信息融合,从而提高了模型的表现力。另一方面,通过矩阵运算的高效实现,多头自注意力也可以实现对序列的并行编码,这种并行性是循环神经网络所无法企及的。

实现细节

在上一部分,我们从理论层面介绍了多头自注意力的计算过程。现在让我们看一下使用NumPy实现多头自注意力的代码,并与PyTorch官方实现进行对比验证。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import numpy as np
import torch
import torch.nn as nn

def softmax(Z):
Z = np.exp(Z - Z.max(axis=-1, keepdims=True))
return Z / Z.sum(axis=-1, keepdims=True)

def MSA(X, mask, heads, W_QKV, W_out):
N, T, d = X.shape
K, Q, V = np.split(X@W_QKV, 3, axis=-1)
K, Q, V = [a.reshape(N, T, heads, d//heads).swapaxes(1,2) for a in (K, Q, V)]

attn = softmax(K@Q.swapaxes(-1,-2)/np.sqrt(d//heads)+mask)
return (attn@V).swapaxes(1,2).reshape(N, T, d)@W_out, attn

这段代码实现了前文描述的多头自注意力计算过程。值得注意的几点是:

  1. 线性投影得到Q、K、V后,使用列表解析重新reshape并交换维度,这样处理比较简洁高效。
  2. 计算注意力分数时,先进行Q与K的点积运算,再除以\(\sqrt{d/h}\)进行缩放,最后加上掩码矩阵进行软化约束。
  3. 加权求和时,先在注意力分数与V相乘,再与\(W_o\)相乘,保证了形状的匹配。
  4. 计算 Z 沿着最后一个维度(axis=-1)的最大值,并保留这个维度(keepdims=True),得到一个新的具有相同形状但最后一个维度上所有值为原数组对应最大值的新数组。然后将原数组 Z 中的相应元素减去这个最大值数组,这样可以确保所有元素的值都不大于1。

为了验证这一实现的正确性,我们可以与PyTorch官方实现进行对比:

1
2
3
4
5
6
7
8
9
10
11
12
N = 10 
T = 100
d = 64
heads = 4
X = torch.randn(N,T,d)
M = torch.triu(-float("inf")*torch.ones(T,T),1)
attn = nn.MultiheadAttention(d, heads, bias=False, batch_first=True)
Y_, A_ = attn(X,X,X, attn_mask=M)
Y, A = MSA(X.numpy(), M.numpy(), heads, attn.in_proj_weight.detach().numpy().T, attn.out_proj.weight.detach().numpy().T)

print(np.linalg.norm(Y - Y_.detach().numpy())) # 输出接近0,说明两者结果很接近
print(np.linalg.norm(A.mean(1) - A_.detach().numpy())) # 输出也接近0

这段代码构造了一个简单的多头自注意力计算示例,并与PyTorch内置的nn.MultiheadAttention模块进行结果对比。可以看到,两者的输出之间的差异接近于0,说明我们的NumPy实现是正确的。

通过上述代码实现和结果验证,相信你现在对多头自注意力机制有了更加具体和透彻的理解。这种自注意力机制是Transformer和其它顺序模型的核心,掌握它有助于我们进一步探索和优化这些模型。


Transformer中的多头自注意力机制
http://jingmengzhiyue.top/2024/03/28/MSA/
作者
Jingmengzhiyue
发布于
2024年3月28日
许可协议