In this follow-up post to Nathan Chen's Triton Flash Attention Kernel Walkthrough: The Forward Pass, we dive into gradient computation for queries, keys, and values in the backward pass.
So you’ve seen the forward pass in FlashAttention, and know how to compute the attention output without ever materializing the big $N \times N$ score matrix using tiling and online softmax.
But in deep learning, computing the output is only half the battle. To train the model, we need gradients. We need the backward pass.
Implementing a backward pass for FlashAttention is notoriously more difficult than the forward pass. Why? Because we didn’t save the massive attention score matrix. Instead, we recompute attention scores on the fly using the saved LogSumExp (lse) from the forward pass.
In this walkthrough, we will dissect the backward pass in FLA’s excellent implementation in 3 phases.
Before we jump into the code, let’s explicitly define what we are working with. The kernel is designed to handle Grouped Query Attention (GQA), which made the shapes slightly more complicated.
We operate on the following tensors:
q (Queries): Shape [B, T, HQ, K]. B: Batch size.T: Sequence length.HQ: Number of Query Heads.K: Head dimension for keys/queries.k (Keys): Shape [B, T, H, K]. H: Number of Key/Value Heads. Note that HQ must be a multiple of H.v (Values): Shape [B, T, H, V]. V: Head dimension for values.do (Output’s gradient): Shape [B, T, HQ, V]. Our goal is to produce dq, dk, and dv, matching the shapes of q, k, and v respectively.
Inside the ParallelAttentionFunction, the backward method is automatically called by PyTorch when we call optimizer.step().
@torch.compile
class ParallelAttentionFunction(torch.autograd.Function):
@staticmethod
@contiguous
@autocast_custom_fwd
def forward(ctx, q, k, v, g, scale, cu_seqlens):
ctx.dtype = q.dtype
RCP_LN2: float = 1.4426950216
g_cumsum = chunk_global_cumsum(g, cu_seqlens=cu_seqlens, scale=RCP_LN2) if g is not None else None
o, lse = parallel_attn_fwd(
q=q,
k=k,
v=v,
g_cumsum=g_cumsum,
scale=scale,
cu_seqlens=cu_seqlens,
)
ctx.save_for_backward(q, k, v, o, g_cumsum, lse)
ctx.cu_seqlens = cu_seqlens
ctx.scale = scale
return o.to(q.dtype)
@staticmethod
@contiguous
@autocast_custom_bwd
def backward(ctx, do):
q, k, v, o, g_cumsum, lse = ctx.saved_tensors
dq, dk, dv, dg = parallel_attn_bwd(
q=q,
k=k,
v=v,
o=o,
g_cumsum=g_cumsum,
lse=lse,
do=do,
scale=ctx.scale,
cu_seqlens=ctx.cu_seqlens,
)
if dg is not None:
dg = chunk_global_cumsum(dg, cu_seqlens=ctx.cu_seqlens, reverse=True)
return dq.to(q), dk.to(k), dv.to(v), dg, None, None
The context object ctx gives us back the tensors we saved during the forward pass. In particular, we have o (the forward output) and lse (the Log-Sum-Exp statistics). We pass these, along with the incoming gradient do, into our launcher parallel_attn_bwd.
The launcher function parallel_attn_bwd handles the dispatch. Unlike the forward pass which is usually one big kernel, the backward pass here is split into three distinct operations.
delta.dq) and the log-decay g (dg).dk) and Values (dv).def parallel_attn_bwd(
# arguments omitted here
):
# extracting shapes
B, T, H, K, V = *k.shape, v.shape[-1]
HQ = q.shape[2]
G = HQ // H # group size in GQA
# determine block sizes based on GPU architecture (Hopper vs Ampere)
if check_shared_mem('hopper'):
BT = 128
BS = 64
# ...
# preprocess delta
delta = parallel_attn_bwd_preprocess(o, do)
# prepare output tensors
# dq: [B, T, HQ, K]
# dk: [B, T, HQ, K]
# dv: [B, T, HQ, V]
# since we replicate G copies of keys/values in GQA
# shape of dk and dv might not be the same as k and v, respectively
dq = torch.empty(..., device=q.device)
dk = torch.empty(..., device=q.device)
dv = torch.empty(..., device=q.device)
# launch separate kernels for dQ and dKV
grid = (NV, NT, B * HQ)
parallel_attn_bwd_kernel_dq[grid](
q=q, k=k, v=v, lse=lse, delta=delta, do=do, dq=dq, ...
)
parallel_attn_bwd_kernel_dkv[grid](
q=q, k=k, v=v, lse=lse, delta=delta, do=do, dk=dk, dv=dv, ...
)
# Handle GQA reduction
# We computed gradients for replicated keys/values. Now we sum them up.
dk = reduce(dk, 'b t (h g) k -> b t h k', g=G, reduction='sum')
dv = reduce(dv, 'b t (h g) v -> b t h v', g=G, reduction='sum')
return dq, dk, dv, dg_cumsum
The design choice to split dq and dkv calculation into two kernels (unlike some implementations that fuse them) simplifies the logic regarding causal masking loops and reduces the amount of SRAM needed per thread block.
Before calculating gradients, we need a term delta. To derive the Softmax gradient, there is a row-wise summation term $\sum (O \cdot dO)$, which we pre-compute to avoid constant re-computation in the main loop below. (Note that recomputation of this term is more expensive than recomputing the attentions scores.)
@triton.jit
def parallel_attn_bwd_kernel_preprocess(
o, do, delta,
B: tl.constexpr, V: tl.constexpr,
):
i_n = tl.program_id(0)
o_d = tl.arange(0, B) # B is block size for V dimension
m_d = o_d < V
b_o = tl.load(o + i_n * V + o_d, mask=m_d, other=0)
b_do = tl.load(do + i_n * V + o_d, mask=m_d, other=0).to(tl.float32)
# compute a tile of delta
# delta = sum(output * grad_output)
b_delta = tl.sum(b_o * b_do)
tl.store(delta + i_n, b_delta.to(delta.dtype.element_ty))
This kernel computes the dot product of o and do for every token and stores it in delta. This term is critical for correcting the gradient of the Softmax function.
Now comes the heavy lifting. parallel_attn_bwd_kernel_dq computes gradients with respect to q.
The kernel loads a block of queries q (BT x BK), the incoming gradient do (BT x BV), the denominator lse, and our precomputed delta into SRAM. These stay in registers/SRAM for the duration of the inner loop because we are computing dq for this specific block of time BT.
# [BT, BK] - Query Block
b_q = tl.load(p_q, boundary_check=(0, 1))
# [BT, BV] - Gradient of Output Block
b_do = tl.load(p_do, boundary_check=(0, 1))
# [BT] - LogSumExp and Delta
b_lse = tl.load(p_lse, boundary_check=(0,))
b_delta = tl.load(p_delta, boundary_check=(0,))
# Initialize accumulator for dq
b_dq = tl.zeros([BT, BK], dtype=tl.float32)
We iterate over blocks of Keys (k) and Values (v). This looks very similar to the forward pass, but we are doing it to recover the attention probability matrix $P$.
for i_s in range(0, i_t * BT, BS):
# Ommitted: Load k [BK, BS] and v [BV, BS]
# Recompute Attention Scores [BT, BS]
b_s = tl.dot(b_q, b_k) * scale * RCP_LN2
# Recompute Probability P = exp(Score - LSE)
# We mask out future tokens if causal
b_s = tl.where((o_q[:, None] >= o_k[None, :]) & m_k[None, :], b_s, float('-inf'))
b_p = exp2(b_s - b_lse[:, None])
# Compute Gradient of Attention Scores (dS)
# [BT, BV] @ [BV, BS] -> [BT, BS]
b_dp = tl.dot(b_do, b_v)
# dS = P * (dP - delta)
b_ds = b_p * (b_dp.to(tl.float32) - b_delta[:, None])
# Accumulate Gradient for Query (dQ)
# [BT, BS] @ [BS, BK] -> [BT, BK]
b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k))
The code concisely implements the backward pass of Softmax-Attention.
b_p: Using q, k, and lse, we recompute the attention probabilities on the fly.b_ds: This represents a tile of $\frac{\partial L}{\partial S}$ (the gradient of the attention scores). The formula P * (dot(do, v) - delta) is the standard efficient softmax gradient implementation.1 b_dq: Since $S = QK^T$, the gradient flow tells us $dQ = dS \cdot K$. In Triton, this is tl.dot(b_ds, tl.trans(b_k)).The second kernel, parallel_attn_bwd_kernel_dkv, is the “transpose” of the first.
In dq, we fixed a block of Queries (rows) and iterated over columns. Here, we fix a block of Keys/Values (columns) and iterate over the Queries (rows) that attended to them.
# Load the Key and Value blocks we want gradients for
# [BT, BK]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BT, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
# Initialize accumulators
b_dk = tl.zeros([BT, BK], dtype=tl.float32)
b_dv = tl.zeros([BT, BV], dtype=tl.float32)
# Iterate over Query blocks (i_s loop)
# Note: We only loop over queries that come AFTER the current key (causal masking)
# If I am a Key at time 10, only Queries at time >= 10 cared about me.
for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
# Ommitted: Load q, do, lse, delta ...
# Recompute Score (Transposed)
# Note: We compute dot(K, Q^T) effectively
b_s = tl.dot(b_k, tl.trans(b_q)) * scale * RCP_LN2
# Recompute P
b_p = tl.where((o_k[:, None] <= o_q[None, :]) & m_q[None, :], exp2(b_s - b_lse[None, :]), 0)
# Compute dV
# dV += P^T * dO
# [BT, BS] @ [BS, BV] -> [BT, BV]
b_dv += tl.dot(b_p.to(b_do.dtype), b_do)
# Compute dS
# We need dP = dO * V^T.
# In code: b_dp = dot(b_v, trans(b_do))
b_dp = tl.dot(b_v, tl.trans(b_do))
b_ds = b_p * (b_dp - b_delta[None, :])
# Compute dK
# dK += dS^T * Q
# [BT, BS] @ [BS, BK] -> [BT, BK]
b_dk += tl.dot(b_ds.to(b_q.dtype), b_q)
The logic mirrors the dq kernel, but the matrices are transposed. For example, to get dv, we are essentially performing $P^T \cdot dO$.
By keeping k and v in SRAM and streaming q, do, and delta from HBM, we efficiently compute the gradients for the Key and Value matrices. Finally, the launcher script will take these results and reduce them if we are using GQA, summing the gradients across the replicated heads.
This code walkthrough is inspired by our very amazing Nathan Chen’s Triton Flash Attention Kernel Walkthrough: The Forward Pass. Big shoutout to Nathan!
The Delta Term derivation: If we define attention output $O_i = \sum_j P_{ij} V_j$, the gradient of the loss function $L$ with respect to the pre-softmax scores $S_{ij}$ involves the Jacobian of the Softmax. Specifically, $\frac{\partial L}{\partial S_{ij}} = P_{ij} (\frac{\partial L}{\partial P_{ij}} - \sum_k P_{ik} \frac{\partial L}{\partial P_{ik}})$.
Since $\frac{\partial L}{\partial P_{ij}} = dO_i \cdot V_j$, we can substitute: $\frac{\partial L}{\partial S_{ij}} = P_{ij} ( (dO_i \cdot V_j) - \sum_k P_{ik} (dO_i \cdot V_k) )$.
The term $\sum_k P_{ik} (dO_i \cdot V_k)$ can be rewritten by factoring out $dO_i$: $dO_i \cdot (\sum_k P_{ik} V_k) = dO_i \cdot O_i$.
This term $dO_i \cdot O_i$ is exactly what we calculate in the preprocess kernel and store as delta. So, $dS_{ij} = P_{ij} ( (dO_i \cdot V_j) - \text{delta}_i )$. This matches the code: b_ds = b_p * (b_dp - b_delta). ↩