File tree Expand file tree Collapse file tree 3 files changed +10
-7
lines changed
Expand file tree Collapse file tree 3 files changed +10
-7
lines changed Original file line number Diff line number Diff line change @@ -124,8 +124,9 @@ def _construct_expected_sampling_metadata(
124124 if req .sampling_params .allowed_token_ids :
125125 allowed_token_ids_mask [index_in_input_batch ][
126126 req .sampling_params .allowed_token_ids ] = True
127- bad_words_token_ids [
128- index_in_input_batch ] = req .sampling_params .bad_words_token_ids
127+ if req .sampling_params .bad_words_token_ids :
128+ bad_words_token_ids [
129+ index_in_input_batch ] = req .sampling_params .bad_words_token_ids
129130
130131 return SamplingMetadata (
131132 temperature = torch .tensor (temperature , dtype = torch .float ,
Original file line number Diff line number Diff line change @@ -235,7 +235,7 @@ class SamplingParams(
235235
236236 # Fields used for bad words
237237 bad_words : Optional [list [str ]] = None
238- _bad_words_token_ids : list [list [int ]] = msgspec . field ( default_factory = list )
238+ _bad_words_token_ids : Optional [ list [list [int ]]] = None
239239
240240 @staticmethod
241241 def from_optional (
@@ -464,8 +464,9 @@ def update_from_generation_config(
464464 self .stop_token_ids = list (eos_ids )
465465
466466 def update_from_tokenizer (self , tokenizer : AnyTokenizer ) -> None :
467- if self .bad_words is None :
467+ if not self .bad_words :
468468 return
469+ self ._bad_words_token_ids = []
469470 for bad_word in self .bad_words :
470471 # To prohibit words both at the beginning
471472 # and in the middle of text
@@ -516,7 +517,7 @@ def all_stop_token_ids(self) -> set[int]:
516517 return self ._all_stop_token_ids
517518
518519 @property
519- def bad_words_token_ids (self ) -> list [list [int ]]:
520+ def bad_words_token_ids (self ) -> Optional [ list [list [int ] ]]:
520521 # For internal use only. Backward compatibility not guaranteed
521522 return self ._bad_words_token_ids
522523
Original file line number Diff line number Diff line change @@ -324,8 +324,9 @@ def add_request(
324324 self .allowed_token_ids_mask_cpu_tensor [req_index ][
325325 sampling_params .allowed_token_ids ] = False
326326
327- self .bad_words_token_ids [
328- req_index ] = sampling_params .bad_words_token_ids
327+ if sampling_params .bad_words_token_ids :
328+ self .bad_words_token_ids [
329+ req_index ] = sampling_params .bad_words_token_ids
329330
330331 # Add request lora ID
331332 if request .lora_request :
You can’t perform that action at this time.
0 commit comments