diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f78582495814..eebdbcc621c6 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2828,7 +2828,7 @@ def _get_mm_dummy_batch( def _dummy_run( self, num_tokens: int, - cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + cudagraph_runtime_mode: Optional[CUDAGraphMode] = None, force_attention: bool = False, uniform_decode: bool = False, allow_microbatching: bool = True, @@ -2844,6 +2844,8 @@ def _dummy_run( Args: num_tokens: Number of tokens to run the dummy forward pass. cudagraph_runtime_mode: used to control the behavior. + - if not set will determine the cudagraph mode based on using + the self.cudagraph_dispatcher. - CUDAGraphMode.NONE: No cudagraph, for warm up and profile run - CUDAGraphMode.PIECEWISE: Piecewise cudagraph. - CUDAGraphMode.FULL: Full cudagraph, attention metadata is @@ -2857,7 +2859,7 @@ def _dummy_run( (1 token) and prefill (multiple tokens) requests. remove_lora: If False, dummy LoRAs are not destroyed after the run """ - assert cudagraph_runtime_mode in { + assert cudagraph_runtime_mode is None or cudagraph_runtime_mode in { CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL } @@ -2899,10 +2901,6 @@ def _dummy_run( elif uniform_decode: assert not create_mixed_batch num_reqs = cdiv(num_tokens, max_query_len) - assert num_reqs <= max_num_reqs, \ - f"Do not capture num_reqs {num_reqs} > max_num_reqs " \ - f"{max_num_reqs} for uniform batch. Num tokens: " \ - f"{num_tokens}, max_query_len: {max_query_len}" num_scheduled_tokens_list = [max_query_len] * num_reqs if num_tokens % max_query_len != 0: num_scheduled_tokens_list[-1] = num_tokens % max_query_len @@ -3043,18 +3041,20 @@ def _dummy_run( intermediate_tensors = self.sync_and_slice_intermediate_tensors( num_tokens, None, False) - if cudagraph_runtime_mode == CUDAGraphMode.NONE: - batch_descriptor = None - else: - # filter out the valid batch descriptor - _cg_mode, batch_descriptor = \ - self.cudagraph_dispatcher.dispatch( - BatchDescriptor(num_tokens=num_tokens, - uniform_decode=uniform_decode)) - # sanity check - assert cudagraph_runtime_mode == _cg_mode, ( + + # filter out the valid batch descriptor + _cg_mode, batch_descriptor = self.cudagraph_dispatcher.dispatch( + BatchDescriptor(num_tokens=num_tokens, + uniform_decode=uniform_decode)) + if cudagraph_runtime_mode is not None: + # we allow forcing NONE when the dispatcher disagrees to support + # warm ups for cudagraph capture + assert cudagraph_runtime_mode == CUDAGraphMode.NONE or \ + cudagraph_runtime_mode == _cg_mode, ( f"Cudagraph runtime mode mismatch at dummy_run. " f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}.") + else: + cudagraph_runtime_mode = _cg_mode if ubatch_slices is not None: num_tokens = num_tokens // 2