Structured pruning: removing heads, layers, and channels
After this lesson you can decide whether structured pruning is worth attempting on your SLM, pick the right granularity (head / layer / channel), score importance on a calibration set, run the prune-then-recover workflow, and recognise when 4-bit quantization has already eaten the budget pruning would have given you.
Structured pruning is the version of pruning that actually shows up in wall-clock benchmarks. You pick a granularity — a whole attention head, a whole transformer layer, or a whole MLP channel — score every unit at that granularity for importance on a calibration set, drop the lowest-scoring ones until you hit a target size, then run a short SFT or distillation pass to recover the quality you just damaged. Done well, you get a model with smaller matrices, fewer FLOPs, lower memory, and a believable latency win. Done badly — especially without the recovery pass — you get a smaller model that's also dramatically worse, and you'll wish you'd just shipped the quantized original.
What structured pruning actually does
"Structured" means the unit you remove is a whole tensor-aligned block, not an individual scalar weight. There are three common granularities, listed coarsest to finest:
- Layer pruning. Drop entire transformer blocks. The model goes from N layers to N − k. Coarsest cut, biggest absolute speed-up per unit pruned, and the easiest to implement — you literally delete blocks from
model.model.layers. ShortGPT and similar layer-importance studies show that, in many trained LLMs, several middle layers can be removed with surprisingly little quality loss before recovery. - Head pruning. Remove whole attention heads inside each block. With H heads of dimension d_head, you're shrinking the Q/K/V projections from
H · d_headback to(H − k) · d_head. The shapes ofq_proj,k_proj,v_proj, ando_projall change consistently. Real wins, finer-grained than layer pruning. - Channel / MLP-width pruning. Inside each MLP block, the intermediate dimension d_ff shrinks.
up_projandgate_projlose output rows;down_projloses input columns. Because MLPs hold the bulk of an SLM's parameters (often 60–70% in modern Llama-style architectures), MLP-width pruning is where most of the parameter savings come from.
All three cases end with smaller matrices. nn.Linear(4096, 11008) becomes nn.Linear(4096, 8192). PyTorch sees a smaller tensor; cuBLAS does a smaller matmul; the inference engine runs faster. This is the win that matters.
Why unstructured pruning rarely pays off in practice
Unstructured pruning zeros out individual weights inside a matrix, leaving the matrix the same shape. On paper a 70% sparse weight matrix needs 70% fewer multiplies. In practice, on commodity GPUs and on Apple Silicon, your matmul is still being dispatched to a dense kernel — cuBLAS, cuDNN, MPS — which happily multiplies by zero billions of times and reports identical wall-clock latency to the un-pruned version. Sparse-aware kernels exist (NVIDIA's 2:4 structured sparsity on Ampere+, some research kernels for higher sparsity), but coverage across the stack you actually use — your tokenizer wrapper, your inference server, your quantization library — is patchy at best.
For an SLM you intend to ship on a real device, treat unstructured pruning as a research technique and structured pruning as the production one. The rest of this lesson is about the production version.
Importance scoring: how do you decide what to cut?
Once you've picked a granularity, you need a score per unit. Three families dominate, in order of cost and accuracy:
- Magnitude. Score a head, layer, or channel by the L2 (or L1) norm of its weights. Cheap, requires no data, and often misleading. Large-magnitude weights can sit in dead pathways that the rest of the network has learned to route around; small-magnitude weights can be the load-bearing ones for a specific behaviour. Magnitude is a reasonable first pass for a baseline; it is not a reliable production criterion.
- Gradient × weight (Taylor / SNIP-style). Run a calibration set through the model and compute, per unit, the first-order Taylor estimate of the loss change if that unit were set to zero — usually
|w · ∂L/∂w|aggregated over the unit. Requires a few hundred to a few thousand examples representative of the deployment distribution. LLM-Pruner uses a variant of this for head and channel scoring. It is the practical sweet spot: significantly better signal than magnitude, far cheaper than second-order methods, and tractable on a single GPU. - Hessian-based / OBD / WoodFisher. Use second-order curvature information (Optimal Brain Damage / Optimal Brain Surgeon style, or the WoodFisher Fisher-matrix approximation). Most accurate signal per unit, considerably more expensive to compute, and finicky to implement correctly. Worth it when you're squeezing the last few percent of quality out of a heavily-pruned model.
For most SLM pruning work, gradient × weight on a calibration set of 256–2,048 examples drawn from your real distribution is the right starting point. The calibration set matters more than the scoring formula — score on data that looks like deployment and you'll keep the units that matter for deployment.
The prune-to-target, then recover workflow
The workflow has four steps:
- Pick a target. Either a parameter count (e.g. 7B → 4.5B), a latency budget (e.g. p50 first-token under 200 ms on your target hardware), or a memory ceiling. Concrete targets — not "make it smaller" — keep the rest of the process honest.
- Score importance on the calibration set at the chosen granularity.
- Drop the lowest-scoring heads / layers / channels until the target is hit. Re-shape the parameter tensors so the model actually has fewer parameters; this is the step where the matmuls get smaller.
- Recover. Run a short SFT pass (or, better, a teacher-distilled pass using the pre-pruning model as the teacher — the offline KD recipe in Lesson 4.4 plugs in directly) on a small, high-quality dataset. The model regains most of the quality the pruning step cost. Skip this step and you're shipping a damaged model.
# Recovery pass after structured pruning — short SFT with a small LR
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_from_disk
from trl import SFTConfig, SFTTrainer
model = AutoModelForCausalLM.from_pretrained("checkpoints/pruned-4_5B")
tok = AutoTokenizer.from_pretrained("checkpoints/pruned-4_5B")
ds = load_from_disk("data/recovery_sft") # ~5–50k high-quality rows
trainer = SFTTrainer(
model=model, processing_class=tok, train_dataset=ds,
args=SFTConfig(output_dir="checkpoints/pruned-4_5B-recovered",
per_device_train_batch_size=4, gradient_accumulation_steps=8,
learning_rate=1e-5, num_train_epochs=1, lr_scheduler_type="cosine",
warmup_ratio=0.03, bf16=True, max_seq_length=2048, report_to=[]),
)
trainer.train(); trainer.save_model("checkpoints/pruned-4_5B-recovered")
Two notes on the recovery pass. First, the learning rate is small — you're nudging the surviving weights back into a good basin, not retraining from scratch; 1e-5 to 5e-5 is the usual band. Second, the dataset is small but clean. Sheared-LLaMA's published recipe uses on the order of tens of billions of tokens for recovery on a 2.7B–1.3B-class prune; an SFT recovery on a narrower task can be much smaller, but the floor is "enough to see the loss curve flatten and the eval metrics rebound."
When it's worth it for SLMs
Rarely. The honest framing: SLMs are already small. The first compression knobs to turn are quantization (Lesson 4.6) and LoRA distillation (Lesson 4.4), both of which usually deliver larger wall-clock wins with smaller quality cost than pruning does. Pruning is the third lever, and it earns its place only when the first two have been pushed and the budget still isn't met.
- 7B → 4.5B with structured pruning + recovery is a believable target, and is roughly the regime Sheared-LLaMA published. You'll lose a measurable amount of quality even with a careful recovery pass — usually a few points on aggregate benchmarks — but the speed and memory wins are real.
- Pruning a 135M model is almost never worth it. The remaining capacity is too small to absorb the damage; the recovery pass struggles; the wall-clock saving is a millisecond no one will notice. Quantize it to 4-bit and move on.
- If 4-bit quantization already meets your latency budget, skip pruning entirely. The compounded engineering cost — score, prune, recover, re-evaluate, re-quantize — is rarely justified by the marginal gains over a clean quantized model.
Honest beat — pruning without recovery training tanks quality fast
Aggressive structured pruning — say, dropping 25–40% of heads, layers, or channels — typically costs 10–30 points on substrate benchmarks (MMLU, HellaSwag, the usual suspects) before the recovery pass. The recovery pass is what makes the technique competitive; it is also where most of the compute and data cost lives. If you cannot afford the recovery training — the GPU hours, the curated dataset, the eval loop to confirm the rebound — do not prune. Ship the quantized un-pruned model and revisit when you have the budget. A correctly quantized 7B will beat a badly recovered pruned 4.5B every time.
Tooling pointers
You will almost never write structured pruning from scratch — the bookkeeping (re-shaping every dependent tensor consistently, re-registering the model's config, keeping the tokenizer and chat template valid) is painful and error-prone. Use a published recipe:
- Sheared-LLaMA (Xia et al., 2023) — the canonical recipe for structured pruning a 7B Llama down to 2.7B / 1.3B with a multi-stage recovery. Published code, dataset mix, and learning-rate schedule. Start here if you want a working baseline.
- ShortGPT (Men et al., 2024) — layer-importance scoring with a "Block Influence" metric. The cleanest read on how few layers some models actually need.
- LLM-Pruner (Ma et al., 2023) — gradient × weight scoring with a CLI for head and channel pruning across common architectures. Good for SLM-scale experiments.
- torch.nn.utils.prune — PyTorch's built-in pruning utilities. Useful for the mechanical re-shaping and mask-and-remove steps; you still bring the importance scoring and the recovery pass.
Key idea
Structured pruning is a three-step bargain: shrink the matrices (heads / layers / channels), pay a quality tax, then pay a recovery training bill to claw most of it back. For SLMs it is the third compression lever, not the first — reach for it only after quantization (Lesson 4.6) and distillation (Lesson 4.4) have been pushed and you still need to be smaller or faster. If 4-bit quantization already hits your latency budget, the right move is to skip pruning entirely.
Key terms
- Structured pruning
- Removing whole tensor-aligned units — heads, layers, or channels — so the model's parameter matrices physically shrink and matmuls become smaller.
- Unstructured pruning
- Zeroing out individual weights inside a matrix; only translates to wall-clock speed-ups with sparse-aware kernels, which are not standard on commodity GPUs.
- Head pruning
- Removing whole attention heads from each block; the Q/K/V/O projections shrink consistently and the per-block attention compute drops.
- Layer pruning
- Deleting whole transformer blocks; coarsest cut, easiest to implement, biggest absolute speed-up per unit pruned.
- Channel / MLP-width pruning
- Shrinking the intermediate dimension of the MLP block;
up_projandgate_projlose output rows anddown_projloses input columns. - Recovery training
- A short SFT or teacher-distilled pass run after pruning to claw back most of the quality lost in the prune step; not optional.
- Importance scoring
- Per-unit score (magnitude, gradient × weight, or Hessian-based) used to decide which heads / layers / channels to drop.
- Calibration set (pruning sense)
- A small dataset — typically 256–2,048 examples drawn from your real deployment distribution — used to compute importance scores under realistic input statistics.
- Sparse matmul kernels
- Specialised matmul implementations that skip zero weights for real wall-clock gain; rare in the commodity-GPU stack, which is why unstructured pruning usually fails to translate FLOP savings into latency savings.
Check yourself
Answers are saved to this browser.