@@ -712,14 +712,62 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
712712
713713 def _use_captured_graph (self ,
714714 batch_size : int ,
715+ decode_only : bool ,
715716 max_decode_seq_len : int ,
716717 max_encoder_seq_len : int = 0 ) -> bool :
717- return (self . decode_only and not self .runner .model_config .enforce_eager
718+ return (decode_only and not self .runner .model_config .enforce_eager
718719 and batch_size <= _BATCH_SIZES_TO_CAPTURE [- 1 ]
719720 and max_decode_seq_len <= self .runner .max_seq_len_to_capture
720721 and max_encoder_seq_len <= self .runner .max_seq_len_to_capture
721722 and batch_size <= self .runner .max_batchsize_to_capture )
722723
724+ def _get_cuda_graph_pad_size (self ,
725+ num_seqs : int ,
726+ max_decode_seq_len : int ,
727+ max_encoder_seq_len : int = 0 ) -> int :
728+ """
729+ Determine the number of padding sequences required for running in
730+ CUDA graph mode. Returns -1 if CUDA graphs cannot be used.
731+
732+ In the multi-step + chunked-prefill case, only the first step
733+ has Prefills (if any). The rest of the steps are guaranteed to be all
734+ decodes. In this case, we set up the padding as if all the sequences
735+ are decodes so we may run all steps except the first step in CUDA graph
736+ mode. The padding is accounted for in the multi-step `advance_step`
737+ family of functions.
738+
739+ Args:
740+ num_seqs (int): Number of sequences scheduled to run.
741+ max_decode_seq_len (int): Greatest of all the decode sequence
742+ lengths. Used only in checking the viablility of using
743+ CUDA graphs.
744+ max_encoder_seq_len (int, optional): Greatest of all the encode
745+ sequence lengths. Defaults to 0. Used only in checking the
746+ viability of using CUDA graphs.
747+ Returns:
748+ int: Returns the determined number of padding sequences. If
749+ CUDA graphs is not viable, returns -1.
750+ """
751+ is_mscp : bool = self .runner .scheduler_config .is_multi_step and \
752+ self .runner .scheduler_config .chunked_prefill_enabled
753+ decode_only = self .decode_only or is_mscp
754+ if not decode_only :
755+ # Early exit so we can treat num_seqs as the batch_size below.
756+ return - 1
757+
758+ # batch_size out of this function refers to the number of input
759+ # tokens being scheduled. This conflation of num_seqs as batch_size
760+ # is valid as this is a decode-only case.
761+ batch_size = num_seqs
762+ if not self ._use_captured_graph (batch_size , decode_only ,
763+ max_decode_seq_len ,
764+ max_encoder_seq_len ):
765+ return - 1
766+
767+ graph_batch_size = _get_graph_batch_size (batch_size )
768+ assert graph_batch_size >= batch_size
769+ return graph_batch_size - batch_size
770+
723771 def build (self ) -> ModelInputForGPU :
724772 """Finalize the builder intermediate data and
725773 create on-device tensors.
@@ -778,21 +826,17 @@ def build(self) -> ModelInputForGPU:
778826 for data in self .inter_data_list
779827 }
780828
781- batch_size = len (input_tokens )
782- use_captured_graph = self ._use_captured_graph (
783- batch_size ,
784- max_decode_seq_len ,
829+ cuda_graph_pad_size = self ._get_cuda_graph_pad_size (
830+ num_seqs = len (seq_lens ),
831+ max_decode_seq_len = max_encoder_seq_len ,
785832 max_encoder_seq_len = max_encoder_seq_len )
786833
787- # If cuda graph can be used, pad tensors accordingly.
788- # See `capture_model` API for more details.
789- # vLLM uses cuda graph only for decoding requests.
790- cuda_graph_pad_size = - 1
791- if use_captured_graph :
792- graph_batch_size = _get_graph_batch_size (batch_size )
793- assert graph_batch_size >= batch_size
794- cuda_graph_pad_size = graph_batch_size - batch_size
795- batch_size = graph_batch_size
834+ batch_size = len (input_tokens )
835+ if cuda_graph_pad_size != - 1 :
836+ # If cuda graph can be used, pad tensors accordingly.
837+ # See `capture_model` API for more details.
838+ # vLLM uses cuda graph only for decoding requests.
839+ batch_size += cuda_graph_pad_size
796840
797841 # Tokens and positions.
798842 if cuda_graph_pad_size :
0 commit comments