diff --git a/core/testcontainers/compose/compose.py b/core/testcontainers/compose/compose.py index fb697692..f2a4debc 100644 --- a/core/testcontainers/compose/compose.py +++ b/core/testcontainers/compose/compose.py @@ -247,7 +247,9 @@ def docker_compose_command(self) -> list[str]: @cached_property def compose_command_property(self) -> list[str]: - docker_compose_cmd = [self.docker_command_path, "compose"] if self.docker_command_path else ["docker", "compose"] + docker_compose_cmd = ( + [self.docker_command_path, "compose"] if self.docker_command_path else ["docker", "compose"] + ) if self.compose_file_name: for file in self.compose_file_name: docker_compose_cmd += ["-f", file] diff --git a/core/testcontainers/core/container.py b/core/testcontainers/core/container.py index dcf576ea..d1854402 100644 --- a/core/testcontainers/core/container.py +++ b/core/testcontainers/core/container.py @@ -20,7 +20,7 @@ from testcontainers.core.network import Network from testcontainers.core.utils import is_arm, setup_logger from testcontainers.core.wait_strategies import LogMessageWaitStrategy -from testcontainers.core.waiting_utils import WaitStrategy, wait_container_is_ready +from testcontainers.core.waiting_utils import WaitStrategy if TYPE_CHECKING: from docker.models.containers import Container @@ -247,8 +247,13 @@ def get_container_host_ip(self) -> str: # ensure that we covered all possible connection_modes assert_never(connection_mode) - @wait_container_is_ready() def get_exposed_port(self, port: int) -> int: + from testcontainers.core.wait_strategies import ContainerStatusWaitStrategy as C + + C().wait_until_ready(self) + return self._get_exposed_port(port) + + def _get_exposed_port(self, port: int) -> int: if self.get_docker_client().get_connection_mode().use_mapped_port: c = self._container assert c is not None diff --git a/core/testcontainers/core/docker_client.py b/core/testcontainers/core/docker_client.py index bf7b506c..12384c94 100644 --- a/core/testcontainers/core/docker_client.py +++ b/core/testcontainers/core/docker_client.py @@ -174,7 +174,7 @@ def get_container(self, container_id: str) -> dict[str, Any]: """ Get the container with a given identifier. """ - containers = self.client.api.containers(filters={"id": container_id}) + containers = self.client.api.containers(all=True, filters={"id": container_id}) if not containers: raise RuntimeError(f"Could not get container with id {container_id}") return cast("dict[str, Any]", containers[0]) diff --git a/core/testcontainers/core/generic.py b/core/testcontainers/core/generic.py index e427c2ad..591a4a8a 100644 --- a/core/testcontainers/core/generic.py +++ b/core/testcontainers/core/generic.py @@ -16,7 +16,6 @@ from testcontainers.core.container import DockerContainer from testcontainers.core.exceptions import ContainerStartException from testcontainers.core.utils import raise_for_deprecated_parameter -from testcontainers.core.waiting_utils import wait_container_is_ready ADDITIONAL_TRANSIENT_ERRORS = [] try: @@ -34,8 +33,11 @@ class DbContainer(DockerContainer): Generic database container. """ - @wait_container_is_ready(*ADDITIONAL_TRANSIENT_ERRORS) def _connect(self) -> None: + from testcontainers.core.wait_strategies import ContainerStatusWaitStrategy as C + + C().with_transient_exceptions(*ADDITIONAL_TRANSIENT_ERRORS).wait_until_ready(self) + import sqlalchemy engine = sqlalchemy.create_engine(self.get_connection_url()) diff --git a/core/testcontainers/core/wait_strategies.py b/core/testcontainers/core/wait_strategies.py index a1f5b112..cac2a2ef 100644 --- a/core/testcontainers/core/wait_strategies.py +++ b/core/testcontainers/core/wait_strategies.py @@ -31,14 +31,20 @@ import time from datetime import timedelta from pathlib import Path -from typing import Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast from urllib.error import HTTPError, URLError from urllib.request import Request, urlopen +from typing_extensions import Self + +from testcontainers.compose import DockerCompose from testcontainers.core.utils import setup_logger # Import base classes from waiting_utils to make them available for tests -from .waiting_utils import WaitStrategy, WaitStrategyTarget +from testcontainers.core.waiting_utils import WaitStrategy, WaitStrategyTarget + +if TYPE_CHECKING: + from testcontainers.core.container import DockerContainer logger = setup_logger(__name__) @@ -77,22 +83,6 @@ def __init__( self._times = times self._predicate_streams_and = predicate_streams_and - def with_startup_timeout(self, timeout: Union[int, timedelta]) -> "LogMessageWaitStrategy": - """Set the maximum time to wait for the container to be ready.""" - if isinstance(timeout, timedelta): - self._startup_timeout = int(timeout.total_seconds()) - else: - self._startup_timeout = timeout - return self - - def with_poll_interval(self, interval: Union[float, timedelta]) -> "LogMessageWaitStrategy": - """Set how frequently to check if the container is ready.""" - if isinstance(interval, timedelta): - self._poll_interval = interval.total_seconds() - else: - self._poll_interval = interval - return self - def wait_until_ready(self, container: "WaitStrategyTarget") -> None: """ Wait until the specified message appears in the container logs. @@ -198,22 +188,6 @@ def __init__(self, port: int, path: Optional[str] = "/") -> None: self._body: Optional[str] = None self._insecure_tls = False - def with_startup_timeout(self, timeout: Union[int, timedelta]) -> "HttpWaitStrategy": - """Set the maximum time to wait for the container to be ready.""" - if isinstance(timeout, timedelta): - self._startup_timeout = int(timeout.total_seconds()) - else: - self._startup_timeout = timeout - return self - - def with_poll_interval(self, interval: Union[float, timedelta]) -> "HttpWaitStrategy": - """Set how frequently to check if the container is ready.""" - if isinstance(interval, timedelta): - self._poll_interval = interval.total_seconds() - else: - self._poll_interval = interval - return self - @classmethod def from_url(cls, url: str) -> "HttpWaitStrategy": """ @@ -483,22 +457,6 @@ class HealthcheckWaitStrategy(WaitStrategy): def __init__(self) -> None: super().__init__() - def with_startup_timeout(self, timeout: Union[int, timedelta]) -> "HealthcheckWaitStrategy": - """Set the maximum time to wait for the container to be ready.""" - if isinstance(timeout, timedelta): - self._startup_timeout = int(timeout.total_seconds()) - else: - self._startup_timeout = timeout - return self - - def with_poll_interval(self, interval: Union[float, timedelta]) -> "HealthcheckWaitStrategy": - """Set how frequently to check if the container is ready.""" - if isinstance(interval, timedelta): - self._poll_interval = interval.total_seconds() - else: - self._poll_interval = interval - return self - def wait_until_ready(self, container: WaitStrategyTarget) -> None: """ Wait until the container's health check reports as healthy. @@ -581,22 +539,6 @@ def __init__(self, port: int) -> None: super().__init__() self._port = port - def with_startup_timeout(self, timeout: Union[int, timedelta]) -> "PortWaitStrategy": - """Set the maximum time to wait for the container to be ready.""" - if isinstance(timeout, timedelta): - self._startup_timeout = int(timeout.total_seconds()) - else: - self._startup_timeout = timeout - return self - - def with_poll_interval(self, interval: Union[float, timedelta]) -> "PortWaitStrategy": - """Set how frequently to check if the container is ready.""" - if isinstance(interval, timedelta): - self._poll_interval = interval.total_seconds() - else: - self._poll_interval = interval - return self - def wait_until_ready(self, container: WaitStrategyTarget) -> None: """ Wait until the specified port is available for connection. @@ -654,22 +596,6 @@ def __init__(self, file_path: Union[str, Path]) -> None: super().__init__() self._file_path = Path(file_path) - def with_startup_timeout(self, timeout: Union[int, timedelta]) -> "FileExistsWaitStrategy": - """Set the maximum time to wait for the container to be ready.""" - if isinstance(timeout, timedelta): - self._startup_timeout = int(timeout.total_seconds()) - else: - self._startup_timeout = timeout - return self - - def with_poll_interval(self, interval: Union[float, timedelta]) -> "FileExistsWaitStrategy": - """Set how frequently to check if the container is ready.""" - if isinstance(interval, timedelta): - self._poll_interval = interval.total_seconds() - else: - self._poll_interval = interval - return self - def wait_until_ready(self, container: WaitStrategyTarget) -> None: """ Wait until the specified file exists on the host filesystem. @@ -718,6 +644,65 @@ def wait_until_ready(self, container: WaitStrategyTarget) -> None: time.sleep(self._poll_interval) +class ContainerStatusWaitStrategy(WaitStrategy): + """ + The possible values for the container status are: + created + running + paused + restarting + exited + removing + dead + https://docs.docker.com/reference/cli/docker/container/ls/#status + """ + + CONTINUE_STATUSES = frozenset(("created", "restarting")) + + def __init__(self) -> None: + super().__init__() + + def wait_until_ready(self, container: WaitStrategyTarget) -> None: + result = self._poll(lambda: self.running(self.get_status(container))) + if not result: + raise TimeoutError("container did not become running") + + @staticmethod + def running(status: str) -> bool: + if status == "running": + logger.debug("status is now running") + return True + if status in ContainerStatusWaitStrategy.CONTINUE_STATUSES: + logger.debug( + "status is %s, which is valid for continuing (%s)", + status, + ContainerStatusWaitStrategy.CONTINUE_STATUSES, + ) + return False + raise StopIteration(f"container status not valid for continuing: {status}") + + def get_status(self, container: Any) -> str: + from testcontainers.core.container import DockerContainer + + if isinstance(container, DockerContainer): + return self._get_status_tc_container(container) + if isinstance(container, DockerCompose): + return self._get_status_compose_container(container) + raise TypeError(f"not supported operation: 'get_status' for type: {type(container)}") + + @staticmethod + def _get_status_tc_container(container: "DockerContainer") -> str: + logger.debug("fetching status of container %s", container) + wrapped = container.get_wrapped_container() + wrapped.reload() + return cast("str", wrapped.status) + + @staticmethod + def _get_status_compose_container(container: DockerCompose) -> str: + logger.debug("fetching status of compose container %s", container) + raise NotImplementedError + + class CompositeWaitStrategy(WaitStrategy): """ Wait for multiple conditions to be satisfied in sequence. @@ -748,42 +733,22 @@ def __init__(self, *strategies: WaitStrategy) -> None: super().__init__() self._strategies = list(strategies) - def with_startup_timeout(self, timeout: Union[int, timedelta]) -> "CompositeWaitStrategy": - """ - Set the startup timeout for all contained strategies. - - Args: - timeout: Maximum time to wait in seconds - - Returns: - self for method chaining - """ - if isinstance(timeout, timedelta): - self._startup_timeout = int(timeout.total_seconds()) - else: - self._startup_timeout = timeout - - for strategy in self._strategies: - strategy.with_startup_timeout(timeout) + def with_poll_interval(self, interval: Union[float, timedelta]) -> Self: + super().with_poll_interval(interval) + for _strategy in self._strategies: + _strategy.with_poll_interval(interval) return self - def with_poll_interval(self, interval: Union[float, timedelta]) -> "CompositeWaitStrategy": - """ - Set the poll interval for all contained strategies. - - Args: - interval: How frequently to check in seconds - - Returns: - self for method chaining - """ - if isinstance(interval, timedelta): - self._poll_interval = interval.total_seconds() - else: - self._poll_interval = interval + def with_startup_timeout(self, timeout: Union[int, timedelta]) -> Self: + super().with_startup_timeout(timeout) + for _strategy in self._strategies: + _strategy.with_startup_timeout(timeout) + return self - for strategy in self._strategies: - strategy.with_poll_interval(interval) + def with_transient_exceptions(self, *transient_exceptions: type[Exception]) -> Self: + super().with_transient_exceptions(*transient_exceptions) + for _strategy in self._strategies: + _strategy.with_transient_exceptions(*transient_exceptions) return self def wait_until_ready(self, container: WaitStrategyTarget) -> None: @@ -816,6 +781,7 @@ def wait_until_ready(self, container: WaitStrategyTarget) -> None: __all__ = [ "CompositeWaitStrategy", + "ContainerStatusWaitStrategy", "FileExistsWaitStrategy", "HealthcheckWaitStrategy", "HttpWaitStrategy", diff --git a/core/testcontainers/core/waiting_utils.py b/core/testcontainers/core/waiting_utils.py index 7775fce9..9942854a 100644 --- a/core/testcontainers/core/waiting_utils.py +++ b/core/testcontainers/core/waiting_utils.py @@ -20,8 +20,9 @@ from typing import Any, Callable, Optional, Protocol, TypeVar, Union, cast import wrapt +from typing_extensions import Self -from testcontainers.core.config import testcontainers_config as config +from testcontainers.core.config import testcontainers_config from testcontainers.core.utils import setup_logger logger = setup_logger(__name__) @@ -73,10 +74,11 @@ class WaitStrategy(ABC): """Base class for all wait strategies.""" def __init__(self) -> None: - self._startup_timeout: float = config.timeout - self._poll_interval: float = config.sleep_time + self._startup_timeout: float = testcontainers_config.timeout + self._poll_interval: float = testcontainers_config.sleep_time + self._transient_exceptions: list[type[Exception]] = [*TRANSIENT_EXCEPTIONS] - def with_startup_timeout(self, timeout: Union[int, timedelta]) -> "WaitStrategy": + def with_startup_timeout(self, timeout: Union[int, timedelta]) -> Self: """Set the maximum time to wait for the container to be ready.""" if isinstance(timeout, timedelta): self._startup_timeout = float(int(timeout.total_seconds())) @@ -84,7 +86,7 @@ def with_startup_timeout(self, timeout: Union[int, timedelta]) -> "WaitStrategy" self._startup_timeout = float(timeout) return self - def with_poll_interval(self, interval: Union[float, timedelta]) -> "WaitStrategy": + def with_poll_interval(self, interval: Union[float, timedelta]) -> Self: """Set how frequently to check if the container is ready.""" if isinstance(interval, timedelta): self._poll_interval = interval.total_seconds() @@ -92,11 +94,46 @@ def with_poll_interval(self, interval: Union[float, timedelta]) -> "WaitStrategy self._poll_interval = interval return self + def with_transient_exceptions(self, *transient_exceptions: type[Exception]) -> Self: + self._transient_exceptions.extend(transient_exceptions) + return self + @abstractmethod def wait_until_ready(self, container: WaitStrategyTarget) -> None: """Wait until the container is ready.""" pass + def _poll(self, check: Callable[[], bool], transient_exceptions: Optional[list[type[Exception]]] = None) -> bool: + if not transient_exceptions: + all_te_types = self._transient_exceptions + else: + all_te_types = [*self._transient_exceptions, *(transient_exceptions or [])] + + start = time.time() + while True: + start_attempt = time.time() + duration = start_attempt - start + if duration > self._startup_timeout: + return False + + # noinspection PyBroadException + try: + result = check() + if result: + return result + except StopIteration: + return False + except Exception as e: # noqa: E722, RUF100 + is_transient = False + for et in all_te_types: + if isinstance(e, et): + is_transient = True + if not is_transient: + raise RuntimeError(f"exception while checking for strategy {self}") from e + + seconds_left_until_next = self._poll_interval - (time.time() - start_attempt) + time.sleep(max(0.0, seconds_left_until_next)) + # Keep existing wait_container_is_ready but make it use the new system internally def wait_container_is_ready(*transient_exceptions: type[Exception]) -> Callable[[F], F]: @@ -194,7 +231,7 @@ def wait_for(condition: Callable[..., bool]) -> bool: def wait_for_logs( container: WaitStrategyTarget, predicate: Union[Callable[[str], bool], str, WaitStrategy], - timeout: float = config.timeout, + timeout: float = testcontainers_config.timeout, interval: float = 1, predicate_streams_and: bool = False, raise_on_exit: bool = False, @@ -261,7 +298,7 @@ def wait_for_logs( # Original implementation for backwards compatibility re_predicate: Optional[Callable[[str], Any]] = None if timeout is None: - timeout = config.timeout + timeout = testcontainers_config.timeout if isinstance(predicate, str): re_predicate = re.compile(predicate, re.MULTILINE).search elif callable(predicate): diff --git a/core/tests/test_container.py b/core/tests/test_container.py index f87bb94c..30b80f79 100644 --- a/core/tests/test_container.py +++ b/core/tests/test_container.py @@ -67,14 +67,14 @@ def fake_mapped(container_id: str, port: int) -> int: monkeypatch.setattr(client, "port", fake_mapped) monkeypatch.setattr(client, "get_connection_mode", lambda: mode) - assert container.get_exposed_port(8080) == 45678 + assert container._get_exposed_port(8080) == 45678 def test_get_exposed_port_original(container: DockerContainer, monkeypatch: pytest.MonkeyPatch) -> None: client = container._docker monkeypatch.setattr(client, "get_connection_mode", lambda: ConnectionMode.bridge_ip) - assert container.get_exposed_port(8080) == 8080 + assert container._get_exposed_port(8080) == 8080 @pytest.mark.parametrize( diff --git a/core/tests/test_new_docker_api.py b/core/tests/test_new_docker_api.py index 26a79aa9..675ebab9 100644 --- a/core/tests/test_new_docker_api.py +++ b/core/tests/test_new_docker_api.py @@ -4,12 +4,11 @@ def test_docker_custom_image(): - container = DockerContainer("mysql:5.7.17") - container.with_exposed_ports(3306) - container.with_env("MYSQL_ROOT_PASSWORD", "root") + container = DockerContainer("nginx:alpine-slim") + container.with_exposed_ports(80) with container: - port = container.get_exposed_port(3306) + port = container.get_exposed_port(80) assert int(port) > 0 diff --git a/core/tests/test_wait_strategies.py b/core/tests/test_wait_strategies.py index 8e70e254..da62f1fb 100644 --- a/core/tests/test_wait_strategies.py +++ b/core/tests/test_wait_strategies.py @@ -484,7 +484,7 @@ def test_wait_until_ready(self, mock_sleep, mock_time, mock_socket, connection_s strategy.wait_until_ready(mock_container) mock_socket_instance.connect.assert_called_once_with(("localhost", 8080)) else: - with pytest.raises(TimeoutError, match="Port 8080 not available within 1 seconds"): + with pytest.raises(TimeoutError, match="Port 8080 not available within 1.0 seconds"): strategy.wait_until_ready(mock_container) diff --git a/core/tests/test_waiting_utils.py b/core/tests/test_waiting_utils.py index bd77fc25..635c58c4 100644 --- a/core/tests/test_waiting_utils.py +++ b/core/tests/test_waiting_utils.py @@ -1,6 +1,7 @@ import pytest from testcontainers.core.container import DockerContainer +from testcontainers.core.wait_strategies import ContainerStatusWaitStrategy from testcontainers.core.waiting_utils import wait_for_logs, wait_for, wait_container_is_ready @@ -28,8 +29,15 @@ def simple_check() -> bool: def test_wait_container_is_ready_decorator_with_container() -> None: """Test wait_container_is_ready decorator with a real container.""" - @wait_container_is_ready() def check_container_logs(container: DockerContainer) -> bool: + # wait until it becomes running. + # if it is too late, it is actually fine in this case. + # we are happy with an exited (even crashed) container that has logs. + try: + ContainerStatusWaitStrategy().wait_until_ready(container) + except TimeoutError: + pass + stdout, stderr = container.get_logs() return b"Hello from Docker!" in stdout or b"Hello from Docker!" in stderr