Skip to content

Commit 85b4ae2

Browse files
authored
[https://nvbugs/5451342][fix] Use runtime max_batch_size when cuda_graph_config.max_batch_size is not provided in trtllm-bench (#7031)
Signed-off-by: Jiagan Cheng <[email protected]>
1 parent 7409d56 commit 85b4ae2

File tree

1 file changed

+18
-2
lines changed

1 file changed

+18
-2
lines changed

tensorrt_llm/bench/dataclasses/configuration.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,24 @@ def get_llm_args(self) -> Dict:
8484
backend_cache_config = llm_args.pop("kv_cache_config", {})
8585
llm_args["kv_cache_config"] = backend_cache_config | kv_cache_config
8686

87-
return update_llm_args_with_extra_options(llm_args,
88-
self.extra_llm_api_options)
87+
updated_llm_args = update_llm_args_with_extra_options(
88+
llm_args, self.extra_llm_api_options)
89+
90+
if self.backend == "pytorch":
91+
cuda_graph_config = updated_llm_args.pop(
92+
"cuda_graph_config", llm_args["cuda_graph_config"])
93+
# Use runtime max_batch_size as cuda_graph_config.max_batch_size
94+
# if both max_batch_size and batch_sizes are not set.
95+
batch_sizes_set = cuda_graph_config.get("batch_sizes",
96+
None) is not None
97+
max_batch_size_set = cuda_graph_config.get("max_batch_size",
98+
None) is not None
99+
if not batch_sizes_set and not max_batch_size_set:
100+
cuda_graph_config[
101+
"max_batch_size"] = self.settings_config.max_batch_size
102+
updated_llm_args["cuda_graph_config"] = cuda_graph_config
103+
104+
return updated_llm_args
89105

90106
@model_validator(mode="after")
91107
def validate_full_config(self) -> RuntimeConfig:

0 commit comments

Comments
 (0)