Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions src/app/api/helper/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@

from fastapi import Header, HTTPException

TOKEN = os.getenv("TOKEN")

# Dependency to verify the Authorization header
def verify_authorization_header(authorization: str = Header("Authorization")):
async def verify_authorization_header(authorization: str | None = Header(default=None)) -> str:
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(
status_code=401, detail="Invalid or missing Authorization header"
)
token = authorization.split("Bearer ")[1]
if token != TOKEN:
expected = os.getenv("TOKEN")
if not expected or token != expected:
raise HTTPException(status_code=401, detail="Invalid token")
return token
31 changes: 28 additions & 3 deletions src/app/main.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,46 @@
import os
from contextlib import asynccontextmanager
from collections.abc import AsyncIterator

from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse

from .api import router as api_router
from .api.response.response import ok, error, http_exception
from .logger import log
from .traffic_control import TrafficControl, TrafficControlMiddleware


@asynccontextmanager
async def _lifespan(app: FastAPI) -> AsyncIterator[None]:
traffic_control: TrafficControl | None = getattr(app.state, "traffic_control", None)
owns_traffic_control = traffic_control is None
if owns_traffic_control:
vllm_base_url = os.getenv("VLLM_BASE_URL", "http://vllm:8000")
traffic_control = TrafficControl(vllm_base_url)
app.state.traffic_control = traffic_control
await traffic_control.start()
assert traffic_control is not None
try:
yield
finally:
if owns_traffic_control:
await traffic_control.stop()


app = FastAPI()
app = FastAPI(lifespan=_lifespan)
app.add_middleware(TrafficControlMiddleware)
app.include_router(api_router)


@app.get("/")
async def root():
async def root() -> dict:
return ok()


# Custom global error handler
@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
async def global_exception_handler(request: Request, exc: Exception) -> JSONResponse:
"""
Handle all uncaught exceptions globally.
"""
Expand Down
239 changes: 239 additions & 0 deletions src/app/traffic_control.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
import asyncio
import os
import time
from collections.abc import Callable
from typing import Optional

import httpx
from fastapi import HTTPException
from fastapi.responses import JSONResponse
from starlette.types import ASGIApp, Message, Receive, Scope, Send

from app.logger import log


def _parse_num_requests_waiting(metrics_text: str) -> float:
max_waiting = 0.0
for line in metrics_text.splitlines():
if not line.startswith("vllm:num_requests_waiting"):
continue
try:
value = float(line.split()[-1])
except (IndexError, ValueError):
continue
if value > max_waiting:
max_waiting = value
return max_waiting


def _normalize_user_tier(raw: Optional[str]) -> str:
if raw and raw.strip().lower() == "premium":
return "premium"
return "basic"


def _basic_max_request() -> int:
raw = os.getenv("BASIC_MAX_REQUEST", "20").strip()
try:
value = int(raw)
except ValueError:
value = 20
return max(1, value)


def _metrics_timeout_seconds() -> float:
raw = os.getenv("METRICS_TIMEOUT_SECONDS", "2").strip()
try:
return float(raw)
except ValueError:
return 2.0


def _metrics_poll_interval_seconds() -> float:
raw = os.getenv("METRICS_POLL_INTERVAL_SECONDS", "2").strip()
try:
value = float(raw)
except ValueError:
value = 2.0
return max(0.5, min(10.0, value))


def _metrics_max_age_seconds() -> float:
raw = os.getenv("METRICS_MAX_AGE_SECONDS", "30").strip()
try:
value = float(raw)
except ValueError:
value = 30.0
return max(10.0, min(30.0, value))


class TrafficControl:
def __init__(self, vllm_base_url: str) -> None:
base = vllm_base_url.rstrip("/")
self._metrics_urls = (f"{base}/v1/metrics", f"{base}/metrics")

self._busy = False
self._last_ok_monotonic: Optional[float] = None
self._in_flight = 0

self._task: Optional[asyncio.Task[None]] = None
self._stop = asyncio.Event()

def _update_from_metrics(self, metrics_text: str) -> None:
waiting = _parse_num_requests_waiting(metrics_text)
now = time.monotonic()
self._busy = waiting > 0
self._last_ok_monotonic = now

def _backend_busy_or_stale(self) -> bool:
last_ok = self._last_ok_monotonic
if last_ok is None or (time.monotonic() - last_ok) > _metrics_max_age_seconds():
return False
return self._busy

def start_inference(self, user_tier_header: Optional[str]) -> Callable[[], None]:
tier = _normalize_user_tier(user_tier_header)
if tier != "premium":
if self._backend_busy_or_stale():
raise HTTPException(status_code=429, detail="Upstream vLLM is congested")
limit = _basic_max_request()
if self._in_flight >= limit:
raise HTTPException(status_code=429, detail="Too many concurrent requests")
self._in_flight += 1

