KV Cache: Speed Up LLM Inference
Key-Value (KV) caching is the single most important optimization for the decode phase of LLM inference. During decoding, the model generates one new token at a time and must recompute attention scores for all previous tokens. KV caching stores the attention key and value matrices from previous tokens in GPU memory, allowing the model to reuse them without recomputation. This optimization reduces decode latency by 2-3x and is enabled by default in all production inference engines (vLLM, TensorRT-LLM, DeepSpeed). Understanding KV cache is essential for tuning batch sizes, managing GPU memory, and diagnosing decode bottlenecks.
How Transformer Attention Works
Before understanding KV caching, recall how transformer attention works. Given a query (Q), key (K), and value (V) matrix, attention computes:
Attention(Q, K, V) = softmax(Q * K^T / sqrt(d_k)) * V
During the prefill phase (processing the entire prompt), the model computes all three matrices (Q, K, V) for all tokens. During the decode phase (generating one new token at a time), the model only needs to:
- Compute Q for the new token (one vector).
- Compute K and V for the new token (two vectors).
- Multiply Q against K^T for all previous tokens (includes all previous K and V).
The key insight: K and V for previous tokens never change. Recomputing them every iteration is wasteful. KV caching stores all previous K and V in memory and reuses them.
Memory Layout of KV Cache
For a model with num_layers transformer layers, num_heads attention heads, and head_dim dimensions per head, the KV cache for a batch of batch_size sequences each with seq_len tokens occupies:
KV_cache_size (bytes) =
num_layers * batch_size * seq_len * num_heads * head_dim * 2 (K and V) * bytes_per_element
For a 7B-parameter model (typical configuration: 32 layers, 32 heads, 256 dims per head) with batch size 8 and sequence length 2000 tokens, in FP16 (2 bytes per float):
KV_cache_size = 32 * 8 * 2000 * 32 * 256 * 2 * 2 = ~67 GB
This is a problem: a single A100 GPU has 40-80 GB of VRAM, so KV cache alone can saturate memory. For this reason, KV cache is carefully managed:
- Allocate pre-sized buffers: Preallocate maximum possible KV cache at startup (enough for max batch size and max sequence length) rather than growing dynamically.
- Free cache on sequence completion: When a sequence finishes generation (reaches max_tokens or stop token), release its cache.
- Use more efficient data types: Store KV cache in FP8 or INT8 instead of FP16, reducing memory by 50% (with minimal accuracy loss on most models).
Code Example: KV Cache Management
Below is a simplified example showing how KV cache grows during decode:
import torch
from torch import nn
class KVCache:
"""
Stores key and value caches for a transformer model.
Reuses stored K/V during decode to avoid recomputation.
"""
def __init__(self, num_layers: int, batch_size: int,
max_seq_len: int, num_heads: int, head_dim: int,
dtype: torch.dtype = torch.float16):
self.num_layers = num_layers
self.batch_size = batch_size
self.max_seq_len = max_seq_len
self.num_heads = num_heads
self.head_dim = head_dim
# Pre-allocate buffers for K and V
# Shape: (batch_size, num_heads, max_seq_len, head_dim)
self.k_cache = torch.zeros(
(num_layers, batch_size, num_heads, max_seq_len, head_dim),
dtype=dtype, device='cuda'
)
self.v_cache = torch.zeros(
(num_layers, batch_size, num_heads, max_seq_len, head_dim),
dtype=dtype, device='cuda'
)
# Track current sequence length for each batch
self.seq_len = torch.zeros(batch_size, dtype=torch.long, device='cuda')
def update(self, layer_idx: int, k_new: torch.Tensor,
v_new: torch.Tensor, batch_idx: int):
"""
Append new K and V for a single token to the cache.
k_new shape: (num_heads, head_dim)
v_new shape: (num_heads, head_dim)
"""
pos = self.seq_len[batch_idx].item()
if pos >= self.max_seq_len:
raise RuntimeError(f"Sequence length {pos} exceeds max {self.max_seq_len}")
self.k_cache[layer_idx, batch_idx, :, pos, :] = k_new
self.v_cache[layer_idx, batch_idx, :, pos, :] = v_new
self.seq_len[batch_idx] += 1
def get(self, layer_idx: int, batch_idx: int) -> tuple:
"""Retrieve cached K and V for a sequence up to current seq_len."""
seq_len = self.seq_len[batch_idx].item()
k = self.k_cache[layer_idx, batch_idx, :, :seq_len, :]
v = self.v_cache[layer_idx, batch_idx, :, :seq_len, :]
return k, v
def memory_usage_mb(self) -> float:
"""Estimate memory usage in MB."""
# FP16 = 2 bytes per element
bytes_per_element = 2
total_bytes = (
self.k_cache.numel() + self.v_cache.numel()
) * bytes_per_element
return total_bytes / (1024 ** 2)
# Example usage
cache = KVCache(
num_layers=32, batch_size=8, max_seq_len=4096,
num_heads=32, head_dim=128
)
print(f"KV cache memory: {cache.memory_usage_mb():.1f} MB") # ~268 MB
# During decode, update cache with each new token
for token_idx in range(100):
k_new = torch.randn(8, 32, 128, device='cuda', dtype=torch.float16) # batch, heads, dim
v_new = torch.randn(8, 32, 128, device='cuda', dtype=torch.float16)
for batch_idx in range(8):
cache.update(layer_idx=0, k_new=k_new[batch_idx],
v_new=v_new[batch_idx], batch_idx=batch_idx)
This code pre-allocates all KV cache at startup and appends new K/V as tokens are generated.
KV Cache Trade-offs: Speed vs. Memory
KV caching is a classic speed-memory tradeoff:
| Technique | Decode Speed | Memory Overhead | Notes |
|---|---|---|---|
| No KV cache | 1x (baseline) | 0% | Recompute K/V every token; very slow |
| FP16 KV cache | 2.5x-3x | 20-30% | Standard; widely used |
| FP8 KV cache | 2.5x-3x | 10-15% | Reduced precision; minimal accuracy loss |
| KV cache + quantization | 3x-4x | 10-20% | Combined with model quantization |
On an A100 GPU with a 7B model and batch size 8, FP16 KV cache reduces decode latency from 15-20ms per token to 5-7ms per token.
Managing KV Cache Under Load
In production, managing KV cache requires careful resource planning:
-
Eviction policies: When GPU memory fills up, which sequences' caches should be evicted? Common policies: Least-Recently-Used (LRU), First-In-First-Out (FIFO), or priority-based (higher-priority users keep cache). vLLM uses a scheduler-aware eviction policy.
-
Selective caching: Cache only the final layer(s) if memory is extremely tight. Caching the last 4 layers instead of all 32 reduces memory by 7/8 with minimal speed loss.
-
Reuse caches across requests: If two requests have identical prompt prefixes (common in multi-turn chat), reuse the same cached K/V. This is prompt caching, covered in the next article.
vLLM KV Cache Configuration
vLLM enables KV caching by default and exposes these tuning options:
from vllm import LLM, EngineArgs
engine_args = EngineArgs(
model="meta-llama/Llama-2-7b-hf",
# GPU memory allocation (including KV cache)
gpu_memory_utilization=0.9, # Use 90% of GPU VRAM
# KV cache settings
cache_dtype="auto", # auto, fp16, fp8, or int8
# auto picks the model's dtype (usually float16)
# Preemption: if true, pause low-priority requests
# to free cache for high-priority ones
preempt_mode="swap", # "swap" uses CPU memory as overflow
# Block manager settings
block_size=16, # Allocate KV cache in 16-token blocks
)
llm = LLM(**engine_args.to_dict())
The block_size=16 parameter deserves explanation: instead of allocating contiguous memory for each sequence's cache, vLLM allocates it in fixed-size blocks (16 tokens). This reduces fragmentation and allows flexible cache reuse across sequences. A sequence of 100 tokens uses 7 blocks; a sequence of 110 tokens uses 7 blocks (unused space in the 7th block is wasted, but total fragmentation is low).
Key Takeaways
- KV caching stores attention K/V matrices to avoid recomputing them during token generation.
- 2-3x decode speedup with 20-30% memory overhead; essential for production.
- Memory scales linearly with batch size and sequence length: large batches or long sequences consume significant cache.
- Use FP8 or INT8 for KV cache when memory is tight; minimal accuracy loss.
- Combine with prompt caching (next article) to reuse cache across requests with identical prefixes.
Frequently Asked Questions
Why does KV cache grow linearly with sequence length?
During prefill, the model computes K and V for all tokens in one pass. During decode, it appends one new K/V per token generated. By the time the sequence is 2000 tokens long, the KV cache stores 2000 K/V matrices, which is roughly 2x the memory of the model weights themselves.
Can I reduce KV cache memory by using a smaller head dimension?
No, head dimension is a model architecture parameter. You cannot change it without retraining. However, you can use 8-bit quantization for KV cache (FP8, INT8) to reduce memory by 50%.
Does KV cache help during prefill?
No, KV cache only helps during decode. During prefill, the model processes all prompt tokens in parallel, so there is nothing to reuse. KV cache becomes relevant only after prefill completes and token generation begins.
What happens if KV cache is full?
vLLM evicts low-priority sequences to free space. If preemption is enabled (preempt_mode="swap"), it spills evicted KV cache to CPU memory, which is slower but keeps the sequence alive. If disabled, the sequence is paused until more cache becomes available.