diff --git a/.github/workflows/healthcheck-Tests.yml b/.github/workflows/healthcheck-Tests.yml index f2ae49c24..f2972237f 100644 --- a/.github/workflows/healthcheck-Tests.yml +++ b/.github/workflows/healthcheck-Tests.yml @@ -12,6 +12,12 @@ on: jobs: run-tests: + env: + ZIMFARM_API_URL: http://localhost:8001 + ZIMFARM_USERNAME: admin + ZIMFARM_PASSWORD: admin_pass + ZIMFARM_DATABASE_URL: postgresql+psycopg://zimfarm:zimpass@localhost:5432/zimtest + runs-on: ubuntu-24.04 steps: - name: Retrieve source code @@ -23,6 +29,16 @@ jobs: python-version-file: healthcheck/pyproject.toml architecture: x64 + - name: Install dependencies (and project) + working-directory: healthcheck + run: | + pip install -U pip + pip install -e .[test,scripts] + + - name: Run tests + working-directory: healthcheck + run: inv coverage --args "-vvv" + - name: Build healthcheck Docker image working-directory: healthcheck run: docker build -t zimfarm-healthcheck:test . @@ -30,10 +46,10 @@ jobs: - name: Run healthcheck container run: | docker run -d --name zimfarm-healthcheck-test \ - -e ZIMFARM_API_URL=http://localhost:8001 \ - -e ZIMFARM_USERNAME=admin \ - -e ZIMFARM_PASSWORD=admin_pass \ - -e ZIMFARM_DATABASE_URL=postgresql+psycopg://zimfarm:zimpass@localhost:5432/zimtest \ + -e ZIMFARM_API_URL \ + -e ZIMFARM_USERNAME \ + -e ZIMFARM_PASSWORD \ + -e ZIMFARM_DATABASE_URL \ -p 8000:80 \ zimfarm-healthcheck:test # wait for container to be ready diff --git a/healthcheck/pyproject.toml b/healthcheck/pyproject.toml index 1937bd9c8..19774d3b4 100644 --- a/healthcheck/pyproject.toml +++ b/healthcheck/pyproject.toml @@ -24,6 +24,7 @@ dependencies = [ "humanfriendly == 10.0", "jinja2 == 3.1.6", "psycopg[binary,pool] == 3.2.9", + "diskcache == 5.6.3", ] dynamic = ["version"] @@ -37,14 +38,15 @@ lint = [ ] check = [ "pyright == 1.1.400", - "types-humanfriendly == 10.0.0" + "types-humanfriendly == 10.0.0", + "diskcache-stubs == 5.6.3.6.20240818", ] test = [ "coverage == 7.8.0", "httpx == 0.28.0", "pytest == 8.3.5", - "pytest-asyncio == 0.26.0", "pytest-env == 1.1.5", + "pytest-asyncio == 1.2.0", "Faker==37.3.0", ] dev = [ diff --git a/healthcheck/src/healthcheck/__init__.py b/healthcheck/src/healthcheck/__init__.py index 10ac97c06..47045a88d 100644 --- a/healthcheck/src/healthcheck/__init__.py +++ b/healthcheck/src/healthcheck/__init__.py @@ -1,12 +1,29 @@ import datetime import logging +import os +from typing import Any + + +def getenv(key: str, *, mandatory: bool = False, default: Any = None) -> Any: + value = os.getenv(key) or default + + if mandatory and not value: + raise OSError(f"Please set the {key} environment variable") + + return value + + +def parse_bool(value: Any) -> bool: + """Parse value into boolean.""" + return str(value).lower() in ("true", "1", "yes", "y", "on") -from healthcheck.constants import DEBUG logger = logging.getLogger("healthcheck") if not logger.hasHandlers(): - logger.setLevel(logging.DEBUG if DEBUG else logging.INFO) + logger.setLevel( + logging.DEBUG if parse_bool(getenv("DEBUG", default="false")) else logging.INFO + ) handler = logging.StreamHandler() handler.setFormatter(logging.Formatter("[%(asctime)s: %(levelname)s] %(message)s")) logger.addHandler(handler) diff --git a/healthcheck/src/healthcheck/cache.py b/healthcheck/src/healthcheck/cache.py new file mode 100644 index 000000000..41cc20b01 --- /dev/null +++ b/healthcheck/src/healthcheck/cache.py @@ -0,0 +1,75 @@ +from collections.abc import Awaitable, Callable +from functools import wraps +from typing import ParamSpec, TypeVar + +from diskcache import FanoutCache + +from healthcheck.constants import ( + CACHE_KEY_PREFIX, + CACHE_LOCATION, + DEFAULT_CACHE_EXPIRATION, +) + +P = ParamSpec("P") +R = TypeVar("R") + +# As per the docs, writers can block other writers to the cache. The FanoutCache as +# opposed to the simpler Cache uses sharding to decrease block writes. This makes +# it a good candidate for our usage because the functions we want to memoize are run +# "concurrently" using asyncio.gather. +_cache: FanoutCache | None = None + + +def init_cache() -> FanoutCache: + """Get or create the disk cache instance.""" + global _cache # noqa: PLW0603 + if _cache is None: + _cache = FanoutCache(CACHE_LOCATION) + return _cache + + +def close_cache() -> None: + """Close the disk cache instance.""" + global _cache # noqa: PLW0603 + if _cache is not None: + _cache.close() + _cache = None + + +def memoize( + key: str, + expire: float = DEFAULT_CACHE_EXPIRATION, + *, + cache_only_on_success: bool = True, +) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]: + """Memoize function calls with results at CACHE_KEY_PREFIX:key. + + Results are considered successful if they have a success attribute and it is truthy. + """ + + def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: + @wraps(func) + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + cache = init_cache() + location = f"{CACHE_KEY_PREFIX}:{key}" + # Types other than the basic types like floats, ints, bytes, strings are + # are stored using pickle by default. Thus, we can save our results + # (pydantic models) directly to the cache and get it back as is. + if (result := cache.get(location)) is not None: + return result + + result = await func(*args, **kwargs) + + if cache_only_on_success: + if ( + hasattr(result, "success") + and result.success # pyright: ignore[reportAttributeAccessIssue, reportUnknownMemberType] + ): + cache.set(location, result, expire=expire) + else: + cache.set(location, result, expire=expire) + return result + + return wrapper + + return decorator diff --git a/healthcheck/src/healthcheck/constants.py b/healthcheck/src/healthcheck/constants.py index bb1897829..4cce28566 100644 --- a/healthcheck/src/healthcheck/constants.py +++ b/healthcheck/src/healthcheck/constants.py @@ -1,24 +1,8 @@ -import os -from typing import Any +from pathlib import Path from humanfriendly import parse_timespan - -def getenv(key: str, *, mandatory: bool = False, default: Any = None) -> Any: - value = os.getenv(key) or default - - if mandatory and not value: - raise OSError(f"Please set the {key} environment variable") - - return value - - -def parse_bool(value: Any) -> bool: - """Parse value into boolean.""" - return str(value).lower() in ("true", "1", "yes", "y", "on") - - -DEBUG = parse_bool(getenv("DEBUG", default="false")) +from healthcheck import getenv REQUESTS_TIMEOUT = parse_timespan(getenv("REQUESTS_TIMEOUT", default="1m")) @@ -26,3 +10,9 @@ def parse_bool(value: Any) -> bool: ZIMFARM_USERNAME = getenv("ZIMFARM_USERNAME", mandatory=True) ZIMFARM_PASSWORD = getenv("ZIMFARM_PASSWORD", mandatory=True) ZIMFARM_DATABASE_URL = getenv("ZIMFARM_DATABASE_URL", mandatory=True) + +CACHE_LOCATION = Path(getenv("CACHE_LOCATION", default="/data/cache")) +CACHE_KEY_PREFIX = getenv("CACHE_KEY_PREFIX", default="healthcheck") +DEFAULT_CACHE_EXPIRATION = parse_timespan( + getenv("DEFAULT_CACHE_EXPIRATION", default="1m") +) diff --git a/healthcheck/src/healthcheck/main.py b/healthcheck/src/healthcheck/main.py index 5213c17b3..03e4562f4 100644 --- a/healthcheck/src/healthcheck/main.py +++ b/healthcheck/src/healthcheck/main.py @@ -1,11 +1,20 @@ +from contextlib import asynccontextmanager from pathlib import Path from fastapi import APIRouter, FastAPI from fastapi.staticfiles import StaticFiles +from healthcheck.cache import close_cache, init_cache from healthcheck.router import router +@asynccontextmanager +async def lifespan(_: FastAPI): + init_cache() + yield + close_cache() + + def create_app(*, debug: bool = True): app = FastAPI( debug=debug, @@ -15,6 +24,7 @@ def create_app(*, debug: bool = True): description=( "Service for checking health status of Zimfarm components and dependencies" ), + lifespan=lifespan, ) main_router = APIRouter() diff --git a/healthcheck/src/healthcheck/status/auth.py b/healthcheck/src/healthcheck/status/auth.py index 0ac37a3d6..534ce754b 100644 --- a/healthcheck/src/healthcheck/status/auth.py +++ b/healthcheck/src/healthcheck/status/auth.py @@ -2,6 +2,7 @@ from pydantic import BaseModel +from healthcheck.cache import memoize from healthcheck.constants import ZIMFARM_API_URL, ZIMFARM_PASSWORD, ZIMFARM_USERNAME from healthcheck.requests import query_api from healthcheck.status import Result @@ -16,6 +17,7 @@ class Token(BaseModel): token_type: str = "Bearer" +@memoize("ZIMFARM_AUTH") async def authenticate() -> Result[Token]: """Check if authentication is sucessful with Zimfarm""" response = await query_api( diff --git a/healthcheck/src/healthcheck/status/database.py b/healthcheck/src/healthcheck/status/database.py index c38c1b375..952f3d65e 100644 --- a/healthcheck/src/healthcheck/status/database.py +++ b/healthcheck/src/healthcheck/status/database.py @@ -5,6 +5,7 @@ from sqlalchemy.sql import text from healthcheck import logger +from healthcheck.cache import memoize from healthcheck.constants import ZIMFARM_DATABASE_URL as DATABASE_URL from healthcheck.status import Result @@ -21,6 +22,7 @@ class DatabaseConnectionInfo(BaseModel): version: str +@memoize("ZIMFARM_DATABASE") async def check_database_connection() -> Result[DatabaseConnectionInfo]: """Check if we can connect to the database and run a simple query.""" try: diff --git a/healthcheck/src/healthcheck/status/workers.py b/healthcheck/src/healthcheck/status/workers.py index 27cb7a214..5e6c5f8ef 100644 --- a/healthcheck/src/healthcheck/status/workers.py +++ b/healthcheck/src/healthcheck/status/workers.py @@ -3,6 +3,7 @@ from pydantic import BaseModel +from healthcheck.cache import memoize from healthcheck.constants import ZIMFARM_API_URL from healthcheck.requests import query_api from healthcheck.status import Result @@ -25,6 +26,7 @@ def check_worker_online(worker: Worker) -> bool: return worker.status == "online" +@memoize("ZIMFARM_WORKERS_STATUS") async def get_workers_status() -> Result[WorkersStatus]: """Fetch the list of workers and check their online status.""" response = await query_api( diff --git a/healthcheck/tests/conftest.py b/healthcheck/tests/conftest.py new file mode 100644 index 000000000..33ecb768b --- /dev/null +++ b/healthcheck/tests/conftest.py @@ -0,0 +1,23 @@ +from pathlib import Path + +import pytest + +from healthcheck import cache as cache_module +from healthcheck.cache import close_cache, init_cache + + +@pytest.fixture(autouse=True) +def cache_dir(tmp_path: Path) -> Path: + """Create a temporary directory for cache files.""" + cache_dir = tmp_path / "cache" + cache_dir.mkdir() + return cache_dir + + +@pytest.fixture(autouse=True) +def cache(cache_dir: Path, monkeypatch: pytest.MonkeyPatch): + """Configure cache to use temporary directory and ensure it's closed after test.""" + monkeypatch.setattr(cache_module, "CACHE_LOCATION", cache_dir) + cache = init_cache() + yield cache + close_cache() diff --git a/healthcheck/tests/test_cache.py b/healthcheck/tests/test_cache.py new file mode 100644 index 000000000..14a2ab8b4 --- /dev/null +++ b/healthcheck/tests/test_cache.py @@ -0,0 +1,145 @@ +import asyncio +from http import HTTPStatus +from typing import Any + +import pytest +from diskcache import FanoutCache + +from healthcheck.cache import CACHE_KEY_PREFIX, memoize +from healthcheck.status import Result + + +@pytest.mark.asyncio +async def test_memoize_successful_result() -> None: + """Test that successful results are cached.""" + counter = 0 + + @memoize("test-success") + async def get_success() -> Result[str]: + nonlocal counter + counter += 1 + return Result( + success=True, + status_code=HTTPStatus.OK, + data="success", + ) + + result1 = await get_success() + assert result1.success + assert result1.data == "success" + assert counter == 1 + + result2 = await get_success() + assert result2.success + assert result2.data == "success" + assert counter == 1 # Counter shouldn't increment + + +@pytest.mark.asyncio +async def test_failed_results_are_not_memoized() -> None: + """Test that failed results are not cached.""" + counter = 0 + + @memoize("test-failure") + async def get_failure() -> Result[Any]: + nonlocal counter + counter += 1 + return Result( + success=False, + status_code=HTTPStatus.SERVICE_UNAVAILABLE, + data=None, + ) + + result1 = await get_failure() + assert not result1.success + assert result1.data is None + assert counter == 1 + + result2 = await get_failure() + assert not result2.success + assert result2.data is None + assert counter == 2 # Counter should increment + + +@pytest.mark.asyncio +async def test_memoize_failed_result() -> None: + """Test that failed results can be memoized with setting.""" + counter = 0 + + @memoize("test-failure", cache_only_on_success=False) + async def get_failure() -> Result[Any]: + nonlocal counter + counter += 1 + return Result( + success=False, + status_code=HTTPStatus.SERVICE_UNAVAILABLE, + data=None, + ) + + result1 = await get_failure() + assert not result1.success + assert result1.data is None + assert counter == 1 + + result2 = await get_failure() + assert not result2.success + assert result2.data is None + assert counter == 1 # Counter should not increment + + +@pytest.mark.asyncio +async def test_memoize_cache_expiry() -> None: + """Test that cached values expire.""" + counter = 0 + + @memoize("test-expiry", expire=0.1) + async def get_data() -> Result[int]: + nonlocal counter + counter += 1 + return Result( + success=True, + status_code=HTTPStatus.OK, + data=counter, + ) + + result1 = await get_data() + assert result1.success + assert result1.data == 1 + assert counter == 1 + + # Immediate second call should use cache + result2 = await get_data() + assert result2.data == 1 + assert counter == 1 + + # Wait for cache to expire + await asyncio.sleep(0.2) + + # Call after expiry should execute function again + result3 = await get_data() + assert result3.data == 2 + assert counter == 2 + + +@pytest.mark.asyncio +async def test_cache_key_prefix(cache: FanoutCache) -> None: + """Test that cache keys include the configured prefix.""" + + @memoize("test-prefix") + async def get_data() -> Result[str]: + return Result( + success=True, + status_code=HTTPStatus.OK, + data="test", + ) + + await get_data() + + # Check that the key exists with correct prefix + key = f"{CACHE_KEY_PREFIX}:test-prefix" + assert key in cache + + cached_value = cache.get(key) + assert isinstance(cached_value, Result) + assert cached_value.success + assert cached_value.data == "test" # pyright: ignore[reportUnknownMemberType]