结合底部的测试用例(batch_size=2, seq_len=10, model_dim=64, num_heads=8),逐步跟踪数据在每个阶段的维度变化和计算逻辑。

零、初始化阶段(__init__

1
self.head_dim = model_dim // num_heads  # 64 // 8 = 8

每个头的维度 dhead = 8。然后创建四个线性投影层:

1
2
3
4
self.w_q = nn.Linear(64, 64)   # Wᵠ ∈ ℝ^(64×64)
self.w_k = nn.Linear(64, 64) # Wᴷ ∈ ℝ^(64×64)
self.w_v = nn.Linear(64, 64) # Wⱽ ∈ ℝ^(64×64)
self.w_o = nn.Linear(64, 64) # Wᴼ ∈ ℝ^(64×64)

注意这里每个投影层的维度是 (model_dim, model_dim) = (64, 64),而不是 (64, 8)。这是因为所有头的投影被合并在一个大矩阵里,后面通过 reshape 来拆分成多个头。这样做的好处是只需要一次矩阵乘法就能完成所有头的投影,效率远高于分别为每个头做投影。

一、线性投影

测试用例中调用方式是 mha(x, x),即 x_query = x_context = x,这是自注意力

1
2
3
q = self.w_q(x_query)   # [2, 10, 64] × [64, 64] → [2, 10, 64]
k = self.w_k(x_query) # [2, 10, 64] → [2, 10, 64]
v = self.w_v(x_query) # [2, 10, 64] → [2, 10, 64]

此时 Q、K、V 的形状都是 [2, 10, 64],包含了所有 8 个头的信息,还没有拆分。

如果是交叉注意力(比如 x_context 来自编码器输出,形状为 [2, 20, 64]),那么 K 和 V 的 seq_len 就是 20 而不是 10,这也完全兼容后续的计算。

二、分头处理(reshape + transpose)

这一步是多头注意力实现的关键技巧:

1
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)

拆解来看分两步。第一步 view

1
2
[2, 10, 64] → [2, 10, 8, 8]
batch seq model_dim batch seq heads head_dim

这里把最后一个维度 64 拆成了 8 个头 × 每头 8 维。本质上就是把一个 64 维的"大向量"理解为 8 个 8 维的"小向量"。

第二步 transpose(1, 2)

1
2
[2, 10, 8, 8] → [2, 8, 10, 8]
batch seq heads head_dim batch heads seq head_dim

把 heads 维移到 seq 前面。这样做的目的是:让 (seq, head_dim) 成为最内层的两个维度,方便后续对每个头独立做矩阵运算。transpose 之后,可以把 [2, 8, 10, 8] 理解为“2 个样本,每个样本有 8 个头,每个头看到 10 个位置,每个位置用 8 维表示”。

K 和 V 也做完全相同的变换,最终:

1
2
3
q: [2, 8, 10, 8]
k: [2, 8, 10, 8]
v: [2, 8, 10, 8]

三、缩放点积注意力

3.1 计算注意力分数

1
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)

先看 k.transpose(-2, -1),这是对最后两个维度做转置:

1
2
3
k: [2, 8, 10, 8] → k.T: [2, 8, 8, 10]
↑ ↑
head_dim seq

然后 q 和 k.T 做矩阵乘法。batch 和 heads 这两个维度作为“批次维度”不参与乘法,实际运算发生在最后两个维度上:

1
2
3
4
q:   [2, 8, 10, 8]
k.T: [2, 8, 8, 10]
↓ matmul
scores: [2, 8, 10, 10]

scores 中每个元素 scores[b][h][i][j] 表示:第 b 个样本中,第 h 个头里,位置 i 的 query 对位置 j 的 key 的点积相似度。

然后除以 √dhead = √8 ≈ 2.83。这里除的是每个头的维度 √8,而不是整个模型维度 √64。

3.2 应用掩码(可选)

1
2
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)

在自回归语言模型(如 GPT、Qwen)中,会传入一个因果掩码(causal mask),形状为 [1, 1, 10, 10] 的下三角矩阵。mask 为 0 的位置(即未来位置)会被填充为 -1e9(一个极大的负数),经过 softmax 后这些位置的注意力权重就趋近于 0,从而阻止模型“看到未来的 token”。

本测试用例中 mask=None,所以跳过这一步。

3.3 Softmax 归一化

1
attn_weights = F.softmax(scores, dim=-1)

对 scores 的最后一个维度(seq_len_k)做 softmax:

1
2
3
scores:       [2, 8, 10, 10]
↑ 对这个维度 softmax
attn_weights: [2, 8, 10, 10] (每行之和 = 1)

attn_weights[b][h][i] 是一个长度为 10 的概率分布,表示位置 i 对所有 10 个位置的注意力权重分配。

然后应用 Dropout:

1
attn_weights = self.dropout(attn_weights)

训练时随机将部分注意力权重置零,起正则化作用。

3.4 加权求和

1
context = torch.matmul(attn_weights, v)

用注意力权重对 V 做加权求和:

1
2
3
4
attn_weights: [2, 8, 10, 10]
v: [2, 8, 10, 8]
↓ matmul
context: [2, 8, 10, 8]

context[b][h][i] 是一个 8 维向量,它是位置 i 根据注意力权重对所有位置的 value 向量做加权平均的结果。如果位置 i 对位置 j 的注意力权重很高,那么位置 j 的 value 向量就会对 context[b][h][i] 贡献更多。

四、合并多头

现在需要把 8 个头的结果拼接回去:

