-
Notifications
You must be signed in to change notification settings - Fork 68
Description
Based on #364 (comment)
An adaptive inference algorithm that dynamically reduces the number of tokens processed in transformer-based
models (vision, language, and multimodal) by:
- Pruning redundant or low-importance tokens during forward passes
- Merging similar tokens to reduce computational overhead
- Adapting the pruning strategy based on input complexity
This would be a new optimization category or could extend the existing pruner to support dynamic token-level
pruning.
How it Improves Workflows
For Vision Models (ViT, DINO, etc.):
- Reduces FLOPs by 30-50% with <1% accuracy drop
- Especially effective for high-resolution images where many patches are redundant (backgrounds, uniform regions)
For Language Models:
- Speeds up long-context inference by removing padding and redundant tokens
- Reduces KV cache size in attention mechanisms
For Diffusion Models:
- Can prune latent tokens in U-Net architectures during denoising steps
- Complements caching by reducing computation per cached step
Workflow Benefits
from pruna import smash, SmashConfig
smash_config = SmashConfig()
smash_config["quantizer"] = "hqq"
smash_config["compiler"] = "torch_compile"
smashed_model = smash(model=base_model, smash_config=smash_config)
# Proposed workflow - add adaptive token optimization
smash_config["token_optimizer"] = {
"method": "tome", # Token Merging
"reduction_ratio": 0.3, # Reduce tokens by 30%
"adaptive": True # Adjust based on input complexity
}
smashed_model = smash(model=base_model, smash_config=smash_config)
Examples & References
Key Algorithms to Implement:
- ToMe (Token Merging) - https://arxiv.org/abs/2210.09461
- Merges similar tokens in ViTs using bipartite matching
- Already has open-source implementation: https://github.com/facebookresearch/ToMe - DynamicViT - https://arxiv.org/abs/2106.02034
- Learns which tokens to keep/drop per layer
- Requires minimal fine-tuning - AdaViT - https://arxiv.org/abs/2111.15668
- Adaptive token selection with auxiliary networks - ATS (Adaptive Token Sampling) for LLMs - https://arxiv.org/abs/2310.11589
- Dynamic token dropping for decoder-only models
Implementation Strategy
class TokenPruner(BaseAlgorithm):
def __init__(self, reduction_ratio=0.3, method="tome"):
self.reduction_ratio = reduction_ratio
self.method = method
def smash(self, model):
# Inject token merging/pruning modules into attention layers
for layer in model.transformer_layers:
layer.register_forward_hook(self.token_reduction_hook)
return model
def token_reduction_hook(self, module, input, output):
# Apply ToMe or pruning strategy
tokens, attention_weights = output
important_tokens = self.select_tokens(tokens, attention_weights)
return important_tokens
Performance Expectations
| Model Type | Speedup | Memory Reduction | Quality Impact |
|---|---|---|---|
| ViT-B/16 | 1.5-2x | 20-30% | <1% accuracy |
| CLIP | 1.4x | 25% | <2% on retrieval |
| Stable Diffusion | 1.3x | 15% | <0.5 FID increase |
Why This Matters
- Complements existing algorithms: Works alongside quantization, caching, and compilation
- Zero/minimal training: ToMe requires no fine-tuning; others need light calibration
- Proven effectiveness: Multiple papers demonstrate real-world gains
- Broad applicability: Works across vision, language, and diffusion models
This would position Pruna as having one of the most comprehensive optimization toolkits, covering not just model
compression but also dynamic inference optimization.
✅ Acceptance Criteria
- It follows the style guidelines.
- Tests are created and pass.
- The algorithm integrates properly into Pruna.
And don’t forget to give us a ⭐️!
❓ Questions?
Feel free to jump into the #contributing Discord channel if you hit any roadblocks. Can’t wait to see your contribution! 🚀