7.3.5 高效注意力机制
- 理解普通注意力为什么会在长上下文下变贵
- 区分不同高效路线分别在优化什么瓶颈
- 通过一个可运行示例感受全局注意力和局部注意力的差别
- 建立训练期和推理期效率问题的第一层判断
一、普通注意力到底贵在哪里?
Section titled “一、普通注意力到底贵在哪里?”每个 token 都要和很多 token 比较
Section titled “每个 token 都要和很多 token 比较”假设序列长度是 n。
普通自注意力里,每个位置都要和其他位置做相似度计算。
于是比较次数大约是:
n * n
也就是:
O(n^2)
当 n = 512 时还不算夸张,
但当 n = 32768 时,情况就完全不同了。
长度翻倍,开销不是翻倍
Section titled “长度翻倍,开销不是翻倍”这正是很多新人最容易低估的地方。
序列长度如果从:
- 4k -> 8k
不是开销简单乘以 2, 而是很多部分接近乘以 4。
所以长上下文模型真正难的地方,不是“支持更多 token”这句话, 而是:
怎样在代价不爆炸的前提下支持更多 token。
训练和推理的痛点还不完全一样
Section titled “训练和推理的痛点还不完全一样”训练时更常见的压力是:
- 注意力矩阵太大
- 中间激活太多
推理时更常见的压力是:
- KV cache 越积越大
- 长会话越聊越慢、越占内存
所以高效注意力方法也分很多路线, 不是所有方法都在解决同一个问题。
二、先把几条主流路线分开
Section titled “二、先把几条主流路线分开”滑动窗口 / 局部注意力:减少“谁看谁”
Section titled “滑动窗口 / 局部注意力:减少“谁看谁””最直观的一条路线是:
- 不让每个 token 看全世界
- 只让它看附近一小段窗口
这相当于说:
- 远处信息不是完全不要
- 但不是每一层、每个位置都必须全量对齐
典型思路有:
- sliding window attention
- local attention
MQA / GQA:减少 KV cache 体积
Section titled “MQA / GQA:减少 KV cache 体积”另一条很重要的路线不是改 mask,
而是改多头注意力的 K / V 组织方式。
普通多头注意力里,不同 head 往往各自有一套 K/V。 这会让推理期 KV cache 体积非常大。
于是出现了:
- MQA:多个 查询 head 共享一组 K/V
- GQA:把 查询 head 分组共享 K/V
它们的核心收益更偏向:
- 推理内存更省
- 吞吐更好
FlashAttention:不是改公式,而是改算的方式
Section titled “FlashAttention:不是改公式,而是改算的方式”FlashAttention 很容易被误解成:
- 一种新的注意力定义
其实更准确的理解是:
注意力公式基本不变,但通过更高效的分块计算与内存读写方式,减少显存开销和访存浪费。
它优化的重点是:
- 训练和推理时的实现效率
而不是让模型突然能理解完全不同的关系。

