Synthetic Data Generation: Distillation Step-by-Step
Synthetic data generation is the bridge between teacher and student. Rather than training the student solely on hard labels, you generate a large corpus of (input, teacher output) pairs, capturing the teacher's reasoning across a diverse distribution of examples. This synthetic training set becomes the curriculum for student training. The quality and scale of synthetic data directly determine student quality: a larger, more diverse synthetic dataset yields a student that generalizes better and retains more teacher knowledge.
Why Synthetic Data Matters for Distillation
In traditional supervised learning, you train on labeled data produced by humans or automated annotators. In distillation, the teacher itself becomes the labeler. This inversion has profound implications: you can generate unlimited training data (not bounded by human annotation budgets), you can control the data distribution (to cover edge cases), and you can create examples that stress-test the teacher's reasoning. However, synthetic data has pitfalls: if the data distribution does not match deployment, the student will learn teacher artifacts rather than robust patterns. If the teacher is overconfident or prone to errors on certain inputs, the student will inherit those biases.
Effective synthetic data generation balances three objectives: scale (enough examples to train the student), diversity (varied prompts and domains to ensure generalization), and distribution alignment (the synthetic data matches real-world deployment inputs). A common mistake is generating synthetic data purely from random prompts, which often yields trivial or off-distribution examples. Instead, grounded synthetic data—generated from templates based on real user queries, known edge cases, or domain-specific distributions—produces better students.
Step 1: Define Your Data Distribution
Before generating data, characterize the distribution you want the student to excel at. If you are building a customer-support chatbot, your distribution should match real customer queries (frequent topics, common misspellings, informal language). If you are fine-tuning for code generation, your distribution should mirror real code tasks (function signatures, libraries in use, difficulty levels).
import json
from collections import Counter
# Load or define a representative sample of real-world inputs
real_inputs = [
"How do I reset my password?",
"Why is my account locked?",
"Can I change my email address?",
# ... more real queries from production logs
]
# Analyze the distribution: topics, length, complexity
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
lengths = [len(tokenizer.encode(inp)) for inp in real_inputs]
topics = [q.split()[0] for q in real_inputs] # Naive: first token as topic
print(f"Average input length: {sum(lengths) / len(lengths):.1f} tokens")
print(f"Topic distribution: {Counter(topics).most_common(5)}")
print(f"Length percentiles: 25th={np.percentile(lengths, 25)}, "
f"75th={np.percentile(lengths, 75)}")
This analysis reveals that if 70% of queries are short (10-30 tokens) and focus on account issues, your synthetic data should also reflect this: 70% short account queries, 30% other domains. If you ignore this distribution and generate purely random prompts, the student will be well-trained on rare cases (30-token code reviews) but weak on common cases (10-token account questions).
Step 2: Generate Synthetic Examples via Teacher
You have two approaches: template-based and free-form generation. Template-based is faster and more controllable; free-form is more flexible but noisier.
Template-Based Approach:
import random
from typing import List
# Define templates reflecting real-world distribution
account_templates = [
"How do I reset my {resource: password/PIN/security question}?",
"My {resource} is {status: lost/forgotten/compromised}. What should I do?",
"Can I change my {resource}?",
]
billing_templates = [
"How much does {feature} cost?",
"What are the terms of {plan: basic/pro/enterprise}?",
"How do I upgrade to {plan}?",
]
def expand_templates(templates: List[str],
substitutions: dict,
num_per_template: int = 10) -> List[str]:
"""Expand templates by sampling substitutions."""
examples = []
for template in templates:
for _ in range(num_per_template):
filled = template
for placeholder, options in substitutions.items():
if f"{{{placeholder}" in filled:
value = random.choice(options)
filled = filled.replace(f"{{{placeholder}: {'/'.join(options)}}}",
value)
examples.append(filled)
return examples
substitutions = {
"resource": ["password", "PIN", "security question", "email address"],
"status": ["lost", "forgotten", "compromised", "reset"],
"feature": ["basic storage", "priority support", "API access"],
"plan": ["basic", "pro", "enterprise"],
}
synthetic_inputs = expand_templates(
account_templates + billing_templates,
substitutions,
num_per_template=20
)
print(f"Generated {len(synthetic_inputs)} template-based examples")
Template-based generation is fast and ensures diversity. The downside: templates are labor-intensive to write and may not capture natural language variation. Use templates for ~50% of synthetic data, paired with free-form generation for variety.
Free-Form Generation from the Teacher:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# Load teacher
teacher = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
def generate_synthetic_data(
teacher,
tokenizer,
prompts: List[str],
num_outputs_per_prompt: int = 5,
max_length: int = 200,
temperature: float = 0.8,
top_p: float = 0.9
) -> List[dict]:
"""
Query teacher to generate synthetic examples.
Returns:
List of {"input": prompt, "output": teacher_response}
"""
teacher.eval()
synthetic_data = []
for prompt in prompts:
for _ in range(num_outputs_per_prompt):
inputs = tokenizer(prompt, return_tensors="pt").to(teacher.device)
with torch.no_grad():
# Generate diverse outputs via temperature and top-p sampling
output_ids = teacher.generate(
inputs.input_ids,
max_length=max_length,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
response = tokenizer.decode(
output_ids[0][inputs.input_ids.shape[1]:],
skip_special_tokens=True
)
synthetic_data.append({
"input": prompt,
"output": response.strip()
})
return synthetic_data
# Generate synthetic data from a mix of seeds
seed_prompts = synthetic_inputs[:500] # Use templates as seeds
synthetic_data = generate_synthetic_data(
teacher,
tokenizer,
seed_prompts,
num_outputs_per_prompt=3,
temperature=0.8
)
print(f"Generated {len(synthetic_data)} synthetic training examples")
# Save to disk
import json
with open("synthetic_data.jsonl", "w") as f:
for example in synthetic_data:
f.write(json.dumps(example) + "\n")
Key parameters for generation:
- Temperature: 0.7-0.9 encourages diversity; <0.5 produces repetitive outputs.
- top_p: 0.85-0.95 balances coherence and variety. Avoid 1.0 (uniform sampling, incoherent).
- max_length: Set based on your task. For summaries, 50-100; for code, 200-500.
- num_outputs_per_prompt: 3-10 per prompt usually suffices. More outputs yield marginal improvements and higher cost.
Step 3: Validate Synthetic Data Quality
Before training the student, audit the synthetic data. Check for:
- Coverage: Does the synthetic data cover the expected input distribution?
# Analyze synthetic data statistics
synthetic_lengths = [
len(tokenizer.encode(ex["input"])) for ex in synthetic_data
]
synthetic_output_lengths = [
len(tokenizer.encode(ex["output"])) for ex in synthetic_data
]
print(f"Synthetic input length: mean={np.mean(synthetic_lengths):.1f}, "
f"std={np.std(synthetic_lengths):.1f}")
print(f"Synthetic output length: mean={np.mean(synthetic_output_lengths):.1f}")
# Compare to real data distribution
real_lengths = [len(tokenizer.encode(inp)) for inp in real_inputs]
print(f"Real input length: mean={np.mean(real_lengths):.1f}, "
f"std={np.std(real_lengths):.1f}")
# If distributions diverge, regenerate with adjusted sampling
if abs(np.mean(synthetic_lengths) - np.mean(real_lengths)) > 10:
print("WARNING: Synthetic data length distribution differs from real data")
- Coherence: Sample and manually review outputs. Do they make sense?
# Random sample for human review
import random
sample = random.sample(synthetic_data, min(100, len(synthetic_data)))
for ex in sample[:10]:
print(f"Input: {ex['input']}")
print(f"Output: {ex['output']}")
print("---")
- Uniqueness: Are there many duplicates?
from collections import Counter
outputs = [ex["output"] for ex in synthetic_data]
unique_outputs = len(set(outputs))
print(f"Unique outputs: {unique_outputs}/{len(outputs)} "
f"({100*unique_outputs/len(outputs):.1f}%)")
# If duplication is high (>20%), increase temperature or reduce repeats
- Teacher Errors: Does the teacher make mistakes in synthetic data?
If the teacher is fallible (as all models are), synthetic data will inherit biases. One mitigation: train the student on a mixture of synthetic data and real labeled data (if available). Another: use an ensemble of teachers to reduce individual teacher errors.
Scale and Efficiency
For large-scale synthetic data generation, batch query the teacher:
def generate_synthetic_data_batched(
teacher,
tokenizer,
prompts: List[str],
batch_size: int = 32,
num_outputs_per_prompt: int = 3
) -> List[dict]:
"""Generate synthetic data efficiently in batches."""
synthetic_data = []
teacher.eval()
for i in range(0, len(prompts), batch_size):
batch_prompts = prompts[i:i+batch_size]
# Tokenize batch
inputs = tokenizer(
batch_prompts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
).to(teacher.device)
with torch.no_grad():
# Generate (this produces num_beams=3 outputs per input if using beam search)
output_ids = teacher.generate(
inputs.input_ids,
max_length=200,
num_return_sequences=num_outputs_per_prompt,
do_sample=True,
temperature=0.8
)
# Decode and store
for prompt_idx, prompt in enumerate(batch_prompts):
for seq_idx in range(num_outputs_per_prompt):
output_idx = prompt_idx * num_outputs_per_prompt + seq_idx
response = tokenizer.decode(
output_ids[output_idx][inputs.input_ids.shape[1]:],
skip_special_tokens=True
)
synthetic_data.append({
"input": prompt,
"output": response.strip()
})
return synthetic_data
Batching reduces overhead and is 5-10x faster than serial generation. For a typical distillation project, generating 100K synthetic examples takes 2-5 GPU hours on a single A100.
Key Takeaways
- Synthetic data generation scales student training beyond human annotation budgets and allows you to control training distribution.
- Define a representative data distribution (based on real-world usage) and ensure synthetic data matches it.
- Combine template-based and free-form generation: templates are fast and controlled; free-form adds natural variation.
- Validate synthetic data for coverage, coherence, uniqueness, and teacher errors before training.
- Batch generation and use sampling strategies (temperature, top-p) to maximize data diversity.
Frequently Asked Questions
How much synthetic data should I generate?
Start with 50K-100K examples. Monitor student convergence: if validation loss plateaus before seeing all examples, you have enough. If loss continues decreasing at epoch end, generate more. Most distillation projects use 100K-500K synthetic examples.
Should I mix synthetic and real labeled data during student training?
Yes, if you have real labels. A mix of 70% synthetic + 30% real often outperforms synthetic-only training because real data grounds the student in actual task distribution. If synthetic data is high-quality, 100% synthetic is acceptable.
What if the teacher makes mistakes in synthetic data?
The student will learn those mistakes. Mitigate by: (1) using an ensemble of teachers (average logits), (2) filtering synthetic data by confidence (discard low-confidence examples), or (3) adding regularization (focal loss) to reduce reliance on noisy examples.
Does the temperature for synthetic generation affect the student?
Yes. High temperature (T=1.0) creates diverse but lower-quality data; low temperature (T=0.5) creates high-quality but repetitive data. Optimal temperature is often 0.7-0.8. Use different temperatures for different subsets to create a curriculum (easy examples first, hard examples later).
Can I generate synthetic data from an API-based teacher (like GPT-4)?
Yes, but it is costly. API queries (GPT-4 costs ~$0.03 per 1K tokens) add up quickly. For 100K examples at 100 tokens per output, you might spend $300-500. Weigh this against inference cost savings from distillation. For internal models or fine-tuned teachers, synthetic generation is nearly free (just GPU cost).