Track 1 · SFT fundamentals · Lesson 22

Catastrophic forgetting

After this lesson you can explain catastrophic forgetting, predict when SFT will trigger it, and apply the standard mitigations (prefer LoRA, mix in a slice of original data, early stopping, broad eval) before your fine-tune erases the abilities you wanted to keep.

Level: intermediate Read time: ~8 min Prerequisites: Full fine-tuning vs LoRA

You fine-tune a chat model into a perfect JSON extractor. The eval on your task is great. Then you ask the model an unrelated question and it tries to emit JSON. You taught it the new skill — and it forgot the old ones. That failure mode has a name: catastrophic forgetting. It bites SFT especially hard, and once you know to watch for it, the mitigations are cheap.

What it is

Training is a directed walk through parameter space: at every step, the optimizer adjusts every parameter to reduce the loss on the current batch. When the batch is narrow — only your task, only one format — the optimizer drifts the parameters toward what works for that batch, and the broader behaviours encoded in the original weights get pushed aside. Nothing in the loss function rewards keeping them. So they fade.

This is not a model bug. It's exactly what gradient descent is asked to do. Catastrophic forgetting is the consequence of optimising for narrow data without telling the optimizer to preserve anything.

Why SFT is especially susceptible

Pretraining used trillions of tokens covering everything; SFT typically uses hundreds to thousands of examples covering one task. The optimizer's signal during SFT is therefore overwhelmingly about your task, with no counter-pressure from the rest of language. A few epochs of that and the model is biased hard toward your shape. The narrower your data and the longer you train, the harder the bias becomes.

When it bites hardest

Mitigation 1: prefer LoRA

LoRA (Lesson 1.13, 1.14) freezes the base model and trains a small low-rank adapter on top. The base weights — which encode the broad skills — are untouched. The adapter learns the task; remove or merge the adapter and you can always recover the base. This is the single biggest forgetting mitigation in your toolbox and it costs you almost nothing.

Mitigation 2: mix in a slice of original data

Even with LoRA, if your fine-tune data is one shape, the adapter will be aggressive about producing that one shape. The fix: mix in a small fraction (5–10%) of general instruction data from a public set. Each batch then contains both your task and a reminder that other tasks exist, and the adapter learns "do this task on these prompts; do normal-chat on those prompts." Cheap and effective.

# crude mix in code
narrow_ds = load_dataset("path/to/your/task")           # 1,000 rows
broad_ds  = load_dataset("HuggingFaceH4/no_robots")    # general instruction data
broad_sample = broad_ds.shuffle(seed=42).select(range(100))   # 10% of narrow
mixed = concatenate_datasets([narrow_ds, broad_sample]).shuffle(seed=42)

Mitigation 3: stop training early

Forgetting tends to compound with epochs. Track 1.17's loss-curve reading already told you to keep the checkpoint with the lowest eval_loss, not the final one — that policy doubles as a forgetting mitigation, because the final checkpoint is also the most forgotten one. If your task hits its quality gate at epoch 2 of 5, ship epoch 2.

Mitigation 4: evaluate broadly, not just on your task

The fastest way to discover forgetting is to look for it. Build a tiny "broad eval" — 10–30 prompts that test general instruction-following, polite refusals, basic reasoning, format-switching — and run it on every candidate checkpoint alongside your task eval. A run that aces the task eval but tanks the broad eval is forgetting; pick the earlier checkpoint or revisit the mix.

Honest beat — small models forget faster

The 135M default of this Academy is more sensitive to catastrophic forgetting than a 7B model is. Less spare capacity, fewer weights to spread the change across. Be more conservative with epochs on small bases; the right epoch count for a 7B isn't the right epoch count for a 135M. The cost of running the broad eval is one extra coffee. Run it.

Key idea

Catastrophic forgetting is what happens when you optimise for narrow data without protecting the rest. Use LoRA, mix in a slice of broad data, stop early, and evaluate broadly — all cheap. The cost of skipping them is a model that nails your task and breaks everywhere else.

That ends the SFT fundamentals track. You now have the concepts: the loss mask, chat templates, task shapes, tokenisation, data quality including gold sets, the training loop, cross-entropy, LR + schedule, batch + accumulation, full vs LoRA, LoRA knobs, GPU memory, OOM recovery, reading loss curves, evaluation, decoding controls, dataset formats, and catastrophic forgetting. Next: Track 2, where you do all of this in code by hand.

Key terms

Catastrophic forgetting
Degradation of broader skills when a model is fine-tuned on narrow data — the optimizer has no signal to keep them.
Capability retention
The degree to which broad abilities survive a fine-tune; what mitigations are trying to preserve.
Data mixing (mix-back)
Including a small fraction (5–10%) of general instruction data alongside narrow task data, so the optimizer sees both shapes.
Broad eval
A small set of general prompts (instruction-following, refusals, format-switching) evaluated alongside the task eval to detect forgetting.
Early stopping (forgetting view)
Stopping training once the task gate is met, because additional epochs trade task gains for broader losses.

Check yourself

Answers are saved to this browser.

Progress is stored locally in your browser.