Skip to content

Conversation

@tscholak
Copy link
Collaborator

@tscholak tscholak commented Oct 12, 2025

Summary

Implements a stochastic mixer layer for supernet training, enabling random sampling from multiple mixer options (e.g., attention vs. Mamba) during training. Includes checkpoint conversion support and a hierarchical beam search tool for finding optimal mixer placement post-training.

Implementation Details

Stochastic Mixer (fast_llm/layers/decoder/stochastic_mixer.py)

  • Training mode: Randomly samples from configured mixers using distributed RNG (ensures consistency across TP/PP ranks)
  • Eval/inference mode: Uses main_mixer_index for deterministic behavior
  • Sampling strategies: Uniform or weighted sampling
  • Preprocessing: Runs preprocessing for all mixers since we don't know which will be selected

Configuration (fast_llm/layers/decoder/config.py)

  • StochasticMixerConfig: List-based mixer configuration with sampling strategy
  • main_mixer_index: Specifies which mixer to use during inference and which receives pretrained weights during checkpoint conversion
  • Validation ensures sampling weights sum to 1.0 and all indices are valid

Checkpoint Conversion (fast_llm/models/gpt/conversion/apriel.py)

  • AprielStochasticMixerConverter: Handles conversion between Fast-LLM and Apriel formats
  • Only main_mixer_index weights are exported/imported (other mixers randomly initialized during supernet training)
  • Follows existing converter patterns (minimal, no verbose comments)

Beam Search Tool (tools/supernet_beam_search.py)

  • Hierarchical algorithm: Finds optimal placement for mixers at each quality/cost level
    • Phase 1: Find best N layers for primary mixer (e.g., full attention)
    • Phase 2: Find best M layers for secondary mixer (e.g., sliding window attention)
    • Remaining layers use tertiary mixer (e.g., linear attention)
  • Efficient evaluation: Loads checkpoint once, modifies main_mixer_index in-place for each candidate
  • Fast-LLM integration: Uses Fast-LLM's evaluation system directly (no subprocess or checkpoint reconversion)
  • Features: Pre-scoring, beam growth, early stopping, configurable score direction

Tests (tests/utils/model_configs.py)

  • Added stochastic_mixer test configuration with FA/Mamba mixers
  • Enabled checkpoint conversion testing via AprielHybridSSMCheckpointFormat

Use Case

Supernet Training: Train a model where each layer can be either full attention or Mamba, with random sampling at each step. After training, use beam search to find which specific layers benefit most from full attention vs. Mamba, given a budget constraint (e.g., "I can afford 4 FA layers").

Testing

Run the stochastic mixer tests:

pytest tests/models/test_checkpoint.py::test_checkpoint_and_eval tests/models/test_checkpoint.py::test_conversion -k "stochastic_mixer" -v

Example beam search usage:

fast-llm tools/supernet_beam_search.py \
  training_config=path/to/supernet_config.yaml \
  budgets=[4,8] \
  beam_width=12 \
  score_metric="lm_eval/accuracy" \
  output_path=results.json

🤖 Generated with Claude Code

Co-Authored-By: Claude [email protected]

Implements a stochastic mixer layer that randomly samples from multiple
mixer options during training, enabling supernet training where different
architecture variants (e.g., attention vs. Mamba) are trained with
different data subsets.

Key components:
- StochasticMixerConfig: Configuration for stochastic sampling strategy
  (uniform or weighted) with configurable main_mixer_index for inference
- StochasticMixer: Layer implementation with distributed RNG support
- Checkpoint conversion: Apriel converter handles stochastic mixers
- Beam search tool: Hierarchical beam search for optimal mixer placement

The beam search tool finds which layers benefit most from expensive mixers
(e.g., full attention) vs. efficient mixers (e.g., linear attention) by
evaluating different configurations using Fast-LLM's evaluation system.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
tscholak and others added 2 commits October 14, 2025 05:07
- Fix Assert.gt_len AttributeError by moving validation to _validate() method
- Add AttentionConfig import to models/auto.py for proper registration
- Mark all mixer parameters with allow_no_grad=True since only one mixer is active per forward pass

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
Fixed nested config structure bug in AprielStochasticMixerConverter.import_config
that was causing validation errors when loading Apriel checkpoints.

The converter was returning the entire block config (with mixer, mlp, and
normalization keys) instead of just the mixer config, causing these fields
to be incorrectly nested under the mixer field during import.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
Copy link
Collaborator

@jlamypoirier jlamypoirier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, some minor comments


with set_generator(generator):
# Sample from categorical distribution
idx = torch.multinomial(self._sampling_probs, num_samples=1).item()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This requires a costly cuda sync. How about we sample for all layers at once during preprocessing?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now done during preprocessing

mixer_idx = self._sample_mixer_index()

if self._debug.enabled:
logger.debug(f"StochasticMixer selecting mixer {mixer_idx}: {type(self.mixers[mixer_idx]).__name__}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ambiguous if multiple mixers share the same type. Use named mixers instead?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now using named mixers. we retrieve mixer_name from kwargs (line 151) and use it for logging (line 160) and accessing the correct mixer (line 163).

we need to preprocess for all of them. This includes things like
attention masks, rotary embeddings, etc.
"""
for mixer in self.mixers:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There could be name conflicts. Consider namespace?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now namespaced. see lines 214-216 where we prefix with f"{mixer_name}/{loss_def.name}".


return int(expected_usage)

def get_loss_definitions(self, count: int = 1) -> list[LossDef]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit dangerous, there could be name conflicts and counts will be wrong for averaging. Not sure how to fix though.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Acknowledged. The current approach ensures we allocate space for all possible losses, but you're right that counts won't match actual usage since only one mixer runs per forward pass. We could track which mixer was use and only record its losses, but that adds complexity. I think what we have is good enough for now.

return converter_class.mixer_converter_class.export_config(inference_mixer)

@classmethod
def get_converters(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about import? I don't think it will work.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import doesn't work as usual for stochastic mixers. we use drop_on_export=True for non-main mixers, and HF checkpoints only contain the main mixer. I think the correct way to handle this is to either support stochastic mixers in hf (out of scope) or initialize all other mixers randomly while importing only the main mixer.

mixer_converter_class.get_converters(
mixer,
f"{fast_llm_prefix}.mixers.{mixer_index}",
hf_prefix if is_main_mixer else None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hf_prefix. drop_on_export handles the rest.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now uses just hf_prefix without the mixer name prefix.

f"{hf_prefix}.{block_index}",
drop_on_export,
)
match config:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think match is warranted here, since it involves a (slow) initialization of configs.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uses if-else with instance type check now.

ModelTestingGroup.convert: ModelTestingGroupAction.normal,
ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented,
ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented,
ModelTestingGroup.distributed: ModelTestingGroupAction.normal,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's leave as unimportant. All this tests is the consistency of stochastic sampling, and I don't think that warrants the overhead of testing every time.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set to ModelTestingGroupAction.unimportant now

- Add _is_lossy_hf_conversion() utility to detect when HF conversion drops weights
- Skip incompatible tests (test_converted_round_trip, test_load_pretrained) for lossy conversions
- Check converters for IgnoreExportWeightConverter instances
- Factor out config loading into _load_config_from_test_dir() and _load_config_from_checkpoint()
- Export main_mixer_type in stochastic mixer config for HF compatibility
# Conflicts:
#	fast_llm/models/gpt/conversion/apriel.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants