diff --git a/modules/trino/testcontainers/trino/__init__.py b/modules/trino/testcontainers/trino/__init__.py index 97e3f9de4..ce1801601 100644 --- a/modules/trino/testcontainers/trino/__init__.py +++ b/modules/trino/testcontainers/trino/__init__.py @@ -13,41 +13,28 @@ import re from testcontainers.core.config import testcontainers_config as c -from testcontainers.core.generic import DbContainer -from testcontainers.core.waiting_utils import wait_container_is_ready, wait_for_logs -from trino.dbapi import connect +from testcontainers.core.generic import DockerContainer +from testcontainers.core.wait_strategies import LogMessageWaitStrategy -class TrinoContainer(DbContainer): +class TrinoContainer(DockerContainer): def __init__( self, image="trinodb/trino:latest", user: str = "test", port: int = 8080, + container_start_timeout: int = 30, **kwargs, ): super().__init__(image=image, **kwargs) self.user = user self.port = port self.with_exposed_ports(self.port) - - @wait_container_is_ready() - def _connect(self) -> None: - wait_for_logs( - self, - re.compile(".*======== SERVER STARTED ========.*", re.MULTILINE).search, - c.max_tries, - c.sleep_time, - ) - conn = connect( - host=self.get_container_host_ip(), - port=self.get_exposed_port(self.port), - user=self.user, + self.waiting_for( + LogMessageWaitStrategy(re.compile(".*======== SERVER STARTED ========.*", re.MULTILINE)) + .with_poll_interval(c.sleep_time) + .with_startup_timeout(container_start_timeout) ) - cur = conn.cursor() - cur.execute("SELECT 1") - cur.fetchall() - conn.close() def get_connection_url(self): return f"trino://{self.user}@{self.get_container_host_ip()}:{self.port}"