Skip to content

Commit 1e96c33

Browse files
authored
Add extra punica sizes to support bigger vocabs (#4015)
1 parent 95e7d4a commit 1e96c33

File tree

5 files changed

+109
-48
lines changed

5 files changed

+109
-48
lines changed

csrc/punica/bgmv/bgmv_config.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,17 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
6060
f(in_T, out_T, W_T, narrow, 33024) \
6161
f(in_T, out_T, W_T, narrow, 36864) \
6262
f(in_T, out_T, W_T, narrow, 49152) \
63-
// Keep above in sync with vllm/lora/layers::SamplerWithLoRA
63+
f(in_T, out_T, W_T, narrow, 64000) \
64+
f(in_T, out_T, W_T, narrow, 64256) \
65+
f(in_T, out_T, W_T, narrow, 64512) \
66+
f(in_T, out_T, W_T, narrow, 102400) \
67+
f(in_T, out_T, W_T, narrow, 102656) \
68+
f(in_T, out_T, W_T, narrow, 102912) \
69+
f(in_T, out_T, W_T, narrow, 128000) \
70+
f(in_T, out_T, W_T, narrow, 128256) \
71+
f(in_T, out_T, W_T, narrow, 128512) \
72+
// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA
73+
// and vllm/tests/lora/test_punica.py
6474

6575
// Keep this in sync with vllm/config::LoRAConfig
6676
#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \

csrc/punica/punica_ops.cc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ inline void check_shape(const torch::Tensor &a, const torch::Tensor &b,
2020
}
2121
}
2222

