本文1W字,FlashAttention 巧妙地重新组织计算方式,通过使注意力计算具备I/O感知能力,最大限度减少慢速内存操作,它能更高效地获得与标准注意力机制相同的结果。我们深入探讨了它如何利用平铺技术将数据保存在片上内存,通过分块进行softmax和矩阵乘法运算,避免将庞大的 $n×n$ 矩阵写入全局内存。这一创新将内存使用量降低至与序列长度呈线性关系,在实际应用中还将注意力计算速度提升了数倍。有些内容是翻译自 HuggingFace 和一些论文,更多 LLM 架构文章点击查看:LLM 架构专栏
大模型架构专栏文章阅读指南
欢迎加入大模型交流群:加群链接 https://docs.qq.com/doc/DS3VGS0NFVHNRR0Ru#
公众号【柏企阅文

Transformer 通过注意力机制捕捉数据中的长距离依赖关系,彻底革新了深度学习领域。然而,Transformer 核心的强大自注意力机制却伴随着高昂的计算成本。在这里,以直观且精确的方式探讨注意力机制的工作原理、其计算成本高昂的原因,以及 FlashAttention 算法如何应对这些挑战,从而使注意力计算更快、内存使用更高效。

理解注意力

机器学习中的注意力机制是一种对输入不同部分的信息进行加权和组合的方法。在 Transformer(例如著名的 “Attention Is All You Need” 模型)中,使用的是缩放点积注意力(scaled dot - product attention)。其核心是,注意力函数将一个查询(query)和一组键值对(key - value pairs)映射为一个输出。输出是值的加权和,每个值的权重由查询与相应键的匹配程度决定。简单来说,查询是在请求信息,键表明每条信息的主题,而值则是实际的信息内容。注意力机制会找出与查询最 “相关” 的键,并突出显示对应的信息值。

举个例子,想象一个检索系统:当你在 YouTube 上搜索视频时,搜索文本就是查询,每个视频的标题或描述是键,而视频本身就是值。系统会找出与查询最相似的键,并返回相应的视频。在 Transformer 的自注意力机制中,这个过程基于学习到的向量表示:序列中的每个位置(比如句子中的每个单词)都会生成一个查询向量,所有位置也会生成键向量和值向量。每个位置(查询)会查看其他所有位置的键,以决定给予多少注意力,然后相应地对值进行加权求和。这使得模型能够针对每个元素,聚焦于序列中最相关的部分。

从数学角度来看,缩放点积注意力通常表示为:
$$Attention(Q,K,V)=Softmax(\frac{QK^T}{\sqrt{d_k}})V$$

其中:

  • $Q$、$K$、$V$ 分别是查询、键和值的矩阵(矩阵的每一行对应序列中的一个位置)。
  • $d_k$ 是用于缩放的键向量的维度。

下面我们详细拆解一下它的工作原理:

  1. 计算 $QK^T$ :我们计算查询和键之间的点积,这一步衡量了每个查询与每个键之间的相似程度,得到的分数显示了每个序列位置之间的关注程度。
  2. 按 $\sqrt{d_k}$ 缩放:将点积结果除以键维度的平方根,这一步是为了稳定数值,防止随着 $d_k$ 的增加,数值变得过大。
  3. Softmax 操作
    • 对缩放后的分数应用 Softmax 函数。
    • Softmax 函数将分数转换为概率分布。
    • 确保每个查询在所有键上的注意力权重总和为 1。
  4. 乘以 $V$
    • 注意力权重用于计算值向量的加权和。
    • 每个查询得到所有值向量的加权组合。
    • 与查询更相似的键所对应的值,会被赋予更高的权重。

在实际应用中,Transformer 通常会并行使用多个注意力 “头”(即多组不同的 $Q$、$K$、$V$,每组都有不同的学习投影),然后将结果进行组合。这种多头注意力机制允许模型同时关注不同类型的信息,但每个头内部的基本计算都是上述的缩放点积注意力。

为什么注意力计算成本高昂?

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

