Track 4 · Advanced · Lesson 13

Structured pruning: removing heads, layers, and channels

After this lesson you can decide whether structured pruning is worth attempting on your SLM, pick the right granularity (head / layer / channel), score importance on a calibration set, run the prune-then-recover workflow, and recognise when 4-bit quantization has already eaten the budget pruning would have given you.

Level: advanced Read time: ~11 min Prerequisites: Quantization & compression

Structured pruning is the version of pruning that actually shows up in wall-clock benchmarks. You pick a granularity — a whole attention head, a whole transformer layer, or a whole MLP channel — score every unit at that granularity for importance on a calibration set, drop the lowest-scoring ones until you hit a target size, then run a short SFT or distillation pass to recover the quality you just damaged. Done well, you get a model with smaller matrices, fewer FLOPs, lower memory, and a believable latency win. Done badly — especially without the recovery pass — you get a smaller model that's also dramatically worse, and you'll wish you'd just shipped the quantized original.

What structured pruning actually does

"Structured" means the unit you remove is a whole tensor-aligned block, not an individual scalar weight. There are three common granularities, listed coarsest to finest:

All three cases end with smaller matrices. nn.Linear(4096, 11008) becomes nn.Linear(4096, 8192). PyTorch sees a smaller tensor; cuBLAS does a smaller matmul; the inference engine runs faster. This is the win that matters.

Why unstructured pruning rarely pays off in practice