23-
inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) {
24-
return (uint32_t(a) << 16) | uint32_t(b);
23+
inline constexpr uint64_t pack_u32(uint32_t a, uint32_t b) {
24+
return (uint64_t(a) << 32) | uint64_t(b);
2525
}
2626

2727
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
@@ -46,13 +46,13 @@ inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) {
4646
template <typename in_T, typename out_T, typename W_T>
4747
inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
4848
const int64_t *lora_indices,
49-
uint16_t in_features, uint16_t out_features,
49+
uint32_t in_features, uint32_t out_features,
5050
int64_t y_offset, int64_t full_y_size,
5151
int64_t batch_size, int64_t num_layers,
5252
int64_t layer_idx, float scale) {
53-
switch (pack_u16(in_features, out_features)) {
53+
switch (pack_u32(in_features, out_features)) {
5454
#define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \
55-
case pack_u16(feat_in, feat_out): \
55+
case pack_u32(feat_in, feat_out): \
5656
bgmv_kernel<feat_in, feat_out>(Y, X, W, lora_indices, y_offset, \
5757
full_y_size, batch_size, num_layers, \
5858
layer_idx, scale); \
@@ -93,7 +93,7 @@ void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
9393
CHECK_EQ(y.size(0), x.size(0));
9494
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
9595
bool ok = false;
96-
if (h_in < 65536 && h_out < 65536) {
96+
if (h_in <= 128512 && h_out <= 128512) {
9797
// TODO: See if we can get rid of this massive nested switch
9898
switch (x.scalar_type()) {
9999
case at::ScalarType::Half:
@@ -325,7 +325,7 @@ void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
325325
CHECK_EQ(y.size(0), x.size(0));
326326
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
327327
bool ok = false;
328-
if (h_in < 65536 && h_out < 65536) {
328+
if (h_in <= 128512 && h_out <= 128512) {
329329
// TODO: See if we can get rid of this massive nested switch
330330
switch (x.scalar_type()) {
331331
case at::ScalarType::Half:

tests/lora/test_layers.py

Lines changed: 44 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,8 @@ def create_random_inputs(
170170
@torch.inference_mode()
171171
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
172172
@pytest.mark.parametrize("device", CUDA_DEVICES)
173-
def test_embeddings(dist_init, num_loras, device) -> None:
173+
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
174+
def test_embeddings(dist_init, num_loras, device, vocab_size) -> None:
174175

175176
torch.set_default_device(device)
176177
max_loras = 8
@@ -179,9 +180,9 @@ def test_embeddings(dist_init, num_loras, device) -> None:
179180
lora_dtype=torch.float16)
180181

181182
def create_random_embedding_layer():
182-
embedding = VocabParallelEmbedding(512, 256)
183+
embedding = VocabParallelEmbedding(vocab_size, 256)
183184
embedding.weight.data = torch.rand_like(embedding.weight.data)
184-
embedding.weight.data[512:, :] = 0
185+
embedding.weight.data[vocab_size:, :] = 0
185186
lora_embedding = VocabParallelEmbeddingWithLoRA(embedding)
186187
lora_embedding.create_lora_weights(max_loras, lora_config)
187188

@@ -203,12 +204,13 @@ def create_random_embedding_layer():
203204
active_lora_ids=list(lora_dict.keys()),
204205
num_inputs=num_loras * 3,
205206
input_size=(200, ),
206-
input_range=(1, 512),
207+
input_range=(1, vocab_size),
207208
)
208209
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
209210

210211
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
211-
512, lora_config.lora_extra_vocab_size)
212+
vocab_size,
213+
lora_config.lora_extra_vocab_size)
212214
lora_embedding.set_mapping(*mapping_info)
213215

214216
lora_result = lora_embedding(torch.cat(inputs))
@@ -240,12 +242,13 @@ def create_random_embedding_layer():
240242
active_lora_ids=[0],
241243
num_inputs=num_loras * 3,
242244
input_size=(200, ),
243-
input_range=(1, 512),
245+
input_range=(1, vocab_size),
244246
)
245247
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
246248

247249
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
248-
512, lora_config.lora_extra_vocab_size)
250+
vocab_size,
251+
lora_config.lora_extra_vocab_size)
249252
lora_embedding.set_mapping(*mapping_info, )
250253

251254
lora_result = lora_embedding(torch.cat(inputs))
@@ -263,7 +266,9 @@ def create_random_embedding_layer():
263266
# reason="Fails when loras are in any slot other than the first.")
264267
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
265268
@pytest.mark.parametrize("device", CUDA_DEVICES)
266-
def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
269+
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
270+
def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
271+
vocab_size) -> None:
267272

268273
torch.set_default_device(device)
269274
max_loras = 8
@@ -272,15 +277,15 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
272277
lora_dtype=torch.float16)
273278

274279
def create_random_embedding_layer():
275-
embedding = VocabParallelEmbedding(512, 256)
280+
embedding = VocabParallelEmbedding(vocab_size, 256)
276281
embedding_data = torch.rand_like(embedding.weight.data)
277282
embedding.weight.data = embedding_data
278-
embedding.weight.data[512:, :] = 0
283+
embedding.weight.data[vocab_size:, :] = 0
279284
expanded_embedding = VocabParallelEmbedding(
280-
512 + lora_config.lora_extra_vocab_size * max_loras,
285+
vocab_size + lora_config.lora_extra_vocab_size * max_loras,
281286
256,
282-
org_num_embeddings=512)
283-
expanded_embedding.weight.data[:512, :] = embedding_data
287+
org_num_embeddings=vocab_size)
288+
expanded_embedding.weight.data[:vocab_size, :] = embedding_data
284289
# We need to deepcopy the embedding as it will be modified
285290
# in place
286291
lora_embedding = VocabParallelEmbeddingWithLoRA(
@@ -298,7 +303,7 @@ def create_random_embedding_layer():
298303
id_to_index,
299304
layer=lora_embedding,
300305
layer_weights=torch.zeros(
301-
(256, 512 + lora_config.lora_extra_vocab_size)),
306+
(256, vocab_size + lora_config.lora_extra_vocab_size)),
302307
generate_embeddings_tensor=256,
303308
)
304309

