@@ -373,7 +373,9 @@ def step_with_batch_queue(
373373 )
374374 assert isinstance (exec_future , Future )
375375
376- if scheduler_output .needs_structured_output_tokens :
376+ if scheduler_output .pending_structured_output_tokens :
377+ # We need to defer sampling until we have processed the model output
378+ # from the prior step.
377379 deferred_scheduler_output = scheduler_output
378380 grammar_output = None
379381 else :
@@ -383,20 +385,18 @@ def step_with_batch_queue(
383385 # Block-wait for execute to return (continues running async on the GPU).
384386 model_executed = scheduler_output .total_num_scheduled_tokens > 0
385387 with self .log_error_detail (scheduler_output ):
386- model_output = exec_future .result ()
388+ model_output_or_none = exec_future .result ()
387389
388- if deferred_scheduler_output :
389- assert model_output is None
390- else :
391- if model_output is not None :
392- # No sampling required (e.g. all requests finished).
393- future = cast (Future [ModelRunnerOutput ], exec_future )
394- else :
395- # No pending output tokens needed, sample immediately.
390+ if not deferred_scheduler_output :
391+ if model_output_or_none is None :
392+ # No pending output tokens needed here, sample immediately.
396393 sample_future = self .model_executor .sample_tokens (
397394 grammar_output , non_block = True
398395 )
399396 future = cast (Future [ModelRunnerOutput ], sample_future )
397+ else :
398+ # No sampling required (e.g. all requests finished).
399+ future = cast (Future [ModelRunnerOutput ], exec_future )
400400 batch_queue .appendleft ((future , scheduler_output ))
401401 if (
402402 model_executed
@@ -406,6 +406,8 @@ def step_with_batch_queue(
406406 # Don't block on next worker response unless the queue is full
407407 # or there are no more requests to schedule.
408408 return None , True
409+ else :
410+ assert model_output_or_none is None
409411
410412 elif not batch_queue :
411413 # Queue is empty. We should not reach here since this method should
@@ -417,13 +419,12 @@ def step_with_batch_queue(
417419 future , scheduler_output = batch_queue .pop ()
418420 with self .log_error_detail (scheduler_output ):
419421 model_output = future .result ()
420- assert model_output is not None
422+
421423 engine_core_outputs = self .scheduler .update_from_output (
422424 scheduler_output , model_output
423425 )
424426
425427 # TODO TBD return outputs here first?
426-
427428 if deferred_scheduler_output :
428429 # We now have the tokens needed to compute the bitmask for the
429430 # deferred request. Get the bitmask and dispatch sample request.
0 commit comments