Skip to content

Commit b326906

Browse files
committed
moe lora early exit
Signed-off-by: gnovack <[email protected]>
1 parent 48eb8eb commit b326906

File tree

10 files changed

+133
-33
lines changed

10 files changed

+133
-33
lines changed

csrc/moe/moe_lora_align_sum_kernels.cu

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,16 @@ __global__ void moe_lora_align_sum_kernel(
2828
int64_t block_size, int num_experts, int max_loras, size_t numel,
2929
int max_num_tokens_padded, int max_num_m_blocks,
3030
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
31-
int topk_num, int32_t* total_tokens_post_pad) {
31+
int topk_num, int32_t* total_tokens_post_pad, int32_t* adapter_enabled,
32+
int32_t* lora_ids) {
3233
const size_t tokens_per_thread = div_ceil(numel, blockDim.x);
3334
const size_t start_idx = threadIdx.x * tokens_per_thread;
3435

35-
int lora_id = blockIdx.x;
36+
int lora_idx = blockIdx.x;
37+
int lora_id = lora_ids[lora_idx];
38+
if (lora_id == -1 || adapter_enabled[lora_id] == 0) {
39+
return;
40+
}
3641
extern __shared__ int32_t shared_mem[];
3742
int32_t* cumsum = shared_mem;
3843
token_cnts_t* tokens_cnts = (token_cnts_t*)(shared_mem + num_experts + 1);
@@ -121,14 +126,13 @@ __global__ void moe_lora_align_sum_kernel(
121126
}
122127
}
123128

124-
void moe_lora_align_block_size(torch::Tensor topk_ids,
125-
torch::Tensor token_lora_mapping,
126-
int64_t num_experts, int64_t block_size,
127-
int64_t max_loras, int64_t max_num_tokens_padded,
128-
int64_t max_num_m_blocks,
129-
torch::Tensor sorted_token_ids,
130-
torch::Tensor expert_ids,
131-
torch::Tensor num_tokens_post_pad) {
129+
void moe_lora_align_block_size(
130+
torch::Tensor topk_ids, torch::Tensor token_lora_mapping,
131+
int64_t num_experts, int64_t block_size, int64_t max_loras,
132+
int64_t max_num_tokens_padded, int64_t max_num_m_blocks,
133+
torch::Tensor sorted_token_ids, torch::Tensor expert_ids,
134+
torch::Tensor num_tokens_post_pad, torch::Tensor adapter_enabled,
135+
torch::Tensor lora_ids) {
132136
const int topk_num = topk_ids.size(1);
133137

134138
TORCH_CHECK(block_size > 0, "block_size should be greater than 0. ");
@@ -164,6 +168,7 @@ void moe_lora_align_block_size(torch::Tensor topk_ids,
164168
max_loras, topk_ids.numel(), max_num_tokens_padded,
165169
max_num_m_blocks, sorted_token_ids.data_ptr<int32_t>(),
166170
expert_ids.data_ptr<int32_t>(), topk_num,
167-
num_tokens_post_pad.data_ptr<int32_t>());
171+
num_tokens_post_pad.data_ptr<int32_t>(),
172+
adapter_enabled.data_ptr<int32_t>(), lora_ids.data_ptr<int32_t>());
168173
});
169174
}

csrc/moe/moe_ops.h

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,13 @@ void batched_moe_align_block_size(int64_t max_tokens_per_batch,
2020
torch::Tensor expert_ids,
2121
torch::Tensor num_tokens_post_pad);
2222

23-
void moe_lora_align_block_size(torch::Tensor topk_ids,
24-
torch::Tensor token_lora_mapping,
25-
int64_t num_experts, int64_t block_size,
26-
int64_t max_loras, int64_t max_num_tokens_padded,
27-
int64_t max_num_m_blocks,
28-
torch::Tensor sorted_token_ids,
29-
torch::Tensor expert_ids,
30-
torch::Tensor num_tokens_post_pad);
23+
void moe_lora_align_block_size(
24+
torch::Tensor topk_ids, torch::Tensor token_lora_mapping,
25+
int64_t num_experts, int64_t block_size, int64_t max_loras,
26+
int64_t max_num_tokens_padded, int64_t max_num_m_blocks,
27+
torch::Tensor sorted_token_ids, torch::Tensor expert_ids,
28+
torch::Tensor num_tokens_post_pad, torch::Tensor adapter_enabled,
29+
torch::Tensor lora_ids);
3130
#ifndef USE_ROCM
3231
torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
3332
torch::Tensor b_qweight, torch::Tensor b_scales,