@@ -316,7 +321,7 @@ def create_random_embedding_layer():
316321
active_lora_ids=list(lora_dict.keys()),
317322
num_inputs=num_loras * 3,
318323
input_size=(200, ),
319-
input_range=(1, 512),
324+
input_range=(1, vocab_size),
320325
)
321326
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
322327

@@ -327,16 +332,18 @@ def create_random_embedding_layer():
327332
for input_, original_input_, lora_id in zip(inputs, original_inputs,
328333
prompt_mapping):
329334
embedding_id = lora_id - 1
330-
input_[-1] = 512 + (embedding_id * embeddings_tensor_len)
331-
original_input_[-1] = 512
332-
input_[-2] = 512 + ((embedding_id + 1) * embeddings_tensor_len - 1)
333-
original_input_[-2] = 512 + embeddings_tensor_len - 1
335+
input_[-1] = vocab_size + (embedding_id * embeddings_tensor_len)
336+
original_input_[-1] = vocab_size
337+
input_[-2] = vocab_size + (
338+
(embedding_id + 1) * embeddings_tensor_len - 1)
339+
original_input_[-2] = vocab_size + embeddings_tensor_len - 1
334340

335341
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
336-
512, lora_config.lora_extra_vocab_size)
342+
vocab_size,
343+
lora_config.lora_extra_vocab_size)
337344
lora_embedding.set_mapping(*mapping_info, )
338345

339-
expanded_embedding.weight[512:512 +
346+
expanded_embedding.weight[vocab_size:vocab_size +
340347
(embeddings_tensor_len *
341348
max_loras)] = torch.cat(embeddings_tensors)
342349

@@ -370,14 +377,15 @@ def create_random_embedding_layer():
370377
active_lora_ids=[0],
371378
num_inputs=num_loras * 3,
372379
input_size=(200, ),
373-
input_range=(1, 512),
380+
input_range=(1, vocab_size),
374381
)
375382
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
376383

377384
original_inputs = deepcopy(inputs)
378385

379386
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
380-
512, lora_config.lora_extra_vocab_size)
387+
vocab_size,
388+
lora_config.lora_extra_vocab_size)
381389
lora_embedding.set_mapping(*mapping_info, )
382390

383391
lora_result = lora_embedding(torch.cat(original_inputs))
@@ -393,7 +401,9 @@ def create_random_embedding_layer():
393401
@torch.inference_mode()
394402
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
395403
@pytest.mark.parametrize("device", CUDA_DEVICES)
396-
def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
404+
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
405+
def test_lm_head_logits_processor(dist_init, num_loras, device,
406+
vocab_size) -> None:
397407

398408
torch.set_default_device(device)
399409
max_loras = 8
@@ -402,12 +412,12 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
402412
lora_dtype=torch.float16)
403413

404414
def _pretest():
405-
linear = ParallelLMHead(32000 + lora_config.lora_extra_vocab_size,
406-
1024, 32000)
415+
linear = ParallelLMHead(vocab_size + lora_config.lora_extra_vocab_size,
416+
1024, vocab_size)
407417
linear.weight.data = torch.rand_like(linear.weight.data)
408-
linear.weight.data[:, 32000:] = 0
418+
linear.weight.data[:, vocab_size:] = 0
409419
logits_processor = LogitsProcessor(
410-
32000 + lora_config.lora_extra_vocab_size, 32000)
420+
vocab_size + lora_config.lora_extra_vocab_size, vocab_size)
411421
lora_logits_processor = LogitsProcessorWithLoRA(
412422
logits_processor, 1024, linear.weight.dtype, linear.weight.device)
413423
lora_logits_processor.create_lora_weights(max_loras, lora_config)
@@ -444,7 +454,7 @@ def _pretest():
444454
lora_mapping,
445455
id_to_index,
446456
max_loras,
447-
32000,
457+
vocab_size,
448458
lora_config.lora_extra_vocab_size,
449459
)
450460
lora_logits_processor.set_mapping(*mapping_info, )
@@ -460,19 +470,19 @@ def _pretest():
460470
org_vocab_size:logits_processor.org_vocab_size +
461471
embeddings_tensor_len] = embeddings_tensor
462472

