@@ -255,7 +255,7 @@ def __init__(
255255 override_neuron_config : Optional [dict [str , Any ]] = None ,
256256 override_pooler_config : Optional ["PoolerConfig" ] = None ,
257257 logits_processor_pattern : Optional [str ] = None ,
258- generation_config : Optional [ str ] = None ,
258+ generation_config : str = "auto" ,
259259 enable_sleep_mode : bool = False ,
260260 override_generation_config : Optional [dict [str , Any ]] = None ,
261261 model_impl : Union [str , ModelImpl ] = ModelImpl .AUTO ,
@@ -951,7 +951,7 @@ def get_multimodal_config(self) -> "MultiModalConfig":
951951 return self .multimodal_config
952952
953953 def try_get_generation_config (self ) -> dict [str , Any ]:
954- if self .generation_config is None or self . generation_config == "auto" :
954+ if self .generation_config in ( "auto" , "vllm" ) :
955955 config = try_get_generation_config (
956956 self .hf_config_path or self .model ,
957957 trust_remote_code = self .trust_remote_code ,
@@ -971,17 +971,14 @@ def try_get_generation_config(self) -> dict[str, Any]:
971971 def get_diff_sampling_param (self ) -> dict [str , Any ]:
972972 """
973973 This method returns a dictionary containing the parameters
974- that differ from the default sampling parameters, but only
975- if `generation_config` is set. If `generation_config` is not
976- set, an empty dictionary is returned.
974+ that differ from the default sampling parameters. If
975+ `generation_config` is `"vllm"`, an empty dictionary is returned.
977976
978977 Returns:
979978 dict[str, Any]: A dictionary with the differing sampling
980- parameters if `generation_config` is set, otherwise an
981- empty dictionary.
979+ parameters, if `generation_config` is `"vllm"` an empty dictionary.
982980 """
983- if self .generation_config is None :
984- # When generation_config is not set
981+ if self .generation_config == "vllm" :
985982 config = {}
986983 else :
987984 config = self .try_get_generation_config ()
0 commit comments