Skip to main content

Model Evaluation: Measure Distillation Quality Loss

Evaluating a distilled model is more nuanced than checking top-1 accuracy. While accuracy matters, you also need to measure knowledge retention (does the student capture the teacher's reasoning?), calibration (is confidence aligned with correctness?), and task-specific metrics. This article covers a comprehensive evaluation framework that helps you quantify distillation quality loss, identify where the student lags the teacher, and decide whether the student is production-ready.

Metrics Beyond Accuracy

For classification, accuracy is necessary but insufficient. A student with 95% accuracy might allocate confidence very differently than the teacher: correct predictions might be made with high confidence (good), but incorrect predictions might also be confident (bad calibration).

Accuracy and Task-Specific Metrics

Start with standard metrics appropriate to your task:

import torch
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

def evaluate_student_accuracy(student, teacher, test_loader, device='cuda'):
"""
Compare student and teacher accuracy on a test set.
"""
student.to(device)
teacher.to(device)
student.eval()
teacher.eval()

student_preds = []
teacher_preds = []
true_labels = []

with torch.no_grad():
for inputs, labels in test_loader:
inputs, labels = inputs.to(device), labels.to(device)

student_logits = student(inputs).logits
teacher_logits = teacher(inputs).logits

student_preds.extend(torch.argmax(student_logits, dim=1).cpu().numpy())
teacher_preds.extend(torch.argmax(teacher_logits, dim=1).cpu().numpy())
true_labels.extend(labels.cpu().numpy())

student_acc = accuracy_score(true_labels, student_preds)
teacher_acc = accuracy_score(true_labels, teacher_preds)
accuracy_retention = student_acc / teacher_acc

# Breakdown by class
student_prec, student_rec, student_f1, _ = precision_recall_fscore_support(
true_labels, student_preds, average='weighted'
)
teacher_prec, teacher_rec, teacher_f1, _ = precision_recall_fscore_support(
true_labels, teacher_preds, average='weighted'
)

print(f"Student Accuracy: {student_acc:.4f}")
print(f"Teacher Accuracy: {teacher_acc:.4f}")
print(f"Accuracy Retention: {100*accuracy_retention:.1f}%")
print(f"\nStudent Precision/Recall/F1: {student_prec:.4f}/{student_rec:.4f}/{student_f1:.4f}")
print(f"Teacher Precision/Recall/F1: {teacher_prec:.4f}/{teacher_rec:.4f}/{teacher_f1:.4f}")

return {
'student_accuracy': student_acc,
'teacher_accuracy': teacher_acc,
'accuracy_retention': accuracy_retention,
'student_f1': student_f1,
'teacher_f1': teacher_f1
}

Task-specific metrics include:

  • Classification: Accuracy, F1, precision/recall, confusion matrix.
  • Ranking: NDCG@k, MAP (mean average precision).
  • Generation (NLP): BLEU, ROUGE, BERTScore, human evaluation.
  • Regression: MAE, RMSE, R^2.
  • Detection: mAP, IoU.

A student with 95% of the teacher's accuracy on these metrics is typically production-ready.

KL Divergence and Prediction Distribution Similarity

Beyond accuracy, measure how similar the student's and teacher's probability distributions are:

import torch.nn.functional as F

def compute_kl_divergence(student, teacher, test_loader, device='cuda'):
"""
Compute KL divergence between student and teacher predictions.
Measures how similar their confidence distributions are.

KL(teacher || student) = sum(teacher_prob * log(teacher_prob / student_prob))

Lower KL = more similar distributions.
"""
student.to(device)
teacher.to(device)
student.eval()
teacher.eval()

kl_divs = []

with torch.no_grad():
for inputs, _ in test_loader:
inputs = inputs.to(device)

student_logits = student(inputs).logits
teacher_logits = teacher(inputs).logits

student_probs = F.softmax(student_logits, dim=1)
teacher_probs = F.softmax(teacher_logits, dim=1)

# KL divergence: teacher as reference
kl = F.kl_div(
F.log_softmax(student_logits, dim=1),
teacher_probs,
reduction='batchmean'
)
kl_divs.append(kl.item())

avg_kl = sum(kl_divs) / len(kl_divs)
print(f"Average KL Divergence: {avg_kl:.4f}")
print(f"Interpretation: Lower is better. <0.5 is excellent; <0.1 is exceptional.")

return avg_kl

KL divergence measures the information-theoretic distance between distributions:

  • 0.01-0.05: Student predictions are nearly identical to teacher (excellent distillation).
  • 0.05-0.1: Student closely matches teacher (very good).
  • 0.1-0.5: Student differs noticeably but is reasonable (acceptable).
  • >0.5: Large divergence (weak distillation; investigate hyperparameters).

Ranking Correlation: Spearman and Kendall

For ranking tasks, measure how well the student preserves the teacher's ranking of examples:

from scipy.stats import spearmanr, kendalltau
import numpy as np

def compute_ranking_correlation(student, teacher, test_loader, device='cuda'):
"""
Measure ranking correlation between student and teacher.

For each example, compute how confident each model is.
Then compute correlation in these confidence rankings.
"""
student.to(device)
teacher.to(device)
student.eval()
teacher.eval()

student_confs = []
teacher_confs = []

with torch.no_grad():
for inputs, _ in test_loader:
inputs = inputs.to(device)

student_logits = student(inputs).logits
teacher_logits = teacher(inputs).logits

# Confidence = max probability
student_conf = torch.max(F.softmax(student_logits, dim=1), dim=1)[0]
teacher_conf = torch.max(F.softmax(teacher_logits, dim=1), dim=1)[0]

student_confs.extend(student_conf.cpu().numpy())
teacher_confs.extend(teacher_conf.cpu().numpy())

student_confs = np.array(student_confs)
teacher_confs = np.array(teacher_confs)

# Spearman correlation
spearman_corr, spearman_pval = spearmanr(student_confs, teacher_confs)

# Kendall tau correlation
kendall_corr, kendall_pval = kendalltau(student_confs, teacher_confs)

print(f"Spearman Correlation: {spearman_corr:.4f} (p={spearman_pval:.2e})")
print(f"Kendall Tau Correlation: {kendall_corr:.4f} (p={kendall_pval:.2e})")
print(f"Interpretation: >0.9 is excellent agreement; <0.7 indicates divergence.")

return {'spearman': spearman_corr, 'kendall': kendall_corr}

Ranking correlation is particularly useful for information retrieval or ranking tasks: even if the student misclassifies some examples, as long as it ranks them similarly to the teacher, it is useful.

Calibration Error

Calibration measures how well confidence aligns with accuracy (covered briefly in Article 3):

def expected_calibration_error(student, teacher, test_loader, 
num_bins=10, device='cuda'):
"""
Compute Expected Calibration Error (ECE) for student and teacher.

Divides predictions into confidence bins and measures:
|accuracy_in_bin - average_confidence_in_bin|
"""
student.to(device)
teacher.to(device)
student.eval()
teacher.eval()

student_confs = []
teacher_confs = []
correct_student = []
correct_teacher = []

with torch.no_grad():
for inputs, labels in test_loader:
inputs, labels = inputs.to(device), labels.to(device)

student_logits = student(inputs).logits
teacher_logits = teacher(inputs).logits

student_probs = F.softmax(student_logits, dim=1)
teacher_probs = F.softmax(teacher_logits, dim=1)

student_conf, student_pred = torch.max(student_probs, dim=1)
teacher_conf, teacher_pred = torch.max(teacher_probs, dim=1)

student_confs.extend(student_conf.cpu().numpy())
teacher_confs.extend(teacher_conf.cpu().numpy())
correct_student.extend((student_pred == labels).cpu().numpy())
correct_teacher.extend((teacher_pred == labels).cpu().numpy())

student_confs = np.array(student_confs)
teacher_confs = np.array(teacher_confs)
correct_student = np.array(correct_student)
correct_teacher = np.array(correct_teacher)

# Compute ECE for each
def compute_ece(confs, correct):
bins = np.linspace(0, 1, num_bins + 1)
ece = 0.0
for i in range(num_bins):
mask = (confs >= bins[i]) & (confs < bins[i+1])
if mask.sum() > 0:
bin_acc = correct[mask].mean()
bin_conf = confs[mask].mean()
ece += np.abs(bin_acc - bin_conf) * mask.sum() / len(correct)
return ece

student_ece = compute_ece(student_confs, correct_student)
teacher_ece = compute_ece(teacher_confs, correct_teacher)

print(f"Student ECE: {student_ece:.4f}")
print(f"Teacher ECE: {teacher_ece:.4f}")
print(f"Interpretation: <0.05 is well-calibrated; >0.1 is poorly calibrated.")

return {'student_ece': student_ece, 'teacher_ece': teacher_ece}

A well-calibrated student has ECE close to the teacher's. If the student is less calibrated (higher ECE), you can improve it via temperature scaling (mentioned in Article 3) or focal loss during training.

Error Analysis: Where Does the Student Diverge?

Understanding failure modes helps identify if distillation is working or if the student is incapable:

def error_analysis(student, teacher, test_loader, device='cuda'):
"""
Categorize errors: (1) both wrong, (2) only student wrong,
(3) only teacher wrong, (4) both correct.
"""
student.to(device)
teacher.to(device)
student.eval()
teacher.eval()

both_correct = 0
both_wrong = 0
only_student_wrong = 0
only_teacher_wrong = 0

student_errors = []
teacher_errors = []

with torch.no_grad():
for inputs, labels in test_loader:
inputs, labels = inputs.to(device), labels.to(device)

student_pred = torch.argmax(student(inputs).logits, dim=1)
teacher_pred = torch.argmax(teacher(inputs).logits, dim=1)

student_correct = (student_pred == labels)
teacher_correct = (teacher_pred == labels)

both_correct += (student_correct & teacher_correct).sum().item()
both_wrong += (~student_correct & ~teacher_correct).sum().item()
only_student_wrong += (~student_correct & teacher_correct).sum().item()
only_teacher_wrong += (student_correct & ~teacher_correct).sum().item()

# Collect student-specific errors for analysis
student_errors.extend(inputs[~student_correct & teacher_correct].cpu())
teacher_errors.extend(inputs[student_correct & ~teacher_correct].cpu())

total = sum([both_correct, both_wrong, only_student_wrong, only_teacher_wrong])

print(f"Both correct: {100*both_correct/total:.1f}%")
print(f"Both wrong: {100*both_wrong/total:.1f}%")
print(f"Only student wrong: {100*only_student_wrong/total:.1f}%")
print(f"Only teacher wrong: {100*only_teacher_wrong/total:.1f}%")

if only_student_wrong / total > 0.05:
print(f"\nWARNING: {100*only_student_wrong/total:.1f}% of errors are student-specific.")
print("This suggests weak distillation; increase training data or temperature.")

return {
'both_correct': both_correct,
'both_wrong': both_wrong,
'only_student_wrong': only_student_wrong,
'only_teacher_wrong': only_teacher_wrong
}

If "only student wrong" is high (>5%), the student is struggling to learn from the teacher. Consider:

  • Increasing synthetic data volume.
  • Raising temperature (softer targets).
  • Using a larger student.
  • Longer training (more epochs).

Comprehensive Evaluation Report

Here is a template for a complete evaluation report:

def full_evaluation_report(student, teacher, test_loader, device='cuda'):
"""
Generate comprehensive evaluation comparing student to teacher.
"""
print("=" * 60)
print("STUDENT vs. TEACHER EVALUATION REPORT")
print("=" * 60)

# 1. Accuracy
print("\n1. ACCURACY")
acc_metrics = evaluate_student_accuracy(student, teacher, test_loader, device)
print(f" Retention: {100*acc_metrics['accuracy_retention']:.1f}%")

# 2. Prediction Distribution Similarity
print("\n2. KL DIVERGENCE (Distribution Similarity)")
kl = compute_kl_divergence(student, teacher, test_loader, device)
if kl < 0.1:
print(" Status: EXCELLENT (distributions nearly identical)")
elif kl < 0.5:
print(" Status: GOOD (acceptable divergence)")
else:
print(" Status: WEAK (investigate hyperparameters)")

# 3. Ranking Correlation
print("\n3. RANKING CORRELATION")
ranking = compute_ranking_correlation(student, teacher, test_loader, device)
print(f" Status: {'EXCELLENT' if ranking['spearman'] > 0.9 else 'GOOD' if ranking['spearman'] > 0.8 else 'ACCEPTABLE'}")

# 4. Calibration
print("\n4. CALIBRATION")
calibration = expected_calibration_error(student, teacher, test_loader, device)

# 5. Error Analysis
print("\n5. ERROR ANALYSIS")
errors = error_analysis(student, teacher, test_loader, device)

# 6. Summary
print("\n" + "=" * 60)
print("OVERALL ASSESSMENT")
print("=" * 60)
if acc_metrics['accuracy_retention'] > 0.95 and kl < 0.1:
print("STATUS: EXCELLENT - Student is a high-fidelity compression.")
print("RECOMMENDATION: Ready for production deployment.")
elif acc_metrics['accuracy_retention'] > 0.92 and kl < 0.3:
print("STATUS: GOOD - Student captures most of teacher's knowledge.")
print("RECOMMENDATION: Production-ready with monitoring.")
elif acc_metrics['accuracy_retention'] > 0.90:
print("STATUS: ACCEPTABLE - Some knowledge loss, but usable.")
print("RECOMMENDATION: Acceptable for cost-sensitive deployments.")
else:
print("STATUS: WEAK - Significant knowledge loss detected.")
print("RECOMMENDATION: Improve distillation (more data, tuning, larger student).")

return {
'accuracy': acc_metrics,
'kl_divergence': kl,
'ranking': ranking,
'calibration': calibration,
'errors': errors
}

Key Takeaways

  • Accuracy is necessary but not sufficient; measure KL divergence, ranking correlation, and calibration.
  • KL divergence <0.1 indicates excellent knowledge transfer; <0.5 is acceptable.
  • Error analysis reveals if the student is structurally weak or just needs more training.
  • Calibration (ECE) should be <0.05 for production; temperature scaling can improve it.
  • A comprehensive evaluation report guides whether the student is ready for deployment.

Frequently Asked Questions

What accuracy retention should I target?

For most tasks, 95%+ is excellent; 92-95% is acceptable; below 90% requires investigation. The acceptable range depends on use case: safety-critical applications (medical, finance) should target 98%+; cost-sensitive applications (low-latency serving) can accept 90-92%.

Is KL divergence the best similarity metric?

KL divergence is good for measuring prediction distribution similarity, but it is not the only metric. Cosine distance between logits, Wasserstein distance, or ranking correlation can also be useful depending on your task. Use multiple metrics for a complete picture.

How do I compare students trained with different temperatures?

Plot accuracy vs. KL divergence for each temperature. The optimal temperature usually yields the best balance: high accuracy and low KL divergence. A Pareto frontier of (accuracy, KL) can help you choose the best trade-off.

Should I prioritize accuracy or KL divergence?

Prioritize accuracy for your specific downstream task. KL divergence is a proxy for knowledge transfer; it helps diagnose why accuracy might be low. If accuracy is low but KL is low, the student is not extracting enough signal from synthetic data (increase volume or improve distribution). If KL is high, the student is diverging from the teacher (adjust hyperparameters).

How often should I run evaluation during training?

Run full evaluation every 2-3 epochs, or after every 10% of training. Early in training, focus on validation accuracy and KL divergence to catch issues (divergence, overfitting). At the end, run the comprehensive report to assess production readiness.

Further Reading