2017年,《Attention is All You Need》一文发表,向世界展示了Transformer模型可以依靠注意力(Attention)层取得优异的性能。八年后,我们见证了这些模型借助注意力的力量通过了图灵测试以及其他更多成就。尽管注意力非常强大,但它也有一定的代价。随着输入变长,计算注意力所需的内存呈二次方增长。这种内存的增加会带来多种后果,我们将重点关注其对硬件的影响。

随着需要处理的内存增多,我们最终会在GPU内部的内存传输上遇到瓶颈。在Tri Dao提出Flash Attention(一种内核融合技术,使运行注意力计算在内存使用上显著提高效率)之前,我们曾在这个问题上停滞了一段时间。这一见解可能是大语言模型(LLM)推理中最重要的优化。

详细阐述这一发现背后的原理。《From Online Softmax to FlashAttention》

分块矩阵乘法的结合律

我们将从一个非常简单的GPU计算开始:矩阵乘法。矩阵乘法要求我们迭代地将操作矩阵的行和列读入内存。下面的简单示例展示了我们如何将第一个矩阵的行读入内存,以便与第二个矩阵的列相乘。

矩阵乘法示例

下面的代码以一种简单直接的方式实现了这一点,它会反复将相同的行或列读入内存,导致较高比例的缓存未命中,从而使计算速度变慢。

import numpy as np
def naive_matmul(A, B):
    n, m = A.shape
    m2, p = B.shape
    assert m == m2, "维度不兼容"
    C = np.zeros((n, p))
    for i in range(n):
        for j in range(p):
            for k in range(m):
                C[i, j] += A[i, k] * B[k, j]
    return C

为了提高矩阵乘法中的缓存性能,我们使用分块矩阵乘法。这种技术将矩阵分解为适合CPU缓存的较小子矩阵(分块)。我们不是按顺序计算乘积,而是一次处理一个分块,尽可能多地在更快的内存中重用数据。通过选择合适的分块大小,我们可以显著减少缓存未命中,并加快计算速度。下面的Python示例演示了这种方法。

import numpy as np
def tiled_matmul(A, B, tile_size):
    n, m = A.shape
    m2, p = B.shape
    assert m == m2, "维度不兼容"
    C = np.zeros((n, p))
    for ii in range(0, n, tile_size):
        for jj in range(0, p, tile_size):
            for kk in range(0, m, tile_size):
                for i in range(ii, min(ii + tile_size, n)):
                    for j in range(jj, min(jj + tile_size, p)):
                        for k in range(kk, min(kk + tile_size, m)):
                            C[i, j] += A[i, k] * B[k, j]
    return C

分块矩阵乘法带来的速度提升取决于你的硬件和矩阵的大小。对于我们的机器学习用例,你可以期望看到20倍到100倍的速度提升。所以,当涉及到优化矩阵计算(注意力是矩阵计算的一种特殊形式)时,关键在于采用分块矩阵乘法。

那么,为什么在注意力计算中很难做到这一点呢?虽然我们可以在Q和K矩阵之间进行分块矩阵乘法,但在进行最后一次矩阵乘法之前,我们需要对得到的矩阵进行Softmax操作。因此,优化注意力的一个关键步骤是弄清楚如何处理这个Softmax操作。首先,我们需要了解这个方程所涉及的复杂性。

SoftMax和溢出问题

Softmax通过使较大的数字更突出,较小的数字不那么突出,将一个数字列表转换为概率。它通过对每个数字取指数来扩大差异,然后将每个结果除以所有这些指数的总和,使所有结果相加等于1。

有趣的是,这个方程对我们的浮点数所能表示的最大值范围造成了压力。16位浮点数所能表示的最大值是65,504,所以如果x大于11,这将导致溢出并破坏我们的计算。

为了解决这个问题,我们在张量中找到最大值(m),并从指数中减去这个值。这确保了x永远不会大于浮点数所能表示的范围,但它确实带来了下溢错误的风险。值得庆幸的是,因为在Softmax中非常小的数字会变为0,这实际上变成了一个舍入误差,不会影响我们的计算。

为什么安全Softmax和Softmax是一样的呢?我们实际上是将分数的分子和分母都乘以$e^{-m}$,这就相当于乘以1。

3-Pass Safe SoftMax

为了有效地计算Softmax,我们可以将其分为三个步骤。因为我们要对输入张量迭代三次,所以我们称之为三趟方法。

我们首先找到最大值(max_val)以处理溢出问题。然后,我们使用max_val计算所有的指数,同时将这些值累加到总和中。最后,我们使用总和对所有的指数进行归一化。下面的Python代码展示了这一过程:

import numpy as np
def softmax_3pass(input_array):
    n = len(input_array)
    output = np.zeros(n, dtype=float)
    
    # 第一趟:找到最大值
    max_val = input_array[0]
    for i in range(1, n):
        if input_array[i] > max_val:
            max_val = input_array[i]
    
    # 第二趟:计算exp(x - max)并求和
    sum_val = 0.0
    for i in range(n):
        output[i] = np.exp(input_array[i] - max_val)
        sum_val += output[i]
    
    # 第三趟:归一化
    for i in range(n):
        output[i] /= sum_val
    
    return output