问题出在对这些矩阵的运算上。如果序列长度为 $n$,向量维度为 $d$,那么矩阵 $QK^T$ 的大小为 $n \times n$。计算这个矩阵需要大约 $n^2 \times d$ 次乘加运算,而生成最终的输出 $Softmax(QK^T)V$,还需要将一个 $n \times n$ 的矩阵与 $V$ 矩阵相乘(又大约需要 $n^2 \times d$ 次运算)。用大 O 符号表示,注意力机制的时间复杂度为 $O(n^2 \times d)$,对于长序列来说,这个计算量是巨大的。在内存方面,存储 $n \times n$ 的注意力矩阵(甚至只是存储权重)需要 $O(n^2)$ 的空间。总之,自注意力机制的计算和内存需求会随着序列长度的增加呈二次方增长。也就是说,如果将序列长度翻倍,注意力计算所需的时间会变为原来的四倍,内存使用量也会变为原来的四倍。

从实际角度来看,对于长度 $n = 1024$ 的序列(这在语言模型中很常见),注意力分数矩阵有超过一百万个元素。如果序列长度为 4096,生成的矩阵元素约为 1600 万。处理如此大规模的矩阵成为了一个主要瓶颈。尽管点积注意力因其可以借助矩阵乘法库进行优化而被选用,但长序列带来的巨大规模仍是一个严峻的问题。正如我们接下来会看到的,这种二次方的缩放特性对模型的训练和部署都有重要影响。

注意力机制面临的挑战

标准注意力机制的二次方复杂度带来了几个关键挑战:

  1. 内存使用(二次空间复杂度):随着序列长度 $n$ 的增加,存储中间注意力结果(分数或概率的 $n \times n$ 矩阵)所需的内存会以 $n²$ 的速度增长。对于长序列而言,这很容易耗尽现代 GPU 的内存。例如,长度为 16k 的序列会生成一个包含 2.56 亿个元素的注意力矩阵——即使采用半精度(每个元素 2 字节)存储,单层就需要约 0.5GB 的内存!由于这些注意力激活操作,训练大型 Transformer 模型时经常会遇到内存限制问题。
  2. 计算时间(二次时间复杂度):$O(n²d)$ 的时间成本意味着,对于长输入,注意力计算会变得非常缓慢。尽管 Transformer 能够有效地并行计算,但整体计算量仍会随着 $n²$ 急剧增加。这在大规模模型的训练过程中成为了时间瓶颈——对于长序列或多层模型,很大一部分训练时间可能都花在了注意力计算上。研究人员发现,尽管自注意力机制的并行性比循环网络更好,但在处理长序列时,其运行时间的增长情况却比循环网络更糟糕。在实际应用中,这要么限制了模型可行的上下文长度,要么迫使用户接受更长的运行时间。
  3. 内存带宽和 I/O 瓶颈:现代 GPU 具有分层内存结构。它们拥有高带宽内存(HBM),例如 NVIDIA A100 上有 40GB 的 HBM,数据传输速度约为 1.5 - 2.0TB/s;同时还有容量小得多的片上 SRAM(共享内存/缓存),A100 上每个流式多处理器的片上 SRAM 约为 192KB,总计约 20MB,但其吞吐量极高(聚合可达约 19TB/s)。标准注意力算法无法在如此小的片上内存中完成全部计算操作,因此需要依赖向 HBM 读写中间结果。每一步计算(计算 $QK^T$、应用 softmax、乘以 $V$)都涉及大量数据在 HBM 和其他内存之间的传输。这就产生了问题,因为许多深度学习操作受内存带宽限制,而非计算能力限制——GPU 花在数据传输上的时间比进行数学计算的时间更多。实际上,对于注意力计算来说,内存访问往往是速度的主要限制因素。由于计算核心(张量核心)经常需要等待从较慢的内存中加载数据,其利用率并未得到充分发挥。
  4. 训练大型模型时的 GPU 限制:由于上述问题,训练包含长序列的大型 Transformer 模型面临实际瓶颈。注意力矩阵会占用大量内存,导致批量大小不得不减小,或者采用梯度检查点技术(这会降低训练速度)。长上下文(例如数千个词元)会使每个训练步骤明显变慢,这就是为什么以往模型训练时的上下文长度通常较为有限(BERT 为 512,GPT - 3 为 2048 等)。要将上下文长度扩展到 8k、16k 甚至更长,要么需要巨大的计算资源,要么需要新技术的支持。虽然有人提出了一些近似方法(如稀疏或截断注意力模式)来降低计算复杂度,但这些方法往往会牺牲模型质量,并且实现起来比较复杂。值得注意的是,许多近似注意力方法在实际应用中并没有显著提高运行速度。额外的开销或并行性的降低常常抵消了理论上的优势。因此,迫切需要一种高效且精确的注意力机制,既能缓解内存和时间瓶颈,又不会影响模型的准确性。