1
2
3
context = context.transpose(1, 2)       # [2, 8, 10, 8] → [2, 10, 8, 8]
context = context.contiguous() # 确保内存连续
output = context.view(batch_size, -1, self.model_dim) # [2, 10, 8, 8] → [2, 10, 64]

transpose(1, 2) 把 heads 和 seq 换回来。contiguous() 是因为 transpose 后张量在内存中可能不连续,而 view 要求内存连续,所以需要先调用 contiguous 重新整理内存布局。最后 view 把 (8 heads, 8 dim) 合并回 64 维。
最后通过输出投影:

1
output = self.w_o(output)  # [2, 10, 64] × [64, 64] → [2, 10, 64]

Wᴼ 让不同头学到的信息互相融合。最终输出形状 [2, 10, 64],和输入完全一致,可以无缝接入残差连接和后续层。

五、完整数据流总结

用测试用例(batch=2, seq=10, d_model=64, heads=8, d_head=8):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
输入 x:          [2, 10, 64]
↓ w_q / w_k / w_v 线性投影
Q, K, V: [2, 10, 64]
↓ view + transpose 分头
Q, K, V: [2, 8, 10, 8] (8个头各自独立)
↓ QKᵀ / √8
scores: [2, 8, 10, 10] (每个头的注意力分数)
↓ mask(可选)+ softmax + dropout
attn_weights: [2, 8, 10, 10] (归一化的注意力权重)
↓ × V 加权求和
context: [2, 8, 10, 8] (每个头的输出)
↓ transpose + view 合并多头
merged: [2, 10, 64] (拼接回完整维度)
↓ w_o 输出投影
output: [2, 10, 64] (最终输出)

整个过程的核心洞察是:通过 view 和 transpose 这两个零计算开销的操作,巧妙地将"多头并行计算"转化为了标准的批量矩阵乘法,让 GPU 能高效并行执行。所有头的计算在同一次 matmul 中完成,完全没有显式的循环。

完整代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""
多头注意力(Multi-Head Attention)

Transformer 的核心组件,通过并行运行多个注意力头来捕捉不同子空间的特征。
每个头独立计算注意力,最后将结果拼接并通过线性层融合。
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class MultiHeadAttention(nn.Module):
"""
多头注意力模块

支持自注意力(Self-Attention)和交叉注意力(Cross-Attention):
- 自注意力:Q = K = V = x_query
- 交叉注意力:Q = x_query, K = V = x_context

Args:
model_dim: 模型隐藏维度
num_heads: 注意力头数
dropout_p: Dropout 概率,默认 0.0
"""

def __init__(self, model_dim, num_heads, dropout_p=0.0):
super().__init__()

assert model_dim % num_heads == 0, "model_dim must be divisible by num_heads"

self.model_dim = model_dim
self.num_heads = num_heads
self.head_dim = model_dim // num_heads # 每个头的维度

# Q, K, V 投影层
self.w_q = nn.Linear(model_dim, model_dim)
self.w_k = nn.Linear(model_dim, model_dim)
self.w_v = nn.Linear(model_dim, model_dim)

# 输出投影层
self.w_o = nn.Linear(model_dim, model_dim)

self.dropout = nn.Dropout(dropout_p)

def forward(self, x_query, x_context, mask=None):
"""
前向传播

Args:
x_query: 查询输入 [batch_size, seq_len_q, model_dim]
x_context: 上下文输入(用于生成 K 和 V)[batch_size, seq_len_k, model_dim]
如果为 None,则使用 x_query(自注意力)
mask: 注意力掩码 [batch_size, 1, seq_len_q, seq_len_k] 或 [1, 1, seq_len_q, seq_len_k]

Returns:
output: 注意力输出 [batch_size, seq_len_q, model_dim]
"""
batch_size = x_query.size(0)

# ========== 线性投影 ==========
# 自注意力: q = k = v = x_query
# 交叉注意力: q = x_query, k = v = x_context
q = self.w_q(x_query)

if x_context is not None:
k = self.w_k(x_context)
v = self.w_v(x_context)
else:
k = self.w_k(x_query)
v = self.w_v(x_query)

# ========== 分头处理 ==========
# [batch_size, seq_len, model_dim] -> [batch_size, seq_len, num_heads, head_dim] -> [batch_size, num_heads, seq_len, head_dim]
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)

# ========== 缩放点积注意力 ==========
# scores: [batch_size, num_heads, seq_len_q, seq_len_k]
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)

# 应用注意力掩码
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)

# attn_weights: [batch_size, num_heads, seq_len_q, seq_len_k]
attn_weights = F.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)

# context: [batch_size, num_heads, seq_len_q, head_dim]
context = torch.matmul(attn_weights, v)

# ========== 合并多头 ==========
# [batch_size, num_heads, seq_len_q, head_dim] -> [batch_size, seq_len_q, num_heads, head_dim] -> [batch_size, seq_len_q, model_dim]
context = context.transpose(1, 2)
context = context.contiguous()
output = context.view(batch_size, -1, self.model_dim)

# 输出投影
output = self.w_o(output)

return output


if __name__ == "__main__":
# batch_size=2, seq_len=10, model_dim=64, num_heads=8
x = torch.randn(2, 10, 64)
mha = MultiHeadAttention(model_dim=64, num_heads=8)
out = mha(x, x) # Self-Attention: x_query=x, x_context=x
print(f"Input shape: {x.shape}")
print(f"Output shape: {out.shape}") # 应该还是 (2, 10, 64)