Input Drift Detection: Catch Distribution Shifts Early
Input drift detection is the art of spotting when user prompts shift in distribution before model outputs degrade. This is often your first warning sign: if users start asking about topics your training data didn't cover, or they switch from formal English to code-mixed slang, your LLM's output quality will suffer even if the model weights never change. This article teaches you to quantify input distributions, detect shifts using statistical tests, and act on them before users notice quality loss.
Why input drift matters for LLM applications
Input drift precedes output drift. When your user base changes (e.g., expansion into a new geographic market, or acquisition of a competitor's customers), or when new use cases emerge (e.g., your API suddenly serves code-generation requests instead of summarization), the nature of prompts you receive shifts. A summarization model trained on English news articles will struggle with noisy social-media text or mixed-language inputs if those arrive unexpectedly.
Unlike supervised learning, where input shift alone might not hurt performance if the model is robust, LLMs are sensitive to prompt structure, style, and domain. A prompt that worked for casual users may be inadequate for technical users; a prompt engineered for GPT-3.5 may misfire on GPT-4 or vice versa. Detecting input drift early lets you update your system prompt, retrain with new examples, or segment users before output quality tanks.
Measuring input distributions: key metrics
To detect drift, first quantify inputs. Key metrics include:
Lexical features:
- Vocabulary diversity (unique tokens / total tokens)
- Average token count per prompt
- Readability score (e.g., Flesch-Kincaid grade level)
- Presence of code, URLs, or special characters
Semantic features:
- Embedding-based similarity to training distribution (cosine distance from training centroid)
- Topic distribution (using BERTopic or LDA)
- Intent distribution (if you have an intent classifier)
Metadata:
- User geography, language, device, or tier
- Time of day, day of week (temporal drift)
- Estimated user expertise (from query complexity)
# Input drift metric collection
import numpy as np
from sentence_transformers import SentenceTransformer
from collections import Counter
import re
model = SentenceTransformer("all-MiniLM-L6-v2")
def measure_input_metrics(prompts):
"""Compute a feature vector for a batch of prompts."""
# Lexical features
token_counts = [len(p.split()) for p in prompts]
vocab_sizes = [len(set(p.lower().split())) for p in prompts]
diversity = [v / t if t > 0 else 0 for v, t in zip(vocab_sizes, token_counts)]
# Semantic features: embed and compute centroid distance
embeddings = model.encode(prompts)
centroid = np.mean(embeddings, axis=0)
distances_from_centroid = [
np.linalg.norm(e - centroid) for e in embeddings
]
# Special character presence
code_indicators = [1 if any(c in p for c in "(){}[];:") else 0 for p in prompts]
metrics = {
"avg_token_count": np.mean(token_counts),
"avg_diversity": np.mean(diversity),
"avg_distance_from_centroid": np.mean(distances_from_centroid),
"pct_with_code": 100.0 * sum(code_indicators) / len(prompts)
}
return metrics
# Example usage
baseline_prompts = [
"Summarize this article about AI",
"What is machine learning?",
"Explain neural networks"
]
current_prompts = [
"def fibonacci(n): # Calculate nth Fibonacci number",
"write a REST API in Python",
"debug: why is my function slow?"
]
baseline_metrics = measure_input_metrics(baseline_prompts)
current_metrics = measure_input_metrics(current_prompts)
print("Baseline:", baseline_metrics)
print("Current:", current_metrics)
Statistical tests for input drift
Once you have metrics, apply statistical tests to detect shifts:
Kolmogorov-Smirnov (KS) test: Compare the empirical cumulative distribution function (CDF) of a metric (e.g., token count) in the baseline vs. current data. KS is sensitive to shifts anywhere in the distribution.
Kullback-Leibler (KL) divergence: Measure how much one probability distribution differs from another. Higher KL = more drift. KL divergence is directional, so it's useful for detecting when current data is "harder" than baseline.
Chi-squared test: For categorical metrics (e.g., intent distribution), test if the observed frequency distribution differs from expected.
# Statistical drift detection
from scipy.stats import ks_2samp, chi2_contingency
import numpy as np
def detect_drift_ks_test(baseline_metric, current_metric, threshold=0.05):
"""
KS test for numerical metrics.
Returns True if drift is detected (p-value < threshold).
"""
statistic, p_value = ks_2samp(baseline_metric, current_metric)
return p_value < threshold, p_value, statistic
def detect_drift_chi_squared(baseline_counts, current_counts, threshold=0.05):
"""
Chi-squared test for categorical metrics (e.g., intent distribution).
"""
contingency = np.array([baseline_counts, current_counts])
chi2, p_value, dof, expected = chi2_contingency(contingency)
return p_value < threshold, p_value
# Example
baseline_token_counts = [12, 15, 8, 20, 11, 9, 14]
current_token_counts = [45, 42, 50, 48, 55, 52] # Much longer!
is_drifted, p_val, stat = detect_drift_ks_test(baseline_token_counts, current_token_counts)
print(f"KS test p-value: {p_val:.4f}, Drift detected: {is_drifted}")
# Intent distribution test
baseline_intents = [50, 30, 20] # [question, request, bug-report]
current_intents = [20, 10, 70] # Shift toward bug-reports
is_drifted, p_val = detect_drift_chi_squared(baseline_intents, current_intents)
print(f"Intent drift p-value: {p_val:.4f}, Drift detected: {is_drifted}")
Embedding-based drift detection: semantic shift
A powerful approach uses sentence embeddings to detect when prompts shift semantically, even if simple metrics (token count, etc.) don't change. Compute the mean embedding (centroid) of your baseline prompts, then measure how far current prompts drift from that centroid.
# Embedding-based drift detection using Mahalanobis distance
from sentence_transformers import SentenceTransformer
from scipy.spatial.distance import mahalanobis
import numpy as np
model = SentenceTransformer("all-MiniLM-L6-v2")
def compute_drift_score(baseline_prompts, current_prompts):
"""
Use Mahalanobis distance to detect multivariate embedding drift.
"""
# Baseline: compute mean and covariance
baseline_embeddings = model.encode(baseline_prompts)
baseline_mean = np.mean(baseline_embeddings, axis=0)
baseline_cov = np.cov(baseline_embeddings.T)
baseline_cov_inv = np.linalg.pinv(baseline_cov)
# Current: compute mean Mahalanobis distance
current_embeddings = model.encode(current_prompts)
distances = [
mahalanobis(e, baseline_mean, baseline_cov_inv)
for e in current_embeddings
]
return {
"mean_distance": np.mean(distances),
"std_distance": np.std(distances),
"max_distance": np.max(distances),
"pct_outliers": 100.0 * sum(1 for d in distances if d > np.mean(distances) + 2*np.std(distances)) / len(distances)
}
baseline = ["What is AI?", "Explain ML", "Tell me about deep learning"]
current = ["How do I trade crypto?", "Best NFT to buy?", "Blockchain explained"]
drift_score = compute_drift_score(baseline, current)
print(f"Drift score: {drift_score}")
Segmented input drift: when different user groups behave differently
Not all input drift is global. Different user cohorts may exhibit different drift patterns. For example:
- Enterprise users may migrate from simple queries to complex workflows with structured inputs (JSON, code).
- Mobile users may shift to shorter, more conversational prompts.
- International users may increase code-switching or non-English text.
Segment your baseline and current data by user group, geography, or other attributes, then run drift tests separately. This prevents a drift signal from one segment masking opposite signals in others.
# Segmented drift detection
def detect_segmented_input_drift(baseline_data, current_data, segment_key="user_tier"):
"""
Detect input drift within user segments.
baseline_data and current_data are lists of dicts with 'prompt' and segment_key.
"""
results = {}
# Group by segment
baseline_by_segment = {}
for record in baseline_data:
seg = record[segment_key]
if seg not in baseline_by_segment:
baseline_by_segment[seg] = []
baseline_by_segment[seg].append(record["prompt"])
current_by_segment = {}
for record in current_data:
seg = record[segment_key]
if seg not in current_by_segment:
current_by_segment[seg] = []
current_by_segment[seg].append(record["prompt"])
# Test each segment
for segment in baseline_by_segment:
if segment in current_by_segment:
baseline_tokens = [len(p.split()) for p in baseline_by_segment[segment]]
current_tokens = [len(p.split()) for p in current_by_segment[segment]]
is_drifted, p_val, _ = detect_drift_ks_test(baseline_tokens, current_tokens)
results[segment] = {"drifted": is_drifted, "p_value": p_val}
return results
# Example
baseline = [
{"prompt": "What is AI?", "user_tier": "free"},
{"prompt": "Explain ML", "user_tier": "pro"},
{"prompt": "Deep learning basics", "user_tier": "free"}
]
current = [
{"prompt": "Build a Transformer from scratch", "user_tier": "free"},
{"prompt": "Train a model with GPUs", "user_tier": "pro"}
]
results = detect_segmented_input_drift(baseline, current, segment_key="user_tier")
print(results)
Dashboarding and alerting on input drift
Create a dashboard that visualizes input metrics over time. Key visualizations:
- Time-series plot: Token count, vocabulary diversity, embedding distance from centroid over rolling windows.
- Distribution plots: Histogram of current vs. baseline token counts side-by-side.
- Heatmap: Input metrics by user segment and time (reveals cohort-specific drift).
- Alert timeline: When KS tests or other detectors fired, correlated with system changes.
Key Takeaways
- Input drift is a leading indicator of output degradation; detect it before quality suffers.
- Measure inputs using lexical (token count, diversity), semantic (embeddings), and metadata features.
- Use statistical tests (KS test, KL divergence, Chi-squared) to formally detect shifts.
- Embedding-based drift (Mahalanobis distance) captures semantic shifts that simple metrics miss.
- Segment drift detection by user cohort or domain to prevent false negatives in heterogeneous populations.
Frequently Asked Questions
How many examples do I need to reliably detect input drift?
For KS test, 30–50 examples per group is a rule of thumb, but more is better. With <20 examples, sensitivity drops. Aim for 100+ if traffic allows.
Can I use pre-trained embeddings for drift detection, or should I fine-tune?
Pre-trained embeddings (e.g., all-MiniLM-L6-v2) work well for generic drift detection. Fine-tune on your task data if drift signals seem noisy; a task-specific embedding space may be more sensitive to task-relevant shifts.
What if my baseline is small or biased?
Acknowledge the limitation in your monitoring config and use wide confidence intervals. Collect live data for the first 2–4 weeks, then set tighter thresholds. A biased baseline is better than no baseline.
Should input drift alerts page on-call, or just create a ticket?
Start with tickets; escalate to on-call only if input drift co-occurs with output degradation. Input drift alone is a signal to investigate, not necessarily a production incident.
How do I correlate input drift with root causes?
Log all system changes (model updates, prompt edits, API client versions). When drift is detected, query your logs for changes in the preceding days. Tools like Loom.com or Looker can automate this correlation.