The training loop, step by step
After this lesson you can trace one step of an SFT training loop end to end, name what the optimizer and scheduler do, and explain checkpoints and validation passes.
You now have clean, formatted, masked, tokenized data. Time to train. The SFT training loop is exactly the gradient-descent loop from Track 0, now operating on batches of tokenized examples through a Transformer. Let's walk one full step and the loop around it.
One training step
for batch in dataloader: # a minibatch of examples
logits = model(batch.input_ids, # forward pass
attention_mask=batch.mask)
loss = cross_entropy(logits, batch.labels) # loss on completion tokens only
loss.backward() # backprop: gradients for every parameter
optimizer.step() # nudge parameters
scheduler.step() # adjust the learning rate
optimizer.zero_grad() # reset gradients for the next step
Five familiar moves: run the batch forward to get next-token logits, compute the loss only on the completion tokens (the loss mask, encoded in labels), backpropagate to get the gradient for every parameter, and take an optimizer step. Then reset and repeat. That's the entire engine.
The optimizer (AdamW)
Plain gradient descent works, but in practice we use a smarter optimizer — almost always AdamW. It keeps a running estimate of each parameter's recent gradient behavior and adapts the step per-parameter, which makes training faster and more stable than vanilla SGD. It also costs memory: AdamW stores two extra numbers per parameter (we'll account for that in the memory lesson). The scheduler changes the learning rate over time — the subject of an upcoming lesson.
Checkpoints
Periodically the loop saves a checkpoint — the current parameters (or, for LoRA, the adapter) written to disk. Checkpoints let you resume after a crash, and — more importantly — let you keep the best version rather than the last. Training doesn't always end at its best point, so you save several and select by validation performance.
Validation passes
Every so often the loop runs the model over the validation set (no gradient updates) and records the validation loss or metric. This is your early-warning system: while training loss almost always keeps dropping, validation loss tells you when the model is starting to memorize rather than generalize. You pick the checkpoint where validation was best — not where training loss was lowest.
Key idea
Training loss measures fit to the data you trained on; validation loss estimates generalization. Trust validation for "is it actually better?", and save checkpoints so you can keep the best one.
Everything else in this track tunes this loop: the loss (next lesson), the learning rate and schedule, the batch size, and which parameters you train (full vs LoRA). The loop itself never changes.
Key terms
- Training step
- One iteration: forward → loss → backward → optimizer step → zero gradients.
- Optimizer / AdamW
- The rule that updates parameters from gradients; AdamW adapts per-parameter and is the standard choice.
- Scheduler
- Adjusts the learning rate over the course of training.
- Checkpoint
- Saved parameters (or LoRA adapter) on disk, for resuming and keeping the best version.
- Validation pass
- Running the model on the validation set without updates to estimate generalization.
Check yourself
Answers are saved to this browser.