csrc/moe/torch_bindings.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
4444
" int max_num_m_blocks, "
4545
" Tensor !sorted_token_ids,"
4646
" Tensor !experts_ids,"
47-
" Tensor !num_tokens_post_pad) -> () ");
47+
" Tensor !num_tokens_post_pad,"
48+
" Tensor !adapter_enabled,"
49+
" Tensor !lora_ids) -> () ");
4850
m.impl("moe_lora_align_block_size", torch::kCUDA, &moe_lora_align_block_size);
4951

5052
#ifndef USE_ROCM

tests/lora/test_fused_moe_lora_kernel.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@ def use_fused_moe_lora_kernel(
134134
)
135135
expert_ids = torch.empty((max_loras * max_num_m_blocks,), dtype=torch.int32)
136136
num_tokens_post_padded = torch.empty((max_loras,), dtype=torch.int32)
137+
adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32)
138+
lora_ids = torch.arange(max_loras + 2, dtype=torch.int32)
137139

138140
# call kernel
139141
ops.moe_lora_align_block_size(
@@ -147,6 +149,8 @@ def use_fused_moe_lora_kernel(
147149
sorted_token_ids,
148150
expert_ids,
149151
num_tokens_post_padded,
152+
adapter_enabled,
153+
lora_ids,
150154
)
151155

152156
config = {
@@ -172,6 +176,8 @@ def use_fused_moe_lora_kernel(
172176
num_tokens_post_padded,
173177
max_lora_rank,
174178
top_k_num,
179+
lora_ids,
180+
adapter_enabled,
175181
config["BLOCK_SIZE_M"],
176182
config["BLOCK_SIZE_N"],
177183
config["BLOCK_SIZE_K"],

tests/lora/test_moe_lora_align_sum.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ def test_moe_lora_align_block_size(
6060
(max_loras * max_num_m_blocks,), num_experts, dtype=torch.int32, device="cuda"
6161
)
6262
num_tokens_post_pad = torch.zeros((max_loras,), dtype=torch.int32, device="cuda")
63+
adapter_enabled = torch.ones((max_loras + 1,), dtype=torch.int32, device="cuda")
64+
lora_ids = torch.arange(max_loras + 2, dtype=torch.int32, device="cuda")
6365

6466
# call kernel
6567
ops.moe_lora_align_block_size(
@@ -73,6 +75,8 @@ def test_moe_lora_align_block_size(
7375
sorted_token_ids,
7476
expert_ids,
7577
num_tokens_post_pad,
78+
adapter_enabled,
79+
lora_ids,
7680
)
7781

7882
# verify values

tests/lora/test_olmoe_tp.py

Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
45
import vllm
56
from vllm.lora.request import LoRARequest
67

@@ -28,8 +29,17 @@
2829
"SELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501
2930
]
3031

32+
EXPECTED_BASE_MODEL_OUTPUT = [
33+
"SELECT COUNT(Candidate_ID) FROM candidate",
34+
"SELECT COUNT(Candidate_ID) FROM candidate",
35+
"SELECT Candidate_ID, COUNT(*) as Total_Candidates\nFROM candidate\nINNER JOIN people ON candidate.People_ID = people.People_ID", # noqa: E501
36+
"SELECT Candidate_ID, Poll_Source FROM candidate WHERE People_ID IN (SELECT People_ID FROM people) ORDER BY COUNT(*) DESC LIMIT 1", # noqa: E501
37+
]
38+
3139

32-
def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None:
40+
def generate_and_test(
41+
llm: vllm.LLM, lora_path: str, lora_id: list[int | None] | int | None
42+
) -> None:
3343
prompts = [
3444
PROMPT_TEMPLATE.format(context="How many candidates are there?"),
3545
PROMPT_TEMPLATE.format(context="Count the number of candidates."),
@@ -40,12 +50,18 @@ def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None:
4050
context="Return the poll resource associated with the most candidates."
4151
),
4252
]
53+
54+
lora_request = None
55+
if isinstance(lora_id, int):
56+
lora_request = LoRARequest(str(lora_id), lora_id, lora_path)
57+
elif isinstance(lora_id, list):
58+
lora_request = [
59+
LoRARequest(str(i), i, lora_path) if i is not None else None
60+
for i in lora_id
61+
]
62+
4363
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=64)
44-
outputs = llm.generate(
45-
prompts,
46-
sampling_params,
47-
lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None,
48-
)
64+
outputs = llm.generate(prompts, sampling_params, lora_request=lora_request)
4965
# Print the outputs.
5066
generated_texts: list[str] = []
5167
for output in outputs:
@@ -55,7 +71,13 @@ def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None:
5571
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
5672

5773
for i in range(len(EXPECTED_LORA_OUTPUT)):
58-
assert generated_texts[i].startswith(EXPECTED_LORA_OUTPUT[i])
74+
req_lora_id = lora_id[i] if isinstance(lora_id, list) else lora_id
75+
expected_output = (
76+
EXPECTED_LORA_OUTPUT[i]
77+
if req_lora_id is not None
78+
else EXPECTED_BASE_MODEL_OUTPUT[i]
79+
)
80+
assert generated_texts[i].startswith(expected_output)
5981

6082

6183
def test_olmoe_lora(olmoe_lora_files):
@@ -75,6 +97,34 @@ def test_olmoe_lora(olmoe_lora_files):
7597
generate_and_test(llm, olmoe_lora_files, lora_id=2)
7698

7799

100+
def test_olmoe_lora_base_model(olmoe_lora_files):
101+
llm = vllm.LLM(
102+
MODEL_PATH,
103+
max_model_len=1024,
104+
enable_lora=True,
105+
max_loras=4,
106+
enforce_eager=True,
107+
trust_remote_code=True,
108+
enable_chunked_prefill=True,
109+
)
110+
111+
generate_and_test(llm, olmoe_lora_files, lora_id=None)
112+
113+
114+
def test_olmoe_lora_mixed(olmoe_lora_files):
115+
llm = vllm.LLM(
116+
MODEL_PATH,
117+
max_model_len=1024,
118+
enable_lora=True,
119+
max_loras=4,
120+
enforce_eager=True,
121+
trust_remote_code=True,
122+
enable_chunked_prefill=True,
123+
)
124+
125+
generate_and_test(llm, olmoe_lora_files, lora_id=[1, None, 3, None])
126+
127+
78128
@multi_gpu_test(num_gpus=2)
79129
def test_olmoe_lora_tp2(olmoe_lora_files):
80130
llm = vllm.LLM(

vllm/_custom_ops.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1815,6 +1815,8 @@ def moe_lora_align_block_size(
18151815
sorted_token_ids: torch.Tensor,
18161816
experts_ids: torch.Tensor,
18171817
num_tokens_post_pad: torch.Tensor,
1818+
adapter_enabled: torch.Tensor,
1819+
lora_ids: torch.Tensor,
18181820
) -> None:
18191821
torch.ops._moe_C.moe_lora_align_block_size(
18201822
topk_ids,
@@ -1827,6 +1829,8 @@ def moe_lora_align_block_size(
18271829
sorted_token_ids,
18281830
experts_ids,
18291831
num_tokens_post_pad,
1832+
adapter_enabled,
1833+
lora_ids,
18301834
)
18311835

18321836

vllm/lora/layers/fused_moe.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def wrapper(*args, **kwargs):
120120
config["BLOCK_SIZE_M"],
121121
global_num_experts,
122122
max_loras,
123+
self.adapter_enabled,
123124
expert_map,
124125
)
125126

@@ -147,6 +148,7 @@ def wrapper(*args, **kwargs):
147148
max_lora_rank,
148149
top_k,
149150
config,
151+
self.adapter_enabled,
150152
)
151153

152154
result = func(*args, **kwargs)
@@ -205,6 +207,7 @@ def wrapper(*args, **kwargs):
205207
max_lora_rank,
206208
top_k,
207209
config,
210+
self.adapter_enabled,
208211
True,
209212
)
210213

