@@ -2989,13 +2989,19 @@ def _dummy_run(
29892989 # We currently only microbatch if the number of tokens is
29902990 # over a certain threshold.
29912991 if self .parallel_config .enable_dbo and allow_microbatching :
2992- ubatch_slices , num_tokens_after_padding = ubatch_split (
2992+ ubatch_slices , ubatch_num_tokens_after_padding = ubatch_split (
29932993 num_scheduled_tokens ,
29942994 total_num_scheduled_tokens ,
29952995 total_num_scheduled_tokens ,
29962996 uniform_decode = uniform_decode ,
29972997 vllm_config = self .vllm_config ,
29982998 )
2999+ # Currently when DBO is enabled `ubatch_split` returns
3000+ # the num_tokens_after_padding for a single ubatch, but we have 2
3001+ # TODO(sage,lucas): this is cruft that should be addressed in the
3002+ # padding refactor.
3003+ if ubatch_num_tokens_after_padding is not None :
3004+ num_tokens_after_padding = ubatch_num_tokens_after_padding * 2
29993005
30003006 # If we failed to microbatch, currently need to resynchronize
30013007 # TODO(lucas,sage): we should be able to avoid this second sync by
@@ -3112,8 +3118,9 @@ def _dummy_run(
31123118
31133119 # filter out the valid batch descriptor
31143120 _cg_mode , batch_descriptor = self .cudagraph_dispatcher .dispatch (
3115- BatchDescriptor (num_tokens = num_tokens ,
3116- uniform_decode = uniform_decode ))
3121+ BatchDescriptor (num_tokens = num_tokens_after_padding ,
3122+ uniform_decode = uniform_decode )) \
3123+ if not is_profile else (CUDAGraphMode .NONE , None )
31173124 if cudagraph_runtime_mode is not None :
31183125 # we allow forcing NONE when the dispatcher disagrees to support
31193126 # warm ups for cudagraph capture
@@ -3125,7 +3132,13 @@ def _dummy_run(
31253132 cudagraph_runtime_mode = _cg_mode
31263133
31273134 if ubatch_slices is not None :
3128- num_tokens = num_tokens // 2
3135+ # Adjust values to reflect a single ubatch.
3136+ # TODO(sage,lucas): this is cruft that should be addressed in
3137+ # the padding refactor.
3138+ num_tokens_after_padding = ubatch_slices [0 ].num_tokens
3139+ if num_tokens_across_dp is not None :
3140+ num_tokens_across_dp [:] = num_tokens_after_padding
3141+
31293142 with self .maybe_randomize_inputs (input_ids ), set_forward_context (
31303143 attn_metadata ,
31313144 self .vllm_config ,
0 commit comments