虽然这个方法是正确且易于理解的,但它效率不高,并且仍然阻碍我们在整个注意力计算中使用分块矩阵乘法。我们希望使这个方法更高效,并减少所需的趟数。

2-Pass Safe Softmax

为了进行循环融合,我们需要找到一种方法,在进行某一趟计算时捕获另一趟计算中发现的信息。如果我们关注第一趟循环,我们会发现我们不一定需要张量中的绝对最大值,只需要一个足够大的值来防止到目前为止的溢出即可。

通过只寻找局部最大值而不是全局最大值,我们可以将第一趟和第二趟融合在一起。然而,为了确保我们仍然在两边都只乘以1,每当我们找到一个新的最大值时,我们需要缩放分母。这有一些微妙之处,所以我将在下面详细介绍第一趟中的if语句。

import numpy as np
def softmax_online(input_array):
    n = len(input_array)
    output = np.zeros(n, dtype=float)
    
    # 用第一个元素初始化运行最大值
    m = input_array[0]
    # 运行总和(初始值为e^(x_0 - m_0) = 1.0)
    d = 1.0
    
    # 预计算以计算最终的最大值和总和
    for i in range(1, n):
        if input_array[i] > m:
            # 当找到新的最大值时调整总和
            d = d * np.exp(m - input_array[i]) + 1.0
            m = input_array[i]
        else:
            # 将这个元素的贡献加到总和中
            d += np.exp(input_array[i] - m)
    
    # 最后一趟计算Softmax输出
    for i in range(n):
        output[i] = np.exp(input_array[i] - m) / d
    
    return output

一旦我们知道存在一个比我们当前所使用的缩放值更大的最大值,我们就知道需要调整分母来考虑这一点。我们将所有先前的值乘以旧最大值减去新最大值的指数。这会抵消所有旧的最大值,并将其替换为新的最大值。然后我们需要添加最新的缩放值,即0的指数,也就是1。

我们称这种两趟Softmax为在线Softmax,因为我们在计算过程中确定全局统计信息(分母和最大值)。正是这种在线的见解引导我们走向了Flash Attention!事实上,我们可以找到一种方法将Flash Attention减少到一趟!

Flash Attention

请记住,注意力计算可以分为三个步骤。首先,我们在查询(Queries)和键(Key)之间进行矩阵乘法。然后,我们对该值进行Softmax操作以得到我们的注意力模式。最后,我们将注意力模式与值(Values)进行矩阵乘法以得到我们的输出。

我们已经看到,矩阵乘法可以通过分块矩阵乘法来加速,并且Softmax可以通过在线Softmax减少到两趟。在第一趟中,我们对Q和K进行矩阵乘法,并确定Softmax的最大值X以及分母(基本上是进行Softmax的第一趟计算)。第二趟则完成我们的Softmax(构建注意力模式A),并进行矩阵乘法以得到输出。

现在,如果我们能够即时计算出A的值,那么我们就可以将这两趟合并。虽然当前A的公式依赖于全局统计信息,但我们可以使用在线Softmax中的相同技巧,将A的每个元素重写为局部统计信息的结果。然后,当我们找到一个新的最大值时,像之前一样缩放这些值。

下面展示了实现这一点的新公式:

为了展示这个算法可能的实现细节,这里有一个简化版本:

import numpy as np
def flash_attention(Q, K, V, k):
    """
       参数:
    Q: 查询矩阵
    K: 键矩阵(在计算中进行转置)
    V: 值矩阵
    k: 查询的行索引
    
    返回:
    处理后的输出向量O[k,:] - 相当于softmax(Q[k,:] @ K) @ V
    """
    N = K.shape[1]  # 从K矩阵获取维度
    
    # 初始化变量
    m_i_minus_1 = float('-inf')  # m_{i-1}的初始值
    d_i_minus_1 = 0.0  # d'_{i-1}的初始值
    o_i_minus_1 = np.zeros_like(V[0, :])  # o'_{i-1}的初始值
    
    for i in range(N):
        # 使用Q的第k行和K的第i列计算x_i
        x_i = np.dot(Q[k, :], K[:, i])
        
        # 更新最大值
        m_i = max(m_i_minus_1, x_i)
        
        # 计算d'_i
        d_i = d_i_minus_1 * np.exp(m_i_minus_1 - m_i) + np.exp(x_i - m_i)
        
        # 计算o'_i
        o_i = (o_i_minus_1 * d_i_minus_1 * np.exp(m_i_minus_1 - m_i) / d_i) + (np.exp(x_i - m_i) / d_i) * V[i, :]
        
        # 更新上一次的值以用于下一次迭代
        m_i_minus_1 = m_i
        d_i_minus_1 = d_i
        o_i_minus_1 = o_i
    
    # 结果是o'_N
    return o_i_minus_1

参考文献:
[1] Ye, Z.,《From Online Softmax to Flash Attention》(2023), 华盛顿大学
[2] Xu, P.,《Tiled Matrix Multiplication》(2019)