File tree Expand file tree Collapse file tree 1 file changed +18
-2
lines changed
tensorrt_llm/bench/dataclasses Expand file tree Collapse file tree 1 file changed +18
-2
lines changed Original file line number Diff line number Diff line change @@ -84,8 +84,24 @@ def get_llm_args(self) -> Dict:
8484 backend_cache_config = llm_args .pop ("kv_cache_config" , {})
8585 llm_args ["kv_cache_config" ] = backend_cache_config | kv_cache_config
8686
87- return update_llm_args_with_extra_options (llm_args ,
88- self .extra_llm_api_options )
87+ updated_llm_args = update_llm_args_with_extra_options (
88+ llm_args , self .extra_llm_api_options )
89+
90+ if self .backend == "pytorch" :
91+ cuda_graph_config = updated_llm_args .pop (
92+ "cuda_graph_config" , llm_args ["cuda_graph_config" ])
93+ # Use runtime max_batch_size as cuda_graph_config.max_batch_size
94+ # if both max_batch_size and batch_sizes are not set.
95+ batch_sizes_set = cuda_graph_config .get ("batch_sizes" ,
96+ None ) is not None
97+ max_batch_size_set = cuda_graph_config .get ("max_batch_size" ,
98+ None ) is not None
99+ if not batch_sizes_set and not max_batch_size_set :
100+ cuda_graph_config [
101+ "max_batch_size" ] = self .settings_config .max_batch_size
102+ updated_llm_args ["cuda_graph_config" ] = cuda_graph_config
103+
104+ return updated_llm_args
89105
90106 @model_validator (mode = "after" )
91107 def validate_full_config (self ) -> RuntimeConfig :
You can’t perform that action at this time.
0 commit comments