finished = False

def finish() -> None:
nonlocal finished
if finished:
return
finished = True
self._in_flight = max(0, self._in_flight - 1)

return finish

async def start(self) -> None:
if self._task is not None:
return
self._stop.clear()
self._task = asyncio.create_task(self._run())

async def stop(self) -> None:
if self._task is None:
return
self._stop.set()
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
pass
except Exception:
pass
self._task = None

async def refresh_once(self) -> None:
timeout = httpx.Timeout(_metrics_timeout_seconds())
async with httpx.AsyncClient(timeout=timeout) as client:
for url in self._metrics_urls:
try:
resp = await client.get(url)
except Exception:
continue
if resp.status_code == 404:
continue
if resp.status_code != 200:
continue
try:
self._update_from_metrics(resp.text)
except Exception:
continue
return

async def _run(self) -> None:
interval = _metrics_poll_interval_seconds()
while not self._stop.is_set():
try:
await self.refresh_once()
except Exception as exc:
log.debug(f"metrics refresh failed: {exc}")
try:
await asyncio.wait_for(self._stop.wait(), timeout=interval)
except asyncio.TimeoutError:
pass


def _get_header(scope: Scope, name: bytes) -> Optional[bytes]:
for k, v in scope.get("headers", ()):
if k == name or k.lower() == name:
return v
return None


def _authorized(scope: Scope) -> bool:
auth = _get_header(scope, b"authorization")
if not auth:
return False
try:
auth_str = auth.decode("latin-1")
except Exception:
return False
if not auth_str.startswith("Bearer "):
return False
token = auth_str.split("Bearer ", 1)[1]
return token == os.getenv("TOKEN")


class TrafficControlMiddleware:
def __init__(self, app: ASGIApp) -> None:
self.app = app

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope.get("type") != "http":
await self.app(scope, receive, send)
return

path = scope.get("path") or ""
method = scope.get("method") or ""
if method != "POST" or path not in ("/v1/chat/completions", "/v1/completions"):
await self.app(scope, receive, send)
return

if not _authorized(scope):
await self.app(scope, receive, send)
return

app_obj = scope.get("app")
traffic_control: Optional[TrafficControl] = None
if app_obj is not None and hasattr(app_obj, "state"):
traffic_control = getattr(app_obj.state, "traffic_control", None)
if traffic_control is None:
await self.app(scope, receive, send)
return

tier = _get_header(scope, b"x-user-tier")
tier_str = tier.decode("latin-1") if tier else None
try:
finish = traffic_control.start_inference(tier_str)
except HTTPException as exc:
response = JSONResponse(
status_code=exc.status_code,
content={"error": {"message": str(exc.detail), "type": "rate_limit"}},
)
await response(scope, receive, send)
return

finished = False

async def send_wrapper(message: Message) -> None:
nonlocal finished
await send(message)
if message.get("type") == "http.response.body" and not message.get("more_body", False):
if not finished:
finished = True
finish()

try:
await self.app(scope, receive, send_wrapper)
finally:
if not finished:
finish()
8 changes: 6 additions & 2 deletions tests/app/test_end_to_end_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import pytest
from fastapi.testclient import TestClient

from app.main import app
from tests.app.test_helpers import TEST_AUTH_HEADER
from tests.app.sample_dstack_data import NRAS_SAMPLE_RESPONSE, NRAS_SAMPLE_PPCIE_RESPONSE
from verifiers.attestation_verifier import check_report_data, check_gpu, check_tdx_quote
Expand All @@ -22,6 +21,7 @@
def client():
if not os.path.exists('/var/run/dstack.sock'):
pytest.skip("Not in a real TEE environment.")
from app.main import app
return TestClient(app)

@pytest.mark.parametrize("nras_response", [NRAS_SAMPLE_RESPONSE, NRAS_SAMPLE_PPCIE_RESPONSE])
Expand All @@ -35,7 +35,11 @@ def test_chain_of_trust_end_to_end(client, nras_response):
upstream_payload = {"id": "chatcmpl-test-001", "object": "chat.completion", "choices": [{"message": {"role": "assistant", "content": "Hi there!"}, "index": 0, "finish_reason": "stop"}]}
respx.mock.post(vllm_url).mock(return_value=httpx.Response(200, json=upstream_payload))

response = client.post("/v1/chat/completions", json=request_payload, headers={"Authorization": TEST_AUTH_HEADER})
response = client.post(
"/v1/chat/completions",
json=request_payload,
headers={"Authorization": TEST_AUTH_HEADER, "X-User-Tier": "premium"},
)
assert response.status_code == 200
chat_id = response.json()["id"]

Expand Down
Loading