Multi-Stage Retrieval: Coarse-to-Fine Ranking Architectures
Multi-stage retrieval applies a cascade of increasingly expensive ranking functions to progressively narrow a candidate set, balancing speed and accuracy. Instead of reranking all 1,000+ documents retrieved from a dense index (expensive), a multi-stage system first filters with a fast, coarse ranker (BM25), then applies medium-cost re-ranking (fast reranker model) to top-100, and finally applies the most expensive method (cross-encoder or colbert) to top-20. This approach can reduce latency by 50–70% (from 1,000+ ms to 300–500 ms) while maintaining or improving accuracy compared to single-stage reranking. Multi-stage architectures are prevalent in production systems: Google Search uses 3+ ranking stages, Elasticsearch now supports multi-stage queries, and modern RAG frameworks (LlamaIndex, Langchain) support cascading retrievers. This article teaches you to design, implement, and optimize multi-stage retrieval pipelines for your specific latency and accuracy constraints.
Multi-Stage Architecture: Classic Coarse-to-Fine
A typical 3-stage pipeline:
┌─────────────────────────────────────────────────────────────────┐
│ Stage 1: COARSE (Fast, Cheap) │
│ Method: BM25 keyword retrieval │
│ Candidates: 1,000–10,000 │
│ Latency: 10–50 ms │
│ Output: top-100 │
└─────────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────┐
│ Stage 2: MEDIUM (Balanced) │
│ Method: Dense embedding search OR lightweight reranker │
│ Candidates: 100 │
│ Latency: 50–100 ms │
│ Output: top-20 │
└─────────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────┐
│ Stage 3: FINE (Slow, Accurate) │
│ Method: Cross-encoder neural reranker │
│ Candidates: 20 │
│ Latency: 50–100 ms │
│ Output: top-5 (to LLM) │
└─────────────────────────────────────────────────────────────────┘
Stage 1 (Coarse): BM25 retrieval is very fast (<50 ms) and handles keyword precision. It filters out documents with no lexical overlap with the query, reducing noise. The output is typically top-100, a manageable candidate set.
Stage 2 (Medium): Dense embedding search or a fast reranker model (TinyBERT) re-ranks the 100 candidates. Dense retrieval at this stage is feasible (100 embeddings computed is trivial); fast reranking (500 pairs/sec) adds only 50–100 ms. Output: top-20 for the final stage.
Stage 3 (Fine): A high-quality cross-encoder model (MiniLM, ELECTRA) reranks the 20 candidates. Latency is minimal because k is small; accuracy is maximized because only high-potential candidates are considered.
Total latency: 10 + 75 + 50 = 135 ms (vs. 400+ ms for a single stage dense + cross-encoder reranker of all 10,000 candidates).
Why Multi-Stage Works: Cascade Recall
Multi-stage systems work if each stage has cascade recall > 0 for relevant documents. Cascade recall is the probability that a document ranked relevant in stage 1's output also appears in stage 2's output. If relevant documents are filtered out in early stages, final accuracy crashes.
For example:
- Stage 1 (BM25) retrieves top-100, including 85 out of 100 relevant documents for a query (85% cascade recall).
- Stage 2 (dense) filters to top-20 from those 100, keeping 18 of the 85 relevant (21% of 85 ≈ 21%).
- Stage 3 (cross-encoder) keeps top-5 from the 20, including 4 of the 18 (22%).
Final cascade recall: 85% * 21% * 22% = 4% of original relevant documents reach the final stage. If there are 100 relevant documents total, only 4 are in the final top-5, a recall of 4%—unacceptable.
The solution: set top-k values conservatively at each stage to ensure high cascade recall. A rule of thumb:
top_k at stage i = (# target docs for LLM) * (# stages remaining) ** -1
For 5-doc final target and 3 stages:
- Stage 1: 5 * 3 ≈ 15 * 2 = 30? No—be generous. Use top-100.
- Stage 2: top-50 from stage 1's 100 (50% reduction).
- Stage 3: top-5 from stage 2's 50 (10% reduction).
Implementation: Multi-Stage Orchestration
import asyncio
import time
import logging
logger = logging.getLogger(__name__)
class MultiStageRetriever:
def __init__(self, bm25_index, dense_retriever, reranker):
self.bm25_index = bm25_index
self.dense_retriever = dense_retriever
self.reranker = reranker
async def retrieve_multistage(self, query: str, top_k: int = 5) -> list[dict]:
"""
Multi-stage retrieval: BM25 (coarse) → dense (medium) → cross-encoder (fine)
"""
start_time = time.time()
# Stage 1: BM25 (coarse filter)
logger.info(f"Stage 1: BM25 retrieval for '{query[:50]}'")
stage1_start = time.time()
bm25_results = await self._bm25_stage(query, top_k=100)
stage1_latency = (time.time() - stage1_start) * 1000
logger.info(f" Stage 1 latency: {stage1_latency:.1f}ms, candidates: {len(bm25_results)}")
if not bm25_results:
logger.warning("Stage 1 returned no results")
return []
# Stage 2: Dense retrieval or lightweight reranker (medium)
logger.info(f"Stage 2: Dense retrieval from top-{len(bm25_results)}")
stage2_start = time.time()
stage2_results = await self._dense_stage(query, bm25_results, top_k=50)
stage2_latency = (time.time() - stage2_start) * 1000
logger.info(f" Stage 2 latency: {stage2_latency:.1f}ms, candidates: {len(stage2_results)}")
if not stage2_results:
logger.warning("Stage 2 returned no results, returning Stage 1 results")
return bm25_results[:top_k]
# Stage 3: Cross-encoder reranking (fine)
logger.info(f"Stage 3: Cross-encoder reranking top-{len(stage2_results)}")
stage3_start = time.time()
final_results = await self._rerank_stage(query, stage2_results, top_k=top_k)
stage3_latency = (time.time() - stage3_start) * 1000
logger.info(f" Stage 3 latency: {stage3_latency:.1f}ms, final candidates: {len(final_results)}")
total_latency = (time.time() - start_time) * 1000
logger.info(f"Total multi-stage latency: {total_latency:.1f}ms")
# Attach latency metadata
for result in final_results:
result['retrieval_latency_ms'] = total_latency
return final_results
async def _bm25_stage(self, query: str, top_k: int = 100) -> list[dict]:
"""Stage 1: Fast BM25 retrieval"""
results = self.bm25_index.search(query, top_k=top_k)
return [{'doc_id': doc_id, 'text': text, 'score': score, 'stage': 1}
for doc_id, text, score in results]
async def _dense_stage(self, query: str, candidates: list[dict], top_k: int = 50) -> list[dict]:
"""Stage 2: Dense embedding re-scoring or lightweight reranker"""
# Option A: Re-score candidates with dense embedding
# (computationally cheap for 100 candidates)
query_emb = self.dense_retriever.encode(query)
for candidate in candidates:
doc_emb = self.dense_retriever.encode(candidate['text'])
similarity = self.dense_retriever.compute_similarity(query_emb, doc_emb)
candidate['dense_score'] = similarity
# Sort by dense score and take top-k
candidates.sort(key=lambda x: x['dense_score'], reverse=True)
stage2_results = candidates[:top_k]
for result in stage2_results:
result['stage'] = 2
return stage2_results
async def _rerank_stage(self, query: str, candidates: list[dict], top_k: int = 5) -> list[dict]:
"""Stage 3: Cross-encoder neural reranking"""
# Prepare pairs for cross-encoder
pairs = [(query, candidate['text']) for candidate in candidates]
# Score all pairs
scores = self.reranker.predict(pairs)
# Update candidates with reranker scores
for i, candidate in enumerate(candidates):
candidate['reranker_score'] = float(scores[i])
# Sort by reranker score and take top-k
candidates.sort(key=lambda x: x['reranker_score'], reverse=True)
final_results = candidates[:top_k]
for result in final_results:
result['stage'] = 3
return final_results
# Example usage
retriever = MultiStageRetriever(bm25_index, dense_retriever, reranker)
results = asyncio.run(retriever.retrieve_multistage("What is transformer attention?", top_k=5))
for i, result in enumerate(results, 1):
print(f"{i}. {result['text'][:80]}... (score: {result['reranker_score']:.3f})")
Latency Analysis: When to Use Multi-Stage
Multi-stage retrieval is beneficial when:
- Candidate set is large (500+): Dense retrieval or reranking all 500+ is expensive. Multi-stage filters early.
- Latency is critical (<500 ms): Multi-stage avoids costly single-stage reranking of large sets.
- Accuracy requirements are moderate: If you need top-5 or top-10, multi-stage achieves near-single-stage accuracy with 50–70% latency reduction.
Multi-stage is not beneficial when:
- Candidate set is small (< 50): Overhead of multiple stages exceeds benefit of early filtering.
- All candidates must be scored: Some use-cases require scoring all documents (e.g., clustering, scoring entire corpus).
Learned Cascade Optimization
Beyond manual multi-stage design, you can learn optimal cascade parameters from data:
def learn_cascade_thresholds(eval_set: dict, bm25_index, dense_retriever, reranker):
"""
Learn optimal top-k thresholds for each cascade stage.
Goal: Maximize accuracy (final NDCG@5) while minimizing latency.
"""
from itertools import product
# Parameter grid: (top_k at stage 1, top_k at stage 2, top_k at stage 3)
param_grid = list(product([50, 100, 200], [20, 50], [5, 10]))
best_config = None
best_score = 0.0
for k1, k2, k3 in param_grid:
# Evaluate cascade with these thresholds
metrics = evaluate_cascade(eval_set, bm25_index, dense_retriever, reranker, k1, k2, k3)
ndcg = metrics['NDCG@5']
latency = metrics['latency_ms']
# Prefer higher accuracy, but penalize extreme latency
score = ndcg - 0.0001 * (latency - 200) # Penalize latency > 200ms
if score > best_score:
best_score = score
best_config = (k1, k2, k3)
print(f"Config (k1={k1}, k2={k2}, k3={k3}): NDCG@5={ndcg:.4f}, Latency={latency:.0f}ms")
print(f"\nBest config: k1={best_config[0]}, k2={best_config[1]}, k3={best_config[2]}")
return best_config
# Run optimization
best_k1, best_k2, best_k3 = learn_cascade_thresholds(eval_set, bm25_index, dense_retriever, reranker)
Alternative Multi-Stage Architectures
Beyond BM25 → dense → cross-encoder, other pipelines work well:
Hybrid fusion → reranking:
- Stage 1: Parallel BM25 + dense, fuse with RRF → top-100
- Stage 2: Lightweight reranker (TinyBERT) → top-20
- Stage 3: Cross-encoder (MiniLM) → top-5
- Advantage: Hybrid fusion captures both keyword and semantic signals early, reducing stage 2's burden.
Two-step reranking (ColBERT + Cross-Encoder):
- Stage 1: BM25 → top-100
- Stage 2: ColBERT (efficient late-interaction) → top-20
- Stage 3: Cross-encoder → top-5
- Advantage: ColBERT is faster than dense + cross-encoder for large k, ideal for middle stages.
Key Takeaways
- Multi-stage retrieval applies a cascade of increasingly expensive ranking functions to progressively narrow candidate sets, reducing latency 50–70% vs. single-stage.
- Typical 3-stage: BM25 (coarse, 100 candidates) → dense/lightweight reranker (medium, 20 candidates) → cross-encoder (fine, 5 candidates).
- Cascade recall is critical: ensure each stage's top-k is high enough that relevant documents survive filtering. A rule of thumb: top-k = target_docs * stages.
- Multi-stage is beneficial for large candidate sets (500+) and latency-critical applications (<500 ms). For small sets or when all candidates must be scored, single-stage is simpler.
- Learned cascade optimization can find near-optimal thresholds by evaluating configurations on an evaluation set and maximizing accuracy-latency trade-off.
Frequently Asked Questions
How do I prevent relevant documents from being filtered out in early stages?
Set conservative top-k values at each stage. A document must survive stage 1's top-k to enter stage 2. For stage 1, use top-100 to top-200 (generous thresholds); only use top-20 if you have high confidence in the stage 1 ranker.
Can I use different ranking methods in my cascade?
Yes. A typical cascade might be BM25 → dense → cross-encoder (different methods at each stage). Alternatively, use the same method with different parameters (e.g., BM25 with k1=1.5 at stage 1, k1=2.0 at stage 2 for more aggressive term frequency weighting).
What if dense retrieval fails in stage 2?
Implement fallback: if stage 2 fails or times out, return top-k from stage 1. Similarly, if stage 3 (cross-encoder) fails, return stage 2's ranking. Graceful degradation ensures the system always returns results.
How do I measure latency savings from multi-stage?
Compare end-to-end latency: (1) retrieve candidates and rerank all (single-stage), (2) multi-stage cascade. On typical systems, multi-stage is 2–5x faster for large candidate sets due to early filtering. Benchmark on your corpus and query distribution.
Should I use ReciprocaL Rank Fusion (RRF) in a cascade?
Yes. Use RRF at stage 1 (fuse BM25 + dense in parallel) → top-100 → stage 2 → stage 3. This combines the speed of hybrid fusion with the accuracy of cascaded reranking.
Further Reading
- Cascade Ranking in Information Retrieval (Li, 2008) — Foundational work on cascade optimization
- Learning to Rank in Cascades (Hofmann et al., 2013) — Learning optimal cascade thresholds
- ColBERT: Efficient and Effective Passage Search via Contextualized Late Interaction — Efficient middle-stage reranker
- Elasticsearch Multi-Stage Query Guide — Production implementation of cascades