Knowledge distillation II: the KD loss
After this lesson you can read and reason about the KD loss alpha·CE + (1−alpha)·T²·KL — what temperature does, why the T² factor is there, and how alpha trades off the gold label against the teacher.
You have the teacher's soft targets saved. The student learns from them through a loss with two parts: the ordinary supervised loss against the gold label, and a divergence from the teacher's distribution. BrewSLM's loss is exactly:
loss = alpha * CE(student, gold_label) + (1 - alpha) * T^2 * KL(student_T || teacher_T)
Term one: cross-entropy on the hard label
This is the SFT loss you already know (Track 1's cross-entropy lesson): the student is penalized for not putting probability on the correct gold token. It keeps the student anchored to the truth, even if the teacher is occasionally wrong.
Term two: KL divergence to the teacher
KL divergence measures how far the student's distribution is from the teacher's. Minimizing it pulls the student's whole output distribution toward the teacher's — transferring the dark knowledge. We compute it on temperature-softened distributions.
Temperature: softening the distribution
Dividing logits by a temperature T before softmax flattens the distribution. At T=1 you get the normal probabilities; at T=2–4 the differences shrink, so small probabilities on the runner-up tokens become visible — and those secondary probabilities are precisely the relational signal you want to transfer. A peaky 99/1 distribution teaches almost nothing about the runner-up; softened to 80/20 it does.
Why the T² factor?
Softening by T shrinks the gradients of the KL term by roughly 1/T². Multiplying the KL term by T² rescales them back, so the soft-target loss keeps a consistent magnitude as you change temperature — and stays balanced against the CE term.
Alpha: how much teacher vs how much truth
alpha in [0,1] balances the two terms. alpha → 1 leans on the gold labels (closer to plain SFT); alpha → 0 leans on the teacher's soft targets. A middle value (e.g. 0.5) is a common start: learn from the teacher's nuance while staying pinned to the ground truth.
In code
import torch.nn.functional as F
def kd_loss(student_logits, teacher_logits, labels, alpha=0.5, T=2.0):
# term 1: hard-label cross-entropy (the SFT loss, prompt masked with -100)
ce = F.cross_entropy(
student_logits.view(-1, student_logits.size(-1)),
labels.view(-1), ignore_index=-100)
# term 2: KL to the temperature-softened teacher distribution
s_log = F.log_softmax(student_logits / T, dim=-1)
t_prob = F.softmax(teacher_logits / T, dim=-1)
kl = F.kl_div(s_log, t_prob, reduction="batchmean") * (T * T)
return alpha * ce + (1.0 - alpha) * kl
In BrewSLM this lives in a pure kd_loss function used by the offline kd_trainer, selected by setting training_mode="distillation" and a recipe.kd.* recipe. The trainer reads teacher_capture.jsonl (the top-k logprobs from lesson 4.2) as the teacher term. Otherwise it's the same training loop you already know — only the loss changed. The last distillation lesson asks the real question: how much of the teacher's quality did the student actually keep?
Key idea
The KD loss is alpha·CE + (1−alpha)·T²·KL: cross-entropy anchors the student to the gold label; temperature-softened KL transfers the teacher's distribution; T² keeps the soft term's magnitude stable; alpha dials teacher-vs-truth. Same loop, richer loss.
Key terms
- KD loss
- alpha·CE + (1−alpha)·T²·KL — the blend of hard-label cross-entropy and teacher KL.
- KL divergence
- A measure of how far the student's distribution is from the teacher's; minimized to transfer knowledge.
- temperature (T)
- A divisor on logits before softmax; higher T softens the distribution to expose secondary probabilities.
- T² factor
- Rescales the KL term to offset the 1/T² gradient shrink from softening, keeping the term balanced.
- alpha
- The [0,1] weight trading the gold-label CE term against the teacher KL term.
- kd_trainer
- BrewSLM's offline distillation trainer that applies kd_loss against the captured teacher logprobs.
Check yourself
Answers are saved to this browser.