Skip to content

Commit 4c34ce8

Browse files
authored
[Kernel] Remove marlin moe templating on thread_m_blocks (#8573)
Co-authored-by: [email protected]
1 parent 0d47bf3 commit 4c34ce8

File tree

1 file changed

+28
-51
lines changed

1 file changed

+28
-51
lines changed

csrc/moe/marlin_moe_ops.cu

Lines changed: 28 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1342,9 +1342,6 @@ __device__ inline void MarlinMoESingle(
13421342

13431343
template <const vllm::ScalarTypeId w_type_id, // weight ScalarType id
13441344
const int threads, // number of threads in a threadblock
1345-
const int thread_m_blocks, // number of 16x16 blocks in the m
1346-
// dimension (batchsize) of the
1347-
// threadblock
13481345
const int thread_n_blocks, // same for n dimension (output)
13491346
const int thread_k_blocks, // same for k dimension (reduction)
13501347
const int stages, // number of stages for the async global->shared
@@ -1459,9 +1456,6 @@ __global__ void compute_expert_offsets(int const* __restrict__ topk_ids,
14591456

14601457
template <const vllm::ScalarTypeId w_type_id, // weight ScalarType id
14611458
const int threads, // number of threads in a threadblock
1462-
const int thread_m_blocks, // number of 16x16 blocks in the m
1463-
// dimension (batchsize) of the
1464-
// threadblock
14651459
const int thread_n_blocks, // same for n dimension (output)
14661460
const int thread_k_blocks, // same for k dimension (reduction)
14671461
const int stages, // number of stages for the async global->shared
@@ -1515,26 +1509,24 @@ const int STAGES = 4; // 4 pipeline stages fit into shared memory
15151509
static constexpr int min_thread_n = 64;
15161510
static constexpr int min_thread_k = 64;
15171511

1518-
#define __CALL_IF_MOE(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
1519-
THREAD_K_BLOCKS, HAS_ACT_ORDER, GROUP_BLOCKS, \
1520-
NUM_THREADS) \
1521-
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
1522-
thread_n_blocks == THREAD_N_BLOCKS && \
1523-
thread_k_blocks == THREAD_K_BLOCKS && \
1524-
has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \
1525-
num_threads == NUM_THREADS) { \
1526-
cudaFuncSetAttribute( \
1527-
MarlinMoE<W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
1528-
THREAD_K_BLOCKS, STAGES, HAS_ACT_ORDER, GROUP_BLOCKS>, \
1529-
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
1530-
MarlinMoE<W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
1531-
THREAD_K_BLOCKS, STAGES, HAS_ACT_ORDER, GROUP_BLOCKS> \
1532-
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
1533-
A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \
1534-
g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \
1535-
num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \
1536-
replicate_input, apply_weights, m_block, max_par, \
1537-
exec_cfg.max_m_blocks); \
1512+
#define __CALL_IF_MOE(W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, HAS_ACT_ORDER, \
1513+
GROUP_BLOCKS, NUM_THREADS) \
1514+
else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \
1515+
thread_k_blocks == THREAD_K_BLOCKS && \
1516+
has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \
1517+
num_threads == NUM_THREADS) { \
1518+
cudaFuncSetAttribute( \
1519+
MarlinMoE<W_TYPE.id(), NUM_THREADS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
1520+
STAGES, HAS_ACT_ORDER, GROUP_BLOCKS>, \
1521+
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
1522+
MarlinMoE<W_TYPE.id(), NUM_THREADS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
1523+
STAGES, HAS_ACT_ORDER, GROUP_BLOCKS> \
1524+
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
1525+
A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \
1526+
g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \
1527+
num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \
1528+
replicate_input, apply_weights, m_block, max_par, \
1529+
exec_cfg.max_m_blocks); \
15381530
}
15391531

15401532
typedef struct {
@@ -1711,31 +1703,16 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
17111703
return exec_config_t{0, {-1, -1, -1}};
17121704
}
17131705

1714-
#define CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
1715-
__CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
1716-
__CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
1717-
__CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
1718-
__CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
1719-
\
1720-
__CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
1721-
__CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
1722-
__CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
1723-
__CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
1724-
\
1725-
__CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
1726-
__CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
1727-
__CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
1728-
__CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
1729-
\
1730-
__CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
1731-
__CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
1732-
__CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
1733-
__CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
1734-
\
1735-
__CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
1736-
__CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
1737-
__CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
1738-
__CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)
1706+
#define CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
1707+
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
1708+
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
1709+
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
1710+
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
1711+
\
1712+
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
1713+
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
1714+
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
1715+
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)
17391716

17401717
void marlin_mm_moe_f16i4(const void* A, const void* B, void* C,
17411718
const void* sorted_ids, const void* topk_weights,

0 commit comments

Comments
 (0)