diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 97d8c91b4659..0190505b8a8f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -338,11 +338,11 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: for new_req_data in scheduler_output.scheduled_new_reqs: req_id = new_req_data.req_id sampling_params = new_req_data.sampling_params + generator = torch.Generator(device=self.device) if sampling_params.sampling_type == SamplingType.RANDOM_SEED: - generator = torch.Generator(device=self.device) generator.manual_seed(sampling_params.seed) else: - generator = None + generator.initial_seed() self.requests[req_id] = CachedRequestState( req_id=req_id,