Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
690 changes: 690 additions & 0 deletions demo/qwen3/demo_sampling.py

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions include/mirage/kernel/task_register.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ class TaskRegister {
std::vector<int> const &params);
int register_argmax_reduce_sm100_task(threadblock::Graph const &bgraph,
std::vector<int> const &params);
int register_sampling_sm100_task(threadblock::Graph const &bgraph,
std::vector<int> const &params);
int register_tensor_init_task(threadblock::Graph const &bgraph,
std::vector<int> const &params);
int register_moe_topk_softmax_sm100_task(threadblock::Graph const &bgraph,
Expand Down
1 change: 1 addition & 0 deletions include/mirage/persistent_kernel/runtime_header.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ enum TaskType {
TASK_MOE_TOPK_SOFTMAX_SM100 = 260,
TASK_MOE_MUL_SUM_ADD_SM100 = 261,
TASK_TENSOR_INIT = 262,
TASK_SAMPLING_SM100 = 263,
TASK_SM100_TASK_END = 298, // SM100 end placeholder, not a real task
TASK_NVSHMEM_COPY = 199,
TASK_SCHD_TASKS = 200,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@
#include "linear_sm100_mpk.cuh"
#include "moe_linear_sm100.cuh"
#include "mul_sum_add_sm100.cuh"
#include "tasks/common/sampling.cuh"
#include "tensor_init.cuh"
#include "topk_softmax_sm100.cuh"
228 changes: 228 additions & 0 deletions include/mirage/persistent_kernel/tasks/common/sampling.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
/* Copyright (c) 2025 by CMU.
* Copyright (c) 2025 by FlashInfer team.
*
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per Apache 2.0 you have to keep the original copyright as well.
e.g.
Copyright 2023-2025 FlashInfer contributors

in addition to your own copyright.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I forgot to mention this. @STWMichae, please make sure to add the copyright before merging as well.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*
* Sampling from logits using Gumbel-Max trick
* Based on FlashInfer's sampling kernel (Apache License 2.0).
*/

#pragma once

#include <cub/cub.cuh>
#include <cuda.h>
#include <cuda/std/limits>
#include <curand.h>
#include <curand_kernel.h>
#include <curand_philox4x32_x.h>

namespace kernel {

using namespace cub;

// Helper function for ceiling division
template <typename T>
__host__ __device__ __forceinline__ T sampling_ceil_div(T a, T b) {
return (a + b - 1) / b;
}

/******************* vec_t - Simplified Vector Type *******************/

template <typename T, size_t vec_size>
struct sampling_vec_t {
T data[vec_size];

__device__ __forceinline__ T &operator[](size_t i) {
return data[i];
}
__device__ __forceinline__ T const &operator[](size_t i) const {
return data[i];
}

__device__ __forceinline__ void fill(T val) {
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
data[i] = val;
}
}

__device__ __forceinline__ void cast_load(T const *ptr) {
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
data[i] = ptr[i];
}
}
};

/******************* DataAndIndex Structure *******************/

template <typename DType, typename IdType>
struct SamplingDataAndIndex {
DType data;
IdType index;

__device__ SamplingDataAndIndex
operator+(SamplingDataAndIndex const &other) const {
if (data > other.data) {
return {data, index};
} else {
return {other.data, other.index};
}
}

__device__ SamplingDataAndIndex &
operator+=(SamplingDataAndIndex const &other) {
if (data > other.data) {
return *this;
} else {
data = other.data;
index = other.index;
return *this;
}
}
};

/******************* Gumbel Noise Generation *******************/

template <typename DType, uint32_t VEC_SIZE>
__device__ __forceinline__ sampling_vec_t<DType, VEC_SIZE>
GenerateSamplingGumbelNoise(uint64_t philox_seed,
uint64_t philox_offset,
uint64_t subsequence) {
curandStatePhilox4_32_10_t state;
sampling_vec_t<float, VEC_SIZE> noise;
constexpr float kEPSILON = 1e-20f;
constexpr float kLOG2 = 0.6931471806f;

auto uniform2gumbel = [](float x) {
return -kLOG2 * log2f(-log2f(x + kEPSILON) + kEPSILON);
};

#pragma unroll
for (uint32_t i = 0; i + 4 <= VEC_SIZE; i += 4) {
curand_init(philox_seed, subsequence + i, philox_offset, &state);
float4 noise_vec = curand_uniform4(&state);
noise[i] = uniform2gumbel(noise_vec.x);
noise[i + 1] = uniform2gumbel(noise_vec.y);
noise[i + 2] = uniform2gumbel(noise_vec.z);
noise[i + 3] = uniform2gumbel(noise_vec.w);
}

if constexpr (VEC_SIZE % 4 != 0) {
curand_init(
philox_seed, subsequence + VEC_SIZE / 4 * 4, philox_offset, &state);
float4 noise_vec = curand_uniform4(&state);
if constexpr (VEC_SIZE % 4 == 1) {
noise[VEC_SIZE - 1] = uniform2gumbel(noise_vec.x);
} else if constexpr (VEC_SIZE % 4 == 2) {
noise[VEC_SIZE - 2] = uniform2gumbel(noise_vec.x);
noise[VEC_SIZE - 1] = uniform2gumbel(noise_vec.y);
} else if constexpr (VEC_SIZE % 4 == 3) {
noise[VEC_SIZE - 3] = uniform2gumbel(noise_vec.x);
noise[VEC_SIZE - 2] = uniform2gumbel(noise_vec.y);
noise[VEC_SIZE - 1] = uniform2gumbel(noise_vec.z);
}
}

if constexpr (std::is_same_v<DType, float>) {
return noise;
} else {
sampling_vec_t<DType, VEC_SIZE> ret;
#pragma unroll
for (uint32_t i = 0; i < VEC_SIZE; ++i) {
ret[i] = static_cast<DType>(noise[i]);
}
return ret;
}
}

/******************* Sampling From Logits Kernel *******************/

constexpr BlockScanAlgorithm SAMPLING_SCAN_ALGO = BLOCK_SCAN_WARP_SCANS;
constexpr BlockReduceAlgorithm SAMPLING_REDUCE_ALGO =
BLOCK_REDUCE_WARP_REDUCTIONS;

template <uint32_t BLOCK_THREADS,
uint32_t VEC_SIZE,
typename DType,
typename IdType>
__device__ __forceinline__ void
sampling_from_logits_kernel(DType *logits,
IdType *output,
uint32_t d,
uint64_t philox_seed,
uint64_t philox_offset,
int batch_size) {
const uint32_t tx = threadIdx.x;

using SharedMem = typename BlockReduce<SamplingDataAndIndex<DType, IdType>,
BLOCK_THREADS,
SAMPLING_REDUCE_ALGO>::TempStorage;
extern __shared__ __align__(alignof(SharedMem)) uint8_t smem_sampling_logit[];
auto &temp_storage = reinterpret_cast<SharedMem &>(smem_sampling_logit);

// Loop over all batches
for (int batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
sampling_vec_t<DType, VEC_SIZE> logits_vec;
SamplingDataAndIndex<DType, IdType> max_data = {
-cuda::std::numeric_limits<DType>::infinity(), 0};

// Process logits in chunks with vectorized loads
for (uint32_t i = 0; i < sampling_ceil_div(d, BLOCK_THREADS * VEC_SIZE);
++i) {
logits_vec.fill(-cuda::std::numeric_limits<DType>::infinity());

// Load logits vector if within bounds
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
logits_vec.cast_load(logits + batch_idx * d +
i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
}

// Generate Gumbel noise
sampling_vec_t<DType, VEC_SIZE> gumbel_noise =
GenerateSamplingGumbelNoise<DType, VEC_SIZE>(
philox_seed,
philox_offset,
static_cast<uint64_t>(batch_idx * d +
(i * BLOCK_THREADS + tx) * VEC_SIZE));

// Add noise to logits and prepare for reduction
SamplingDataAndIndex<DType, IdType> cur_data[VEC_SIZE];
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
cur_data[j].data = (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d
? logits_vec[j] + gumbel_noise[j]
: -cuda::std::numeric_limits<DType>::infinity();
cur_data[j].index = (i * BLOCK_THREADS + tx) * VEC_SIZE + j;
}

// Find maximum across block
max_data += BlockReduce<SamplingDataAndIndex<DType, IdType>,
BLOCK_THREADS,
SAMPLING_REDUCE_ALGO>(temp_storage)
.template Sum<VEC_SIZE>(cur_data);
}

// Write output for this batch
if (tx == 0) {
output[batch_idx] = max_data.index;
}

// Sync before next batch iteration to reuse shared memory
__syncthreads();
}
}

} // namespace kernel
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
*/
#pragma once
#include "../common/utils.cuh"
#include "../common/worker_config.h"
namespace kernel {

template <typename T,
Expand Down
22 changes: 21 additions & 1 deletion python/mirage/persistent_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,7 +1009,27 @@ def argmax_reduce_layer(
self.kn_graph.register_task(
tb_graph, "argmax_reduce", [self.argmax_partial_output_size]
)


def sampling_sm100_layer(
self,
logits: DTensor, # [batch_size, vocab_size]
output: DTensor, # [batch_size, 1]
grid_dim: tuple,
block_dim: tuple,
seed: int = 42,
):
"""Sampling from logits using Gumbel-Max trick for stochastic token generation."""
assert logits.num_dims == 2 # (batch_size, vocab_size)
assert output.num_dims == 2 # (batch_size, 1)

tb_graph = TBGraph(CyTBGraph(grid_dim, block_dim, 1, 64))
tb_graph.new_input(logits, (0, -1, -1), -1, True)
tb_graph.new_input(output, (0, -1, -1), -1, True)
self.kn_graph.customized([logits, output], tb_graph)

# Register task with seed parameter
self.kn_graph.register_task(tb_graph, "sampling_sm100", [seed])

def find_ngram_partial_layer(
self, input: DTensor, output: DTensor, grid_dim: tuple, block_dim: tuple, ngram_size: int = 3):
# Currently assume that input/output
Expand Down
4 changes: 4 additions & 0 deletions src/kernel/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,10 @@ void Graph::register_task(char const *task_type, std::vector<int> params) {
customized->bgraph, params);
task_config[op] =
std::make_tuple(2, 1, TASK_ARGMAX_REDUCE_SM100, variant_id);
} else if (name == "sampling_sm100") {
int variant_id =
task_register->register_sampling_sm100_task(customized->bgraph, params);
task_config[op] = std::make_tuple(1, 1, TASK_SAMPLING_SM100, variant_id);
} else if (name == "tensor_init") {
int variant_id =
task_register->register_tensor_init_task(customized->bgraph, params);
Expand Down
1 change: 1 addition & 0 deletions src/kernel/runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1373,6 +1373,7 @@ TaskGraphResult print_task_graph(
task_type_to_name[TASK_ATTN_SM100] = "TASK_ATTN_SM100";
task_type_to_name[TASK_ARGMAX_PARTIAL_SM100] = "TASK_ARGMAX_PARTIAL_SM100";
task_type_to_name[TASK_ARGMAX_REDUCE_SM100] = "TASK_ARGMAX_REDUCE_SM100";
task_type_to_name[TASK_SAMPLING_SM100] = "TASK_SAMPLING_SM100";
task_type_to_name[TASK_TENSOR_INIT] = "TASK_TENSOR_INIT";
task_type_to_name[TASK_MOE_TOPK_SOFTMAX_SM100] =
"TASK_MOE_TOPK_SOFTMAX_SM100";
Expand Down
35 changes: 35 additions & 0 deletions src/kernel/task_register.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2051,6 +2051,41 @@ int TaskRegister::register_argmax_reduce_sm100_task(
return register_task_variant(TASK_ARGMAX_REDUCE_SM100, code.to_string());
}

int TaskRegister::register_sampling_sm100_task(threadblock::Graph const &bgraph,
std::vector<int> const &params) {
// params[0]: seed
assert(params.size() == 1);
std::vector<tb::TBInputOp *> input_ops;
std::vector<tb::TBInputOp *> output_ops;
int num_inputs = 1;
int num_outputs = 1;

assert(bgraph.operators.size() == (size_t)num_inputs + num_outputs);
for (auto const &op : bgraph.operators) {
assert(op->op_type == mirage::type::TB_INPUT_OP);
if (input_ops.size() < (size_t)num_inputs) {
input_ops.push_back(static_cast<tb::TBInputOp *>(op));
} else {
output_ops.push_back(static_cast<tb::TBInputOp *>(op));
}
}
assert(input_ops[0]->output_tensors[0].num_dims == 2);
int batch_size = input_ops[0]->output_tensors[0].dim[0];
int vocab_size = input_ops[0]->output_tensors[0].dim[1];
int seed = params[0];

mirage::transpiler::CodeKeeper code;
code.inc_indent();
code.e("kernel::sampling_from_logits_kernel<256, 4, bfloat16, int>(");
code.e(" static_cast<bfloat16*>(task_desc->input_ptrs[0]),");
code.e(" static_cast<int*>(task_desc->output_ptrs[0]),");
code.e(" $,", vocab_size);
code.e(" $,", seed);
code.e(" 0, // philox_offset");
code.e(" $);", batch_size);
return register_task_variant(TASK_SAMPLING_SM100, code.to_string());
}

int TaskRegister::register_tensor_init_task(threadblock::Graph const &bgraph,
std::vector<int> const &params) {
assert(params.size() == 0);
Expand Down
Loading
Loading