diff --git a/src/app/api/helper/auth.py b/src/app/api/helper/auth.py index cd1474e..20f6276 100644 --- a/src/app/api/helper/auth.py +++ b/src/app/api/helper/auth.py @@ -2,15 +2,13 @@ from fastapi import Header, HTTPException -TOKEN = os.getenv("TOKEN") - -# Dependency to verify the Authorization header -def verify_authorization_header(authorization: str = Header("Authorization")): +async def verify_authorization_header(authorization: str | None = Header(default=None)) -> str: if not authorization or not authorization.startswith("Bearer "): raise HTTPException( status_code=401, detail="Invalid or missing Authorization header" ) token = authorization.split("Bearer ")[1] - if token != TOKEN: + expected = os.getenv("TOKEN") + if not expected or token != expected: raise HTTPException(status_code=401, detail="Invalid token") return token diff --git a/src/app/main.py b/src/app/main.py index 18107ae..8b8657d 100644 --- a/src/app/main.py +++ b/src/app/main.py @@ -1,21 +1,46 @@ +import os +from contextlib import asynccontextmanager +from collections.abc import AsyncIterator + from fastapi import FastAPI, HTTPException, Request +from fastapi.responses import JSONResponse from .api import router as api_router from .api.response.response import ok, error, http_exception from .logger import log +from .traffic_control import TrafficControl, TrafficControlMiddleware + + +@asynccontextmanager +async def _lifespan(app: FastAPI) -> AsyncIterator[None]: + traffic_control: TrafficControl | None = getattr(app.state, "traffic_control", None) + owns_traffic_control = traffic_control is None + if owns_traffic_control: + vllm_base_url = os.getenv("VLLM_BASE_URL", "http://vllm:8000") + traffic_control = TrafficControl(vllm_base_url) + app.state.traffic_control = traffic_control + await traffic_control.start() + assert traffic_control is not None + try: + yield + finally: + if owns_traffic_control: + await traffic_control.stop() + -app = FastAPI() +app = FastAPI(lifespan=_lifespan) +app.add_middleware(TrafficControlMiddleware) app.include_router(api_router) @app.get("/") -async def root(): +async def root() -> dict: return ok() # Custom global error handler @app.exception_handler(Exception) -async def global_exception_handler(request: Request, exc: Exception): +async def global_exception_handler(request: Request, exc: Exception) -> JSONResponse: """ Handle all uncaught exceptions globally. """ diff --git a/src/app/traffic_control.py b/src/app/traffic_control.py new file mode 100644 index 0000000..a409979 --- /dev/null +++ b/src/app/traffic_control.py @@ -0,0 +1,239 @@ +import asyncio +import os +import time +from collections.abc import Callable +from typing import Optional + +import httpx +from fastapi import HTTPException +from fastapi.responses import JSONResponse +from starlette.types import ASGIApp, Message, Receive, Scope, Send + +from app.logger import log + + +def _parse_num_requests_waiting(metrics_text: str) -> float: + max_waiting = 0.0 + for line in metrics_text.splitlines(): + if not line.startswith("vllm:num_requests_waiting"): + continue + try: + value = float(line.split()[-1]) + except (IndexError, ValueError): + continue + if value > max_waiting: + max_waiting = value + return max_waiting + + +def _normalize_user_tier(raw: Optional[str]) -> str: + if raw and raw.strip().lower() == "premium": + return "premium" + return "basic" + + +def _basic_max_request() -> int: + raw = os.getenv("BASIC_MAX_REQUEST", "20").strip() + try: + value = int(raw) + except ValueError: + value = 20 + return max(1, value) + + +def _metrics_timeout_seconds() -> float: + raw = os.getenv("METRICS_TIMEOUT_SECONDS", "2").strip() + try: + return float(raw) + except ValueError: + return 2.0 + + +def _metrics_poll_interval_seconds() -> float: + raw = os.getenv("METRICS_POLL_INTERVAL_SECONDS", "2").strip() + try: + value = float(raw) + except ValueError: + value = 2.0 + return max(0.5, min(10.0, value)) + + +def _metrics_max_age_seconds() -> float: + raw = os.getenv("METRICS_MAX_AGE_SECONDS", "30").strip() + try: + value = float(raw) + except ValueError: + value = 30.0 + return max(10.0, min(30.0, value)) + + +class TrafficControl: + def __init__(self, vllm_base_url: str) -> None: + base = vllm_base_url.rstrip("/") + self._metrics_urls = (f"{base}/v1/metrics", f"{base}/metrics") + + self._busy = False + self._last_ok_monotonic: Optional[float] = None + self._in_flight = 0 + + self._task: Optional[asyncio.Task[None]] = None + self._stop = asyncio.Event() + + def _update_from_metrics(self, metrics_text: str) -> None: + waiting = _parse_num_requests_waiting(metrics_text) + now = time.monotonic() + self._busy = waiting > 0 + self._last_ok_monotonic = now + + def _backend_busy_or_stale(self) -> bool: + last_ok = self._last_ok_monotonic + if last_ok is None or (time.monotonic() - last_ok) > _metrics_max_age_seconds(): + return False + return self._busy + + def start_inference(self, user_tier_header: Optional[str]) -> Callable[[], None]: + tier = _normalize_user_tier(user_tier_header) + if tier != "premium": + if self._backend_busy_or_stale(): + raise HTTPException(status_code=429, detail="Upstream vLLM is congested") + limit = _basic_max_request() + if self._in_flight >= limit: + raise HTTPException(status_code=429, detail="Too many concurrent requests") + self._in_flight += 1 + + finished = False + + def finish() -> None: + nonlocal finished + if finished: + return + finished = True + self._in_flight = max(0, self._in_flight - 1) + + return finish + + async def start(self) -> None: + if self._task is not None: + return + self._stop.clear() + self._task = asyncio.create_task(self._run()) + + async def stop(self) -> None: + if self._task is None: + return + self._stop.set() + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + pass + except Exception: + pass + self._task = None + + async def refresh_once(self) -> None: + timeout = httpx.Timeout(_metrics_timeout_seconds()) + async with httpx.AsyncClient(timeout=timeout) as client: + for url in self._metrics_urls: + try: + resp = await client.get(url) + except Exception: + continue + if resp.status_code == 404: + continue + if resp.status_code != 200: + continue + try: + self._update_from_metrics(resp.text) + except Exception: + continue + return + + async def _run(self) -> None: + interval = _metrics_poll_interval_seconds() + while not self._stop.is_set(): + try: + await self.refresh_once() + except Exception as exc: + log.debug(f"metrics refresh failed: {exc}") + try: + await asyncio.wait_for(self._stop.wait(), timeout=interval) + except asyncio.TimeoutError: + pass + + +def _get_header(scope: Scope, name: bytes) -> Optional[bytes]: + for k, v in scope.get("headers", ()): + if k == name or k.lower() == name: + return v + return None + + +def _authorized(scope: Scope) -> bool: + auth = _get_header(scope, b"authorization") + if not auth: + return False + try: + auth_str = auth.decode("latin-1") + except Exception: + return False + if not auth_str.startswith("Bearer "): + return False + token = auth_str.split("Bearer ", 1)[1] + return token == os.getenv("TOKEN") + + +class TrafficControlMiddleware: + def __init__(self, app: ASGIApp) -> None: + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope.get("type") != "http": + await self.app(scope, receive, send) + return + + path = scope.get("path") or "" + method = scope.get("method") or "" + if method != "POST" or path not in ("/v1/chat/completions", "/v1/completions"): + await self.app(scope, receive, send) + return + + if not _authorized(scope): + await self.app(scope, receive, send) + return + + app_obj = scope.get("app") + traffic_control: Optional[TrafficControl] = None + if app_obj is not None and hasattr(app_obj, "state"): + traffic_control = getattr(app_obj.state, "traffic_control", None) + if traffic_control is None: + await self.app(scope, receive, send) + return + + tier = _get_header(scope, b"x-user-tier") + tier_str = tier.decode("latin-1") if tier else None + try: + finish = traffic_control.start_inference(tier_str) + except HTTPException as exc: + response = JSONResponse( + status_code=exc.status_code, + content={"error": {"message": str(exc.detail), "type": "rate_limit"}}, + ) + await response(scope, receive, send) + return + + finished = False + + async def send_wrapper(message: Message) -> None: + nonlocal finished + await send(message) + if message.get("type") == "http.response.body" and not message.get("more_body", False): + if not finished: + finished = True + finish() + + try: + await self.app(scope, receive, send_wrapper) + finally: + if not finished: + finish() diff --git a/tests/app/test_end_to_end_chain.py b/tests/app/test_end_to_end_chain.py index 632cb05..d794626 100644 --- a/tests/app/test_end_to_end_chain.py +++ b/tests/app/test_end_to_end_chain.py @@ -12,7 +12,6 @@ import pytest from fastapi.testclient import TestClient -from app.main import app from tests.app.test_helpers import TEST_AUTH_HEADER from tests.app.sample_dstack_data import NRAS_SAMPLE_RESPONSE, NRAS_SAMPLE_PPCIE_RESPONSE from verifiers.attestation_verifier import check_report_data, check_gpu, check_tdx_quote @@ -22,6 +21,7 @@ def client(): if not os.path.exists('/var/run/dstack.sock'): pytest.skip("Not in a real TEE environment.") + from app.main import app return TestClient(app) @pytest.mark.parametrize("nras_response", [NRAS_SAMPLE_RESPONSE, NRAS_SAMPLE_PPCIE_RESPONSE]) @@ -35,7 +35,11 @@ def test_chain_of_trust_end_to_end(client, nras_response): upstream_payload = {"id": "chatcmpl-test-001", "object": "chat.completion", "choices": [{"message": {"role": "assistant", "content": "Hi there!"}, "index": 0, "finish_reason": "stop"}]} respx.mock.post(vllm_url).mock(return_value=httpx.Response(200, json=upstream_payload)) - response = client.post("/v1/chat/completions", json=request_payload, headers={"Authorization": TEST_AUTH_HEADER}) + response = client.post( + "/v1/chat/completions", + json=request_payload, + headers={"Authorization": TEST_AUTH_HEADER, "X-User-Tier": "premium"}, + ) assert response.status_code == 200 chat_id = response.json()["id"] diff --git a/tests/app/test_openai.py b/tests/app/test_openai.py index 911bc04..9ea5e9d 100644 --- a/tests/app/test_openai.py +++ b/tests/app/test_openai.py @@ -1,8 +1,8 @@ from unittest.mock import patch, AsyncMock import httpx import pytest -from fastapi.testclient import TestClient import json +import pytest_asyncio # Import and setup test environment before importing app from tests.app.test_helpers import setup_test_environment, TEST_AUTH_HEADER @@ -20,7 +20,12 @@ from app.api.v1.openai import VLLM_URL, VLLM_BASE_URL from tests.app.mock_quote import ED25519, ECDSA, ecdsa_quote, ed25519_quote -client = TestClient(app) + +@pytest_asyncio.fixture +async def client(): + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as async_client: + yield async_client async def yield_sse_response(data_list): @@ -30,7 +35,7 @@ async def yield_sse_response(data_list): @pytest.mark.asyncio @pytest.mark.respx -async def test_stream_chat_completions_success(respx_mock): +async def test_stream_chat_completions_success(respx_mock, client): # Test request data request_data = { "model": "test-model", @@ -78,7 +83,7 @@ async def test_stream_chat_completions_success(respx_mock): ) # Make request - response = client.post( + response = await client.post( "/v1/chat/completions", json=request_data, headers={"Authorization": TEST_AUTH_HEADER}, @@ -106,7 +111,7 @@ async def test_stream_chat_completions_success(respx_mock): @pytest.mark.asyncio @pytest.mark.respx -async def test_stream_chat_completions_upstream_error(respx_mock): +async def test_stream_chat_completions_upstream_error(respx_mock, client): # Test request data request_data = { "model": "test-model", @@ -127,7 +132,7 @@ async def test_stream_chat_completions_upstream_error(respx_mock): ) # Make request - response = client.post( + response = await client.post( "/v1/chat/completions", json=request_data, headers={"Authorization": TEST_AUTH_HEADER}, @@ -145,7 +150,7 @@ async def test_stream_chat_completions_upstream_error(respx_mock): @pytest.mark.asyncio -async def test_signature_default_algo(): +async def test_signature_default_algo(client): # Setup test data chat_id = "test-chat-123" test_data = "test request:response data" @@ -167,7 +172,7 @@ async def test_signature_default_algo(): mock_cache.get_chat.return_value = cache_data # Make request - response = client.get( + response = await client.get( f"/v1/signature/{chat_id}", headers={"Authorization": TEST_AUTH_HEADER} ) @@ -180,7 +185,7 @@ async def test_signature_default_algo(): @pytest.mark.asyncio -async def test_signature_explicit_algo(): +async def test_signature_explicit_algo(client): # Setup test data chat_id = "test-chat-123" test_data = "test request:response data" @@ -203,7 +208,7 @@ async def test_signature_explicit_algo(): # Make request with explicit algorithm explicit_algo = ED25519 # Use ED25519 explicitly - response = client.get( + response = await client.get( f"/v1/signature/{chat_id}?signing_algo={explicit_algo}", headers={"Authorization": TEST_AUTH_HEADER}, ) @@ -217,7 +222,7 @@ async def test_signature_explicit_algo(): @pytest.mark.asyncio -async def test_signature_invalid_algo(): +async def test_signature_invalid_algo(client): chat_id = "test-chat-123" # Create properly formatted cache data @@ -236,7 +241,7 @@ async def test_signature_invalid_algo(): mock_cache.get_chat.return_value = cache_data # Make request with invalid algorithm - response = client.get( + response = await client.get( f"/v1/signature/{chat_id}?signing_algo=invalid-algo", headers={"Authorization": TEST_AUTH_HEADER}, ) @@ -249,7 +254,7 @@ async def test_signature_invalid_algo(): @pytest.mark.asyncio -async def test_signature_chat_not_found(): +async def test_signature_chat_not_found(client): chat_id = "nonexistent-chat" # Mock the cache to return None for chat not found @@ -257,7 +262,7 @@ async def test_signature_chat_not_found(): mock_cache.get_chat.return_value = None # Make request - response = client.get( + response = await client.get( f"/v1/signature/{chat_id}", headers={"Authorization": TEST_AUTH_HEADER} ) @@ -270,7 +275,7 @@ async def test_signature_chat_not_found(): @pytest.mark.asyncio @pytest.mark.respx -async def test_chat_completions_with_request_hash_streaming(respx_mock): +async def test_chat_completions_with_request_hash_streaming(respx_mock, client): # Test request data request_data = { "model": "test-model", @@ -327,7 +332,7 @@ async def test_chat_completions_with_request_hash_streaming(respx_mock): ) as mock_log: # Make request with X-Request-Hash header - response = client.post( + response = await client.post( "/v1/chat/completions", json=request_data, headers={ @@ -351,7 +356,7 @@ async def test_chat_completions_with_request_hash_streaming(respx_mock): @pytest.mark.asyncio @pytest.mark.respx -async def test_chat_completions_with_request_hash_non_streaming(respx_mock): +async def test_chat_completions_with_request_hash_non_streaming(respx_mock, client): # Test request data request_data = { "model": "test-model", @@ -389,7 +394,7 @@ async def test_chat_completions_with_request_hash_non_streaming(respx_mock): ) as mock_log: # Make request with X-Request-Hash header - response = client.post( + response = await client.post( "/v1/chat/completions", json=request_data, headers={ @@ -413,7 +418,7 @@ async def test_chat_completions_with_request_hash_non_streaming(respx_mock): @pytest.mark.asyncio @pytest.mark.respx -async def test_completions_with_request_hash_streaming(respx_mock): +async def test_completions_with_request_hash_streaming(respx_mock, client): # Test request data request_data = {"model": "test-model", "prompt": "Hello", "stream": True} @@ -461,7 +466,7 @@ async def test_completions_with_request_hash_streaming(respx_mock): ) as mock_log: # Make request with X-Request-Hash header - response = client.post( + response = await client.post( "/v1/completions", json=request_data, headers={ @@ -485,7 +490,7 @@ async def test_completions_with_request_hash_streaming(respx_mock): @pytest.mark.asyncio @pytest.mark.respx -async def test_completions_with_request_hash_non_streaming(respx_mock): +async def test_completions_with_request_hash_non_streaming(respx_mock, client): # Test request data request_data = {"model": "test-model", "prompt": "Hello", "stream": False} @@ -513,7 +518,7 @@ async def test_completions_with_request_hash_non_streaming(respx_mock): ) as mock_log: # Make request with X-Request-Hash header - response = client.post( + response = await client.post( "/v1/completions", json=request_data, headers={ @@ -537,7 +542,7 @@ async def test_completions_with_request_hash_non_streaming(respx_mock): @pytest.mark.asyncio @pytest.mark.respx -async def test_chat_completions_without_request_hash(respx_mock): +async def test_chat_completions_without_request_hash(respx_mock, client): # Test request data without X-Request-Hash header request_data = { "model": "test-model", @@ -572,7 +577,7 @@ async def test_chat_completions_without_request_hash(respx_mock): ) as mock_log: # Make request without X-Request-Hash header - response = client.post( + response = await client.post( "/v1/chat/completions", json=request_data, headers={"Authorization": TEST_AUTH_HEADER}, diff --git a/tests/app/test_traffic_control_integration.py b/tests/app/test_traffic_control_integration.py new file mode 100644 index 0000000..a087a71 --- /dev/null +++ b/tests/app/test_traffic_control_integration.py @@ -0,0 +1,272 @@ +import asyncio +import sys +import time +from dataclasses import dataclass, field +from functools import partial +from typing import Any +from unittest.mock import patch + +import httpx +import pytest +import pytest_asyncio +from fastapi import FastAPI, Request +from fastapi.responses import PlainTextResponse + +from tests.app.test_helpers import TEST_AUTH_HEADER, setup_test_environment + + +setup_test_environment() +sys.modules["app.quote.quote"] = __import__("tests.app.mock_quote", fromlist=[""]) + +from app.main import app # noqa: E402 +from app.traffic_control import TrafficControl # noqa: E402 +from app.api.v1 import openai as openai_module # noqa: E402 + + +@dataclass +class MockVLLMState: + num_requests_waiting: float = 0.0 + metrics_enabled: bool = True + delay_seconds: float = 0.0 + chat_calls: int = 0 + completion_calls: int = 0 + last_chat_payload: dict[str, Any] | None = None + last_completion_payload: dict[str, Any] | None = None + seen_models: list[str] = field(default_factory=list) + + def reset(self) -> None: + self.num_requests_waiting = 0.0 + self.metrics_enabled = True + self.delay_seconds = 0.0 + self.chat_calls = 0 + self.completion_calls = 0 + self.last_chat_payload = None + self.last_completion_payload = None + self.seen_models.clear() + + +mock_vllm_state = MockVLLMState() +mock_vllm_app = FastAPI() + + +def _metrics_text() -> str: + model = mock_vllm_state.seen_models[-1] if mock_vllm_state.seen_models else "test-model" + return ( + "# HELP vllm:num_requests_waiting Number of requests waiting\n" + "# TYPE vllm:num_requests_waiting gauge\n" + f'vllm:num_requests_waiting{{engine="0",model_name="{model}"}} {mock_vllm_state.num_requests_waiting}\n' + ) + + +@mock_vllm_app.get("/metrics") +@mock_vllm_app.get("/v1/metrics") +async def metrics() -> PlainTextResponse: + if not mock_vllm_state.metrics_enabled: + return PlainTextResponse("not found", status_code=404) + return PlainTextResponse(_metrics_text()) + + +@mock_vllm_app.post("/v1/chat/completions") +async def chat_completions(request: Request) -> dict[str, Any]: + payload = await request.json() + mock_vllm_state.chat_calls += 1 + mock_vllm_state.last_chat_payload = payload + model = payload.get("model", "test-model") + mock_vllm_state.seen_models.append(model) + if mock_vllm_state.delay_seconds > 0: + await asyncio.sleep(mock_vllm_state.delay_seconds) + return { + "id": f"chatcmpl-mock-{int(time.time() * 1000)}", + "object": "chat.completion", + "created": int(time.time()), + "model": model, + "choices": [{"index": 0, "finish_reason": "stop", "message": {"role": "assistant", "content": "ok"}}], + } + + +@mock_vllm_app.post("/v1/completions") +async def completions(request: Request) -> dict[str, Any]: + payload = await request.json() + mock_vllm_state.completion_calls += 1 + mock_vllm_state.last_completion_payload = payload + model = payload.get("model", "test-model") + mock_vllm_state.seen_models.append(model) + if mock_vllm_state.delay_seconds > 0: + await asyncio.sleep(mock_vllm_state.delay_seconds) + return { + "id": f"cmpl-mock-{int(time.time() * 1000)}", + "object": "text_completion", + "created": int(time.time()), + "model": model, + "choices": [{"index": 0, "finish_reason": "stop", "text": "ok"}], + } + + +@pytest.fixture +def mock_vllm_base_url(monkeypatch) -> str: + base_url = "http://mock-vllm" + monkeypatch.setattr(openai_module, "VLLM_BASE_URL", base_url) + monkeypatch.setattr(openai_module, "VLLM_URL", f"{base_url}/v1/chat/completions") + monkeypatch.setattr(openai_module, "VLLM_COMPLETIONS_URL", f"{base_url}/v1/completions") + monkeypatch.setattr(openai_module, "VLLM_METRICS_URL", f"{base_url}/metrics") + monkeypatch.setattr(openai_module, "VLLM_MODELS_URL", f"{base_url}/v1/models") + return base_url + + +@pytest_asyncio.fixture +async def proxy_client(mock_vllm_base_url): + mock_vllm_state.reset() + vllm_transport = httpx.ASGITransport(app=mock_vllm_app) + original_async_client = httpx.AsyncClient + patched_async_client = partial(original_async_client, transport=vllm_transport) + + with patch("httpx.AsyncClient", patched_async_client): + traffic_control = TrafficControl(mock_vllm_base_url) + old_traffic_control = getattr(app.state, "traffic_control", None) + app.state.traffic_control = traffic_control + + transport = httpx.ASGITransport(app=app) + async with original_async_client(transport=transport, base_url="http://test") as client: + yield client + app.state.traffic_control = old_traffic_control + + +@pytest.mark.asyncio +async def test_basic_rejected_when_backend_busy(proxy_client): + mock_vllm_state.num_requests_waiting = 1.0 + await app.state.traffic_control.refresh_once() + + resp = await proxy_client.post( + "/v1/chat/completions", + json={"model": "test-model", "messages": [{"role": "user", "content": "hi"}]}, + headers={"Authorization": TEST_AUTH_HEADER, "X-User-Tier": "basic"}, + ) + + assert resp.status_code == 429 + assert mock_vllm_state.chat_calls == 0 + + +@pytest.mark.asyncio +async def test_missing_tier_defaults_to_basic(proxy_client): + mock_vllm_state.num_requests_waiting = 1.0 + await app.state.traffic_control.refresh_once() + + resp = await proxy_client.post( + "/v1/chat/completions", + json={"model": "test-model", "messages": [{"role": "user", "content": "hi"}]}, + headers={"Authorization": TEST_AUTH_HEADER}, + ) + + assert resp.status_code == 429 + assert mock_vllm_state.chat_calls == 0 + + +@pytest.mark.asyncio +async def test_unknown_tier_treated_as_basic(proxy_client): + mock_vllm_state.num_requests_waiting = 1.0 + await app.state.traffic_control.refresh_once() + + resp = await proxy_client.post( + "/v1/chat/completions", + json={"model": "test-model", "messages": [{"role": "user", "content": "hi"}]}, + headers={"Authorization": TEST_AUTH_HEADER, "X-User-Tier": "unknown"}, + ) + + assert resp.status_code == 429 + assert mock_vllm_state.chat_calls == 0 + + +@pytest.mark.asyncio +async def test_premium_always_passes_even_when_backend_busy(proxy_client): + mock_vllm_state.num_requests_waiting = 1.0 + await app.state.traffic_control.refresh_once() + + resp = await proxy_client.post( + "/v1/chat/completions", + json={"model": "test-model", "messages": [{"role": "user", "content": "hi"}]}, + headers={"Authorization": TEST_AUTH_HEADER, "X-User-Tier": "premium"}, + ) + + assert resp.status_code == 200 + assert mock_vllm_state.chat_calls == 1 + + +@pytest.mark.asyncio +async def test_basic_rejected_when_over_concurrency_limit(proxy_client, monkeypatch): + mock_vllm_state.num_requests_waiting = 0.0 + mock_vllm_state.delay_seconds = 0.25 + monkeypatch.setenv("BASIC_MAX_REQUEST", "1") + await app.state.traffic_control.refresh_once() + + payload = {"model": "test-model", "messages": [{"role": "user", "content": "hi"}]} + headers = {"Authorization": TEST_AUTH_HEADER, "X-User-Tier": "basic"} + + t1 = asyncio.create_task(proxy_client.post("/v1/chat/completions", json=payload, headers=headers)) + await asyncio.sleep(0.05) + t2 = asyncio.create_task(proxy_client.post("/v1/chat/completions", json=payload, headers=headers)) + r1, r2 = await asyncio.gather(t1, t2) + + assert sorted([r1.status_code, r2.status_code]) == [200, 429] + assert mock_vllm_state.chat_calls == 1 + + +@pytest.mark.asyncio +async def test_traffic_control_applies_to_completions(proxy_client): + mock_vllm_state.num_requests_waiting = 1.0 + await app.state.traffic_control.refresh_once() + + resp = await proxy_client.post( + "/v1/completions", + json={"model": "test-model", "prompt": "hi"}, + headers={"Authorization": TEST_AUTH_HEADER, "X-User-Tier": "basic"}, + ) + + assert resp.status_code == 429 + assert mock_vllm_state.completion_calls == 0 + + +@pytest.mark.asyncio +async def test_fail_open_when_metrics_unknown(proxy_client): + mock_vllm_state.num_requests_waiting = 1.0 + resp = await proxy_client.post( + "/v1/chat/completions", + json={"model": "test-model", "messages": [{"role": "user", "content": "hi"}]}, + headers={"Authorization": TEST_AUTH_HEADER, "X-User-Tier": "basic"}, + ) + + assert resp.status_code == 200 + assert mock_vllm_state.chat_calls == 1 + + +@pytest.mark.asyncio +async def test_fail_open_when_no_metrics_endpoint(proxy_client): + mock_vllm_state.metrics_enabled = False + mock_vllm_state.num_requests_waiting = 1.0 + await app.state.traffic_control.refresh_once() + + resp = await proxy_client.post( + "/v1/chat/completions", + json={"model": "test-model", "messages": [{"role": "user", "content": "hi"}]}, + headers={"Authorization": TEST_AUTH_HEADER, "X-User-Tier": "basic"}, + ) + + assert resp.status_code == 200 + assert mock_vllm_state.chat_calls == 1 + + +@pytest.mark.asyncio +async def test_fail_open_when_metrics_stale(proxy_client, monkeypatch): + mock_vllm_state.num_requests_waiting = 1.0 + await app.state.traffic_control.refresh_once() + + monkeypatch.setattr(app.state.traffic_control, "_busy", True) + monkeypatch.setattr(app.state.traffic_control, "_last_ok_monotonic", time.monotonic() - 3600) + + resp = await proxy_client.post( + "/v1/chat/completions", + json={"model": "test-model", "messages": [{"role": "user", "content": "hi"}]}, + headers={"Authorization": TEST_AUTH_HEADER, "X-User-Tier": "basic"}, + ) + + assert resp.status_code == 200 + assert mock_vllm_state.chat_calls == 1 diff --git a/tests/conftest.py b/tests/conftest.py index da962ba..6cb1ac1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,7 +7,7 @@ # Set required environment variables before any imports os.environ["MODEL_NAME"] = "test-model" -os.environ["VLLM_BASE_URL"] = "http://localhost:8001" +os.environ["VLLM_BASE_URL"] = "http://localhost:8001" os.environ["CHAT_CACHE_EXPIRATION"] = "1200" os.environ["REDIS_HOST"] = "localhost" os.environ["REDIS_PORT"] = "6379" @@ -19,4 +19,4 @@ # Pytest configuration def pytest_configure(config): """Configure pytest with custom markers.""" - config.addinivalue_line("markers", "asyncio: mark test as an asyncio test") \ No newline at end of file + config.addinivalue_line("markers", "asyncio: mark test as an asyncio test")