Track 4 · Advanced · Lesson 3

Knowledge distillation II: the KD loss

After this lesson you can read and reason about the KD loss alpha·CE + (1−alpha)·T²·KL — what temperature does, why the T² factor is there, and how alpha trades off the gold label against the teacher.

Level: advanced Read time: ~11 min Prerequisites: Knowledge distillation I: the teacher and capturing its logits

You have the teacher's soft targets saved. The student learns from them through a loss with two parts: the ordinary supervised loss against the gold label, and a divergence from the teacher's distribution. BrewSLM's loss is exactly:

loss = alpha * CE(student, gold_label)  +  (1 - alpha) * T^2 * KL(student_T || teacher_T)

Term one: cross-entropy on the hard label

This is the SFT loss you already know (Track 1's cross-entropy lesson): the student is penalized for not putting probability on the correct gold token. It keeps the student anchored to the truth, even if the teacher is occasionally wrong.

Term two: KL divergence to the teacher

KL divergence measures how far the student's distribution is from the teacher's. Minimizing it pulls the student's whole output distribution toward the teacher's — transferring the dark knowledge. We compute it on temperature-softened distributions.

Temperature: softening the distribution

Dividing logits by a temperature T before softmax flattens the distribution. At T=1 you get the normal probabilities; at T=2–4 the differences shrink, so small probabilities on the runner-up tokens become visible — and those secondary probabilities are precisely the relational signal you want to transfer. A peaky 99/1 distribution teaches almost nothing about the runner-up; softened to 80/20 it does.

Why the T² factor?

Softening by T shrinks the gradients of the KL term by roughly 1/T². Multiplying the KL term by rescales them back, so the soft-target loss keeps a consistent magnitude as you change temperature — and stays balanced against the CE term.

Alpha: how much teacher vs how much truth

alpha in [0,1] balances the two terms. alpha → 1 leans on the gold labels (closer to plain SFT); alpha → 0 leans on the teacher's soft targets. A middle value (e.g. 0.5) is a common start: learn from the teacher's nuance while staying pinned to the ground truth.

In code

import torch.nn.functional as F

def kd_loss(student_logits, teacher_logits, labels, alpha=0.5, T=2.0):
    # term 1: hard-label cross-entropy (the SFT loss, prompt masked with -100)
    ce = F.cross_entropy(
        student_logits.view(-1, student_logits.size(-1)),
        labels.view(-1), ignore_index=-100)

    # term 2: KL to the temperature-softened teacher distribution
    s_log = F.log_softmax(student_logits / T, dim=-1)
    t_prob = F.softmax(teacher_logits / T, dim=-1)
    kl = F.kl_div(s_log, t_prob, reduction="batchmean") * (T * T)

    return alpha * ce + (1.0 - alpha) * kl

In BrewSLM this lives in a pure kd_loss function used by the offline kd_trainer, selected by setting training_mode="distillation" and a recipe.kd.* recipe. The trainer reads teacher_capture.jsonl (the top-k logprobs from lesson 4.2) as the teacher term. Otherwise it's the same training loop you already know — only the loss changed. The last distillation lesson asks the real question: how much of the teacher's quality did the student actually keep?

Key idea

The KD loss is alpha·CE + (1−alpha)·T²·KL: cross-entropy anchors the student to the gold label; temperature-softened KL transfers the teacher's distribution; keeps the soft term's magnitude stable; alpha dials teacher-vs-truth. Same loop, richer loss.

Key terms

KD loss
alpha·CE + (1−alpha)·T²·KL — the blend of hard-label cross-entropy and teacher KL.
KL divergence
A measure of how far the student's distribution is from the teacher's; minimized to transfer knowledge.
temperature (T)
A divisor on logits before softmax; higher T softens the distribution to expose secondary probabilities.
T² factor
Rescales the KL term to offset the 1/T² gradient shrink from softening, keeping the term balanced.
alpha
The [0,1] weight trading the gold-label CE term against the teacher KL term.
kd_trainer
BrewSLM's offline distillation trainer that applies kd_loss against the captured teacher logprobs.

Check yourself

Answers are saved to this browser.

Progress is stored locally in your browser.