Unstructured pruning zeros out individual weights inside a matrix, leaving the matrix the same shape. On paper a 70% sparse weight matrix needs 70% fewer multiplies. In practice, on commodity GPUs and on Apple Silicon, your matmul is still being dispatched to a dense kernel — cuBLAS, cuDNN, MPS — which happily multiplies by zero billions of times and reports identical wall-clock latency to the un-pruned version. Sparse-aware kernels exist (NVIDIA's 2:4 structured sparsity on Ampere+, some research kernels for higher sparsity), but coverage across the stack you actually use — your tokenizer wrapper, your inference server, your quantization library — is patchy at best.

For an SLM you intend to ship on a real device, treat unstructured pruning as a research technique and structured pruning as the production one. The rest of this lesson is about the production version.

Importance scoring: how do you decide what to cut?

Once you've picked a granularity, you need a score per unit. Three families dominate, in order of cost and accuracy:

For most SLM pruning work, gradient × weight on a calibration set of 256–2,048 examples drawn from your real distribution is the right starting point. The calibration set matters more than the scoring formula — score on data that looks like deployment and you'll keep the units that matter for deployment.

The prune-to-target, then recover workflow

The workflow has four steps:

  1. Pick a target. Either a parameter count (e.g. 7B → 4.5B), a latency budget (e.g. p50 first-token under 200 ms on your target hardware), or a memory ceiling. Concrete targets — not "make it smaller" — keep the rest of the process honest.
  2. Score importance on the calibration set at the chosen granularity.
  3. Drop the lowest-scoring heads / layers / channels until the target is hit. Re-shape the parameter tensors so the model actually has fewer parameters; this is the step where the matmuls get smaller.
  4. Recover. Run a short SFT pass (or, better, a teacher-distilled pass using the pre-pruning model as the teacher — the offline KD recipe in Lesson 4.4 plugs in directly) on a small, high-quality dataset. The model regains most of the quality the pruning step cost. Skip this step and you're shipping a damaged model.
# Recovery pass after structured pruning — short SFT with a small LR
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_from_disk
from trl import SFTConfig, SFTTrainer

model = AutoModelForCausalLM.from_pretrained("checkpoints/pruned-4_5B")
tok   = AutoTokenizer.from_pretrained("checkpoints/pruned-4_5B")
ds    = load_from_disk("data/recovery_sft")            # ~5–50k high-quality rows

trainer = SFTTrainer(
    model=model, processing_class=tok, train_dataset=ds,
    args=SFTConfig(output_dir="checkpoints/pruned-4_5B-recovered",
        per_device_train_batch_size=4, gradient_accumulation_steps=8,
        learning_rate=1e-5, num_train_epochs=1, lr_scheduler_type="cosine",
        warmup_ratio=0.03, bf16=True, max_seq_length=2048, report_to=[]),
)
trainer.train(); trainer.save_model("checkpoints/pruned-4_5B-recovered")

Two notes on the recovery pass. First, the learning rate is small — you're nudging the surviving weights back into a good basin, not retraining from scratch; 1e-5 to 5e-5 is the usual band. Second, the dataset is small but clean. Sheared-LLaMA's published recipe uses on the order of tens of billions of tokens for recovery on a 2.7B–1.3B-class prune; an SFT recovery on a narrower task can be much smaller, but the floor is "enough to see the loss curve flatten and the eval metrics rebound."

When it's worth it for SLMs

Rarely. The honest framing: SLMs are already small. The first compression knobs to turn are quantization (Lesson 4.6) and LoRA distillation (Lesson 4.4), both of which usually deliver larger wall-clock wins with smaller quality cost than pruning does. Pruning is the third lever, and it earns its place only when the first two have been pushed and the budget still isn't met.

Honest beat — pruning without recovery training tanks quality fast

Aggressive structured pruning — say, dropping 25–40% of heads, layers, or channels — typically costs 10–30 points on substrate benchmarks (MMLU, HellaSwag, the usual suspects) before the recovery pass. The recovery pass is what makes the technique competitive; it is also where most of the compute and data cost lives. If you cannot afford the recovery training — the GPU hours, the curated dataset, the eval loop to confirm the rebound — do not prune. Ship the quantized un-pruned model and revisit when you have the budget. A correctly quantized 7B will beat a badly recovered pruned 4.5B every time.

Tooling pointers

You will almost never write structured pruning from scratch — the bookkeeping (re-shaping every dependent tensor consistently, re-registering the model's config, keeping the tokenizer and chat template valid) is painful and error-prone. Use a published recipe:

Key idea

Structured pruning is a three-step bargain: shrink the matrices (heads / layers / channels), pay a quality tax, then pay a recovery training bill to claw most of it back. For SLMs it is the third compression lever, not the first — reach for it only after quantization (Lesson 4.6) and distillation (Lesson 4.4) have been pushed and you still need to be smaller or faster. If 4-bit quantization already hits your latency budget, the right move is to skip pruning entirely.

Key terms

Structured pruning
Removing whole tensor-aligned units — heads, layers, or channels — so the model's parameter matrices physically shrink and matmuls become smaller.
Unstructured pruning
Zeroing out individual weights inside a matrix; only translates to wall-clock speed-ups with sparse-aware kernels, which are not standard on commodity GPUs.
Head pruning
Removing whole attention heads from each block; the Q/K/V/O projections shrink consistently and the per-block attention compute drops.
Layer pruning
Deleting whole transformer blocks; coarsest cut, easiest to implement, biggest absolute speed-up per unit pruned.
Channel / MLP-width pruning
Shrinking the intermediate dimension of the MLP block; up_proj and gate_proj lose output rows and down_proj loses input columns.
Recovery training
A short SFT or teacher-distilled pass run after pruning to claw back most of the quality lost in the prune step; not optional.
Importance scoring
Per-unit score (magnitude, gradient × weight, or Hessian-based) used to decide which heads / layers / channels to drop.
Calibration set (pruning sense)
A small dataset — typically 256–2,048 examples drawn from your real deployment distribution — used to compute importance scores under realistic input statistics.
Sparse matmul kernels
Specialised matmul implementations that skip zero weights for real wall-clock gain; rare in the commodity-GPU stack, which is why unstructured pruning usually fails to translate FLOP savings into latency savings.

Check yourself

Answers are saved to this browser.

Progress is stored locally in your browser.