线性注意力:尝试从公式层面降复杂度
Section titled “线性注意力:尝试从公式层面降复杂度”还有一类方法更激进, 它会直接改写注意力计算形式,希望把复杂度从平方级降下来。
这类方法通常会在:
- 理论复杂度
- 表达能力
- 实际效果
之间做权衡。
三、先跑一个真正说明问题的示例
Section titled “三、先跑一个真正说明问题的示例”下面这个例子会比较两件事:
- 全局注意力:每个位置都能看所有位置
- 局部注意力:每个位置只能看附近窗口
我们不仅会比较“能看谁”, 还会比较需要处理的 pair 数量。
from math import exp
values = [0.2, 0.1, 0.0, 0.8, 0.9, 0.7, 0.1, 0.0]
def softmax(scores): m = max(scores) exps = [exp(x - m) for x in scores] total = sum(exps) return [x / total for x in exps]
def attention_outputs(sequence, window=None): outputs = [] pairs = 0 neighborhoods = []
for i in range(len(sequence)): if window is None: neighbors = list(range(len(sequence))) else: left = max(0, i - window) right = min(len(sequence), i + window + 1) neighbors = list(range(left, right))
neighborhoods.append(neighbors) pairs += len(neighbors)
scores = [sequence[i] * sequence[j] for j in neighbors] weights = softmax(scores) output = sum(w * sequence[j] for w, j in zip(weights, neighbors)) outputs.append(output)
return outputs, pairs, neighborhoods
full_outputs, full_pairs, full_neighbors = attention_outputs(values, window=None)local_outputs, local_pairs, local_neighbors = attention_outputs(values, window=2)
print("full pairs :", full_pairs)print("local pairs:", local_pairs)print("token 4 full neighbors :", full_neighbors[4])print("token 4 local neighbors:", local_neighbors[4])print("full outputs :", [round(x, 3) for x in full_outputs])print("local outputs:", [round(x, 3) for x in local_outputs])预期输出:
full pairs : 64local pairs: 34token 4 full neighbors : [0, 1, 2, 3, 4, 5, 6, 7]token 4 local neighbors: [2, 3, 4, 5, 6]full outputs : [0.376, 0.363, 0.35, 0.457, 0.47, 0.443, 0.363, 0.35]local outputs: [0.101, 0.285, 0.4, 0.604, 0.615, 0.592, 0.44, 0.267]
这段代码到底对应了什么直觉?
Section titled “这段代码到底对应了什么直觉?”它告诉你两件特别关键的事:
- 如果限制每个位置只看局部,pair 数量会明显下降
- 但输出也会变,因为模型失去了远处信息
这正是高效注意力最核心的现实:
你不是在免费提速,而是在效率和可见范围之间做权衡。
为什么 full pairs 和 local pairs 差很多?
Section titled “为什么 full pairs 和 local pairs 差很多?”因为全局注意力里每个位置都看全部位置。 局部注意力里,每个位置只看窗口附近。
当序列长度很长时,这种差距会迅速放大。
为什么局部注意力不一定就更差?
Section titled “为什么局部注意力不一定就更差?”因为很多信息本来就具有局部性。 例如语言里:
- 最近几个 token 往往最相关
- 远程依赖虽然重要,但不一定每一层都要全量建模
所以很多长上下文模型会采用:
- 部分层全局
- 部分层局部
- 或者带稀疏模式的混合方案
四、推理期另一个大头:KV cache
Section titled “四、推理期另一个大头:KV cache”为什么聊天越长,推理越吃内存?
Section titled “为什么聊天越长,推理越吃内存?”因为 decoder-only 模型在生成时,
前面每一步的 K / V 都会缓存下来,供后续 token 重用。
这就是:
- KV cache
它能显著减少重复计算, 但代价是:
- 会话越长,缓存越大
MQA / GQA 到底在省什么?
Section titled “MQA / GQA 到底在省什么?”它们省的不是注意力矩阵本身, 而是每层每步要保存的 K/V 体积。
简单理解:
- 普通 MHA:每个 head 都有自己的 K/V
- MQA:很多 查询 head 共用一组 K/V
- GQA:一组 查询 head 共用一组 K/V
所以它们尤其适合:
- 大模型推理
- 长对话
- 高吞吐服务
一个简单的“谁更省”的估算
Section titled “一个简单的“谁更省”的估算”def kv_units(num_query_heads, num_kv_heads, head_dim, seq_len): return num_kv_heads * head_dim * seq_len * 2
seq_len = 8192head_dim = 128
print("MHA units =", kv_units(32, 32, head_dim, seq_len))print("GQA units =", kv_units(32, 8, head_dim, seq_len))print("MQA units =", kv_units(32, 1, head_dim, seq_len))预期输出:
MHA units = 67108864GQA units = 16777216MQA units = 2097152这里的数字不是完整显存公式, 但足够建立第一层直觉:
num_kv_heads越少- KV cache 越小

