|
5 | 5 |
|
6 | 6 | from __future__ import annotations |
7 | 7 |
|
8 | | -import contextlib |
9 | | -import io |
10 | 8 | import json |
11 | 9 | import os |
12 | | -import signal |
13 | 10 | import sys |
14 | 11 | import threading |
15 | 12 | import types |
16 | | -from collections.abc import Callable, Mapping, Sequence |
| 13 | +from collections.abc import Mapping, Sequence |
17 | 14 | from logging import Logger, getLogger |
18 | 15 | from pathlib import Path |
19 | | -from select import PIPE_BUF, select |
20 | | -from subprocess import Popen, TimeoutExpired |
21 | 16 | from tempfile import NamedTemporaryFile |
22 | 17 | from time import sleep |
23 | | -from types import FrameType |
24 | | -from typing import TYPE_CHECKING, Any, Generic, Self, TypedDict, TypeVar |
| 18 | +from typing import TYPE_CHECKING, Any, Generic, Self, TypeVar |
| 19 | + |
| 20 | +from ert.services.ert_server import ErtServerConnectionInfo, _Proc |
25 | 21 |
|
26 | 22 | if TYPE_CHECKING: |
27 | 23 | pass |
28 | 24 |
|
29 | 25 | T = TypeVar("T", bound="BaseService") |
30 | 26 |
|
31 | 27 |
|
32 | | -class ErtServerConnectionInfo(TypedDict): |
33 | | - urls: list[str] |
34 | | - authtoken: str |
35 | | - host: str |
36 | | - port: str |
37 | | - cert: str |
38 | | - auth: str |
39 | | - |
40 | | - |
41 | | -SERVICE_CONF_PATHS: set[str] = set() |
42 | | - |
43 | | - |
44 | | -class BaseServiceExit(OSError): |
45 | | - pass |
46 | | - |
47 | | - |
48 | | -def cleanup_service_files(signum: int, frame: FrameType | None) -> None: |
49 | | - for file_path in SERVICE_CONF_PATHS: |
50 | | - file = Path(file_path) |
51 | | - if file.exists(): |
52 | | - file.unlink() |
53 | | - raise BaseServiceExit(f"Signal {signum} received.") |
54 | | - |
55 | | - |
56 | | -if threading.current_thread() is threading.main_thread(): |
57 | | - signal.signal(signal.SIGTERM, cleanup_service_files) |
58 | | - signal.signal(signal.SIGINT, cleanup_service_files) |
59 | | - |
60 | | - |
61 | 28 | def local_exec_args(script_args: str | list[str]) -> list[str]: |
62 | 29 | """ |
63 | 30 | Convenience function that returns the exec_args for executing a Python |
@@ -96,134 +63,6 @@ def __exit__( |
96 | 63 | return exc_type is None |
97 | 64 |
|
98 | 65 |
|
99 | | -class _Proc(threading.Thread): |
100 | | - def __init__( |
101 | | - self, |
102 | | - service_name: str, |
103 | | - exec_args: Sequence[str], |
104 | | - timeout: int, |
105 | | - on_connection_info_received: Callable[ |
106 | | - [ErtServerConnectionInfo | Exception | None], None |
107 | | - ], |
108 | | - project: Path, |
109 | | - ) -> None: |
110 | | - super().__init__() |
111 | | - |
112 | | - self._shutdown = threading.Event() |
113 | | - |
114 | | - self._service_name = service_name |
115 | | - self._exec_args = exec_args |
116 | | - self._timeout = timeout |
117 | | - self._propagate_connection_info_from_childproc = on_connection_info_received |
118 | | - self._service_config_path = project / f"{self._service_name}_server.json" |
119 | | - |
120 | | - fd_read, fd_write = os.pipe() |
121 | | - self._comm_pipe = os.fdopen(fd_read) |
122 | | - |
123 | | - env = os.environ.copy() |
124 | | - env["ERT_COMM_FD"] = str(fd_write) |
125 | | - |
126 | | - SERVICE_CONF_PATHS.add(str(self._service_config_path)) |
127 | | - |
128 | | - # The process is waited for in _do_shutdown() |
129 | | - self._childproc = Popen( |
130 | | - self._exec_args, |
131 | | - pass_fds=(fd_write,), |
132 | | - env=env, |
133 | | - close_fds=True, |
134 | | - ) |
135 | | - os.close(fd_write) |
136 | | - |
137 | | - def run(self) -> None: |
138 | | - comm = self._read_connection_info_from_process(self._childproc) |
139 | | - |
140 | | - if comm is None: |
141 | | - self._propagate_connection_info_from_childproc(TimeoutError()) |
142 | | - return # _read_conn_info() has already cleaned up in this case |
143 | | - |
144 | | - conn_info: ErtServerConnectionInfo | Exception | None = None |
145 | | - try: |
146 | | - conn_info = json.loads(comm) |
147 | | - except json.JSONDecodeError: |
148 | | - conn_info = ServerBootFail() |
149 | | - except Exception as exc: |
150 | | - conn_info = exc |
151 | | - |
152 | | - try: |
153 | | - self._propagate_connection_info_from_childproc(conn_info) |
154 | | - |
155 | | - while True: |
156 | | - if self._childproc.poll() is not None: |
157 | | - break |
158 | | - if self._shutdown.wait(1): |
159 | | - self._do_shutdown() |
160 | | - break |
161 | | - |
162 | | - except Exception as e: |
163 | | - print(str(e)) |
164 | | - self.logger.exception(e) |
165 | | - |
166 | | - finally: |
167 | | - self._ensure_connection_info_file_is_deleted() |
168 | | - |
169 | | - def shutdown(self) -> int: |
170 | | - """Shutdown the server.""" |
171 | | - self._shutdown.set() |
172 | | - self.join() |
173 | | - |
174 | | - return self._childproc.returncode |
175 | | - |
176 | | - def _read_connection_info_from_process(self, proc: Popen[bytes]) -> str | None: |
177 | | - comm_buf = io.StringIO() |
178 | | - first_iter = True |
179 | | - while first_iter or proc.poll() is None: |
180 | | - first_iter = False |
181 | | - ready = select([self._comm_pipe], [], [], self._timeout) |
182 | | - |
183 | | - # Timeout reached, exit with a failure |
184 | | - if ready == ([], [], []): |
185 | | - self._do_shutdown() |
186 | | - self._ensure_connection_info_file_is_deleted() |
187 | | - return None |
188 | | - |
189 | | - x = self._comm_pipe.read(PIPE_BUF) |
190 | | - if not x: # EOF |
191 | | - break |
192 | | - comm_buf.write(x) |
193 | | - return comm_buf.getvalue() |
194 | | - |
195 | | - def _do_shutdown(self) -> None: |
196 | | - if self._childproc is None: |
197 | | - return |
198 | | - try: |
199 | | - self._childproc.terminate() |
200 | | - self._childproc.wait(10) # Give it 10s to shut down cleanly.. |
201 | | - except TimeoutExpired: |
202 | | - try: |
203 | | - self._childproc.kill() # ... then kick it harder... |
204 | | - self._childproc.wait(self._timeout) # ... and wait again |
205 | | - except TimeoutExpired: |
206 | | - self.logger.error( |
207 | | - f"waiting for child-process exceeded timeout {self._timeout}s" |
208 | | - ) |
209 | | - |
210 | | - def _ensure_connection_info_file_is_deleted(self) -> None: |
211 | | - """ |
212 | | - Ensure that the JSON connection information file is deleted |
213 | | - """ |
214 | | - with contextlib.suppress(OSError): |
215 | | - if self._service_config_path.exists(): |
216 | | - self._service_config_path.unlink() |
217 | | - |
218 | | - @property |
219 | | - def logger(self) -> Logger: |
220 | | - return getLogger(f"ert.shared.{self._service_name}") |
221 | | - |
222 | | - |
223 | | -class ServerBootFail(RuntimeError): |
224 | | - pass |
225 | | - |
226 | | - |
227 | 66 | class BaseService: |
228 | 67 | """ |
229 | 68 | BaseService provides a block-only-when-needed mechanism for starting and |
|
0 commit comments