Chapter 20. Long Context, From 4K to 10M Tokens
In 2020, GPT-3 could process 2,048 tokens at a time, roughly 1,500 words. By March 2026, GPT-5.4 handles 1.05 million tokens, LLaMA 4 Scout claims 10 million, and Grok 4 Fast processes 2 million. That is a 5,000x increase in just six years. But making long context work is not simply a matter of increasing a number in a configuration file. The core attention mechanism in transformers scales quadratically with sequence length: doubling the context window quadruples the computation. This chapter explains the techniques that make long-context inference practical, from Flash Attention’s memory-efficient computation to Ring Attention’s multi-GPU distribution, and examines why longer context windows do not always mean better performance.
The Quadratic Attention Problem
In Chapter 7, you learned that attention computes a score between every pair of tokens in the sequence. For a sequence of n tokens, this means computing an n x n attention score matrix. Each element of this matrix requires a dot product between a query vector and a key vector. The total number of dot products is n^2.
This quadratic scaling has two consequences:
Computation scales as O(n^2). Doubling the sequence length from 4,096 to 8,192 tokens quadruples the number of attention score computations, from roughly 16.8 million to 67.1 million.
Memory scales as O(n^2). The standard attention implementation materializes the full n x n score matrix in GPU memory. For a single attention head processing a 128,000-token sequence, this matrix has 128,000 x 128,000 = 16.4 billion entries. At float32 (4 bytes each), that is 65.5 GB for a single head in a single layer.
In practice, models have many heads and many layers. LLaMA 3.1 405B has 128 query heads across 126 layers. Even though the score matrix is computed per-head (so each head’s matrix is n x n, not the full model), the memory required to store these intermediate matrices during training or inference is enormous.
import numpy as np
def attention_memory_cost(seq_len, num_heads, bytes_per_element=4):
"""
Calculate the memory needed to store the full attention score matrix.
This is the intermediate matrix that standard attention materializes.
Returns memory per layer (all heads).
"""
# Each head computes an (seq_len x seq_len) score matrix
elements_per_head = seq_len * seq_len
bytes_per_head = elements_per_head * bytes_per_element
bytes_per_layer = bytes_per_head * num_heads
return bytes_per_layer
def format_bytes(b):
if b >= 1024**3:
return f"{b / 1024**3:.1f} GB"
elif b >= 1024**2:
return f"{b / 1024**2:.1f} MB"
else:
return f"{b / 1024:.1f} KB"
print("Attention Score Matrix Memory (per layer, float32)")
print("=" * 65)
print(f"{'Seq Length':>12} {'Heads':>6} {'Per Head':>12} {'Per Layer':>12}")
print("-" * 65)
configs = [
(2_048, 32, "GPT-3 era"),
(8_192, 32, "GPT-4 launch"),
(32_000, 32, "Early long-context"),
(128_000, 128, "LLaMA 3.1 405B"),
(1_000_000, 128, "1M context"),
]
for seq_len, heads, label in configs:
per_head = seq_len * seq_len * 4
per_layer = attention_memory_cost(seq_len, heads)
print(f"{seq_len:>12,} {heads:>6} {format_bytes(per_head):>12} "
f"{format_bytes(per_layer):>12} ({label})")At 2,048 tokens with 32 heads, the attention score matrix for one layer fits in about 512 MB. At 128,000 tokens with 128 heads, it requires over 7.6 TB per layer. At 1 million tokens, the numbers become absurd: over 465 TB per layer. Clearly, no GPU has this much memory. The standard attention algorithm simply cannot work at these sequence lengths.
This is the fundamental problem that every technique in this chapter addresses: how to compute exact (or near-exact) attention without materializing the full n x n score matrix.
Flash Attention: The Breakthrough That Changed Everything
Flash Attention, introduced by Tri Dao et al. in May 2022 (arXiv:2205.14135, NeurIPS 2022), solved the memory problem by rethinking how attention is computed at the hardware level. The key insight is that the bottleneck in attention is not arithmetic (GPUs have plenty of compute power) but memory access: reading and writing the large intermediate matrices between GPU high-bandwidth memory (HBM) and the GPU’s compute units.
The GPU Memory Hierarchy
To understand Flash Attention, you need to know that GPUs have two levels of memory:
HBM (High Bandwidth Memory): The main GPU memory. An H100 has 80 GB of HBM with a bandwidth of about 3.35 TB/s. This is where model weights, the KV cache, and intermediate results are stored.
SRAM (on-chip memory): A much smaller but much faster memory located directly on the compute chip. An H100 has about 50 MB of SRAM (across all streaming multiprocessors) with a bandwidth of roughly 33 TB/s, about 10x faster than HBM.
Standard attention computes the full n x n score matrix in HBM, then reads it back to compute softmax, then writes the softmax result back to HBM, then reads it again to multiply by the value matrix. Each of these read/write operations is slow because HBM bandwidth is the bottleneck.
Flash Attention eliminates these intermediate reads and writes by computing attention in tiles (small blocks) that fit entirely in SRAM. Instead of materializing the full n x n matrix, it processes the attention computation block by block, keeping intermediate results in fast SRAM and only writing the final output to HBM.
How Tiling Works
The algorithm divides the query, key, and value matrices into blocks. For each block of queries, it iterates over all blocks of keys and values, computing partial attention scores and accumulating the result using an online softmax algorithm. The online softmax trick allows computing the correct softmax normalization incrementally, without needing to see all scores at once.
import numpy as np
def standard_attention(Q, K, V):
"""
Standard attention: materializes the full N x N score matrix.
Memory: O(N^2) for the score matrix.
"""
N = Q.shape[0]
d = Q.shape[1]
# Step 1: Compute full score matrix (N x N), stored in HBM
S = Q @ K.T / np.sqrt(d) # O(N^2) memory
# Step 2: Softmax over each row
S_max = np.max(S, axis=-1, keepdims=True)
P = np.exp(S - S_max)
P = P / np.sum(P, axis=-1, keepdims=True) # O(N^2) memory
# Step 3: Multiply by V
O = P @ V # O(N x d) memory
return O
def flash_attention(Q, K, V, block_size=64):
"""
Flash Attention (simplified): computes exact attention
WITHOUT materializing the full N x N score matrix.
Memory: O(N) instead of O(N^2).
Uses online softmax to accumulate results block by block.
"""
N, d = Q.shape
O = np.zeros_like(Q) # Output accumulator
l = np.zeros((N, 1)) # Softmax denominator (running sum of exp)
m = np.full((N, 1), -np.inf) # Running max for numerical stability
# Process keys/values in blocks
num_kv_blocks = (N + block_size - 1) // block_size
for j in range(num_kv_blocks):
# Load one block of K and V into SRAM
kv_start = j * block_size
kv_end = min(kv_start + block_size, N)
K_block = K[kv_start:kv_end] # (block_size, d)
V_block = V[kv_start:kv_end] # (block_size, d)
# Compute scores for ALL queries against this K block
# This is a (N x block_size) matrix, NOT (N x N)
S_block = Q @ K_block.T / np.sqrt(d) # (N, block_size)
# Online softmax update
m_new = np.maximum(m, np.max(S_block, axis=-1, keepdims=True))
# Rescale previous accumulator
exp_diff = np.exp(m - m_new)
l = l * exp_diff
O = O * exp_diff
# Add contribution from this block
P_block = np.exp(S_block - m_new) # (N, block_size)
l = l + np.sum(P_block, axis=-1, keepdims=True)
O = O + P_block @ V_block # (N, d)
m = m_new
# Final normalization
O = O / l
return O
# Verify both produce the same result
np.random.seed(42)
N, d = 256, 64
Q = np.random.randn(N, d).astype(np.float32)
K = np.random.randn(N, d).astype(np.float32)
V = np.random.randn(N, d).astype(np.float32)
out_standard = standard_attention(Q, K, V)
out_flash = flash_attention(Q, K, V, block_size=64)
print(f"Max difference: {np.max(np.abs(out_standard - out_flash)):.2e}")
print(f"Results match: {np.allclose(out_standard, out_flash, atol=1e-5)}")
print()
print(f"Standard attention peak memory: {N}x{N} = {N*N:,} elements ({N*N*4/1024:.0f} KB)")
print(f"Flash attention peak memory: {N}x{64} = {N*64:,} elements ({N*64*4/1024:.0f} KB)")
print(f"Memory reduction: {N*N / (N*64):.0f}x")The critical point: Flash Attention computes the exact same result as standard attention. It is not an approximation. The output is mathematically identical (up to floating-point precision). The only difference is how the computation is organized in memory.
The Memory Savings
The memory reduction from Flash Attention is dramatic. Standard attention requires O(n^2) memory for the score matrix. Flash Attention requires only O(n) memory (proportional to the sequence length, not its square), because it never materializes the full score matrix. The largest intermediate result at any point is a (block_size x block_size) tile that fits in SRAM.
For concrete numbers:
| Sequence Length | Standard Attention (per head) | Flash Attention (per head) | Reduction |
|---|---|---|---|
| 2,048 | 16 MB | 32 KB | 512x |
| 8,192 | 256 MB | 128 KB | 2,048x |
| 32,768 | 4 GB | 512 KB | 8,192x |
| 131,072 | 64 GB | 2 MB | 32,768x |
| 1,048,576 | 4 TB | 16 MB | 262,144x |
These numbers assume float32 for standard attention and a block size of 64 for Flash Attention. The reduction factor equals n / block_size, which grows linearly with sequence length. This is why Flash Attention is essential for long-context models: without it, a 1 million token context window would require terabytes of memory just for the attention score matrices.
The Evolution: Flash Attention 1 Through 4
Flash Attention has gone through four major versions, each optimized for the GPU hardware of its era:
FlashAttention (v1) (Dao et al., arXiv:2205.14135, NeurIPS 2022): The original paper. Introduced IO-aware tiling for attention on NVIDIA A100 GPUs. Achieved 2 to 4x speedup over standard attention with no approximation. Reduced memory from O(n^2) to O(n). This paper is one of the most impactful systems contributions to the LLM era.
FlashAttention-2 (Dao, arXiv:2307.08691, ICLR 2024): Improved parallelism and work partitioning. Reduced non-matmul FLOPs (the overhead operations that are not matrix multiplications). Achieved up to 225 TFLOPs/s per A100 GPU (72% model FLOPs utilization), roughly 2x faster than FlashAttention-1. Enabled training with 2x longer sequences.
FlashAttention-3 (Dao and Shah, arXiv:2407.08608, NeurIPS 2024): Optimized for NVIDIA H100 (Hopper) GPUs. Exploited asynchronous execution between Tensor Cores and the Tensor Memory Accelerator (TMA). Introduced FP8 support for attention computation. Achieved up to 840 TFLOPs/s with BF16 (85% utilization) and 1.3 PFLOPs/s with FP8 on H100. FlashAttention-2 had only achieved 35% utilization on H100, so this was a 1.5 to 2x speedup.
FlashAttention-4 (Dao et al., arXiv:2603.05451, March 2026): Optimized for NVIDIA B200 (Blackwell) GPUs. The Blackwell architecture has fundamentally different performance characteristics: tensor core throughput doubled compared to Hopper, but other functional units (shared memory bandwidth, exponential units) scaled more slowly. FlashAttention-4 addresses this asymmetry with redesigned pipelines, software-emulated exponential functions, and a new 2-CTA MMA mode. Achieves up to 1,613 TFLOPs/s with BF16 on B200 (71% utilization), up to 1.3x faster than cuDNN 9.13 and 2.7x faster than Triton. Notably, FlashAttention-4 is implemented entirely in CuTe-DSL, a Python-embedded domain-specific language for GPU kernels, achieving 20 to 30x faster compile times compared to the traditional C++ template-based approach used in earlier versions.
def flash_attention_timeline():
"""
Performance progression of Flash Attention across GPU generations.
"""
versions = [
("FlashAttention-1", "A100", "NeurIPS 2022", "~150", "~48%", "2205.14135"),
("FlashAttention-2", "A100", "ICLR 2024", "225", "72%", "2307.08691"),
("FlashAttention-3", "H100", "NeurIPS 2024", "840", "85%", "2407.08608"),
("FlashAttention-4", "B200", "March 2026", "1,613", "71%", "2603.05451"),
]
print(f"{'Version':<20} {'GPU':<6} {'Venue':<14} {'TFLOPs/s':>9} "
f"{'Util.':>6} {'arXiv':>12}")
print("-" * 72)
for name, gpu, venue, tflops, util, arxiv in versions:
print(f"{name:<20} {gpu:<6} {venue:<14} {tflops:>9} "
f"{util:>6} {arxiv:>12}")
print()
print("Each version targets the GPU architecture of its era.")
print("Performance roughly doubles with each GPU generation,")
print("but only if the software is redesigned to match the hardware.")
flash_attention_timeline()The progression from 150 TFLOPs/s on A100 to 1,613 TFLOPs/s on B200 represents a roughly 10x improvement in four years. But this improvement required four complete rewrites of the algorithm, each tailored to the specific memory hierarchy, instruction set, and execution model of the target GPU. Flash Attention is not a single algorithm; it is a family of algorithms that co-evolve with hardware.
Source: Dao et al., “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness,” arXiv:2205.14135, May 2022. NeurIPS 2022. Dao, “FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning,” arXiv:2307.08691, July 2023. ICLR 2024 (iclr.cc/virtual/2024/poster/17889, openreview.net/forum?id=mZn2Xyh9Ec). 225 TFLOPs/s per A100, 72% utilization. Dao and Shah, “FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision,” arXiv:2407.08608, July 2024. NeurIPS 2024 (neurips.cc/virtual/2024/poster/93328, openreview.net/forum?id=tVConYid20). 840 TFLOPs/s BF16 (85% utilization), 1.3 PFLOPs/s FP8 on H100. Note: the original arXiv preprint (v1) reported 740 TFLOPs/s FP16 (75% utilization) and ~1.2 PFLOPs/s FP8; the NeurIPS camera-ready version updated these to 840 TFLOPs/s BF16 (85% utilization) and 1.3 PFLOPs/s FP8. Dao et al., “FlashAttention-4: Algorithm and Kernel Pipelining Co-Design for Asymmetric Hardware Scaling,” arXiv:2603.05451, March 2026. Up to 1,613 TFLOPs/s BF16 on B200 (71% utilization), 1.3x over cuDNN 9.13, 2.7x over Triton. Implemented in CuTe-DSL (Python-embedded GPU kernel DSL) with 20-30x faster compile times vs C++ templates (arxiv.org/abs/2603.05451, tridao.me/blog/2026/flash4, together.ai/blog/flashattention-4, blog.ai.princeton.edu). Note: the arXiv abstract reports 1,613 TFLOPs/s while Tri Dao’s blog post (tridao.me/blog/2026/flash4) reports 1,605 TFLOPs/s; this chapter uses the arXiv figure as the primary source.
Ring Attention: Distributing Attention Across GPUs
Flash Attention solves the memory problem on a single GPU, but there is a second constraint: even with Flash Attention, the total computation for attention still scales as O(n^2). For very long sequences (hundreds of thousands or millions of tokens), the computation itself takes too long on a single GPU. Ring Attention solves this by distributing the attention computation across multiple GPUs.
The Core Idea
Ring Attention, introduced by Liu et al. (arXiv:2310.01889, ICLR 2024), arranges GPUs in a logical ring. The input sequence is split into chunks, with each GPU holding one chunk. Each GPU computes attention for its chunk of queries against all chunks of keys and values by passing key-value blocks around the ring.
Here is how it works, step by step:
Split the sequence. If you have 8 GPUs and a 1 million token sequence, each GPU gets 125,000 tokens. GPU 0 holds tokens 0 through 124,999, GPU 1 holds tokens 125,000 through 249,999, and so on.
Each GPU starts with its local K/V block. GPU i computes attention scores between its local queries and its local keys/values.
Pass K/V blocks around the ring. After computing local attention, each GPU sends its K/V block to the next GPU in the ring and receives a K/V block from the previous GPU. Now each GPU computes attention between its local queries and the received K/V block.
Repeat until all K/V blocks have been seen. After N-1 passes (where N is the number of GPUs), every GPU has computed attention between its queries and all keys/values in the entire sequence.
Overlap communication with computation. The key optimization is that while a GPU is computing attention with the current K/V block, it is simultaneously sending/receiving the next K/V block. This hides the communication latency behind the computation.
import numpy as np
def simulate_ring_attention(seq_len, num_gpus, head_dim=128):
"""
Simulate Ring Attention to show how work is distributed.
Each GPU processes its chunk of queries against all K/V blocks.
"""
chunk_size = seq_len // num_gpus
print(f"Ring Attention: {seq_len:,} tokens across {num_gpus} GPUs")
print(f"Chunk size per GPU: {chunk_size:,} tokens")
print()
# Each GPU performs num_gpus rounds of local attention
local_attention_flops_per_round = 2 * chunk_size * chunk_size * head_dim
total_flops_per_gpu = local_attention_flops_per_round * num_gpus
# Compare with single-GPU computation
single_gpu_flops = 2 * seq_len * seq_len * head_dim
print(f"FLOPs per GPU: {total_flops_per_gpu:,.0f}")
print(f"Single GPU FLOPs: {single_gpu_flops:,.0f}")
print(f"Speedup: {single_gpu_flops / total_flops_per_gpu:.1f}x")
print()
# Communication: each round sends chunk_size * head_dim * 2 (K and V) values
bytes_per_round = 2 * chunk_size * head_dim * 2 # 2 for K+V, 2 for bfloat16
total_comm_bytes = bytes_per_round * (num_gpus - 1)
print(f"Communication per round: {bytes_per_round / 1024**2:.1f} MB")
print(f"Total communication: {total_comm_bytes / 1024**2:.1f} MB per GPU")
print()
# Show the ring schedule
print("Ring schedule (which K/V block each GPU processes at each step):")
print(f"{'Step':>6}", end="")
for gpu in range(min(num_gpus, 8)):
print(f" GPU {gpu:>2}", end="")
print()
print("-" * (6 + 7 * min(num_gpus, 8)))
for step in range(min(num_gpus, 8)):
print(f"{step:>6}", end="")
for gpu in range(min(num_gpus, 8)):
kv_block = (gpu + step) % num_gpus
print(f" KV {kv_block:>2}", end="")
print()
if num_gpus > 8:
print(f" ... ({num_gpus - 8} more steps)")
simulate_ring_attention(1_000_000, 8)The speedup from Ring Attention is approximately equal to the number of GPUs. With 8 GPUs, each GPU processes 1/8 of the queries against all keys/values, so the total work per GPU is 1/8 of the single-GPU work. The communication overhead is hidden behind computation, so the effective speedup is close to linear.
Why Ring Attention Enables Million-Token Contexts
The combination of Flash Attention and Ring Attention is what makes million-token context windows practical:
- Flash Attention eliminates the O(n^2) memory requirement on each GPU, reducing it to O(n).
- Ring Attention distributes the O(n^2) computation across multiple GPUs, reducing the per-GPU computation to O(n^2 / P) where P is the number of GPUs.
Together, they allow processing sequences that are P times longer than what a single GPU can handle, where P is the number of GPUs in the ring. Liu et al. demonstrated training on sequences exceeding 100 million tokens using Ring Attention, which is 512 times longer than what was possible with prior memory-efficient transformers.
Source: Liu et al., “Ring Attention with Blockwise Transformers for Near-Infinite Context,” arXiv:2310.01889, October 2023. ICLR 2024 (openreview.net/forum?id=WsRHpHH4s0). Enables training on sequences 512x longer than prior methods, exceeding 100 million tokens. Communication is fully overlapped with computation.
Sparse Attention: Not Every Token Needs to Attend to Every Other Token
Flash Attention and Ring Attention make full attention computationally feasible at long sequence lengths, but they do not change the fundamental O(n^2) scaling. Sparse attention takes a different approach: instead of computing attention between every pair of tokens, it restricts each token to attend to only a subset of other tokens. This reduces the complexity from O(n^2) to O(n * k), where k is the number of tokens each position attends to (typically O(n) or O(n log n) in total).
Sliding Window Attention
The simplest form of sparse attention is sliding window attention, introduced in the Longformer paper (Beltagy, Peters, and Cohan, arXiv:2004.05150, 2020). Each token attends only to a fixed-size window of neighboring tokens, typically a few hundred to a few thousand tokens on each side.
The intuition is that most useful information for understanding a token comes from its local context. When you read the word “bank” in a sentence, the words immediately before and after it (“river bank” vs. “bank account”) are far more informative than words thousands of tokens away. Sliding window attention captures this locality.
import numpy as np
def sliding_window_attention(Q, K, V, window_size=256):
"""
Sliding window attention: each token attends only to
tokens within a fixed window around it.
Complexity: O(n * window_size) instead of O(n^2).
"""
N, d = Q.shape
O = np.zeros_like(Q)
half_w = window_size // 2
for i in range(N):
# Define the window: tokens from (i - half_w) to (i + half_w)
start = max(0, i - half_w)
end = min(N, i + half_w + 1)
# Compute attention only within the window
K_window = K[start:end] # (window_size, d)
V_window = V[start:end] # (window_size, d)
scores = Q[i] @ K_window.T / np.sqrt(d) # (window_size,)
weights = np.exp(scores - np.max(scores))
weights = weights / weights.sum()
O[i] = weights @ V_window
return O
# Memory comparison
seq_len = 128_000
window = 4_096
full_attention_elements = seq_len * seq_len
window_attention_elements = seq_len * window
print(f"Full attention: {full_attention_elements:>15,} score computations")
print(f"Window (w={window:,}): {window_attention_elements:>15,} score computations")
print(f"Reduction: {full_attention_elements / window_attention_elements:.0f}x")The limitation of pure sliding window attention is that tokens cannot directly attend to distant tokens. A token at position 50,000 cannot see a token at position 1,000, no matter how important that distant token might be. However, information can propagate indirectly through multiple layers: if each layer has a window of 4,096 tokens, then after L layers, information can propagate up to L x 4,096 tokens. With 80 layers and a window of 4,096, the effective receptive field is 327,680 tokens, which covers a substantial portion of most context windows.
Mistral 7B (released September 27, 2023) was one of the first widely used models to adopt sliding window attention in production, using a window size of 4,096 tokens. This allowed it to handle longer sequences than its context window might suggest, because the multi-layer stacking creates a much larger effective receptive field.
Global + Local Attention
To address the limitation of pure sliding window attention, several models combine local (sliding window) attention with global attention on selected tokens. The two most influential approaches are:
Longformer (Beltagy, Peters, and Cohan, arXiv:2004.05150, 2020): Combines sliding window attention with global attention on a small number of designated tokens (such as the [CLS] token or task-specific tokens). Global tokens attend to all other tokens and are attended to by all other tokens. This allows long-range information to flow through the global tokens while keeping the overall complexity linear.
BigBird (Zaheer et al., arXiv:2007.14062, NeurIPS 2020): Combines three attention patterns: (1) sliding window attention for local context, (2) global tokens that attend to the entire sequence, and (3) random attention where each token attends to a small number of randomly selected tokens. The random connections are theoretically motivated: Zaheer et al. proved that BigBird’s sparse attention is a universal approximator of sequence functions and is Turing complete, preserving the theoretical properties of full attention. BigBird could handle sequences up to 8x longer than standard transformers on the same hardware.
import numpy as np
def create_attention_mask(seq_len, window_size=256, num_global=4,
num_random=32):
"""
Create a BigBird-style sparse attention mask combining:
1. Sliding window (local context)
2. Global tokens (attend to/from everything)
3. Random connections
Returns a boolean mask where True = attend.
"""
mask = np.zeros((seq_len, seq_len), dtype=bool)
half_w = window_size // 2
# 1. Sliding window
for i in range(seq_len):
start = max(0, i - half_w)
end = min(seq_len, i + half_w + 1)
mask[i, start:end] = True
# 2. Global tokens (first num_global tokens attend to everything)
mask[:num_global, :] = True # Global tokens attend to all
mask[:, :num_global] = True # All tokens attend to global tokens
# 3. Random connections
rng = np.random.RandomState(42)
for i in range(seq_len):
random_indices = rng.choice(seq_len, size=num_random, replace=False)
mask[i, random_indices] = True
return mask
# Compare sparsity
seq_len = 4096
mask = create_attention_mask(seq_len, window_size=256, num_global=4,
num_random=32)
total_elements = seq_len * seq_len
active_elements = mask.sum()
sparsity = 1 - active_elements / total_elements
print(f"Sequence length: {seq_len:,}")
print(f"Full attention: {total_elements:,} connections")
print(f"Sparse attention: {active_elements:,} connections")
print(f"Sparsity: {sparsity*100:.1f}% of connections removed")
print(f"Speedup potential: {total_elements / active_elements:.1f}x")Sparse Attention in Modern Models
While Longformer and BigBird were designed for encoder models (BERT-style), the principles of sparse attention have influenced modern decoder-only LLMs. However, most frontier models as of March 2026 use full attention (with Flash Attention for efficiency) rather than sparse attention patterns. The reason is that full attention with Flash Attention is fast enough for current context lengths (up to 1 to 2 million tokens), and sparse attention introduces complexity in implementation and potential quality degradation.
That said, sparse attention remains important for two reasons:
Extremely long contexts. For sequences beyond 2 million tokens (like LLaMA 4 Scout’s claimed 10 million token window), some form of sparsity is likely necessary because even Flash Attention cannot make the O(n^2) computation fast enough.
Inference efficiency. During the decode phase (Chapter 18), each new token must attend to all previous tokens. For a 1 million token context, this means 1 million attention score computations per head per layer per generated token. Sparse attention during decode can significantly reduce this cost.
Source: Beltagy, Peters, and Cohan, “Longformer: The Long-Document Transformer,” arXiv:2004.05150, April 2020. Sliding window attention with global tokens, linear complexity. Zaheer et al., “Big Bird: Transformers for Longer Sequences,” arXiv:2007.14062, NeurIPS 2020. Combines sliding window, global, and random attention. Proved universal approximation and Turing completeness. Handles sequences 8x longer than standard transformers.
Context Parallelism: Splitting Long Sequences Across Devices
Ring Attention distributes the attention computation across GPUs, but modern training and inference systems use a more general technique called context parallelism (CP) that extends this idea to the entire transformer forward pass, not just attention.
How Context Parallelism Works
In context parallelism, the input sequence is split along the sequence dimension across multiple GPUs. Each GPU processes its chunk of the sequence through the full transformer stack. The key challenge is the attention layers, where each token needs to attend to all other tokens (including those on other GPUs). Context parallelism uses Ring Attention (or similar all-to-all communication patterns) specifically for the attention layers, while the feed-forward layers (which operate independently on each token) require no inter-GPU communication.
This is different from the parallelism strategies you learned about in Chapter 14:
- Data parallelism splits the batch across GPUs (each GPU processes different examples).
- Tensor parallelism splits the model’s weight matrices across GPUs (each GPU holds part of each layer).
- Pipeline parallelism splits the model’s layers across GPUs (each GPU holds a subset of layers).
- Context parallelism splits the sequence across GPUs (each GPU holds a chunk of the input).
These strategies are orthogonal and can be combined. A frontier model training run might use all four simultaneously: data parallelism across nodes, tensor parallelism within a node, pipeline parallelism across groups of nodes, and context parallelism for long sequences.
DeepSpeed-Ulysses
DeepSpeed-Ulysses (Jacobs et al., arXiv:2309.14509, September 2023) is an alternative to Ring Attention for sequence parallelism. Instead of passing K/V blocks around a ring, Ulysses partitions the input along the sequence dimension and uses all-to-all communication to redistribute the data for attention computation.
The key difference from Ring Attention is in how the attention heads are handled. Ulysses partitions the sequence across GPUs, then uses all-to-all communication to gather the full sequence for each attention head (but distributed across heads). This means each GPU computes full attention for a subset of heads, rather than partial attention for all heads. The approach is attention-mechanism-agnostic: it works with standard multi-head attention, grouped query attention, and any other attention variant without modification.
NVIDIA’s Megatron Core framework has integrated Dynamic Context Parallelism (Dynamic-CP), which dynamically selects the context parallelism size per microbatch to handle variable-length sequences efficiently. This achieved up to 1.48x speedup on real-world datasets compared to static context parallelism, because it avoids wasting compute on padding when sequences have different lengths.
Snowflake released Arctic Ulysses in April 2025, adapting the Ulysses sequence parallelism approach from training to inference. By splitting long input sequences across GPUs during the prefill phase, Arctic Ulysses reduces time-to-first-token by up to 6.8x for long-context requests while achieving up to 1.5x higher throughput than latency-optimized tensor parallelism, breaking the traditional tradeoff between latency and throughput that tensor parallelism imposes.
Source: Jacobs et al., “DeepSpeed Ulysses: System Optimizations for Enabling Training of Extreme Long Sequence Transformer Models,” arXiv:2309.14509, September 2023. Partitions sequences across GPUs with all-to-all communication. NVIDIA, “Speeding Up Variable-Length Training with Dynamic Context Parallelism and NVIDIA Megatron Core,” January 28, 2026 (developer.nvidia.com/blog). Dynamic-CP achieves up to 1.48x speedup on variable-length datasets. Snowflake, “Low-Latency and High-Throughput Inference for Long Context with Sequence Parallelism (aka Arctic Ulysses),” April 3, 2025 (snowflake.com/en/engineering-blog/ulysses-low-latency-llm-inference). Up to 6.82x TTFT reduction (Qwen 2.5-32B, TP=1 vs SP=8) and up to 1.46x throughput improvement over TP=8 baseline.
Real Context Windows in March 2026
Context windows have grown by over 5,000x since GPT-3. Here is the timeline of this expansion, with verified numbers for the major models:
| Model | Release Date | Context Window | Notes |
|---|---|---|---|
| GPT-3 | June 2020 | 2,048 tokens | The starting point |
| GPT-3.5 Turbo | March 2023 | 4,096 tokens | Later expanded to 16K |
| GPT-4 | March 2023 | 8,192 tokens | 32K variant also available |
| Claude 2 | July 2023 | 100,000 tokens | First 100K context model |
| GPT-4 Turbo | November 2023 | 128,000 tokens | 128K became the new standard |
| Gemini 1.5 Pro | February 2024 | 1,000,000 tokens | First 1M context model |
| GPT-4o | May 2024 | 128,000 tokens | Maintained 128K standard |
| Claude 3.5 Sonnet | June 2024 | 200,000 tokens | 200K for Claude 3.x family |
| LLaMA 3.1 405B | July 2024 | 131,072 tokens | Open-weights 128K |
| LLaMA 4 Scout | April 2025 | 10,000,000 tokens | 10M; 95%+ NIAH to 8M |
| LLaMA 4 Maverick | April 2025 | 1,048,576 tokens | 1M context; 128K-512K typical |
| Gemini 2.5 Pro | March 2025 | 1,000,000 tokens | 1M standard, 2M planned |
| GPT-5 | August 2025 | 272,000 tokens | ~272K input + 128K output |
| Grok 4 Fast | September 19, 2025 | 2,000,000 tokens | 2M; 40% fewer thinking tokens than Grok 4 |
| Gemini 3 Pro | November 18, 2025 | 1,000,000 tokens | 1M context window |
| Claude Opus 4.6 | February 5, 2026 | 1,000,000 tokens | First Opus with 1M context |
| Claude Sonnet 4.6 | February 17, 2026 | 1,000,000 tokens | 1M context |
| Gemini 3.1 Pro | February 19, 2026 | 1,000,000 tokens | 1M context per official model card; 94.3% GPQA Diamond |
| GPT-5.4 | March 5, 2026 | 1,050,000 tokens | 1.05M via API; 272K standard |
import numpy as np
def context_window_growth():
"""
Visualize the exponential growth of context windows.
"""
models = [
("GPT-3", 2020.5, 2_048),
("GPT-3.5 Turbo", 2023.2, 4_096),
("GPT-4", 2023.2, 8_192),
("Claude 2", 2023.5, 100_000),
("GPT-4 Turbo", 2023.9, 128_000),
("Gemini 1.5 Pro", 2024.1, 1_000_000),
("LLaMA 3.1", 2024.5, 131_072),
("LLaMA 4 Scout", 2025.3, 10_000_000),
("GPT-5", 2025.6, 272_000),
("Claude Opus 4.6", 2026.1, 1_000_000),
("GPT-5.4", 2026.2, 1_050_000),
("Gemini 3 Pro", 2025.9, 1_000_000),
("Grok 4 Fast", 2025.7, 2_000_000),
("Gemini 3.1 Pro", 2026.1, 1_000_000),
]
print(f"{'Model':<20} {'Year':>6} {'Context':>12} {'vs GPT-3':>10}")
print("-" * 52)
base = 2_048
for name, year, ctx in models:
ratio = ctx / base
bar = "#" * min(int(np.log2(ratio)) + 1, 20)
print(f"{name:<20} {year:>6.1f} {ctx:>12,} {ratio:>9.0f}x {bar}")
context_window_growth()The growth is roughly exponential, with context windows doubling every 6 to 9 months. But there is an important caveat: the advertised context window and the effective context window are often very different, as we will see in the next section.
Source: GPT-3 2,048 tokens (Brown et al., 2020). GPT-3.5 Turbo 4,096 tokens, later 16K (OpenAI, March 2023). GPT-4 8,192/32K tokens (OpenAI, March 2023). GPT-4 Turbo 128K (OpenAI, November 2023). Gemini 1.5 Pro 1M tokens (Google, February 2024, blog.google/technology/ai/long-context-window-ai-models). LLaMA 3.1 405B 131,072 tokens (Meta, July 2024, huggingface.co config.json max_position_embeddings=131072). LLaMA 4 Scout 10M tokens (Meta, April 2025, winbuzzer.com, deeplearning.ai). GPT-5 272K input (OpenAI, August 2025, thesyntaxdiaries.com). Grok 4 Fast 2M tokens (xAI, September 19, 2025, x.ai/news/grok-4-fast, grokmag.com, winbuzzer.com). Gemini 3 Pro 1M tokens (Google, November 18, 2025, businessworld.in, thenextgentechinsider.com). Gemini 3.1 Pro 1M tokens (Google, February 19, 2026, deepmind.google/models/model-cards/gemini-3-1-pro “token context window of up to 1M”; note: some third-party sources such as yingtu.ai and buildfastwithai.com claim 2M, but the official DeepMind model card, OpenRouter, llm-stats.com, adtools.org, and gemini31.com all state 1M). Claude Opus 4.6 1M tokens (Anthropic, February 5, 2026, nyu.edu, felloai.com). GPT-5.4 1.05M tokens (OpenAI, March 5, 2026; community.openai.com/t/gpt-5-4-deep-dive-pricing-context-limits-and-tool-search-explained/1375800, automatio.ai/models/gpt-5-4, digitalapplied.com; 272K standard window, 1.05M opt-in via API with 2x input and 1.5x output surcharge above 272K per news.ycombinator.com/item?id=47266670).
The “Context Rot” Problem: Performance Degrades in the Middle
Having a large context window does not mean the model uses all of it equally well. One of the most important findings in long-context research is the “lost in the middle” phenomenon: language models perform significantly worse when relevant information is located in the middle of a long context, compared to when it is at the beginning or end.
The U-Shaped Attention Curve
Liu et al. (arXiv:2307.03172, TACL 2024) systematically tested how well language models retrieve information placed at different positions in their context window. They used a multi-document question answering task where the answer document was placed at various positions among 20 distractor documents.
The results showed a clear U-shaped curve: models performed best when the relevant document was at the very beginning (position 1) or the very end (position 20) of the context, and worst when it was in the middle (around positions 8 to 12). The accuracy drop was substantial: over 30% degradation when the answer moved from position 1 to position 10 in a 20-document context.
This happens because of how attention and positional encoding interact. Tokens at the beginning of the sequence benefit from primacy bias (they are always within the attention window and accumulate attention through the “attention sink” phenomenon described in Chapter 18). Tokens at the end benefit from recency bias (they are closest to the query position during generation). Tokens in the middle get neither advantage.
import numpy as np
def simulate_position_bias(num_documents=20, num_positions=20):
"""
Simulate the U-shaped accuracy curve from the
'Lost in the Middle' paper (Liu et al., 2024).
These are illustrative values based on the paper's findings,
not exact reproductions.
"""
# Approximate the U-shaped curve
positions = np.arange(1, num_positions + 1)
# High accuracy at beginning and end, low in the middle
# Based on the general pattern from Liu et al.
center = num_positions / 2
distance_from_edge = np.minimum(positions - 1, num_positions - positions)
# U-shape: accuracy drops as distance from edges increases
base_accuracy = 75 # Baseline accuracy at edges
middle_penalty = 30 # Maximum drop in the middle
accuracy = base_accuracy - middle_penalty * (
1 - np.cos(np.pi * distance_from_edge / (num_positions / 2))
) / 2
print("Position of relevant document vs. retrieval accuracy")
print("(Illustrative U-shaped curve based on Liu et al., 2024)")
print("=" * 55)
print(f"{'Position':>10} {'Accuracy':>10} {'Visual':>30}")
print("-" * 55)
for pos, acc in zip(positions, accuracy):
bar_len = int(acc / 2)
bar = "#" * bar_len
print(f"{pos:>10} {acc:>9.1f}% {bar}")
print()
print(f"Best accuracy: {accuracy.max():.1f}% (position {positions[accuracy.argmax()]})")
print(f"Worst accuracy: {accuracy.min():.1f}% (position {positions[accuracy.argmin()]})")
print(f"Drop: {accuracy.max() - accuracy.min():.1f} percentage points")
simulate_position_bias()Implications for Real Applications
The lost-in-the-middle effect has practical consequences:
RAG (Retrieval-Augmented Generation): If you retrieve multiple documents and concatenate them into the context, the order matters. Documents placed in the middle of the context are less likely to be used by the model. A common mitigation is to place the most relevant documents at the beginning or end of the context.
Long document analysis: When asking a model to analyze a long document, information in the middle sections may be underweighted. Breaking the document into chunks and processing them separately can sometimes yield better results than feeding the entire document at once.
Multi-turn conversations: In long conversations, information from the middle turns (not the earliest or most recent) is most likely to be “forgotten” by the model, even though it is technically within the context window.
Recent research suggests the effect is strongest when inputs occupy up to 50% of a model’s context window. Beyond that threshold, the primacy bias weakens while recency bias remains stable, effectively shifting from a U-shaped curve to a distance-based bias where performance improves the closer information is to the end of the input.
A July 2025 study by Chroma Research (“Context Rot: How Increasing Input Tokens Impacts LLM Performance”) tested 18 state-of-the-art LLMs and found that performance degrades significantly as input length increases, even on tasks as simple as repeating a string. The degradation is not limited to the “lost in the middle” effect; it is a broader phenomenon where adding more context, even irrelevant context, actively hurts performance. This challenges the assumption that models process context uniformly and suggests that the race to larger context windows may have diminishing returns for many practical applications.
Even the latest models are not immune. GPT-5.4, despite its 1.05 million token context window, is reported to deliver its most stable performance within the first 256K tokens; accuracy on high-complexity reasoning tasks drops noticeably beyond that threshold. This is consistent with the broader pattern: the advertised context window is a hard technical limit, but the effective context window for reliable performance is often much smaller.
Source: Liu et al., “Lost in the Middle: How Language Models Use Long Contexts,” arXiv:2307.03172, July 2023. Transactions of the Association for Computational Linguistics (TACL), 2024, 12:157-173 (direct.mit.edu/tacl/article/doi/10.1162/tacl_a_00638). Over 30% accuracy drop when relevant information moves from position 1 to position 10 in a 20-document context (morphllm.com). Hong, Troynikov, and Huber, “Context Rot: How Increasing Input Tokens Impacts LLM Performance,” Chroma Research, July 2025 (research.trychroma.com/context-rot). Tested 18 LLMs; performance degrades with longer inputs even on simple tasks. GPT-5.4 performance most stable within first 256K tokens (automatio.ai/models/gpt-5-4, apiyi.com).
Needle-in-a-Haystack Tests
The Needle-in-a-Haystack (NIAH) test, popularized by Greg Kamradt in late 2023, has become the standard benchmark for evaluating long-context retrieval. The test is simple: insert a specific fact (the “needle”) at a random position within a large body of irrelevant text (the “haystack”), then ask the model a question that can only be answered by finding that specific fact.
How the Test Works
def create_niah_test(haystack_text, needle_text, needle_position_pct,
question):
"""
Create a Needle-in-a-Haystack test prompt.
needle_position_pct: where to insert the needle (0.0 = beginning,
0.5 = middle, 1.0 = end)
"""
# Split haystack into tokens (simplified: split by words)
haystack_words = haystack_text.split()
total_words = len(haystack_words)
# Calculate insertion point
insert_idx = int(total_words * needle_position_pct)
# Insert the needle
haystack_words.insert(insert_idx, needle_text)
# Construct the prompt
context = " ".join(haystack_words)
prompt = f"{context}\n\nQuestion: {question}\nAnswer:"
return prompt, insert_idx
# Example
needle = "The secret code for the vault is 7392."
question = "What is the secret code for the vault?"
# The model must find this single sentence among thousands of
# irrelevant words. The test is run at multiple positions
# (0%, 10%, 20%, ..., 100%) and multiple context lengths
# (1K, 4K, 16K, 64K, 128K, etc.) to create a 2D heatmap
# of retrieval accuracy.
print("Needle-in-a-Haystack Test Structure:")
print(f" Needle: '{needle}'")
print(f" Question: '{question}'")
print()
print("The test creates a 2D grid:")
print(" X-axis: Context length (1K to max context window)")
print(" Y-axis: Needle position (0% = start, 100% = end)")
print(" Cell value: Whether the model correctly retrieved the needle")
print()
print("A perfect model shows green (correct) everywhere.")
print("Real models show degradation in the middle positions")
print("and at very long context lengths.")What NIAH Tests Reveal
NIAH tests produce a 2D heatmap where the x-axis is context length and the y-axis is needle position. A perfect model would show 100% accuracy everywhere (all green). In practice, models show characteristic patterns:
Short contexts (< 32K tokens): Most modern models achieve near-perfect retrieval at all positions. The lost-in-the-middle effect is minimal at short context lengths.
Medium contexts (32K to 128K tokens): Some models begin to show degradation, particularly in the middle positions. This is where the U-shaped curve becomes visible.
Long contexts (> 128K tokens): Performance varies significantly between models. Some models maintain high accuracy up to their full context window; others degrade substantially.
Claimed vs. effective context: LLaMA 4 Scout claims a 10 million token context window, but independent testing showed mixed results. One analysis (jangwook.net) reported over 95% NIAH retrieval accuracy up to 8 million tokens, dropping to 89% at the full 10 million token limit. However, on tasks requiring even minimal inference (not just literal retrieval), performance degradation was more pronounced. The gap between the advertised context window and the effective context window is a recurring theme in long-context models.
The NIAH test is useful but limited. It tests only simple factual retrieval (finding a single fact), not reasoning over long contexts (synthesizing information from multiple parts of the document). A model might pass NIAH perfectly but still fail at tasks that require integrating information from different sections of a long document.
Source: Greg Kamradt, “Needle in a Haystack” benchmark, November 2023 (github.com/gkamradt/LLMTest_NeedleInAHaystack). Langchain, “Multi Needle in a Haystack,” March 2024 (blog.langchain.com/multi-needle-in-a-haystack). LLaMA 4 Scout NIAH results: over 95% retrieval accuracy up to 8M tokens, 89% at 10M (jangwook.net/en/blog/en/llama4-maverick-scout-enterprise-strategy). Inference-dependent degradation on long-context tasks (deadneurons.substack.com/p/the-dirty-secret-of-million-token).
The Full Picture: How Long Context Actually Works in Practice
Let us put all the pieces together. When you send a 500,000-token prompt to Claude Opus 4.6 or GPT-5.4, here is what happens under the hood:
Step 1: Tokenization and Prompt Caching
The input text is tokenized (Chapter 4) and checked against the prompt cache (Chapter 19). If the beginning of the prompt matches a cached prefix, the server loads the cached KV state and only processes the new tokens.
Step 2: Prefill with Flash Attention
The new tokens (those not covered by the prompt cache) are processed through the transformer stack. At each attention layer, Flash Attention computes the attention scores in tiles, never materializing the full n x n score matrix. This keeps memory usage linear in the sequence length.
If the sequence is distributed across multiple GPUs (context parallelism), Ring Attention or DeepSpeed-Ulysses handles the inter-GPU communication for the attention layers.
Step 3: KV Cache Population
As the prefill phase processes each layer, the K and V vectors for all tokens are stored in the KV cache (Chapter 18). For a 500,000-token prompt on a model like LLaMA 3.1 405B, this KV cache consumes approximately:
# LLaMA 3.1 405B KV cache at 500K tokens
bytes_per_token = 2 * 126 * 16 * 128 * 2 # 2(K+V) * layers * kv_heads * head_dim * bfloat16
total_bytes = bytes_per_token * 500_000
print(f"KV cache for 500K tokens: {total_bytes / 1024**3:.1f} GB")
# Approximately 483 GBThis is why long-context inference requires multiple GPUs: the KV cache alone for a single 500K-token request exceeds the memory of any single GPU.
Step 4: Decode with Cached Attention
During the decode phase (generating the response), each new token’s query attends to all 500,000+ cached keys. Flash Attention is used here too, but the computation pattern is different: instead of an n x n attention matrix, it is a 1 x n vector of attention scores (one query against all keys). This is much faster per step, but the KV cache must be read from memory at every step, which is why the decode phase is memory-bandwidth-bound (as discussed in Chapter 18).
The Cost of Long Context
Long-context inference is expensive in three ways:
Prefill latency (TTFT): Processing 500,000 tokens through the transformer stack takes significant time, even with Flash Attention. TTFT for very long prompts can be 10 to 30 seconds or more.
Memory: The KV cache for long contexts consumes hundreds of gigabytes, limiting the number of concurrent requests the server can handle.
Per-token decode cost: Each generated token must attend to all cached keys, so the decode phase is slower for longer contexts.
API providers reflect these costs in their pricing. As noted in Chapter 19, GPT-5.4 charges 2x the standard input rate and 1.5x the standard output rate for requests exceeding 272,000 tokens. Anthropic previously charged 2x input and 1.5x output for requests exceeding 200,000 tokens on Claude, though this surcharge was removed for Opus 4.6 and Sonnet 4.6 on March 13, 2026.
def long_context_cost_comparison():
"""
Compare the cost of processing different context lengths.
Uses GPT-5.4 pricing as an example.
"""
# GPT-5.4 pricing (March 2026)
standard_input = 2.50 # $/MTok
standard_output = 15.00 # $/MTok
long_context_multiplier = 2.0 # For inputs > 272K tokens
output_long_multiplier = 1.5 # For outputs when input > 272K tokens
threshold = 272_000
output_tokens = 2_000 # Fixed response length
print("GPT-5.4 Cost by Context Length")
print("=" * 60)
print(f"{'Input Tokens':>14} {'Input Cost':>12} {'Output Cost':>12} "
f"{'Total':>10} {'Note':>10}")
print("-" * 60)
for input_tokens in [10_000, 50_000, 128_000, 272_000,
500_000, 1_050_000]:
if input_tokens <= threshold:
input_cost = input_tokens / 1_000_000 * standard_input
output_cost = output_tokens / 1_000_000 * standard_output
note = ""
else:
# Full session billed at 2x input, 1.5x output
input_cost = input_tokens / 1_000_000 * standard_input * long_context_multiplier
output_cost = output_tokens / 1_000_000 * standard_output * output_long_multiplier
note = "2x/1.5x"
total = input_cost + output_cost
print(f"{input_tokens:>14,} ${input_cost:>10.4f} ${output_cost:>10.4f} "
f"${total:>8.4f} {note:>10}")
long_context_cost_comparison()Techniques That Extend Context Beyond Training Length
Models are trained on sequences of a fixed maximum length (determined by the max_position_embeddings in the model configuration). But several techniques allow models to handle sequences longer than their training length at inference time, with varying degrees of success.
RoPE Scaling
As you learned in Chapter 6, most modern models use Rotary Position Embeddings (RoPE) to encode token positions. RoPE applies a rotation to the query and key vectors based on their position in the sequence. The rotation frequency is determined by a base frequency parameter (typically 10,000 in the original formulation).
To extend the context window beyond the training length, you can modify the RoPE frequencies. The simplest approach is linear scaling: divide all position indices by a scaling factor. If the model was trained on 4,096 tokens and you want to use 16,384 tokens, you divide all positions by 4 (so position 16,384 becomes position 4,096 in the model’s internal representation).
YaRN (Yet another RoPE extensioN) is a more sophisticated approach that applies different scaling factors to different frequency components of RoPE. Low-frequency components (which encode long-range position information) are scaled more aggressively, while high-frequency components (which encode local position information) are left unchanged. This preserves the model’s ability to distinguish nearby tokens while extending its reach to distant tokens.
DeepSeek-V3 uses YaRN with a scaling factor of 40 to extend its context from the base training length to 128K tokens. LLaMA 3.1 extended from 8K to 128K tokens using a combination of RoPE scaling and continued pre-training on long sequences.
Continued Pre-training on Long Sequences
RoPE scaling alone is not sufficient for high-quality long-context performance. Models also need to be trained (or fine-tuned) on long sequences to learn how to use the extended context effectively. This is typically done as a second phase of training after the main pre-training run:
- Pre-train on the standard context length (e.g., 8K tokens) for the majority of training.
- Continue pre-training on progressively longer sequences (e.g., 32K, then 128K) for a smaller number of steps, with adjusted RoPE parameters.
This two-phase approach is more efficient than training on long sequences from the start, because long-sequence training is much more expensive (the O(n^2) attention cost means that training on 128K-token sequences is roughly 256x more expensive per token than training on 8K-token sequences, before accounting for Flash Attention optimizations).
Putting It All Together: The Long-Context Stack
Here is a summary of how all the techniques in this chapter work together to enable long-context inference:
Layer 1: Position Encoding (Chapter 6)
RoPE with YaRN scaling enables the model to represent positions
beyond its original training length.
Layer 2: Flash Attention (this chapter)
Computes exact attention in O(n) memory instead of O(n^2)
by tiling the computation to fit in GPU SRAM.
Layer 3: KV Cache (Chapter 18)
Stores K and V vectors so they are computed only once per token,
not recomputed at every generation step.
Layer 4: KV Cache Compression (Chapter 18)
Quantization (FP8, INT4), eviction (H2O), and cross-layer sharing
reduce the memory footprint of the KV cache.
Layer 5: Prompt Caching (Chapter 19)
Reuses KV cache across API calls for shared prefixes,
avoiding redundant prefill computation.
Layer 6: Context Parallelism (this chapter)
Ring Attention or DeepSpeed-Ulysses distributes the attention
computation across multiple GPUs for very long sequences.
Layer 7: Sparse Attention (this chapter)
For extremely long sequences (10M+ tokens), restricts attention
to local windows plus global tokens to reduce O(n^2) to O(n).Each layer addresses a different aspect of the long-context challenge. Flash Attention handles memory. The KV cache handles redundant computation. Prompt caching handles redundant prefill. Context parallelism handles distribution. Sparse attention handles the fundamental quadratic scaling. Together, they form a stack that has enabled context windows to grow from 2,048 tokens to 10 million tokens in six years.
Key Takeaways
The core challenge of long context is the quadratic scaling of attention: computing attention over n tokens requires O(n^2) score computations and, in the standard implementation, O(n^2) memory for the score matrix. At 1 million tokens, the score matrix for a single attention head would require over 4 TB of memory. This makes the standard attention algorithm completely impractical for long sequences.
Flash Attention (Dao et al., arXiv:2205.14135, NeurIPS 2022) solved the memory problem by computing attention in tiles that fit in GPU SRAM, never materializing the full n x n score matrix. It computes the exact same result as standard attention (no approximation) while reducing memory from O(n^2) to O(n). Flash Attention has gone through four versions: v1 for A100 (NeurIPS 2022), v2 for A100 with better parallelism (ICLR 2024, 225 TFLOPs/s, 72% utilization), v3 for H100 (NeurIPS 2024, 840 TFLOPs/s BF16, 85% utilization, 1.3 PFLOPs/s FP8), and v4 for B200 Blackwell (March 2026, 1,613 TFLOPs/s BF16, 71% utilization). Each version required a complete rewrite to match the target GPU’s memory hierarchy and instruction set.
Ring Attention (Liu et al., arXiv:2310.01889, ICLR 2024) distributes the attention computation across multiple GPUs arranged in a logical ring. Each GPU holds a chunk of the sequence and passes K/V blocks around the ring, overlapping communication with computation. This enables processing sequences that are P times longer than a single GPU can handle, where P is the number of GPUs. Liu et al. demonstrated training on sequences exceeding 100 million tokens.
Sparse attention restricts each token to attend to a subset of other tokens, reducing complexity from O(n^2) to O(n * k). Longformer (Beltagy et al., arXiv:2004.05150, 2020) introduced sliding window attention with global tokens. BigBird (Zaheer et al., arXiv:2007.14062, NeurIPS 2020) combined sliding window, global, and random attention, proving that this sparse pattern is a universal approximator. Most frontier models as of March 2026 use full attention with Flash Attention rather than sparse patterns, but sparse attention remains important for extremely long contexts (10M+ tokens).
Context parallelism splits the input sequence across GPUs, using Ring Attention or all-to-all communication (DeepSpeed-Ulysses, arXiv:2309.14509) for the attention layers. NVIDIA’s Dynamic Context Parallelism in Megatron Core achieves up to 1.48x speedup by dynamically adjusting parallelism per microbatch.
Context windows have grown from 2,048 tokens (GPT-3, 2020) to 10 million tokens (LLaMA 4 Scout, 2025), a 5,000x increase in five years. As of March 2026, the major frontier models offer: GPT-5.4 at 1.05M tokens (272K standard, 1.05M opt-in), Claude Opus 4.6 at 1M tokens, Gemini 3.1 Pro at 1M tokens (per official DeepMind model card), and Grok 4 Fast at 2M tokens.
The “lost in the middle” phenomenon (Liu et al., arXiv:2307.03172, TACL 2024) shows that models perform significantly worse when relevant information is in the middle of the context rather than at the beginning or end. Accuracy can drop by over 30% due to the U-shaped attention curve caused by primacy and recency biases. This means a larger context window does not automatically mean better performance.
Needle-in-a-Haystack tests evaluate whether models can retrieve specific facts from long contexts. While useful, they only test simple retrieval, not reasoning over long contexts. The gap between advertised and effective context windows is a recurring issue: LLaMA 4 Scout claims 10M tokens and achieves over 95% NIAH retrieval accuracy up to 8M tokens, but performance on tasks requiring inference (not just literal retrieval) degrades more substantially at long contexts.
Long-context inference is expensive in three ways: high prefill latency (TTFT of 10 to 30+ seconds for very long prompts), large KV cache memory (hundreds of GB for 500K+ token contexts), and slower per-token decode (each generated token must attend to all cached keys). API providers reflect these costs through long-context surcharges, though Anthropic removed its surcharge for Opus 4.6 and Sonnet 4.6 on March 13, 2026.
The long-context stack combines multiple techniques: RoPE scaling for position encoding, Flash Attention for memory-efficient computation, KV cache for avoiding redundant computation, KV cache compression for reducing memory, prompt caching for reusing work across calls, context parallelism for multi-GPU distribution, and sparse attention for extreme sequence lengths. Each layer addresses a different aspect of the challenge.
What’s Next
You now understand how models process sequences from 4K to 10 million tokens, the techniques that make this possible, and the limitations that remain. But language models are not limited to text. In Chapter 21, we will explore how models see images: the vision encoders that convert pixels into tokens, the cross-attention mechanisms that connect visual and textual information, and the real-world capabilities and limitations of visual understanding in modern LLMs.