总之,标准注意力机制虽然强大,但在扩展性方面表现不佳。其 $O(n²)$ 的特性给内存和计算带来了很大压力,尤其是在那些计算能力充足但快速内存有限的硬件上。这就为 FlashAttention 算法的出现奠定了基础,它正是为克服这些瓶颈而设计的。

FlashAttention 算法简介

FlashAttention 是由 Tri Dao 及其同事在 2022 年提出的一种新算法。它重新思考了注意力的计算方式,在保证计算结果与标准注意力完全相同的前提下,显著提高了计算效率。换句话说,FlashAttention 是一种优化算法,而非近似算法。“Flash”(闪存)这个名字暗示了速度,更确切地说,它强调的是对快速内存的利用(就像相机闪光灯存储光线一样):FlashAttention 是一种考虑 I/O 的算法,能够在注意力计算过程中最大限度地减少慢速内存 I/O 操作。

FlashAttention 的核心思想是通过平铺(tiling)和融合(fuse)注意力计算过程,减少不必要的内存读写操作。回顾一下标准注意力机制,它会先计算完整的 $n \times n$ 分数矩阵(并将其写入 GPU 内存),然后读取该矩阵以应用 softmax 函数,再写入概率值,最后再次读取概率值并与 $V$ 矩阵相乘。而 FlashAttention 则将计算过程分解为多个可管理的块,这样就无需将这些中间结果完整地写入慢速内存。它利用 GPU 的片上 SRAM(共享内存),每次处理一个数据块——本质上是为序列的一个子集计算部分注意力结果,并在计算过程中累积输出。通过这种方式,它避免了存储整个注意力矩阵,在任何时刻都只将小块数据保留在快速内存中。

更具体地说,FlashAttention 会将查询、键和值划分为多个数据块(例如,根据硬件配置,每次处理 128 个查询和 128 个键)。它将一个查询块 $Q$、对应的键块 $K$ 和值块 $V$ 加载到高速片上内存中,计算它们的点积(即子矩阵 $QK^T$),对该块数据应用 softmax 归一化处理,然后立即与值块 $V$ 相乘,得到部分输出。接着,它会处理下一个数据块,依此类推。该算法经过精心设计,确保最终结果与一次性完成完整注意力计算的结果完全相同——通过对每个块进行适当的归一化调整并累积贡献,正确计算出整个序列上的 softmax 值。

通过不生成完整的注意力矩阵(也不重复读写该矩阵),FlashAttention 将注意力计算的内存使用从 $O(n²)$ 降低到了 $O(n)$。实际上,它用一些额外的浮点运算(为每个块重新计算部分 softmax 和)换取了更少的内存操作,在现代 GPU 上,这是非常划算的。有研究结果表明,FlashAttention 在实际运行时间上比标准注意力计算快 2 - 4 倍,且不会降低计算精度。事实上,实验显示 FlashAttention 的运行速度比标准的 PyTorch 注意力计算快 2 - 4 倍,内存使用量仅为标准方法的 5% - 20%。这大大减少了内存占用——例如,原本需要 1GB 内存的计算任务,使用 FlashAttention 后可能只需要 50 - 200MB 内存,这使得我们可以采用更大的批量大小或处理更长的序列。

另一种理解 FlashAttention 的方式是:它执行的数学运算($QK^T$、softmax、$PV$)与标准注意力计算相同,但它将这些运算整合在一个融合的 GPU 内核中,以数据块的形式流式处理数据。与多个内核分别向内存写入大量输出不同,FlashAttention 的一个内核就能完成所有操作,并尽可能将数据保存在寄存器或共享内存中。如果不编写自定义 CUDA 代码,这种级别的优化在标准深度学习库中是无法实现的,而 FlashAttention 的开发者正是利用了平铺和重新计算 softmax 规范等技术,实现了这一优化。

FlashAttention 的工作机制:深入剖析

为了理解 FlashAttention 是如何实现性能提升的,我们深入探讨一下它的工作机制和内存访问模式,不过不会涉及底层代码细节。其主要原理是在适合快速内存的数据块中进行处理。GPU 拥有少量片上内存(比如每个流处理器上的 L1/共享内存和 L2 缓存),这些片上内存的速度比大容量的 HBM(GPU 显存)快得多。FlashAttention 通过拆分注意力计算过程,巧妙地利用了这一特性。