463-
logits_processor.org_vocab_size = (32000 +
473+
logits_processor.org_vocab_size = (vocab_size +
464474
lora_config.lora_extra_vocab_size)
465475
expected_results = []
466476
for input_, lora_id in zip(inputs, prompt_mapping):
467477
lora = lora_dict[lora_id]
468478
result = logits_processor._get_logits(hidden_states=input_,
469479
embedding=linear.weight,
470480
embedding_bias=None)
471-
result[:, 32000 + embeddings_tensor_len:] = float("-inf")
481+
result[:, vocab_size + embeddings_tensor_len:] = float("-inf")
472482
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
473483
expected_results.append(result)
474484
expected_result = torch.cat(expected_results)
475-
logits_processor.org_vocab_size = 32000
485+
logits_processor.org_vocab_size = vocab_size
476486

477487
# Check that resetting the lora weights succeeds
478488

@@ -489,14 +499,14 @@ def _pretest():
489499
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
490500

491501
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
492-
32000,
502+
vocab_size,
493503
lora_config.lora_extra_vocab_size)
494504
lora_logits_processor.set_mapping(*mapping_info, )
495505

496506
lora_result = lora_logits_processor._get_logits(
497507
hidden_states=torch.cat(inputs),
498508
embedding=original_weight,
499-
embedding_bias=None)[:, :32000]
509+
embedding_bias=None)[:, :vocab_size]
500510
expected_result = logits_processor._get_logits(
501511
hidden_states=torch.cat(inputs),
502512
embedding=original_weight,

tests/lora/test_punica.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,51 @@ def _lora_ref_impl(
4343

4444

4545
H1 = H2 = [
46-
128, 256, 512, 1024, 1152, 1280, 1536, 2048, 2304, 2560, 2752, 3072, 3456,
47-
3584, 4096, 4608, 5120, 5504, 5632, 6144, 6848, 6912, 7168, 8192, 9216,
48-
10240, 11008, 13824, 14336, 22016, 24576, 27392, 32000, 32256, 32512,
49-
32768, 33024
46+
128,
47+
256,
48+
512,
49+
1024,
50+
1152,
51+
1280,
52+
1536,
53+
2048,
54+
2304,
55+
2560,
56+
2752,
57+
3072,
58+
3456,
59+
3584,
60+
4096,
61+
4608,
62+
5120,
63+
5504,
64+
5632,
65+
6144,
66+
6848,
67+
6912,
68+
7168,
69+
8192,
70+
9216,
71+
10240,
72+
11008,
73+
13824,
74+
14336,
75+
22016,
76+
24576,
77+
27392,
78+
32000,
79+
32256,
80+
32512,
81+
32768,
82+
33024,
83+
36864,
84+
49152,
85+
64000,
86+
64256,
87+
102400,
88+
102656,
89+
128000,
90+
128256,
5091
]
5192
SEED = [0xabcdabcd987]
5293
CUDA_DEVICES = [

vllm/lora/layers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -935,9 +935,9 @@ def create_lora_weights(
935935
model_config: Optional[PretrainedConfig] = None,
936936
) -> None:
937937
# Keep this in sync with csrc/punica/bgmv/bgmv_config.h
938-
if 32000 < self.base_layer.vocab_size > 33024:
938+
if 32000 < self.base_layer.vocab_size > 128512:
939939
raise ValueError("When using LoRA, vocab size must be "
940-
"32000 >= vocab_size <= 33024")
940+
"32000 >= vocab_size <= 128512")
941941
self.lora_a_stacked = torch.zeros(
942942
(
943943
max_loras,

0 commit comments

Comments
 (0)