-
Notifications
You must be signed in to change notification settings - Fork 70
Description
Is your feature request related to a problem? Please describe.
Current pruning algorithms in Pruna (e.g. torch_structured, torch_unstructured) focus on intra-layer compression — removing neurons, channels, or heads.
I came across a paper (e.g., ShortGPT: Layers in Large Language Models are More Redundant Than You Expect, Baichuan Inc) shows that many transformer layers contribute minimally to the model’s overall function, especially in pre-norm architectures like LLaMA or Baichuan, which is inter-layer compression.
Without layer-level pruning support, users can’t exploit this redundancy to reduce depth and latency while keeping accuracy.
Describe the solution you'd like
Add a new pruning algorithm: shortgpt_layer_pruner, implementing the ShortGPT approach for layer-level structured pruning based on Block Influence (BI).
Overview
ShortGPT identifies and removes redundant transformer layers by measuring how much each layer actually transforms its input hidden states.
Layers whose inputs and outputs are nearly identical are considered low-impact and can be safely removed without significant performance loss.
Core Algorithm
-
Run calibration forward pass on a small unlabeled dataset (e.g. PG19).
-
Collect hidden states for each transformer block:
h₀, h₁, …, h_L. -
Compute Block Influence (BI) for each layer i:
BI_i = 1 - mean(cosine_similarity(h_i, h_{i+1}))
(Optionally support angular distance variant.)
-
Rank layers by BI (lower = more redundant).
-
Remove lowest-BI layers according to a configured ratio or count.
-
(Optional) Post-train lightweight fine-tuning or replace removed layers with small MLP adapters to recover minor performance loss.
Integration with Pruna
-
Implement as a subclass of
PrunaPruner, e.g.class ShortGPTPruner(PrunaPruner): algorithm_name = "shortgpt" ...
-
Register under the
PRUNERalgorithm group. -
Compatible with quantizers like
torchao,half, andhqq(since ShortGPT is orthogonal to quantization). -
Expected hyperparameters in
get_hyperparameters():num_layers_to_prune: integer (default: 8)angular_metric: boolean (default: False)calibration_samples: integer (default: 64)pruning_ratio: float (default: 0.25)
Expected Benefits
- ~25% parameter and compute reduction with minimal (<10%) accuracy drop.
- No gradient or label requirements — runs purely from forward activations.
- Simpler than gradient-based or dependency-graph pruning.
- Extends easily to non-transformer models (RWKV, Mamba).
- Orthogonal to quantization, enabling compound compression (layer removal + 4-bit GPTQ).
Additional context
Paper link: https://arxiv.org/abs/2403.03853
@sharpenb Let me know your thoughts on this. If you think this could be a useful addition, assign me and I can send in my PR.