标准自注意力计算的时间复杂度与空间复杂度与序列长度呈现2次方关系,因此Transformer在长序列上处理速度很慢且需要大量内存。同时,随着硬件的进步,计算能力已经超过了内存的读写能力,即内存的读写限制了注意力的计算。FlashAttention是一个考虑内存读写的精确注意力计算算法,通过分片的方式减少了GPU中HBM与SRAM之间的读写次数,从而提高注意力计算的速度与内存效率。如图1左所示,GPU内存层级中SRAM与HBM的读写速度。

图1 FlashAttention与基于Pytorch实现注意力的加速

在进行自注意力计算时,为了减少从HBM读写的次数,需要解决:

  • 在不知道整体输入的情况下计算softmax。
  • 不要存储大的用于反向传播的注意力过渡矩阵。

对于第一个挑战,作者们通过把整个输入分成blocks以重构注意力计算,且创建了输入blocks的通道,最终实现了增量式的softmax缩放计算;对于第二个挑战,作者们存储了前向过程中softmax标准化因子,用于backward pass的注意力过渡矩阵再计算,从而不需要存储注意力过渡矩阵就能够实现梯度的计算,这种方式可视为selective gradient checkpointing。如图1右所示,即使重新计算增加了计算量,但是运行速度与内存效率都得到了很大的提高。FlashAttention的性能:

  • 高效的模型训练
  • 高质量模型
  • Benchmarking Attention

标准注意力计算

给定输入序列$\mathbf{Q},\mathbf{K},\mathbf{V}\in\mathbb{R}^{N\times d}$,$N$为序列长度,$d$为注意力头的维度那么注意力计算的输出为
$$
\begin{aligned}
\mathbf{S}=\mathbf{Q}\mathbf{K}^{T}\in\mathbb{R}^{N\times N},\quad\mathbf{P}=softmax(\mathbf{S})\in\mathbb{R}^{N\times N},\quad \mathbf{O}=\mathbf{P}\mathbf{V}\in\mathbb{R}^{N\times d}
\end{aligned}\tag{1}
$$
标准注意力计算的伪代码可见算法0

FlashAttention

给定HBM中输入序列$\mathbf{Q},\mathbf{K},\mathbf{V}\in\mathbb{R}^{N\times d}$,把其分为blocks,再加载到SRAM。然后,对这些blocks,进行计算注意力输出,可见算法1所示。

Tiling:按块计算注意力,为了降低注意力计算过程中出现$NaN$情况的风险,利用缩放技术分解大的softmax。那么,对于向量$x\in\mathbb{R}^B$,其softmax计算为
$$
\begin{aligned}
m(x):=\underset{i}{max}\quad x_i,\quad f(x):=[e^{x_1-m(x)}\ldots e^{x_{B}-m(x)}] \\
l(x):=\sum_{i}f(x)_{i},\quad softmax(x):=\frac{f(x)}{l(x)}
\end{aligned}\tag{2}
$$
其中,缩放技术就是每个元素都与该行最大元素做差。

对于向量$x^{(1)},x^{(2)}\in\mathbb{R}^B$,两个向量的concat的softmax计算为
$$
\begin{aligned}
m(x)=m([x^{(1)}\quad x^{(2)}])=max(m(x^{(1)}), m(x^{(2)})),\\
f(x)=[e^{m(x^{(1)})-m(x)}f(x^{(1)})\quad e^{m(x^{(2)})-m(x)}f(x^{(2)})] \\
l(x)=l([x^{(1)}\quad x^{(2)}])=e^{m(x^{(1)})-m(x)}l(x^{(1)})+e^{m(x^{(2)})-m(x)}l(x^{(2)}),\\ softmax(x)=\frac{f(x)}{l(x)}
\end{aligned}\tag{3}
$$
由此,每次只需要计算一个block的softmax,不断的进行$l$与$m$的统计,避免了HBM与SRAM之间重复的数据传递。

综上所述,FlashAttention只需要进行一次从HBM读取数据到SRAM,然后进行所有计算,最后把计算结果写入HBM,从而大大提高了内存效率。

引用方法

请参考:

            
                li,wanye. "FlashAttention:快速且高效的精确注意力计算". wyli'Blog (Apr 2024). https://www.robotech.ink/index.php/archives/439.html            
        

或BibTex方式引用:

            
                @online{eaiStar-439,
   title={FlashAttention:快速且高效的精确注意力计算},
   author={li,wanye},
   year={2024},
   month={Apr},
   url="https://www.robotech.ink/index.php/archives/439.html"
}

标签: Attentions

添加新评论