@@ -239,6 +242,9 @@ def create_lora_weights(
239242
assert not self.base_layer.use_ep, (
240243
"EP support for Fused MoE LoRA is not implemented yet."
241244
)
245+
self.adapter_enabled = torch.tensor(
246+
[0] * (max_loras + 1), dtype=torch.int, device=self.device
247+
)
242248

243249
self.w1_lora_a_stacked = torch.zeros(
244250
(
@@ -326,6 +332,7 @@ def reset_lora(self, index: int):
326332
self.w3_lora_b_stacked[index] = 0
327333
self.w2_lora_a_stacked[index] = 0
328334
self.w2_lora_b_stacked[index] = 0
335+
self.adapter_enabled[index] = 0
329336

330337
def set_lora(
331338
self,
@@ -335,8 +342,9 @@ def set_lora(
335342
embeddings_tensor: torch.Tensor | None,
336343
bias: torch.Tensor | None = None,
337344
):
338-
self.reset_lora(index)
339345
"""Overwrites lora tensors at index."""
346+
self.reset_lora(index)
347+
self.adapter_enabled[index] = 1
340348
for eid in range(len(lora_a) // 3):
341349
w1_lora_a = lora_a[eid * 3]
342350
w2_lora_a = lora_a[eid * 3 + 1]

vllm/lora/ops/triton_ops/fused_moe_lora_op.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ def _fused_moe_lora_kernel(
5454
EM,
5555
num_valid_tokens,
5656
num_experts,
57+
lora_ids,
58+
adapter_enabled,
5759
# The stride variables represent how much to increase the ptr by when
5860
# moving by 1 element in a particular dimension. E.g. `stride_am` is
5961
# how much to increase `a_ptr` by to get the element one row down
@@ -84,6 +86,11 @@ def _fused_moe_lora_kernel(
8486
pid = tl.program_id(axis=0)
8587
slice_id = tl.program_id(axis=1)
8688
lora_idx = tl.program_id(axis=2)
89+
lora_id = tl.load(lora_ids + lora_idx)
90+
moe_enabled = tl.load(adapter_enabled + lora_id)
91+
if lora_id == -1 or moe_enabled == 0:
92+
# Early exit for the no-lora case.
93+
return
8794
max_loras = tl.num_programs(axis=2)
8895
grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)
8996

@@ -97,12 +104,12 @@ def _fused_moe_lora_kernel(
97104
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
98105
pid_n = (pid % num_pid_in_group) // group_size_m
99106

100-
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr + lora_idx)
107+
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr + lora_id)
101108
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
102109
return
103110

104111
# get the expert_id to process curr shard
105-
ind = lora_idx * stride_el + pid_m
112+
ind = lora_id * stride_el + pid_m
106113
expert_id = tl.load(expert_ids_ptr + ind, ind < max_loras * stride_el, -1)
107114
if expert_id == -1:
108115
return
@@ -116,7 +123,7 @@ def _fused_moe_lora_kernel(
116123
offs_k = tl.arange(0, BLOCK_SIZE_K)
117124

118125
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
119-
token_ind = stride_tl * lora_idx + offs_token_id
126+
token_ind = stride_tl * lora_id + offs_token_id
120127
offs_token = tl.load(
121128
sorted_token_ids_ptr + token_ind, token_ind < max_loras * stride_tl, 0
122129
)
@@ -129,7 +136,7 @@ def _fused_moe_lora_kernel(
129136

130137
b_ptrs = (
131138
cur_b_ptr
132-
+ lora_idx * stride_bl
139+
+ lora_id * stride_bl
133140
+ expert_id * stride_be
134141
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
135142
)
@@ -180,6 +187,8 @@ def _fused_moe_lora(
180187
num_tokens_post_padded: torch.Tensor, # (max_loras, )
181188
max_lora_rank: int,
182189
top_k_num: int,
190+
lora_ids: torch.Tensor,
191+
adapter_enabled: torch.Tensor,
183192
block_size_m: int,
184193
block_size_n: int,
185194
block_size_k: int,
@@ -268,6 +277,8 @@ def _fused_moe_lora(
268277
EM,
269278
num_tokens,
270279
num_experts,
280+
lora_ids,
281+
adapter_enabled,
271282
qcurr_hidden_states.stride(0),
272283
qcurr_hidden_states.stride(1),
273284
w1_lora_a_stacked.stride(0),
@@ -315,6 +326,8 @@ def _fused_moe_lora(
315326
EM,
316327
num_tokens,
317328
num_experts,
329+
lora_ids,
330+
adapter_enabled,
318331
a_intermediate_cache1.stride(0),
319332
a_intermediate_cache1.stride(1),
320333
w1_lora_b_stacked.stride(0),
@@ -348,6 +361,8 @@ def _fused_moe_lora_fake(
348361
num_tokens_post_padded: torch.Tensor,
349362
max_lora_rank: int,
350363
top_k_num: int,
364+
lora_ids: torch.Tensor,
365+
adapter_enabled: torch.Tensor,
351366
block_size_m: int,
352367
block_size_n: int,
353368
block_size_k: int,

0 commit comments

Comments
 (0)