Skip to content

[FEATURE] Implement Token Merging #399

@sdiazlor

Description

@sdiazlor

Based on #364 (comment)


An adaptive inference algorithm that dynamically reduces the number of tokens processed in transformer-based
models (vision, language, and multimodal) by:

  1. Pruning redundant or low-importance tokens during forward passes
  2. Merging similar tokens to reduce computational overhead
  3. Adapting the pruning strategy based on input complexity

This would be a new optimization category or could extend the existing pruner to support dynamic token-level
pruning.

How it Improves Workflows

For Vision Models (ViT, DINO, etc.):

  • Reduces FLOPs by 30-50% with <1% accuracy drop
  • Especially effective for high-resolution images where many patches are redundant (backgrounds, uniform regions)

For Language Models:

  • Speeds up long-context inference by removing padding and redundant tokens
  • Reduces KV cache size in attention mechanisms

For Diffusion Models:

  • Can prune latent tokens in U-Net architectures during denoising steps
  • Complements caching by reducing computation per cached step

Workflow Benefits

from pruna import smash, SmashConfig

smash_config = SmashConfig()
smash_config["quantizer"] = "hqq"
smash_config["compiler"] = "torch_compile"
smashed_model = smash(model=base_model, smash_config=smash_config)

# Proposed workflow - add adaptive token optimization
smash_config["token_optimizer"] = {
    "method": "tome",  # Token Merging
    "reduction_ratio": 0.3,  # Reduce tokens by 30%
    "adaptive": True  # Adjust based on input complexity
}
smashed_model = smash(model=base_model, smash_config=smash_config)

Examples & References

Key Algorithms to Implement:

  1. ToMe (Token Merging) - https://arxiv.org/abs/2210.09461
    - Merges similar tokens in ViTs using bipartite matching
    - Already has open-source implementation: https://github.com/facebookresearch/ToMe
  2. DynamicViT - https://arxiv.org/abs/2106.02034
    - Learns which tokens to keep/drop per layer
    - Requires minimal fine-tuning
  3. AdaViT - https://arxiv.org/abs/2111.15668
    - Adaptive token selection with auxiliary networks
  4. ATS (Adaptive Token Sampling) for LLMs - https://arxiv.org/abs/2310.11589
    - Dynamic token dropping for decoder-only models

Implementation Strategy

 class TokenPruner(BaseAlgorithm):
     def __init__(self, reduction_ratio=0.3, method="tome"):
         self.reduction_ratio = reduction_ratio
         self.method = method

     def smash(self, model):
         # Inject token merging/pruning modules into attention layers
         for layer in model.transformer_layers:
             layer.register_forward_hook(self.token_reduction_hook)
         return model

     def token_reduction_hook(self, module, input, output):
         # Apply ToMe or pruning strategy
         tokens, attention_weights = output
         important_tokens = self.select_tokens(tokens, attention_weights)
         return important_tokens

Performance Expectations

Model Type Speedup Memory Reduction Quality Impact
ViT-B/16 1.5-2x 20-30% <1% accuracy
CLIP 1.4x 25% <2% on retrieval
Stable Diffusion 1.3x 15% <0.5 FID increase

Why This Matters

  • Complements existing algorithms: Works alongside quantization, caching, and compilation
  • Zero/minimal training: ToMe requires no fine-tuning; others need light calibration
  • Proven effectiveness: Multiple papers demonstrate real-world gains
  • Broad applicability: Works across vision, language, and diffusion models

This would position Pruna as having one of the most comprehensive optimization toolkits, covering not just model
compression but also dynamic inference optimization.

Acceptance Criteria

  • It follows the style guidelines.
  • Tests are created and pass.
  • The algorithm integrates properly into Pruna.

And don’t forget to give us a ⭐️!


❓ Questions?

Feel free to jump into the #contributing Discord channel if you hit any roadblocks. Can’t wait to see your contribution! 🚀


Share on Socials

Share on X

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions