|
27 | 27 |
|
28 | 28 | JOIN_TIMEOUT_S = 2 |
29 | 29 |
|
30 | | -mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD |
31 | | -mp = multiprocessing.get_context(mp_method) |
32 | | - |
33 | 30 |
|
34 | 31 | @dataclass |
35 | 32 | class Result(Generic[T]): |
@@ -77,7 +74,7 @@ class ResultHandler(threading.Thread): |
77 | 74 |
|
78 | 75 | def __init__(self) -> None: |
79 | 76 | super().__init__(daemon=True) |
80 | | - self.result_queue = mp.Queue() |
| 77 | + self.result_queue = get_mp_context().Queue() |
81 | 78 | self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {} |
82 | 79 |
|
83 | 80 | def run(self): |
@@ -147,10 +144,11 @@ class ProcessWorkerWrapper: |
147 | 144 |
|
148 | 145 | def __init__(self, result_handler: ResultHandler, |
149 | 146 | worker_factory: Callable[[], Any]) -> None: |
150 | | - self._task_queue = mp.Queue() |
| 147 | + self.mp = get_mp_context() |
| 148 | + self._task_queue = self.mp.Queue() |
151 | 149 | self.result_queue = result_handler.result_queue |
152 | 150 | self.tasks = result_handler.tasks |
153 | | - self.process: BaseProcess = mp.Process( # type: ignore[attr-defined] |
| 151 | + self.process: BaseProcess = self.mp.Process( # type: ignore[attr-defined] |
154 | 152 | target=_run_worker_process, |
155 | 153 | name="VllmWorkerProcess", |
156 | 154 | kwargs=dict( |
@@ -204,7 +202,7 @@ def _run_worker_process( |
204 | 202 | """Worker process event loop""" |
205 | 203 |
|
206 | 204 | # Add process-specific prefix to stdout and stderr |
207 | | - process_name = mp.current_process().name |
| 205 | + process_name = get_mp_context().current_process().name |
208 | 206 | pid = os.getpid() |
209 | 207 | _add_prefix(sys.stdout, process_name, pid) |
210 | 208 | _add_prefix(sys.stderr, process_name, pid) |
@@ -269,3 +267,8 @@ def write_with_prefix(s: str): |
269 | 267 |
|
270 | 268 | file.start_new_line = True # type: ignore[attr-defined] |
271 | 269 | file.write = write_with_prefix # type: ignore[method-assign] |
| 270 | + |
| 271 | + |
| 272 | +def get_mp_context(): |
| 273 | + mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD |
| 274 | + return multiprocessing.get_context(mp_method) |
0 commit comments