五、FlashAttention 为什么这么常被提?
Section titled “五、FlashAttention 为什么这么常被提?”因为很多瓶颈不在“算不出来”,而在“搬数据太贵”
Section titled “因为很多瓶颈不在“算不出来”,而在“搬数据太贵””注意力实现里,一个常见问题是:
- 中间矩阵太大
- GPU 显存读写频繁
FlashAttention 的关键思路是:
- 把计算分块
- 尽量减少中间结果落回高开销内存
所以它常常能带来:
- 更高吞吐
- 更低显存占用
它和滑动窗口不是同一类东西
Section titled “它和滑动窗口不是同一类东西”这一点非常重要。
- 滑动窗口是在改“看谁”
- FlashAttention 是在改“怎么算”
所以它们甚至可以组合使用, 并不是互斥关系。
六、什么时候该优先想哪条路线?
Section titled “六、什么时候该优先想哪条路线?”如果你主要卡在长上下文训练显存
Section titled “如果你主要卡在长上下文训练显存”优先会想到:
- FlashAttention
- activation checkpointing
- 序列并行
如果你主要卡在推理时 KV cache 太大
Section titled “如果你主要卡在推理时 KV cache 太大”优先会想到:
- MQA
- GQA
- KV cache 量化
如果你主要卡在超长上下文的平方复杂度
Section titled “如果你主要卡在超长上下文的平方复杂度”优先会想到:
- 滑动窗口
- 稀疏注意力
- 分块或混合注意力
- 线性注意力类方法
也就是说:
高效注意力不是一把锤子,而是一组针对不同瓶颈的工具。
七、常见误区
Section titled “七、常见误区”误区一:高效注意力 = 更快而且一定更好
Section titled “误区一:高效注意力 = 更快而且一定更好”很多方法本质上是在交换:
- 速度
- 内存
- 感受野
- 实现复杂度
不可能所有指标都白赚。
误区二:只要支持长上下文,模型就一定“会用长上下文”
Section titled “误区二:只要支持长上下文,模型就一定“会用长上下文””支持 128k 上下文,不等于模型真的能稳定利用 128k 里的关键信息。
这是两件不同的事:
- 工程支持长度
- 模型有效利用长度
误区三:FlashAttention 是一种新模型架构
Section titled “误区三:FlashAttention 是一种新模型架构”不是。 它更像一种高效实现技术。
学完这一页,至少保留这张证据卡:
- 成本来源
- 普通注意力会存储或计算 seq_len x seq_len 交互
- 方法
- 根据瓶颈选择 sparse、linear、FlashAttention 或 KV cache
- KV 缓存
- 加速解码但消耗内存
- 硬件说明
- 算法收益取决于运行时/kernel 支持
- 决策
- 在改变架构前先测量延迟/内存
这节最重要的不是记住一串方法名, 而是先分清问题:
你到底是在被平方复杂度卡住、被 KV cache 卡住,还是被显存读写效率卡住。
只有先把瓶颈分清楚,你才知道该看:
- 滑动窗口
- GQA / MQA
- FlashAttention
中的哪一类方案。
- 把示例中的
window=2改成window=1或window=3,观察 pair 数量怎么变化。 - 用自己的话解释:为什么说滑动窗口是在改“看谁”,FlashAttention 是在改“怎么算”?
- 如果你做的是长对话推理服务,为什么 GQA / MQA 往往比滑动窗口更先进入视野?
- 想一想:支持很长上下文,和真正能有效利用长上下文,为什么不是一回事?
参考实现与讲解
- 窗口越小,可见 pair 越少;窗口越大,可见 pair 越多。具体数量会随着每个 token 被允许看到的局部上下文范围而增长。
- 滑动窗口改变的是 attention pattern,也就是限制每个 token 能看哪些邻居。FlashAttention 不改变数学上的 attention 结果,而是用更省显存的 kernel 来计算。
- 长对话推理服务往往先撞到 KV cache 显存瓶颈。GQA / MQA 能减少 key-value cache 占用,所以即使保持 full attention,也能先改善服务容量。
- 模型能接收很多 token,不代表它一定能检索、连接并优先使用正确证据。上下文长度是容量上限,长上下文利用质量是需要单独评估的行为能力。