SFT with TRL's SFTTrainer (the 20-line version)
After this lesson you can run the same LoRA fine-tune from Lesson 2.5 with TRL's SFTTrainer in about 20 lines, and you know what the wrapper does for you (chat-template, loss mask, PEFT) and what it hides.
In Lesson 2.5 you built a LoRA fine-tune from raw Trainer, with a manual tokenize function, a hand-built loss mask, and an explicit DataCollatorForSeq2Seq. That was deliberate — you should know what's happening when you click Train. But for a normal SFT run, the standard tool is TRL's SFTTrainer, which wraps that whole loop. Same training, same outputs, dramatically less code.
Why SFTTrainer exists
Look at what Lesson 2.5 made you write by hand: apply the chat template; tokenize each example; build input_ids, attention_mask, and labels with the prompt masked to -100; choose a collator; pass a peft_config through a get_peft_model call. None of that is task-specific. SFTTrainer absorbs all of it and lets you describe what you have (a dataset of chat messages) and what you want (a LoRA fine-tune), in a handful of lines.
Same data, same model, same LoRA
We reuse the same model, tokenizer, and LoRA config from 2.2 / 2.5. The only thing that changes is the dataset shape: SFTTrainer prefers chat messages format (Lesson 1.20) — a list of {role, content} dicts per row.
import torch
from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig
from trl import SFTConfig, SFTTrainer
MODEL_ID = "HuggingFaceTB/SmolLM2-135M-Instruct"
tok = AutoTokenizer.from_pretrained(MODEL_ID)
if tok.pad_token is None: tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, dtype=torch.bfloat16)
# data — chat messages per row
raw = [
{"prompt": "Classify the sentiment as positive or negative: I loved it.", "completion": "positive"},
{"prompt": "Classify the sentiment as positive or negative: Broken on arrival.", "completion": "negative"},
# ...
]
def to_chat(row):
return {"messages": [
{"role": "user", "content": row["prompt"]},
{"role": "assistant", "content": row["completion"]},
]}
ds = Dataset.from_list(raw).map(to_chat, remove_columns=["prompt", "completion"])
# LoRA — same config as Lesson 2.5
lora = LoraConfig(
r=16, lora_alpha=32, lora_dropout=0.05,
target_modules=["q_proj", "v_proj"], task_type="CAUSAL_LM",
)
The 20-line version
This is the SFT loop you wrote by hand, in twenty lines. Everything from Lesson 2.5 is in here — just expressed declaratively.
trainer = SFTTrainer(
model=model,
args=SFTConfig(
output_dir="sft-out",
per_device_train_batch_size=8,
gradient_accumulation_steps=2, # effective batch 16
learning_rate=2e-4,
num_train_epochs=3,
lr_scheduler_type="cosine",
warmup_ratio=0.03,
bf16=True,
logging_steps=5,
save_strategy="epoch",
max_seq_length=256,
completion_only_loss=True, # build the loss mask for us
report_to=[],
),
train_dataset=ds,
processing_class=tok, # SFTTrainer applies the chat template
peft_config=lora, # attach LoRA — no get_peft_model() call needed
)
trainer.train()
trainer.save_model("sft-out/adapter")
What SFTTrainer does for you
- Chat template application. Given
messages, it callstok.apply_chat_templateon every row — you don't render special tokens by hand. - Loss mask.
completion_only_loss=Truesets non-assistant tokens to -100, so cross-entropy is computed only on the assistant turn. Same effect as Lesson 2.4's manual masking, with one knob. - PEFT integration. Passing
peft_configwraps the model in LoRA internally; no separateget_peft_modelcall. - Padding + collation. A sensible default collator handles padding and labels-as-input-ids without you choosing one.
- Optional sequence packing. Set
packing=TrueinSFTConfigand short examples are concatenated to fillmax_seq_length, saving compute on padding (Lesson 1.12 territory).
From Track 1 + Lesson 2.5
Everything you learned about loss masks (Lesson 1.3), chat templates (Lesson 1.4), and the training loop (Lessons 1.9, 2.5) is still happening — just inside the wrapper. The reason we did the hand-built version first was so when the wrapper does something surprising, you know where to look.
When raw Trainer beats SFTTrainer
SFTTrainer is the right tool ~90% of the time. The cases where you go back to raw Trainer:
- Custom loss. Distillation (Track 4, Lesson 4.3) needs an
α·CE + (1−α)·T²·KLloss against captured teacher logprobs — SFTTrainer doesn't support it; thekd_trainerin Track 4 wraps rawTrainer. - Non-chat data shapes. If your task is a pure classifier head with no chat wrapping, the SFT framing is wrong; use raw
Trainerwith your own preprocessing. - Custom masking. Loss only on specific tokens, or a curriculum that masks different things in different phases — you want raw
Trainer. - Debugging exactly what gets masked. When something's not learning and you suspect the mask, going back to the manual version (Lesson 2.4) gives you full visibility.
Honest beat — the wrapper has assumptions
completion_only_loss depends on your model's chat template having a clean assistant-turn boundary that apply_chat_template emits consistently. On most instruct models this is fine; on a few obscure or custom-chat-template models it isn't, and the mask silently includes prompt tokens or excludes some assistant tokens. The check: train a few steps, dump one batch's labels, and confirm the non-(-100) positions are exactly the assistant tokens. Same check we did in Lesson 2.4 — still cheap, still worth it.
Key idea
SFTTrainer is the SFT loop you wrote by hand, packaged. Chat template, loss mask, LoRA wiring, padding — all handled when you declare the dataset shape and the training config. Reach for raw Trainer only when you need a loss or a mask SFTTrainer doesn't model; for ordinary fine-tunes, the 20-line version is the right default.
The next lesson keeps the SFTTrainer pattern and adds memory: QLoRA via bitsandbytes lets you fine-tune a much larger model in the same VRAM you've been using.
Key terms
- TRL
- Hugging Face's preference- and SFT-training library. Provides
SFTTrainer,DPOTrainer,ORPOTrainer. - SFTTrainer
- A wrapper around HF
Trainerthat handles chat-template, loss mask, padding, and PEFT for SFT. - SFTConfig
- SFTTrainer's training-args object; superset of
TrainingArgumentswith SFT-specific fields (max_seq_length,completion_only_loss,packing). - completion_only_loss
- SFTConfig field; when
True, builds the loss mask so cross-entropy ignores everything except assistant tokens. - peft_config
- SFTTrainer argument; pass a
LoraConfigand SFTTrainer attaches LoRA internally (noget_peft_model). - packing
- SFTConfig field; when
True, concatenates short examples tomax_seq_lengthto save compute on padding.
Check yourself
Answers are saved to this browser.