From b673b8c64baa942918a12f33b3f3e99a251d1463 Mon Sep 17 00:00:00 2001 From: Stas Moreinis Date: Wed, 24 Dec 2025 14:28:59 -0800 Subject: [PATCH] Add proper readiness checks to health interceptor - Separate liveness (/healthz) from readiness (/readyz, /healthcheck) - Liveness probe returns 200 immediately (sub-millisecond, no deps) - Readiness probes check PostgreSQL, Redis, and MongoDB connectivity - Add per-check timeout (2s) and overall timeout (5s) - Return 503 when dependencies are unhealthy or not initialized - Run all dependency checks concurrently with asyncio.gather() - Update tests to mock GlobalDependencies and verify behavior --- agentex/src/api/health_interceptor.py | 263 +++++++++++++++--- .../tests/unit/api/test_health_interceptor.py | 106 ++++++- 2 files changed, 319 insertions(+), 50 deletions(-) diff --git a/agentex/src/api/health_interceptor.py b/agentex/src/api/health_interceptor.py index 0273479..784fe8f 100644 --- a/agentex/src/api/health_interceptor.py +++ b/agentex/src/api/health_interceptor.py @@ -1,21 +1,35 @@ """ -Pure ASGI middleware for fast health check responses. +Pure ASGI middleware for health check responses. This middleware intercepts health check requests at the ASGI level, bypassing all Starlette/FastAPI middleware for maximum performance. -Kubernetes probes hit these endpoints frequently, so sub-millisecond -response time is critical. + +Health check endpoints: +- /healthz: Liveness probe - fast, no dependency checks (sub-millisecond) +- /readyz: Readiness probe - checks DB, Redis, MongoDB connectivity +- /healthcheck: Alias for readiness probe """ +import asyncio +import json +from typing import Any + from starlette.types import ASGIApp, Receive, Scope, Send -HEALTH_CHECK_PATHS: frozenset[str] = frozenset( - { - "/healthcheck", - "/healthz", - "/readyz", - } -) +# Liveness probes - fast, no dependency checks +LIVENESS_PATHS: frozenset[str] = frozenset({"/healthz"}) + +# Readiness probes - check dependencies +READINESS_PATHS: frozenset[str] = frozenset({"/readyz", "/healthcheck"}) + +# All health check paths +HEALTH_CHECK_PATHS: frozenset[str] = LIVENESS_PATHS | READINESS_PATHS + +# Timeout for individual dependency checks (seconds) +DEPENDENCY_CHECK_TIMEOUT = 2.0 + +# Total timeout for all readiness checks (seconds) +READINESS_CHECK_TIMEOUT = 5.0 class HealthCheckInterceptor: @@ -23,9 +37,14 @@ class HealthCheckInterceptor: Pure ASGI middleware that intercepts health check requests before they reach the FastAPI middleware stack. - This provides sub-millisecond response times for Kubernetes probes - by avoiding BaseHTTPMiddleware task group overhead, logging, - and request body parsing. + Liveness (/healthz): + Returns 200 immediately - used by Kubernetes to detect stuck processes. + Sub-millisecond response time. + + Readiness (/readyz, /healthcheck): + Checks database, Redis, and MongoDB connectivity. + Returns 200 if all dependencies are healthy, 503 otherwise. + Used by Kubernetes to decide whether to route traffic. Only GET requests are intercepted. Other methods fall through to FastAPI for proper 405 Method Not Allowed handling. @@ -37,37 +56,193 @@ def __init__(self, app: ASGIApp) -> None: self.app = app async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - if scope["type"] == "http" and scope["path"] in HEALTH_CHECK_PATHS: - if scope.get("method") == "GET": - await send( - { - "type": "http.response.start", - "status": 200, - "headers": [], - } - ) - await send( - { - "type": "http.response.body", - "body": b"", - } - ) - return + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + path = scope["path"] + method = scope.get("method") + + # Handle liveness probe - fast path + if path in LIVENESS_PATHS: + if method == "GET": + await self._send_response(send, 200, {"status": "ok"}) else: - # Return 405 Method Not Allowed for non-GET requests - await send( - { - "type": "http.response.start", - "status": 405, - "headers": [(b"allow", b"GET")], - } - ) - await send( - { - "type": "http.response.body", - "body": b"", - } - ) - return + await self._send_method_not_allowed(send) + return + # Handle readiness probe - check dependencies + if path in READINESS_PATHS: + if method == "GET": + await self._handle_readiness_check(send) + else: + await self._send_method_not_allowed(send) + return + + # Pass through to FastAPI await self.app(scope, receive, send) + + async def _handle_readiness_check(self, send: Send) -> None: + """Check all dependencies and return appropriate status.""" + try: + # Run all checks with overall timeout + results = await asyncio.wait_for( + self._check_all_dependencies(), + timeout=READINESS_CHECK_TIMEOUT, + ) + + # Determine overall health + all_healthy = all(r["healthy"] for r in results.values()) + status_code = 200 if all_healthy else 503 + + response_body = { + "status": "ok" if all_healthy else "degraded", + "checks": results, + } + + await self._send_response(send, status_code, response_body) + + except TimeoutError: + await self._send_response( + send, + 503, + { + "status": "timeout", + "error": "Health check timed out", + }, + ) + except Exception as e: + await self._send_response( + send, + 503, + { + "status": "error", + "error": str(e), + }, + ) + + async def _check_all_dependencies(self) -> dict[str, dict[str, Any]]: + """Check all dependencies concurrently.""" + # Import here to avoid circular imports and ensure dependencies are loaded + from src.config.dependencies import GlobalDependencies + + deps = GlobalDependencies() + + # Run all checks concurrently + postgres_task = self._check_postgres(deps) + redis_task = self._check_redis(deps) + mongodb_task = self._check_mongodb(deps) + + results = await asyncio.gather( + postgres_task, + redis_task, + mongodb_task, + return_exceptions=True, + ) + + return { + "postgres": self._format_check_result(results[0]), + "redis": self._format_check_result(results[1]), + "mongodb": self._format_check_result(results[2]), + } + + def _format_check_result( + self, result: dict[str, Any] | Exception + ) -> dict[str, Any]: + """Format a check result, handling exceptions.""" + if isinstance(result, Exception): + return {"healthy": False, "error": str(result)} + return result + + async def _check_postgres(self, deps: Any) -> dict[str, Any]: + """Check PostgreSQL connectivity.""" + from sqlalchemy import text + + try: + engine = deps.database_async_read_write_engine + if engine is None: + return {"healthy": False, "error": "Engine not initialized"} + + async with asyncio.timeout(DEPENDENCY_CHECK_TIMEOUT): + async with engine.connect() as conn: + await conn.execute(text("SELECT 1")) + + return {"healthy": True} + except TimeoutError: + return {"healthy": False, "error": "Connection timeout"} + except Exception as e: + return {"healthy": False, "error": str(e)} + + async def _check_redis(self, deps: Any) -> dict[str, Any]: + """Check Redis connectivity.""" + try: + pool = deps.redis_pool + if pool is None: + return {"healthy": False, "error": "Pool not initialized"} + + import redis.asyncio as redis_lib + + async with asyncio.timeout(DEPENDENCY_CHECK_TIMEOUT): + client = redis_lib.Redis(connection_pool=pool) + await client.ping() + + return {"healthy": True} + except TimeoutError: + return {"healthy": False, "error": "Connection timeout"} + except Exception as e: + return {"healthy": False, "error": str(e)} + + async def _check_mongodb(self, deps: Any) -> dict[str, Any]: + """Check MongoDB connectivity.""" + try: + client = deps.mongodb_client + if client is None: + return {"healthy": False, "error": "Client not initialized"} + + # MongoDB client is synchronous, run in thread pool + async with asyncio.timeout(DEPENDENCY_CHECK_TIMEOUT): + await asyncio.to_thread(client.admin.command, "ping") + + return {"healthy": True} + except TimeoutError: + return {"healthy": False, "error": "Connection timeout"} + except Exception as e: + return {"healthy": False, "error": str(e)} + + async def _send_response( + self, send: Send, status: int, body: dict[str, Any] + ) -> None: + """Send a JSON response.""" + body_bytes = json.dumps(body).encode("utf-8") + await send( + { + "type": "http.response.start", + "status": status, + "headers": [ + (b"content-type", b"application/json"), + (b"content-length", str(len(body_bytes)).encode()), + ], + } + ) + await send( + { + "type": "http.response.body", + "body": body_bytes, + } + ) + + async def _send_method_not_allowed(self, send: Send) -> None: + """Send 405 Method Not Allowed response.""" + await send( + { + "type": "http.response.start", + "status": 405, + "headers": [(b"allow", b"GET")], + } + ) + await send( + { + "type": "http.response.body", + "body": b"", + } + ) diff --git a/agentex/tests/unit/api/test_health_interceptor.py b/agentex/tests/unit/api/test_health_interceptor.py index 5ca1d5a..d58b919 100644 --- a/agentex/tests/unit/api/test_health_interceptor.py +++ b/agentex/tests/unit/api/test_health_interceptor.py @@ -3,8 +3,15 @@ Tests that health checks bypass the middleware stack. """ +from unittest.mock import AsyncMock, MagicMock, patch + import pytest -from src.api.health_interceptor import HEALTH_CHECK_PATHS, HealthCheckInterceptor +from src.api.health_interceptor import ( + HEALTH_CHECK_PATHS, + LIVENESS_PATHS, + READINESS_PATHS, + HealthCheckInterceptor, +) from starlette.applications import Starlette from starlette.responses import PlainTextResponse from starlette.routing import Route @@ -20,23 +27,110 @@ def test_health_paths_constant(self): expected_paths = {"/healthcheck", "/healthz", "/readyz"} assert HEALTH_CHECK_PATHS == expected_paths - def test_intercepts_get_health_requests(self): - """Test that GET requests to health paths are intercepted.""" + def test_liveness_paths_constant(self): + """Verify liveness paths are defined correctly.""" + assert LIVENESS_PATHS == {"/healthz"} + + def test_readiness_paths_constant(self): + """Verify readiness paths are defined correctly.""" + assert READINESS_PATHS == {"/readyz", "/healthcheck"} + + def test_liveness_returns_200_without_dependencies(self): + """Test that liveness probe (/healthz) returns 200 without checking deps.""" def should_not_be_called(request): raise AssertionError("Inner app should not be called for health checks") inner_app = Starlette( - routes=[Route(path, should_not_be_called) for path in HEALTH_CHECK_PATHS] + routes=[Route(path, should_not_be_called) for path in LIVENESS_PATHS] ) wrapped_app = HealthCheckInterceptor(inner_app) client = TestClient(wrapped_app, raise_server_exceptions=True) - for path in HEALTH_CHECK_PATHS: + for path in LIVENESS_PATHS: response = client.get(path) assert response.status_code == 200 - assert response.content == b"" + assert response.json() == {"status": "ok"} + + def test_readiness_returns_503_when_dependencies_unavailable(self): + """Test that readiness probes return 503 when dependencies aren't initialized.""" + + def should_not_be_called(request): + raise AssertionError("Inner app should not be called for health checks") + + inner_app = Starlette( + routes=[Route(path, should_not_be_called) for path in READINESS_PATHS] + ) + wrapped_app = HealthCheckInterceptor(inner_app) + + # Mock GlobalDependencies to return None for all dependencies + mock_deps = MagicMock() + mock_deps.database_async_read_write_engine = None + mock_deps.redis_pool = None + mock_deps.mongodb_client = None + + with patch( + "src.config.dependencies.GlobalDependencies", return_value=mock_deps + ): + client = TestClient(wrapped_app, raise_server_exceptions=True) + + for path in READINESS_PATHS: + response = client.get(path) + assert response.status_code == 503 + data = response.json() + assert data["status"] == "degraded" + assert "checks" in data + + def test_readiness_returns_200_when_all_dependencies_healthy(self): + """Test that readiness probes return 200 when all dependencies are healthy.""" + + def should_not_be_called(request): + raise AssertionError("Inner app should not be called for health checks") + + inner_app = Starlette( + routes=[Route(path, should_not_be_called) for path in READINESS_PATHS] + ) + wrapped_app = HealthCheckInterceptor(inner_app) + + # Mock healthy dependencies + mock_engine = MagicMock() + mock_conn = AsyncMock() + mock_conn.execute = AsyncMock() + mock_engine.connect = MagicMock( + return_value=AsyncMock( + __aenter__=AsyncMock(return_value=mock_conn), __aexit__=AsyncMock() + ) + ) + + mock_redis_pool = MagicMock() + + mock_mongodb_client = MagicMock() + mock_mongodb_client.admin.command = MagicMock(return_value={"ok": 1}) + + mock_deps = MagicMock() + mock_deps.database_async_read_write_engine = mock_engine + mock_deps.redis_pool = mock_redis_pool + mock_deps.mongodb_client = mock_mongodb_client + + with ( + patch("src.config.dependencies.GlobalDependencies", return_value=mock_deps), + patch("redis.asyncio.Redis") as mock_redis_class, + ): + mock_redis_instance = AsyncMock() + mock_redis_instance.ping = AsyncMock() + mock_redis_class.return_value = mock_redis_instance + + client = TestClient(wrapped_app, raise_server_exceptions=True) + + for path in READINESS_PATHS: + response = client.get(path) + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ok" + assert data["checks"]["postgres"]["healthy"] is True + assert data["checks"]["redis"]["healthy"] is True + assert data["checks"]["mongodb"]["healthy"] is True def test_passes_through_non_health_requests(self): """Test that non-health requests pass through to inner app."""