FlashAttention-2:利用有效并行化与分片机制实现高效注意力
FlashAttention利用非对称GPU显存层级的特性不仅提高了内存效率,也提高了训练速度。然而,随着上下文长度的增加,它没有优化的GEMM运算一样快,且只达到了理论最大FLOPs/s的25-40%。这种不高效主要是由GPU中不同线程块与线程束之间次优的work分片导致的低显存占有率或不必要的共享内存读写所引起的。为了处理这些问题,FlashAttention-2设计了更好的woker分片。
FlashAttention利用非对称GPU显存层级的特性不仅提高了内存效率,也提高了训练速度。然而,随着上下文长度的增加,它没有优化的GEMM运算一样快,且只达到了理论最大FLOPs/s的25-40%。这种不高效主要是由GPU中不同线程块与线程束之间次优的work分片导致的低显存占有率或不必要的共享内存读写所引起的。为了处理这些问题,FlashAttention-2设计了更好的woker分片。
标准自注意力计算的时间复杂度与空间复杂度与序列长度呈现2次方关系,因此Transformer在长序列上处理速度很慢且需要大量内存。同时,随着硬件的进步,计算能力已经超过了内存的读写能力,即内存的读写限制了注意力的计算。FlashAttention是一个考虑内存读写的精确注意力计算算法,通过分片的方式减少了GPU中HBM与SRAM之间的读写次数,从而提高注意力计算的速度与内存效率。
自回归解码器推理的成本很高,这是因为每个解码步骤加载解码器权重和所有注意力的keys与values的内存带宽很高。多查询注意力MAQ利用多个查询头但只有一个键与值,因此内存带宽的需求大大降低。然而,MQA会导致模型质量退化且训练不稳定。而且,为了优化质量与推理速度,单独训练一个模型不可行。
对于标准随机梯度下降,$L_2$正则化与权重衰退正则化的作用是相同。然而,对于自适应梯度下降算法,例如:Adam,这种等效不存在。确切的说,由于大部分深度学习库中正则化利用的是$L_2$,从而导致部分任务中利用带有动量的SGD进行优化产生的性能优于自适应梯度下降算法优化产生的模型。AdamW梯度下降算法通过对权重衰减与学习率设置进行解耦合,从而提升Adam算法的泛化性。
在深度神经网络中,LayerNorm用于帮助稳定化训练且提升模型的拟合能力。这是因为LayerNorm对输入和权重矩阵具有re-centering与re-scaling不变的特性。然而,随着网络加深,尤其是RNN,因计算量越来越大导致LayerNorm带来性能的成本越来越高。由此,在LayerNorm的re-centering不变属于不必要的假设下,RMSNorm基于均方根进行标准化,且使模型拥有re-scaling不变的特性和学习率的隐式自适应能力。
深度神经网络是由线性变形和激活函数构成。其中,激活函数对深度神经网络的训练成功很重要。激活函数ReLU因其简单性和可靠性,而得到了广泛的采用。虽然许多实践者提出了ReLU的替代版,但是这些激活函数对于不同的模型和数据集往往拥有不一致的表现。由此,Searching for Activation Functions作者们利用自动搜索技术,找到了Swish激活函数,其性能不仅优越于ReLU,且表现一致。