Skip to content

[FEATURE] ShortGPT pruning algorithm #418

@sky-2002

Description

@sky-2002

Is your feature request related to a problem? Please describe.

Current pruning algorithms in Pruna (e.g. torch_structured, torch_unstructured) focus on intra-layer compression — removing neurons, channels, or heads.

I came across a paper (e.g., ShortGPT: Layers in Large Language Models are More Redundant Than You Expect, Baichuan Inc) shows that many transformer layers contribute minimally to the model’s overall function, especially in pre-norm architectures like LLaMA or Baichuan, which is inter-layer compression.

Without layer-level pruning support, users can’t exploit this redundancy to reduce depth and latency while keeping accuracy.

Describe the solution you'd like

Add a new pruning algorithm: shortgpt_layer_pruner, implementing the ShortGPT approach for layer-level structured pruning based on Block Influence (BI).

Overview

ShortGPT identifies and removes redundant transformer layers by measuring how much each layer actually transforms its input hidden states.
Layers whose inputs and outputs are nearly identical are considered low-impact and can be safely removed without significant performance loss.

Core Algorithm

  1. Run calibration forward pass on a small unlabeled dataset (e.g. PG19).

  2. Collect hidden states for each transformer block: h₀, h₁, …, h_L.

  3. Compute Block Influence (BI) for each layer i:

    BI_i = 1 - mean(cosine_similarity(h_i, h_{i+1}))

    (Optionally support angular distance variant.)

  4. Rank layers by BI (lower = more redundant).

  5. Remove lowest-BI layers according to a configured ratio or count.

  6. (Optional) Post-train lightweight fine-tuning or replace removed layers with small MLP adapters to recover minor performance loss.

Integration with Pruna

  • Implement as a subclass of PrunaPruner, e.g.

    class ShortGPTPruner(PrunaPruner):
        algorithm_name = "shortgpt"
        ...
  • Register under the PRUNER algorithm group.

  • Compatible with quantizers like torchao, half, and hqq (since ShortGPT is orthogonal to quantization).

  • Expected hyperparameters in get_hyperparameters():

    • num_layers_to_prune: integer (default: 8)
    • angular_metric: boolean (default: False)
    • calibration_samples: integer (default: 64)
    • pruning_ratio: float (default: 0.25)

Expected Benefits

  • ~25% parameter and compute reduction with minimal (<10%) accuracy drop.
  • No gradient or label requirements — runs purely from forward activations.
  • Simpler than gradient-based or dependency-graph pruning.
  • Extends easily to non-transformer models (RWKV, Mamba).
  • Orthogonal to quantization, enabling compound compression (layer removal + 4-bit GPTQ).

Additional context

Paper link: https://arxiv.org/abs/2403.03853

@sharpenb Let me know your thoughts on this. If you think this could be a useful addition, assign me and I can send in my PR.

Metadata

Metadata

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions