研究报告:《Do Transformers Need Three Projections? Systematic Study of QKV Variants》
1. 作者及机构
本文作者为Ali Kayyam, Anusha Madan Gopal, M Anthony Lewis。三位作者均隶属于位于美国加利福尼亚州拉古纳山的BrainChip公司。
2. 这项工作如何融入更广泛的研究领域
Transformer 架构已成为从自然语言处理到计算机视觉等各种人工智能任务的基础架构。然而,其广泛应用同时也凸显了其效率方面的重大挑战,特别是其自注意力机制的二次计算和内存开销,随着上下文窗口的增大和对实时推理需求的增加,这些开销会变得更加突出。
现有的研究领域已通过多种方法解决了这些效率问题。这些方法包括开发线性复杂度的注意力模型(例如 Performer、Linformer)、新型注意力机制(例如环形注意力机制和分块方法),以及在推理过程中减少键值(KV)缓存大小的技术(例如分组查询注意力机制 (GQA) 和多查询注意力机制 (MQA))。虽然这些创新有助于缓解自注意力机制的二次方瓶颈,但一个基本的结构性问题仍未得到充分探讨:三方(查询、键、值)投影的必要性。与其他神经网络架构(例如卷积神经网络 (CNN) 或状态空间模型 (SSM))通常使用更统一的内部表示不同,标准的 Transformer 模型维护着三个不同的投影矩阵,分别用于查询、键和值。
本文旨在对现有的效率提升方案进行补充。作者并非完全替换自注意力机制或仅仅关注头部共享策略,而是研究能否在不牺牲注意力机制核心功能和性能的前提下,统一或共享三个投影矩阵。这构成了对注意力机制内部权重绑定的系统性探索,旨在减少固有冗余,并充分利用参数数量、计算开销以及至关重要的推理内存占用方面的潜在优势。本文着重探讨了这种投影共享方式如何与现有的头部共享技术相互作用并加以结合,从而揭示了Transformer效率提升中一个尚未充分探索的维度。
3. 主要目标和动机
本研究的主要目标是系统地评估Transformer模型中三种不同的查询(Q)、键(K)和值(V)投影的必要性,并量化共享或统一这些投影的影响。这项研究的驱动力主要来自以下几个方面:
首先,一个重要的动机源于这样的观察:尽管存在许多提高Transformer效率的方法,但三个独立的QKV投影的基本结构却鲜少受到质疑。作者试图确定这种架构冗余是否可以减少,从而构建更简洁高效的模型。这涉及到评估此类简化是否能在对下游任务性能影响最小的情况下,降低参数数量和计算开销。
其次,一个重要的实际动机是解决日益严重的推理内存瓶颈问题,特别是键值(KV)缓存的大小,这在大型语言模型(LLM)的自回归生成过程中尤为显著。随着上下文窗口的扩展以及对LLM设备端或边缘部署需求的增加,KV缓存消耗的内存成为影响服务成本和吞吐量的主要因素。本研究旨在开发和评估能够直接减少KV缓存占用空间的投影共享策略,从而实现:
- 在给定内存预算的情况下,延长上下文窗口。
- 更高的吞吐量(每个 GPU 可同时服务更多用户)。
- 降低内存密集型部署的服务成本。
- 促进了LLM在资源受限的边缘设备和移动平台上的实际部署。
第三,本研究的动机在于,将投影共享与现有的头部共享机制(例如分组查询注意力 (GQA) 和多查询注意力 (MQA))相结合,有望实现协同效应。这些头部共享技术通过跨层共享头部来减少键值缓存,而投影共享则直接针对投影矩阵本身。作者旨在证明这两种方法是正交的,并且可以相乘地结合,从而在内存效率方面实现复合提升。
最后,研究人员试图更深入地理解查询、键和值各自的角色和表征空间。通过系统地约束这些投射,研究人员旨在深入了解哪些投射对于维持模型质量更为关键,以及某些共享方案成功或失败的原因。例如,该研究旨在探究为何统一键和值可以保持模型质量,而统一查询和键却可能损害注意力方向性,尤其是在序列任务中。这涉及到将投射共享描述为注意力机制中权重绑定的一个具体实例。
4. 方法论和途径
本研究系统地探讨了自注意力机制中的投影共享约束,提出了三种主要变体及其增强形式。研究方法还包括计算和内存成本的比较分析、与头部共享技术相结合的探索,以及对实际部署的考量。
提出的投影共享注意力变体: 作者评估了三种主要变体,逐步减少学习到的投影矩阵的数量:
-
Q=KV(统一查询和键;分离值): 在此变体中,查询 (Q) 投影矩阵与键 (K) 投影矩阵相同($Q=K$)。然后,注意力机制计算 $A = ext{Softmax}(alpha KK^T)V$。此公式生成一个对称的注意力矩阵($KK^T$)。
- (Q=KV)+: 为了缓解序列任务中对称注意力的局限性,引入了二维位置编码。将固定的正弦位置编码 $P (位于 mathbb{R}^{n imes n imes m}$) 添加到注意力图 $A' = A + P$ 中,然后进行 $1 imes 1$ 卷积,将其映射回二维注意力矩阵。此举旨在引入不对称性和方向性偏差。(X)+ 变体专门应用于非因果场景(视觉、合成任务),在这些场景中,对称注意力是主要限制因素,因为因果语言建模已经通过掩蔽强制实现了不对称性。
-
QK=V(独立查询;统一键值): 此处,键 (K) 和值 (V) 的投影矩阵被统一($V=K$),而查询 (Q) 保持独立。注意力机制计算为 $A = ext{Softmax}(alpha QK^T)K$。由于 Q 和 K 仍然独立,因此该变体保留了非对称注意力图。键值统一被视为一种权重绑定形式。
-
Q=K=V(三个矩阵使用单一投影): 这是最彻底的简化,其中三个投影矩阵被统一起来($Q=K=V$)。注意力机制变为 $A = ext{Softmax}(alpha KK^T)K$。这结合了 Q=KV 的对称注意力机制和 K=V 的表征瓶颈。
- (Q=K=V)+: 与(Q=KV)+类似,添加了 2D 位置编码,以解决对称注意力可能引起的问题。
将投影共享与头部共享相结合: 作者强调,他们的投影共享方法与现有的头部共享方法(例如分组查询注意力机制 (GQA) 和多查询注意力机制 (MQA))正交。GQA 和 MQA 通过多个查询头部共享较少数量的键值缓存来减小键值缓存的大小。该研究提出将这些策略结合起来:
-
Q-GQA-g:
在每个 GQA 组内应用 K=V 约束
g。 - Q-MQA: 将 K=V 约束应用于 MQA 中使用的单个 KV 磁头。这些组合有望带来内存效率的复合式乘法提升。
计算和内存分析: 本文分析了每种变体的计算复杂度(具体而言,投影操作)和参数数量,并与标准 QKV Transformer 进行了比较(表 1)。
- Q=KV 和 QK=V 将投影操作和参数减少了 33%(从 $3nd^2$ 操作减少到 $2nd^2$ 操作,从 $3d^2$ 参数减少到 $2d^2$ 参数,其中 $n$ 是序列长度,$d$ 是嵌入维度)。
- Q=K=V 实现了 66% 的运算量减少(降至 $nd^2$ 次运算和 $d^2$ 个参数)。关键在于,QK=V 及其变体在自回归推理期间可将 KV 缓存内存减少 50%,因为只需存储 K 张量,V 张量可从 K 张量中复用。这一优势体现在能够实现更长的上下文窗口、更高的吞吐量和更低的服务器成本。本文还探讨了这些优化如何与其他效率提升技术(例如量化、稀疏注意力机制和 Flash Attention)相辅相成。
设计考虑因素:
- 对称注意力中的对角线优势: $KK^T$ 公式可能会导致强烈的自我注意力,而 QK=V 的 $QK^T$ 自然地避免了这一点。
- 扩展到编码器-解码器架构: 该方法适用于编码器-解码器模型中的自注意力层,类似于 MQA 的选择性使用方式。
- 与其他效率技术的协同作用: 投影共享可以与量化(例如 INT8/INT4)、稀疏注意力、替代激活以及硬件高效的实现(如 Flash Attention)相结合。
- 根据任务而定的应用: 最优变体的选择取决于任务(例如,因果任务采用 QK=V,非因果任务采用 Q=KV 或 Q=K=V,极端资源限制情况下采用组合方法)。
实验设置: 评估涵盖三个领域中的一系列不同任务:
- 综合推理: 使用单个 Transformer 编码器完成五项任务(反转、排序、删除、交换、复制)。
- 计算机视觉: 在 MNIST、FashionMNIST、CIFAR-10、CIFAR-100 和 TinyImageNet 数据集上进行图像分类,并进行异常检测。此外,本文还基于外部研究成果,对医学图像进行了分割。
- 语言建模: 使用来自 SlimPajama 数据集的 100 亿个词元,对 3 亿参数和 12 亿参数的 GPT 风格模型进行预训练。所有模型均使用相同的超参数从头开始训练,以确保可控的比较。语言建模实验使用了 8 个 NVIDIA A100 40GB GPU,并采用 bfloat16 混合精度。使用 EleutherAI lm-eval-harness 在 HellaSwag、PIQA、ARC-Easy、ARC-Challenge 和 WinoGrande 数据集上对语言建模模型进行了下游任务评估。
5. 主要发现和结果
该研究对合成、视觉和语言任务中的投影共享变体进行了系统评估,揭示了不同的性能特征和显著的效率提升。
综合任务:( 表 2)
- 所有测试的Transformer变体在合成任务中均表现出良好的性能。
- Q=KV 变体的性能与标准 QKV 变压器相当。
- Q=K=V 变体的性能明显下降。
- 引入二维位置编码(以“+”后缀表示,例如 (Q=KV)+)显著提升了所有变体的性能,部分缓解了受限配置下的性能损失。例如,(Q=KV)+ 的平均准确率略高于 QKV(0.870 对 0.851)。
- QKV变换器通常表现出更快的收敛速度,尽管最终精度往往具有竞争力。
视觉任务:( 表 3、图 2、附录 A.3.2、A.3.3)
- 分类: (Q=KV)+ Transformer 在 MNIST、FashionMNIST 和 CIFAR 数据集上的表现与 QKV Transformer 相当。Q=K=V Transformer 在 CIFAR 数据集上也表现出色。在更大的 TinyImageNet 数据集上,Q=K=V Transformer 尽管仅使用了一种投影,但在所有测试的变体中取得了最佳结果,展现了在精度损失极小的情况下实现高效训练的能力。此外,共享投影的变体训练时间也显著缩短(例如,Q=K=V 比 QKV 每 epoch 快 20%)。
- 设置异常检测: 所有模型均表现出相当的性能,其中 (Q=KV)+ 略有优势。
- 图像分割(来自附录中的 Hwa 等人 (2025)): QK=V 变体(简称 KV)即使在复杂的医学图像分割任务中也与标准 QKV 注意力机制保持竞争力,同时参数数量和 MAC 减少了约 10%。
自然语言处理任务(语言建模):( 表 4、5、6、7、8、9、10、11、14、15、16,图 8、9、10、11、12)
-
语言模型质量(3亿参数):
- QK=V 被证明是最有效的投影共享变体,与 QKV 基线相比,在困惑度仅下降 3.1% 的情况下,实现了 50% 的键值缓存减少(验证困惑度分别为 5.27 和 5.11)。该变体在整个训练过程中与基线性能非常接近。
- Q=KV 展现了具有竞争力的训练质量(困惑度降低了 4.9%),但与标准 QKV 相比,在推理内存方面没有优势,因为它仍然需要分别缓存 K 和 V 张量。
- Q=K=V 导致质量大幅下降,困惑度降低了 25.4%,这表明这种程度的约束对于语言建模来说过于严格。
- 不同变体的训练吞吐量相似,速度差异很小(Q=K=V 最多快 8.7%)。
-
参数计数和计算:
- 虽然投影共享导致注意力参数大幅减少(25-50%),但这转化为整体模型参数的适度减少(例如,QK=V:总参数减少 6.9%),因为注意力投影约占 Transformer 总参数的三分之一。
- 类似地,由于 MLP 层和语言建模头的显著贡献,推理计算成本 (MAC) 也略有节省(例如,QK=V:序列长度为 2048 时总 MAC 减少 5.4%)。
- 核心发现是,主要好处在于 推理记忆效率 ,而不是减少整体参数或 FLOPs。
-
KV高速缓存分析:
- QK =V 和 Q=K=V 变体通过仅存储 K 张量并在生成期间将其用作 V,实现了 KV 缓存内存减少 50%。
- Q=KV 没有提供缓存节省。
- 此次 50% 的缓存减少被认为是实际部署的关键,它能够实现 2 倍更长的上下文窗口、2 倍更高的吞吐量,并在生产 LLM 服务中大幅节省成本(例如,在典型的部署场景中每年节省 72,000 美元)。
- 实证推断的墙钟基准测试证实了这些理论上的节省,与 QKV 相比,QK=V 的峰值内存减少 6.5%–6.9%,解码吞吐量提高 4.4%–5.3%。
-
随序列长度缩放:
- 随着序列长度的增加,注意力在总计算中所占的比例越来越大,从而放大了简化投影变体的效率提升。不同变体的相对排名在不同的序列长度下保持稳定。
-
扩展到 12 亿个参数:
- 研究结果可以推广到更大的(12亿参数)模型,相对质量排名与3亿参数实验的结果保持一致。
- MQA 与 QKV 的性能几乎持平(困惑度增加 1.06%),缓存减少 97%。
- GQA-8 提供了良好的质量效率平衡(困惑度提高 0.52%,缓存减少 76%)。
- QK=V 在保持合理质量(困惑度提升 2.48%)的同时,缓存节省了 50%。值得注意的是,QK=V 在参数量为 12 亿时困惑度下降幅度低于 3 亿时,这表明更大的模型可能对投影约束更具鲁棒性。
-
投影与头部共享相结合:
- 该研究表明,投影共享与头部共享是完全互补的。
- Q-GQA-4 实现了 87.5% 的缓存减少,困惑度下降了 3.9%(3 亿个模型)。
- Q-MQA 在 3 亿个模型规模下实现了 96.9% 的缓存缩减,同时困惑度下降了 4.8%。这些组合方法在保持实用模型质量的同时,显著压缩了内存。在 12 亿个模型规模下,Q-MQA 实现了 98.5% 的缓存缩减,困惑度下降了 4.16%。推理基准测试表明,Q-MQA 实现了 12.8%–13.6% 的内存缩减和 11.7%–13.2% 的吞吐量提升。
-
下游任务评估(12亿个模型):
- 尽管困惑度差距高达 2.48%,但 QK=V 在五个标准基准测试(HellaSwag、PIQA、ARC-Easy、ARC-Challenge 和 WinoGrande)中,下游准确率平均仅下降了 0.41%。这表明困惑度的下降并不总是会导致特定任务能力的相应损失。
- 组合后的 Q-GQA-8 在平均下游准确率上略高于 QKV,同时实现了 87.5% 的缓存减少,强化了协同效应。
-
建筑洞察:
- QK=V之所以有效,是因为键和值可以占据相似的表征空间,并且注意力机制在低秩机制下运行,同时保持了Q与K/V对之间的不对称性。对 训练好的QKV模型的分析表明,K和V投影矩阵之间具有较高的余弦相似度(0.73)和相似的有效秩,而Q与K和V的相似度均较低。这表明K=V约束得到了K和V固有冗余或共享功能的有力支持。
- Q=KV 不适用于因果语言建模,因为它强制使用对称的注意力模式 (KK^T),破坏了因果依赖关系所需的方向性。
- Q=K=V 结合了这两种病理 ,导致灾难性的质量下降。
- 理论附录还详细说明了线性注意力中的 QKV 崩溃如何导致类似于具有自适应观察的状态空间模型的递归形式,这表明可编程记忆和动力系统之间存在连续性。
6. 重要性和潜在影响
本研究系统地将投影共享描述为Transformer注意力机制中权重绑定的一种形式,并展示了其显著优势,尤其是在推理效率方面。研究结果在多个维度上具有重要意义:
首先,QK=V 被确定 为一种高效且可扩展的投影共享策略, 这是一项主要贡献。该变体在模型质量影响相对较小(12 亿参数下困惑度下降 2.48%)的情况下,实现了 KV 缓存内存减少 50%。这代表了 Transformer 架构的实质性进步,为实践者在效率-质量帕累托前沿上提供了一个新的起点。
其次,该研究强调 推理内存优势 而非单纯的参数或浮点运算次数减少,这对于实际部署至关重要。在LLM中,KV缓存通常是自回归生成过程中的主要内存瓶颈。通过将该缓存减半,QK=V可以实现:
- 扩展上下文窗口: 模型可以在现有硬件限制内处理更长的序列。
- 吞吐量提高: 同一硬件可以服务更多并发用户,从而提高系统利用率。
- 降低运营成本: 更低的内存占用直接转化为基于云的 LLM 服务费用的大幅节省。
第三,投影共享与头部共享技术(GQA/MQA)的 互补性证明 ,为实现前所未有的效率水平开辟了道路。Q-GQA-8 和 Q-MQA 等组合方法分别实现了 88% 和 98.5% 的缓存缩减,且质量损失可控。这种协同效应对于 边缘部署和设备端 AI 尤为重要,它使得大型语言模型能够在资源受限的硬件(例如移动设备或物联网平台)上高效运行,而这在以前由于内存需求高而难以实现。
第四,该研究为查询、键和值投影的作用提供了 宝贵的架构见解 。实证研究表明,键和值可以有效地共享表征空间(高余弦相似度、相似的有效排名),这支持了 K=V 约束的有效性。相反,研究强调了查询在建立非对称注意力模式中的独特作用,解释了为什么将查询与键统一(Q=KV)会导致序列任务的性能下降。这种更深入的理解可以指导未来的架构设计,使其不再局限于简单的试错法。
最后,这些优势在参数数量从 3 亿到 12 亿范围内持续扩展,并且观察到更大的模型可能对投影约束更稳健,这表明投影共享的影响在更大的规模(例如,70 亿以上参数)下可能会更加显著。QK=V 时困惑度下降与下游任务性能之间的解耦进一步强化了其应用价值,表明可以在不损失实际模型能力的情况下实现内存节省。
总之,这项工作为注意力机制中模型复杂度与性能的权衡提供了一个实用且原则性的框架,并带来了直接、可量化的推理记忆优势。它使低阶模型能够在性能较低的硬件上运行,从而拓宽了人工智能的普及范围,并为降低人工智能计算的能耗提供了途径,有助于实现更可持续、更普及的人工智能部署。