My Work Log implementing Flash Attention using CuTe DSL.
1. The Math
2. The Algorithms
Algorithm 1: FlashAttention-2 Forward PassRequire: Q,K,V∈RN×d in HBM, block sizes Bc,Br1.Tr=⌈N/Br⌉,Tc=⌈N/Bc⌉Divide Q into blocks Q1,…,QTr of size Br×dDivide K,V into blocks K1,…,KTc,V1,…,VTc of size Bc×d2.Divide O∈RN×d into Tr blocks of size Br×dDivide logsumexp L into Tr blocks of size Br3.for 1≤i≤Tr do4.Load Qi from HBM to on-chip SRAM5.Initialize Oi(0)=0Br×d,ℓi(0)=0Br,mi(0)=(−∞)Br6.for 1≤j≤Tc do7.Load Kj,Vj from HBM to on-chip SRAM8.Si(j)=QiKj⊤∈RBr×Bc9.mi(j)=max(mi(j−1),rowmax(Si(j)))P~i(j)=exp(Si(j)−mi(j))(pointwise)ℓi(j)=emi(j−1)−mi(j)ℓi(j−1)+rowsum(P~i(j))10.Oi(j)=diag(emi(j−1)−mi(j))−1Oi(j−1)+P~i(j)Vj11.end for12.Oi=diag(ℓi(Tc))−1Oi(Tc)13.Li=mi(Tc)+log(ℓi(Tc))14.Write Oi to HBM as the i-th block of O15.Write Li to HBM as the i-th block of L16.end for17.return O,L