@@ -170,7 +170,8 @@ def create_random_inputs(
170170@torch .inference_mode ()
171171@pytest .mark .parametrize ("num_loras" , [1 , 2 , 4 , 8 ])
172172@pytest .mark .parametrize ("device" , CUDA_DEVICES )
173- def test_embeddings (dist_init , num_loras , device ) -> None :
173+ @pytest .mark .parametrize ("vocab_size" , [512 , 32000 , 64000 , 128000 ])
174+ def test_embeddings (dist_init , num_loras , device , vocab_size ) -> None :
174175
175176 torch .set_default_device (device )
176177 max_loras = 8
@@ -179,9 +180,9 @@ def test_embeddings(dist_init, num_loras, device) -> None:
179180 lora_dtype = torch .float16 )
180181
181182 def create_random_embedding_layer ():
182- embedding = VocabParallelEmbedding (512 , 256 )
183+ embedding = VocabParallelEmbedding (vocab_size , 256 )
183184 embedding .weight .data = torch .rand_like (embedding .weight .data )
184- embedding .weight .data [512 :, :] = 0
185+ embedding .weight .data [vocab_size :, :] = 0
185186 lora_embedding = VocabParallelEmbeddingWithLoRA (embedding )
186187 lora_embedding .create_lora_weights (max_loras , lora_config )
187188
@@ -203,12 +204,13 @@ def create_random_embedding_layer():
203204 active_lora_ids = list (lora_dict .keys ()),
204205 num_inputs = num_loras * 3 ,
205206 input_size = (200 , ),
206- input_range = (1 , 512 ),
207+ input_range = (1 , vocab_size ),
207208 )
208209 lora_mapping = LoRAMapping (index_mapping , prompt_mapping )
209210
210211 mapping_info = convert_mapping (lora_mapping , id_to_index , max_loras ,
211- 512 , lora_config .lora_extra_vocab_size )
212+ vocab_size ,
213+ lora_config .lora_extra_vocab_size )
212214 lora_embedding .set_mapping (* mapping_info )
213215
214216 lora_result = lora_embedding (torch .cat (inputs ))
@@ -240,12 +242,13 @@ def create_random_embedding_layer():
240242 active_lora_ids = [0 ],
241243 num_inputs = num_loras * 3 ,
242244 input_size = (200 , ),
243- input_range = (1 , 512 ),
245+ input_range = (1 , vocab_size ),
244246 )
245247 lora_mapping = LoRAMapping (index_mapping , prompt_mapping )
246248
247249 mapping_info = convert_mapping (lora_mapping , id_to_index , max_loras ,
248- 512 , lora_config .lora_extra_vocab_size )
250+ vocab_size ,
251+ lora_config .lora_extra_vocab_size )
249252 lora_embedding .set_mapping (* mapping_info , )
250253
251254 lora_result = lora_embedding (torch .cat (inputs ))
@@ -263,7 +266,9 @@ def create_random_embedding_layer():
263266# reason="Fails when loras are in any slot other than the first.")
264267@pytest .mark .parametrize ("num_loras" , [1 , 2 , 4 , 8 ])
265268@pytest .mark .parametrize ("device" , CUDA_DEVICES )
266- def test_embeddings_with_new_embeddings (dist_init , num_loras , device ) -> None :
269+ @pytest .mark .parametrize ("vocab_size" , [512 , 32000 , 64000 , 128000 ])
270+ def test_embeddings_with_new_embeddings (dist_init , num_loras , device ,
271+ vocab_size ) -> None :
267272
268273 torch .set_default_device (device )
269274 max_loras = 8
@@ -272,15 +277,15 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
272277 lora_dtype = torch .float16 )
273278
274279 def create_random_embedding_layer ():
275- embedding = VocabParallelEmbedding (512 , 256 )
280+ embedding = VocabParallelEmbedding (vocab_size , 256 )
276281 embedding_data = torch .rand_like (embedding .weight .data )
277282 embedding .weight .data = embedding_data
278- embedding .weight .data [512 :, :] = 0
283+ embedding .weight .data [vocab_size :, :] = 0
279284 expanded_embedding = VocabParallelEmbedding (
280- 512 + lora_config .lora_extra_vocab_size * max_loras ,
285+ vocab_size + lora_config .lora_extra_vocab_size * max_loras ,
281286 256 ,
282- org_num_embeddings = 512 )
283- expanded_embedding .weight .data [:512 , :] = embedding_data
287+ org_num_embeddings = vocab_size )
288+ expanded_embedding .weight .data [:vocab_size , :] = embedding_data
284289 # We need to deepcopy the embedding as it will be modified
285290 # in place
286291 lora_embedding = VocabParallelEmbeddingWithLoRA (
@@ -298,7 +303,7 @@ def create_random_embedding_layer():
298303 id_to_index ,
299304 layer = lora_embedding ,
300305 layer_weights = torch .zeros (
301- (256 , 512 + lora_config .lora_extra_vocab_size )),
306+ (256 , vocab_size + lora_config .lora_extra_vocab_size )),
302307 generate_embeddings_tensor = 256 ,
303308 )
304309
@@ -316,7 +321,7 @@ def create_random_embedding_layer():
316321 active_lora_ids = list (lora_dict .keys ()),
317322 num_inputs = num_loras * 3 ,
318323 input_size = (200 , ),
319- input_range = (1 , 512 ),
324+ input_range = (1 , vocab_size ),
320325 )
321326 lora_mapping = LoRAMapping (index_mapping , prompt_mapping )
322327
@@ -327,16 +332,18 @@ def create_random_embedding_layer():
327332 for input_ , original_input_ , lora_id in zip (inputs , original_inputs ,
328333 prompt_mapping ):
329334 embedding_id = lora_id - 1
330- input_ [- 1 ] = 512 + (embedding_id * embeddings_tensor_len )
331- original_input_ [- 1 ] = 512
332- input_ [- 2 ] = 512 + ((embedding_id + 1 ) * embeddings_tensor_len - 1 )
333- original_input_ [- 2 ] = 512 + embeddings_tensor_len - 1
335+ input_ [- 1 ] = vocab_size + (embedding_id * embeddings_tensor_len )
336+ original_input_ [- 1 ] = vocab_size
337+ input_ [- 2 ] = vocab_size + (
338+ (embedding_id + 1 ) * embeddings_tensor_len - 1 )
339+ original_input_ [- 2 ] = vocab_size + embeddings_tensor_len - 1
334340
335341 mapping_info = convert_mapping (lora_mapping , id_to_index , max_loras ,
336- 512 , lora_config .lora_extra_vocab_size )
342+ vocab_size ,
343+ lora_config .lora_extra_vocab_size )
337344 lora_embedding .set_mapping (* mapping_info , )
338345
339- expanded_embedding .weight [512 : 512 +
346+ expanded_embedding .weight [vocab_size : vocab_size +
340347 (embeddings_tensor_len *
341348 max_loras )] = torch .cat (embeddings_tensors )
342349
@@ -370,14 +377,15 @@ def create_random_embedding_layer():
370377 active_lora_ids = [0 ],
371378 num_inputs = num_loras * 3 ,
372379 input_size = (200 , ),
373- input_range = (1 , 512 ),
380+ input_range = (1 , vocab_size ),
374381 )
375382 lora_mapping = LoRAMapping (index_mapping , prompt_mapping )
376383
377384 original_inputs = deepcopy (inputs )
378385
379386 mapping_info = convert_mapping (lora_mapping , id_to_index , max_loras ,
380- 512 , lora_config .lora_extra_vocab_size )
387+ vocab_size ,
388+ lora_config .lora_extra_vocab_size )
381389 lora_embedding .set_mapping (* mapping_info , )
382390
383391 lora_result = lora_embedding (torch .cat (original_inputs ))
@@ -393,7 +401,9 @@ def create_random_embedding_layer():
393401@torch .inference_mode ()
394402@pytest .mark .parametrize ("num_loras" , [1 , 2 , 4 , 8 ])
395403@pytest .mark .parametrize ("device" , CUDA_DEVICES )
396- def test_lm_head_logits_processor (dist_init , num_loras , device ) -> None :
404+ @pytest .mark .parametrize ("vocab_size" , [512 , 32000 , 64000 , 128000 ])
405+ def test_lm_head_logits_processor (dist_init , num_loras , device ,
406+ vocab_size ) -> None :
397407
398408 torch .set_default_device (device )
399409 max_loras = 8
@@ -402,12 +412,12 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
402412 lora_dtype = torch .float16 )
403413
404414 def _pretest ():
405- linear = ParallelLMHead (32000 + lora_config .lora_extra_vocab_size ,
406- 1024 , 32000 )
415+ linear = ParallelLMHead (vocab_size + lora_config .lora_extra_vocab_size ,
416+ 1024 , vocab_size )
407417 linear .weight .data = torch .rand_like (linear .weight .data )
408- linear .weight .data [:, 32000 :] = 0
418+ linear .weight .data [:, vocab_size :] = 0
409419 logits_processor = LogitsProcessor (
410- 32000 + lora_config .lora_extra_vocab_size , 32000 )
420+ vocab_size + lora_config .lora_extra_vocab_size , vocab_size )
411421 lora_logits_processor = LogitsProcessorWithLoRA (
412422 logits_processor , 1024 , linear .weight .dtype , linear .weight .device )
413423 lora_logits_processor .create_lora_weights (max_loras , lora_config )
@@ -444,7 +454,7 @@ def _pretest():
444454 lora_mapping ,
445455 id_to_index ,
446456 max_loras ,
447- 32000 ,
457+ vocab_size ,
448458 lora_config .lora_extra_vocab_size ,
449459 )
450460 lora_logits_processor .set_mapping (* mapping_info , )
@@ -460,19 +470,19 @@ def _pretest():
460470 org_vocab_size :logits_processor .org_vocab_size +
461471 embeddings_tensor_len ] = embeddings_tensor
462472
463- logits_processor .org_vocab_size = (32000 +
473+ logits_processor .org_vocab_size = (vocab_size +
464474 lora_config .lora_extra_vocab_size )
465475 expected_results = []
466476 for input_ , lora_id in zip (inputs , prompt_mapping ):
467477 lora = lora_dict [lora_id ]
468478 result = logits_processor ._get_logits (hidden_states = input_ ,
469479 embedding = linear .weight ,
470480 embedding_bias = None )
471- result [:, 32000 + embeddings_tensor_len :] = float ("-inf" )
481+ result [:, vocab_size + embeddings_tensor_len :] = float ("-inf" )
472482 result += input_ @ lora .lora_a @ lora .lora_b * lora .scaling
473483 expected_results .append (result )
474484 expected_result = torch .cat (expected_results )
475- logits_processor .org_vocab_size = 32000
485+ logits_processor .org_vocab_size = vocab_size
476486
477487 # Check that resetting the lora weights succeeds
478488
@@ -489,14 +499,14 @@ def _pretest():
489499 lora_mapping = LoRAMapping (index_mapping , prompt_mapping )
490500
491501 mapping_info = convert_mapping (lora_mapping , id_to_index , max_loras ,
492- 32000 ,
502+ vocab_size ,
493503 lora_config .lora_extra_vocab_size )
494504 lora_logits_processor .set_mapping (* mapping_info , )
495505
496506 lora_result = lora_logits_processor ._get_logits (
497507 hidden_states = torch .cat (inputs ),
498508 embedding = original_weight ,
499- embedding_bias = None )[:, :32000 ]
509+ embedding_bias = None )[:, :vocab_size ]
500510 expected_result = logits_processor ._get_logits (
501511 hidden_states = torch .cat (inputs ),
502512 embedding = original_weight ,
0 commit comments