Skip to content

Conversation

@STWMichael
Copy link
Collaborator

@STWMichael STWMichael commented Nov 12, 2025

Description of changes:
This PR adds support for Gumbel-Max Sampling Kernel in all architectures. The original kernel code is from flashinfer.

In order to support the sampling in MPK, we:

  1. Change BLOCK_THREADS from 1024 to 128/256.
  2. Set the grid dimension to always be 1 and, instead, takes a parameter called batch_size and iterates through all batches sequentially.

This PR also adds unit tests to verify the correctness of the kernel. It also provides a demo_sampling.py that shows the sampling effect during inference.

Related Issues:

#519

Linked Issues:

  • Issue #

Issues closed by this PR:

  • Closes #

@STWMichael STWMichael self-assigned this Nov 12, 2025
@STWMichael STWMichael changed the title Support Sampling in Mirage Supporting TopP and TopK sampling Nov 12, 2025
@yzh119
Copy link

yzh119 commented Nov 20, 2025

Please consider the latest version flashinfer-ai/flashinfer#2119 which significant improves performance.

@STWMichael STWMichael changed the title Supporting TopP and TopK sampling Support Gumbel-Max Sampling Kernel for Blackwell Nov 24, 2025
@STWMichael STWMichael marked this pull request as ready for review November 24, 2025 02:39
@STWMichael STWMichael changed the title Support Gumbel-Max Sampling Kernel for Blackwell Support Gumbel-Max Sampling Kernel Nov 24, 2025

template <uint32_t BLOCK_THREADS,
uint32_t VEC_SIZE,
int BATCH_SIZE,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need this template arg if we are passing in batch size as a function arg

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

constexpr int CONSUMER_NUM_THREADS = 128; // Grace Hopper setting
#else
// Default settings for other architectures
constexpr int WORKER_NUM_THREADS = 128;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also have WORKER_NUM_THREADS defined in the persistent_kernel.cuh file, not sure if it's a duplicated definition here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

Copy link
Collaborator

@JackFram JackFram left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sampling task implementation in this PR is already in a pretty good shape. I left some comments, and feel free to address them. Waiting for other reviewers input

@@ -0,0 +1,228 @@
/* Copyright 2023-2024 CMU
*
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per Apache 2.0 you have to keep the original copyright as well.
e.g.
Copyright 2023-2025 FlashInfer contributors

in addition to your own copyright.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I forgot to mention this. @STWMichae, please make sure to add the copyright before merging as well.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@jiazhihao jiazhihao merged commit 5b86df9 into mirage-project:mpk Nov 25, 2025
6 of 7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants