Track 2 · Hands-on · Lesson 14

Multi-turn chat SFT

After this lesson you can fine-tune on multi-turn conversations with the loss mask applied to every assistant turn (not just the last), verify the mask, and frame honestly what multi-turn training does and doesn't change about the model's long-conversation behaviour.

Level: intermediate Read time: ~9 min Prerequisites: SFT with TRL's SFTTrainer

Everything in Track 2 so far has trained on single-turn pairs: one user prompt, one assistant reply. Real chat data isn't shaped like that. A support conversation has a system message, several user turns, and several assistant turns — and you want the model to learn how to handle every assistant turn in context, not just the last. Multi-turn SFT is exactly the same training loop with a different loss mask: every assistant turn is supervised, every user / system span is masked.

The data shape: chat-messages, but longer

From Lesson 1.20: chat-messages format is a list of {role, content} dicts. The only difference here is more entries.

raw = [
  {"messages": [
    {"role": "system",    "content": "You are a polite support agent. Be concise."},
    {"role": "user",      "content": "I want a refund."},
    {"role": "assistant", "content": "I can help — what's the order number?"},
    {"role": "user",      "content": "It's ORD-1234."},
    {"role": "assistant", "content": "Thanks. I see the order. Refund will process in 3-5 days."}
  ]},
  # ... more conversations
]
ds = Dataset.from_list(raw)

Two messages per conversation become four, six, ten. Sequence lengths grow; you'll usually want a larger max_seq_length than the single-turn lessons used.

Loss on every assistant turn

With completion_only_loss=True, SFTTrainer's chat-template-aware preprocessing masks every non-assistant token: system message, user turns, all the chat-template special tokens, and pad tokens. What remains unmasked is each assistant turn's content — including the assistant's earlier replies in the same conversation, not just the final one. This is the right default for chat fine-tuning: the model learns how to handle every turn position it'll see in production.

from trl import SFTConfig, SFTTrainer
from peft import LoraConfig

lora = LoraConfig(
    r=16, lora_alpha=32, lora_dropout=0.05,
    target_modules=["q_proj", "v_proj"], task_type="CAUSAL_LM",
)

trainer = SFTTrainer(
    model=model,
    args=SFTConfig(
        output_dir="multi-turn-out",
        per_device_train_batch_size=4,            # smaller batch — longer sequences
        gradient_accumulation_steps=4,            # effective batch 16
        learning_rate=2e-4,
        num_train_epochs=3,
        lr_scheduler_type="cosine",
        warmup_ratio=0.03,
        bf16=True,
        logging_steps=5,
        save_strategy="epoch",
        max_seq_length=2048,                      # multi-turn conversations need it
        completion_only_loss=True,                # mask all non-assistant tokens
        report_to=[],
    ),
    train_dataset=ds,
    processing_class=tok,
    peft_config=lora,
)
trainer.train()

Verify the mask — the cheap check that catches the loud bug

Same check from Lesson 2.4 / 2.10, now more important: pull a batch, decode the unmasked positions, confirm they correspond to assistant turns only.

# after building the trainer, before .train()
batch = next(iter(trainer.get_train_dataloader()))
ids    = batch["input_ids"][0].tolist()
labels = batch["labels"][0].tolist()

# the tokens the model is supervised on
unmasked = [i for i, l in zip(ids, labels) if l != -100]
print(tok.decode(unmasked))
# Should be the concatenation of every assistant turn's content.
# If you see "user:" or system text in there, the mask is wrong.

Pitfalls specific to multi-turn

Honest beat — what multi-turn training doesn't do

Training on multi-turn data does not make the model "better at long conversations" in some general sense. It makes the model better at your turn structure on conversations resembling the ones in your data. Long-range coherence (remembering what was said 20 turns ago, keeping a stable persona over time) is mostly a property of the base model and its context window — fine-tuning on a few hundred multi-turn examples won't move it. If long-conversation behaviour is your goal, also evaluate broadly on out-of-domain conversations (Lesson 1.21's broad eval); the gains and the regressions both live there.

Key idea

Multi-turn SFT is single-turn SFT with a richer mask: every assistant turn is supervised, every user / system span is -100. SFTTrainer + completion_only_loss=True handles it for you on chat-messages data; the only changes from single-turn are a larger max_seq_length and a smaller batch. The mask-verification check is still the cheapest catch for the loudest bug.

You can now train on every shape of SFT data you'll meet. The last Track 2 lesson is the project gallery: six different SLM tasks, each expressed as a recipe you can pick up and run with the techniques you've learned.

Key terms

Multi-turn SFT
Fine-tuning on conversations with multiple turns, supervising every assistant turn.
Per-turn loss mask
Loss applied to every assistant token in the conversation; user / system / chat-template tokens set to -100.
Mask verification
Decoding the unmasked positions of a training batch to confirm they correspond to assistant turns only; cheapest catch for a silent mask bug.
Role bleed
Mask leaking into user turns so the model is supervised on (and starts mimicking) user-voice text.
Context-window pressure
Long multi-turn conversations approaching max_seq_length; truncation cuts the final assistant turn first, eating the resolution.

Check yourself

Answers are saved to this browser.

Progress is stored locally in your browser.