Logit Masking: Control Token Probability Distributions
Logit masking is the foundational operation behind all token-level constrained decoding. When an LLM generates the next token, it outputs a vector of logits—raw, unnormalized scores—for every token in its vocabulary (often 100,000+ tokens). Logit masking zeros out (or sets to negative infinity) the logits of tokens you want to forbid, forcing the softmax normalization step to reassign all probability mass to allowed tokens. This is how freedom of choice becomes structural constraint.
A logit is the natural log-odds score a model assigns before probability conversion. The softmax function converts logits to probabilities by exponentiating and normalizing: if you set a logit to −∞, its exponential becomes 0, giving it 0 probability after softmax, making sampling it impossible. This single operation—logit manipulation—is the atomic unit of constrained decoding.
From Logits to Probabilities: The Math
When an LLM processes a prompt and reaches the position to generate the next token, its final layer outputs a vector of logits: one scalar per vocabulary token. Let's say the vocab size is V (e.g., 128,000). The logits vector z has shape [V].
To convert logits to probabilities, the model applies softmax:
P(token_i) = exp(z_i) / sum(exp(z_j) for all j in vocab)
This ensures probabilities sum to 1 and preserve the relative ordering of logits (highest logits = highest probabilities).
Without masking: Suppose logit vector z = [2.5, 1.3, 0.8, 3.2, ...] for tokens ["apple", "is", "red", "dog", ...]. After softmax, "dog" (3.2) gets the highest probability, and all tokens are reachable (though less likely).
With masking: If your constraint requires the token to be a food word, you mask out "dog" by setting its logit to −∞. The masked vector becomes z = [2.5, 1.3, 0.8, −∞, ...]. When softmax is applied:
exp(−∞) = 0
P("dog") = 0 / (sum of positive exponentials) = 0 (exactly zero)
Now "apple" and other food tokens share all the probability. The model's internal preference for "dog" is overridden; the token becomes impossible to generate.
Implementing Logit Masking: Pseudocode
Here's how a decoder loop implements this:
# Simplified LLM token generation with logit masking
import torch
import numpy as np
def generate_constrained_token(
model,
input_ids,
constraint_mask, # Boolean: True = valid, False = invalid
temperature=1.0
):
"""
Generate one token, enforcing constraint_mask.
Args:
model: LLM (returns logits)
input_ids: Token IDs generated so far
constraint_mask: Boolean tensor [vocab_size], True for allowed tokens
temperature: Softmax temperature (>1 = flatter, <1 = sharper)
Returns:
next_token_id: Integer, guaranteed to satisfy constraint_mask
"""
# Step 1: Forward pass, get logits
with torch.no_grad():
outputs = model(input_ids)
logits = outputs.logits[:, -1, :] # Shape: [batch, vocab_size]
logits = logits.squeeze(0) # Single sequence: [vocab_size]
# Step 2: Apply constraint mask by setting invalid logits to -inf
masked_logits = logits.clone()
masked_logits[~constraint_mask] = float('-inf')
# Step 3: Apply temperature scaling
masked_logits = masked_logits / temperature
# Step 4: Softmax -> probabilities
# (Numerically stable: subtract max to prevent overflow/underflow)
max_logit = masked_logits.max()
shifted_logits = masked_logits - max_logit
# Only valid logits (not -inf) contribute to softmax denominator
exp_logits = torch.exp(shifted_logits)
exp_logits[~constraint_mask] = 0 # Ensure invalid tokens stay at 0 prob
probabilities = exp_logits / exp_logits.sum()
# Step 5: Sample or greedy-pick from valid tokens only
next_token_id = torch.multinomial(probabilities, num_samples=1).item()
# Sanity check: guarantee it's valid
assert constraint_mask[next_token_id], "Mask violation!"
return next_token_id
This pattern repeats: after generating each token, recompute the constraint based on the updated sequence and mask the logits again.
Building Constraint Masks: From Grammars to Booleans
The trick is computing constraint_mask efficiently. For simple cases:
Enum constraint (finite set of valid tokens):
# Constraint: next word must be "apple", "banana", or "cherry"
valid_token_ids = model.tokenizer.encode(["apple", "banana", "cherry"])
constraint_mask = torch.zeros(model.vocab_size, dtype=torch.bool)
constraint_mask[valid_token_ids] = True
Regex constraint:
import regex # or re module
def regex_constraint_mask(prefix_text, regex_pattern, tokenizer):
"""
Which tokens can validly extend prefix_text while staying regex-valid?
Brute force: try each token, check if (prefix + token) matches regex.
"""
mask = torch.zeros(tokenizer.vocab_size, dtype=torch.bool)
for token_id in range(tokenizer.vocab_size):
token_text = tokenizer.decode([token_id])
candidate = prefix_text + token_text
if regex.fullmatch(regex_pattern, candidate, partial=True):
mask[token_id] = True
return mask
JSON schema constraint:
import json
def json_constraint_mask(prefix_text, schema, tokenizer):
"""
Which tokens keep (prefix + token) as valid JSON matching schema?
Use a JSON parser in partial/streaming mode.
"""
mask = torch.zeros(tokenizer.vocab_size, dtype=torch.bool)
for token_id in range(tokenizer.vocab_size):
token_text = tokenizer.decode([token_id])
candidate = prefix_text + token_text
try:
json.loads(candidate)
# Valid JSON so far; also check schema match
parsed = json.loads(candidate)
if schema_matches(parsed, schema):
mask[token_id] = True
except json.JSONDecodeError:
# Incomplete JSON (expected during generation)
# Check if it could be valid if completed
if could_complete_to_valid_json(candidate, schema):
mask[token_id] = True
return mask
The bottleneck: checking all 100,000+ vocabulary tokens at each step. Production systems optimize by maintaining a state machine (finite-state transducer) that computes valid next tokens in O(1) or O(log vocab) time. Libraries like Outlines and XGrammar pre-compile constraint rules into efficient state machines.
Logit Masking in Practice: Temperature and Sampling
Temperature is a hyperparameter that controls generation randomness. Setting logits / temperature flattens the distribution (temperature >1) or sharpens it (temperature <1):
- Temperature = 1.0: Preserve original probabilities (default).
- Temperature = 0.5: Sharpen: high-logit tokens get even higher relative probability. Greedy behavior.
- Temperature = 2.0: Flatten: all tokens become more equally likely. More random exploration.
With constraints, temperature still applies to valid tokens only:
masked_logits = logits.clone()
masked_logits[~constraint_mask] = float('-inf') # Mask first
masked_logits = masked_logits / temperature # Then scale
Example: If you have 3 valid tokens with logits [0.5, 1.5, 2.5] and temperature 0.5:
- Scaled logits:
[1.0, 3.0, 5.0] - After softmax: much higher probability on the 3rd token.
Greedy decoding (always pick the max logit) works well with constraints if the model is confident. Sampling (multinomial) is useful for diverse outputs while respecting constraints.
Numerical Stability and Optimization
A key challenge: when many logits are set to −∞, naive softmax computation can produce NaN. Production code uses:
- Subtract max logit before exp:
exp(logits - max)prevents overflow. - Mask invalid tokens before softmax sum: Ensure their contribution is zero.
- Log-space softmax: Work in log-space for numerical stability.
# Numerically stable masked softmax
def stable_masked_softmax(logits, mask):
"""logits: [vocab], mask: [vocab] bool"""
# Set invalid logits to a large negative, not -inf (for numerical safety)
safe_logits = logits.clone()
safe_logits[~mask] = float('-1e9')
# Subtract max for stability
safe_logits = safe_logits - safe_logits.max()
exp_logits = torch.exp(safe_logits)
exp_logits[~mask] = 0.0 # Ensure invalid tokens have 0 prob
probs = exp_logits / exp_logits.sum()
return probs
For very large vocabularies, libraries also use vocabulary pruning: pre-filter the vocabulary to only tokens reachable given the constraint, then apply softmax only over that subset. This cuts computation 10–100x.
Key Takeaways
- Logits are raw, unnormalized model scores; softmax converts them to probabilities.
- Masking is done by setting invalid logit values to negative infinity, forcing their softmax probability to exactly zero.
- Constraints are translated into boolean masks: for each token position, compute which tokens are grammatically valid.
- Brute-force masking checks all vocabulary tokens at each step; production systems pre-compile constraints into efficient state machines.
- Temperature controls randomness among valid tokens; numerical stability requires care with large vocabularies and many masked values.
Frequently Asked Questions
Can logit masking slow down generation significantly?
Yes, if done naively. Checking 100,000+ tokens at each step costs O(vocab_size) per token. For a 100-token output, that's 10 million checks. Production systems pre-compile constraints into finite-state machines (O(1) or O(log vocab) lookup per token), reducing overhead to 5–20%. Vocabulary pruning further cuts cost.
What if my constraint rules out all tokens at some position?
This indicates a logic error in your constraint definition or a mismatch between the grammar and what the model can actually generate. For example, if you require {strict_enum_of_3_values} but the model's tokenizer splits numbers differently, no token may be valid. Debugging: trace the sequence of valid masks and find where the constraint becomes empty. Widen the constraint or adjust your grammar.
Does logit masking work with sampling or only greedy decoding?
Both. Softmax converts masked logits to a valid probability distribution over allowed tokens, so multinomial sampling works. Greedy (argmax) also works: the max valid logit is always higher than negative infinity. Sampling is more diverse; greedy is more consistent.
Can I combine multiple constraints (e.g., JSON schema AND regex)?
Yes—compute masks separately and AND them together: combined_mask = json_mask & regex_mask. The token must satisfy all constraints. This is useful for hierarchical constraints (e.g., "must be valid JSON and match a specific field pattern").
Why not just fine-tune the model to always output valid JSON?
Fine-tuning helps but doesn't guarantee correctness. Even a well-tuned model occasionally violates constraints (2–5% failure rate). Logit masking makes violations structurally impossible, which is stronger and more reliable than behavioral training.
Further Reading
- Hugging Face Transformers: logits_processor documentation — Built-in logit processors for HF models.
- vLLM Sampling and Constraint Support — High-performance inference engine with native constraint masking.
- The illustrated guide to transformers: Decoding — Visual explanation of token generation.
- Constrained Decoding with Finite Automata (Thesis) — Theoretical foundations of state-machine-based constraints.