FlashNote的阻塞计算。注意力是在完全片上(SRAM)的块(橙色虚线方块)中计算的,避免在HBM(蓝色虚线框)中实现满分矩阵。在本例中,键/值被分成两个块K^(1),V^(1)和K^(2),V^(2)。查询矩阵Q与第一个块交互以产生部分分数S^(1)=Q(K^(1))^T,这些分数被软最大化(产生P^(1))并与V^(1)相乘以贡献输出。然后它类似地处理第二个块S^(2)=Q(K^(2))^T。每个块的输出被组合(通过适当的重新缩放以考虑跨块的softmax归一化)以给出最终的关注结果,与标准计算相同

如上图所示,在整个计算过程中,该算法无需在内存中保存完整的 $S = QK^T$ 或 $P = Softmax(S)$ 矩阵——它每次只处理 $S$ 矩阵的一列或几列数据块。重要的是,即使 FlashAttention 是分块计算 softmax 的,它也能确保在整行(查询)上正确计算 softmax 值。它通过在处理每个数据块时,跟踪每个查询的运行最大值和指数之和(即 “softmax 累加器”)来实现这一点。这样,当所有数据块都处理完毕后,每个查询的 softmax 值就如同一次性看到了所有键一样准确。这种技术属于在线 softmax 计算的一种形式,并且它结合了缩放和归一化步骤,以保持数值稳定性(例如,在指数运算前减去最大值,每个数据块都进行这样的操作,并调整全局最大值)。

从内存角度来看,FlashAttention 从 HBM 读取查询、键和值向量,并将输出写回 HBM,但它不会写入大型中间矩阵。这极大地减少了内存流量。研究表明,对于典型的数据规模,FlashAttention 的 HBM 内存访问次数比标准方法最多可减少 9 倍。由于在现代加速器中,内存带宽往往是性能瓶颈,减少内存访问次数就能相应地提高计算速度。实际上,有分析指出,在 A100 GPU 上,HBM 访问是影响注意力计算运行时间的主要因素,通过减少这些访问,FlashAttention 能够实现更快的执行速度。

值得注意的是,像 A100 或 H100 这样的 GPU 具有极高的计算吞吐量(数百个 TFLOP),但由于内存限制,在注意力计算等工作负载中,这些计算能力往往无法得到充分利用。FlashAttention 通过将注意力计算重新设计为计算密集型(每个数据块重新进行一些计算)、减少内存密集型操作,让 GPU 能够进行更多的数学运算,减少等待内存数据的时间。这有效地释放了 GPU 的更多峰值性能。例如,NVIDIA H100 上的 FlashAttention - 2(FlashAttention 的后续迭代版本)能够达到 GPU 理论 FLOP/s 利用率的 75%左右,而原始方法仅为 35%左右——这充分证明了更好的内存使用方式能够转化为更高的计算利用率。

权衡与局限性

FlashAttention 主要的权衡在于其实现复杂度有所增加。它需要自定义内核或库的支持——不过好在目前它已经集成到了流行的深度学习框架和 Transformer 库中。因此,终端用户通常默认就能享受到它带来的优势。此外,由于每个数据块都需要重新计算部分 softmax 等操作,算术运算量会略有增加,但与内存效率提升带来的收益相比,这种开销是微不足道的。在硬件方面,FlashAttention 在那些具有一定量片上内存,且 SRAM 和 HBM 之间带宽差异较大的 GPU 上效果最为显著(目前所有英伟达的新型 GPU 都符合这一条件)。如果序列长度非常短(例如只有几十个词元),二次方计算成本本来就不高,此时可能并不需要使用 FlashAttention——不过在这种情况下,除了少量的内核启动开销外,使用它也没有太多弊端。另一个需要考虑的因素是数值稳定性:FlashAttention 在跨数据块对指数进行求和时,必须小心避免精度问题。该算法采用了保持最大值和缩放等技术来确保稳定性,开发者表示其精度与标准注意力计算相同。

FlashAttention:性能总结

FlashAttention 在多个维度上为 Transformer 模型带来了显著的效率提升:

  • 速度提升:将注意力计算时间缩短 2 - 4 倍。例如,GPT - 2 在训练时(序列长度为 1K),训练吞吐量提升了 3 倍;BERT - large 模型(序列长度为 512)端到端速度提升了 15%。而且,序列长度越长,性能提升越明显。
  • 内存效率:与传统注意力机制相比,内存使用量仅为其 5% - 20%。这使得在不受到 GPU 内存限制的情况下,可以采用更大的批量大小,处理更长的序列。
  • 硬件扩展性:在配备 FlashAttention - 2 的 H100 GPU 上,性能比 A100 上的经典注意力计算快约 6 倍。在某些配置下,配备 FlashAttention - 2 的 8 个 H100 GPU 比单个 A100 基线实现了 47 倍的加速。
  • 训练成本:显著降低了大型模型的训练成本。据估计,训练 GPT - 3(1750 亿参数)模型所需的 GPU 小时数可减少至 24.2 万,约为之前成本的十分之一。
  • FlashAttention - 2 的改进:在原始版本的基础上,FlashAttention - 2 优化了平铺策略和工作分区方式,通过更好地利用 GPU 并行性和减少内核启动开销,实现了高达 2 倍的更高吞吐量。这使得它在处理长序列和大型模型时更加有效。
  • FlashAttention - 3 的进展:FlashAttention - 3 利用混合精度(例如支持 FP8)以及像 H100 GPU 上的张量核心加速等硬件特定优化,进一步提升效率,计算能力超过 1 PFLOP/s,释放了前沿加速器的峰值性能。
  • 更长的上下文窗口:支持 32K - 100K +词元的实际上下文长度,与之前 2 - 4K 的限制相比有了大幅提升。

FlashAttention 已在主要的深度学习框架中被广泛采用,如今已成为 Transformer 实现中的标准配置。

结论

注意力机制通过实现灵活的信息整合,为神经网络带来了新的能力,但这是以高昂的计算成本和内存使用为代价的。一开始,我们了解了自注意力机制的工作原理:通过查询、键和值,让输入的每个部分都能关注其他部分。我们也看到,尽管这种机制很强大,但随着序列长度的增加,其扩展性较差(时间和内存复杂度均为二次方)。随着模型和数据集规模的不断扩大,这一问题愈发凸显。

FlashAttention 证明了巧妙地重新组织计算方式能够克服这些瓶颈。通过使注意力计算具备 I/O 感知能力,最大限度减少慢速内存操作,它能更高效地获得与标准注意力机制相同的结果。我们深入探讨了它如何利用平铺技术将数据保存在片上内存,通过分块进行 softmax 和矩阵乘法运算,避免将庞大的 $n×n$ 矩阵写入全局内存。这一创新将内存使用量降低至与序列长度呈线性关系,在实际应用中还将注意力计算速度提升了数倍。

其更广泛的影响意义重大。更快的注意力计算意味着我们可以更快地训练更大的模型,或者训练相同的模型用时更短。对于语言模型来说,这意味着更长的上下文窗口,使它们能够处理诸如长文档摘要或数千词元提示等任务,而这些任务此前要么不可行,要么处理速度非常慢。这也体现了深度学习基础设施中的一个普遍主题:随着模型规模的增长,单纯的算法复杂度并非唯一需要关注的问题,如何利用硬件(内存层次结构、并行性)同样至关重要。

FlashAttention 表明,通过优化内存访问,可以在完全不改变模型质量的情况下显著提升性能。在我们朝着越来越大的 AI 模型迈进的过程中,效率将决定实际可实现的目标,这一点尤为重要。

FlashAttention 的发展并未止步于最初的版本。FlashAttention - 2 已被广泛应用,通过更好的并行化和内存访问模式带来了更大的改进。FlashAttention - 3借助最新的 GPU 功能进一步突破极限,在 H100 GPU 上利用混合精度实现了超过 1 PFLOP/s 的计算能力。这些迭代版本如今已成为现代 Transformer 框架的标准组件,使得上下文长度从 2 - 4K 大幅扩展到超过 100K 词元。

总之,注意力机制或许是实现出色模型性能的关键,但像 FlashAttention 这样的算法,才是让这种性能变得高效且可扩展的必备要素。通过理解并接纳这些技术进步,我们能够训练出此前无法实现的下一代模型,同时保持注意力机制的精妙与强大。随着模型规模的持续增长,这些效率创新对于决定人工智能的发展边界将愈发关键。

推荐阅读

1. DeepSeek-R1 的顿悟时刻是如何出现的? 背后的数学原理
2. 微调 DeepSeek LLM:使用监督微调(SFT)与 Hugging Face 数据
3. 使用 DeepSeek-R1 等推理模型将 RAG 转换为 RAT
4. DeepSeek R1:了解 GRPO 和多阶段训练
5. 深度探索:DeepSeek-R1 如何从零开始训练
6. DeepSeek 发布 Janus Pro 7B 多模态模型,免费又强大!