diff --git a/docs/guards.md b/docs/guards.md new file mode 100644 index 0000000..0317356 --- /dev/null +++ b/docs/guards.md @@ -0,0 +1,565 @@ +# Guards and Authentication + +PyNest supports route guards similar to NestJS, providing a powerful way to implement authentication and authorization logic. Guards are fully compatible with FastAPI's security system and automatically integrate with OpenAPI documentation. + +## Overview + +Guards are classes that implement custom authorization logic and can be applied to controllers or individual routes using the `@UseGuards` decorator. When a guard defines a FastAPI security scheme via the `security_scheme` attribute, the generated OpenAPI schema will mark the route as protected and the interactive docs will show an "Authorize" button. + +## Basic Guard Example + +```python +from fastapi import Request +from nest.core import Controller, Get, UseGuards, BaseGuard + +class AuthGuard(BaseGuard): + def can_activate(self, request: Request, credentials=None) -> bool: + token = request.headers.get("X-Token") + return token == "secret" + +@Controller("/items") +@UseGuards(AuthGuard) +class ItemsController: + @Get("/") + def list_items(self): + return ["a", "b"] +``` + +When the guard returns `False`, a `403 Forbidden` response is sent automatically. + +## FastAPI Security Integration + +PyNest guards support all FastAPI security schemes and automatically appear in OpenAPI documentation: + +### API Key Authentication + +#### API Key in Header (Most Common) + +```python +from fastapi.security import APIKeyHeader + +class APIKeyGuard(BaseGuard): + security_scheme = APIKeyHeader( + name="X-API-Key", + description="API key required for authentication" + ) + + def can_activate(self, request: Request, credentials=None) -> bool: + # credentials contains the API key value + valid_keys = {"admin-key-123", "user-key-456"} + return credentials in valid_keys +``` + +#### API Key in Query Parameter + +```python +from fastapi.security import APIKeyQuery + +class APIKeyQueryGuard(BaseGuard): + security_scheme = APIKeyQuery( + name="api_key", + description="API key as query parameter (?api_key=your-key)" + ) + + def can_activate(self, request: Request, credentials=None) -> bool: + return credentials == "secret-query-key" +``` + +#### API Key in Cookie + +```python +from fastapi.security import APIKeyCookie + +class SessionGuard(BaseGuard): + security_scheme = APIKeyCookie( + name="session_token", + description="Session token stored in cookie" + ) + + def can_activate(self, request: Request, credentials=None) -> bool: + valid_sessions = {"sess_abc123", "sess_def456"} + return credentials in valid_sessions +``` + +### HTTP Authentication + +#### Basic Authentication + +```python +from fastapi.security import HTTPBasic +from fastapi.security.http import HTTPBasicCredentials + +class BasicAuthGuard(BaseGuard): + security_scheme = HTTPBasic( + description="Username and password authentication" + ) + + def can_activate(self, request: Request, credentials: HTTPBasicCredentials = None) -> bool: + if not credentials: + return False + + # In production, use hashed passwords + users = {"admin": "admin123", "user": "user456"} + expected_password = users.get(credentials.username) + return expected_password == credentials.password +``` + +#### Bearer Token Authentication + +```python +from fastapi.security import HTTPBearer +from fastapi.security.http import HTTPAuthorizationCredentials + +class BearerTokenGuard(BaseGuard): + security_scheme = HTTPBearer(description="Bearer token authentication") + + def can_activate(self, request: Request, credentials: HTTPAuthorizationCredentials = None) -> bool: + if not credentials or credentials.scheme != "Bearer": + return False + + token = credentials.credentials + return self.validate_jwt_token(token) + + def validate_jwt_token(self, token: str) -> bool: + # Implement proper JWT validation + return token.startswith("eyJ") and len(token) > 20 +``` + +## JWT Authentication Example + +Complete JWT implementation using third-party libraries: + +```python +import jwt +from fastapi import Request +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +from nest.core import BaseGuard + +class JWTGuard(BaseGuard): + security_scheme = HTTPBearer(description="JWT Bearer token") + + def can_activate( + self, request: Request, credentials: HTTPAuthorizationCredentials = None + ) -> bool: + if not credentials: + return False + + try: + payload = jwt.decode( + credentials.credentials, "your-secret", algorithms=["HS256"] + ) + # Attach user info to request for use in controllers + request.state.user = payload.get("sub") + return True + except jwt.PyJWTError: + return False +``` + +Attach the guard with `@UseGuards(JWTGuard)` on controllers or routes to secure them. Because `JWTGuard` specifies a `security_scheme`, the route will display a lock icon in the docs and allow entering a token. + +## OAuth2 Authentication + +### Basic OAuth2 Password Bearer + +```python +from fastapi.security import OAuth2PasswordBearer + +class OAuth2Guard(BaseGuard): + security_scheme = OAuth2PasswordBearer( + tokenUrl="auth/token", + description="OAuth2 password bearer token" + ) + + def can_activate(self, request: Request, credentials=None) -> bool: + if not credentials: + return False + # Validate OAuth2 token with your auth server + return self.validate_oauth2_token(credentials) + + def validate_oauth2_token(self, token: str) -> bool: + # Implement OAuth2 token validation + valid_tokens = {"oauth2_token_123", "oauth2_token_456"} + return token in valid_tokens +``` + +### OAuth2 with Scopes (Fine-grained Permissions) + +```python +class OAuth2ScopesGuard(BaseGuard): + security_scheme = OAuth2PasswordBearer( + tokenUrl="auth/token", + scopes={ + "read": "Read access to resources", + "write": "Write access to resources", + "admin": "Full administrative access" + } + ) + + def __init__(self, required_scopes: list = None): + self.required_scopes = required_scopes or [] + + def can_activate(self, request: Request, credentials=None) -> bool: + if not credentials: + return False + + user_scopes = self.get_token_scopes(credentials) + return all(scope in user_scopes for scope in self.required_scopes) + + def get_token_scopes(self, token: str) -> list: + # Extract scopes from token + token_scopes = { + "admin_token": ["read", "write", "admin"], + "user_token": ["read", "write"], + "readonly_token": ["read"] + } + return token_scopes.get(token, []) + +# Usage with specific scopes +@Controller("admin") +@UseGuards(OAuth2ScopesGuard(["admin"])) +class AdminController: + @Get("/users") + def list_users(self): + return {"users": ["user1", "user2"]} +``` + +## Controller vs. Route Guards + +You can attach guards at the controller level so they apply to every route in the controller. Individual routes can also specify their own guards. + +```python +@Controller('/admin') +@UseGuards(AdminGuard) +class AdminController: + @Get('/dashboard') + def dashboard(self): + return {'ok': True} + + @Post('/login') + @UseGuards(PublicOnlyGuard) # Overrides controller guard + def login(self): + return {'logged_in': True} +``` + +In this example `AdminGuard` protects all routes while `PublicOnlyGuard` is applied only to the `login` route. + +## Combining Multiple Guards + +`UseGuards` accepts any number of guard classes. All specified guards must return `True` in order for the request to proceed. + +```python +class TokenGuard(BaseGuard): + security_scheme = APIKeyHeader(name="X-Token") + + def can_activate(self, request: Request, credentials=None) -> bool: + return credentials == "secret" + +class RoleGuard(BaseGuard): + security_scheme = HTTPBearer(description="JWT with role info") + + def can_activate(self, request: Request, credentials=None) -> bool: + # Extract role from JWT token + user_role = self.get_user_role(credentials.credentials) + return user_role == "admin" + +@Controller('/secure') +class SecureController: + @Get('/') + @UseGuards(TokenGuard, RoleGuard) # Both guards must pass + def root(self): + return {'ok': True} +``` + +## Role-Based Access Control + +```python +class RoleBasedGuard(BaseGuard): + security_scheme = HTTPBearer(description="JWT token with role information") + + def __init__(self, allowed_roles: list): + self.allowed_roles = allowed_roles + + def can_activate(self, request: Request, credentials=None) -> bool: + if not credentials: + return False + + user_roles = self.get_user_roles(credentials.credentials) + return any(role in user_roles for role in self.allowed_roles) + + def get_user_roles(self, token: str) -> list: + # Extract roles from JWT or database + role_mapping = { + "admin_token": ["admin", "user"], + "user_token": ["user"], + "guest_token": ["guest"] + } + return role_mapping.get(token, []) + +# Usage +@Controller("api/users") +class UserController: + @Get("/") + @UseGuards(RoleBasedGuard(["user", "admin"])) + def list_users(self): + return {"users": []} + + @Delete("/{user_id}") + @UseGuards(RoleBasedGuard(["admin"])) # Admin only + def delete_user(self, user_id: int): + return {"deleted": user_id} +``` + +## Asynchronous Guards + +Guards can perform asynchronous checks by making `can_activate` async or returning an awaitable: + +```python +class AsyncGuard(BaseGuard): + security_scheme = APIKeyHeader(name="X-Auth-Token") + + async def can_activate(self, request: Request, credentials=None) -> bool: + if not credentials: + return False + + # Async database lookup + user = await self.get_user_from_db(credentials) + return user is not None and user.get("is_active", False) + + async def get_user_from_db(self, token: str): + # Simulate async database call + import asyncio + await asyncio.sleep(0.1) + + users = { + "valid_token_123": {"id": 1, "is_active": True}, + "expired_token_456": {"id": 2, "is_active": False} + } + return users.get(token) +``` + +PyNest automatically awaits the result. + +## Custom Guards Without Security Schemes + +Guards don't always need security schemes. They can implement custom logic like rate limiting: + +```python +from datetime import datetime, timedelta + +class RateLimitGuard(BaseGuard): + # No security_scheme - won't appear in OpenAPI docs + + def __init__(self, max_requests: int = 100, window_minutes: int = 60): + self.max_requests = max_requests + self.window_minutes = window_minutes + self.request_counts = {} + + def can_activate(self, request: Request, credentials=None) -> bool: + client_ip = request.client.host + now = datetime.now() + + # Clean old entries + cutoff = now - timedelta(minutes=self.window_minutes) + + if client_ip not in self.request_counts: + self.request_counts[client_ip] = [] + + # Filter recent requests + recent_requests = [ + t for t in self.request_counts[client_ip] + if t > cutoff + ] + + if len(recent_requests) >= self.max_requests: + return False + + recent_requests.append(now) + self.request_counts[client_ip] = recent_requests + return True + +@Controller("api") +@UseGuards(APIKeyGuard, RateLimitGuard) # API key + rate limiting +class APIController: + @Get("/data") + def get_data(self): + return {"data": "protected and rate limited"} +``` + +## Multi-Method Authentication + +Guards can accept multiple authentication methods: + +```python +class MultiAuthGuard(BaseGuard): + # Primary security scheme for OpenAPI docs + security_scheme = HTTPBearer(description="Bearer token or API key") + + def can_activate(self, request: Request, credentials=None) -> bool: + # Method 1: Bearer token from security scheme + if credentials and self.validate_bearer(credentials.credentials): + return True + + # Method 2: API key in custom header + api_key = request.headers.get("X-API-Key") + if api_key and self.validate_api_key(api_key): + return True + + # Method 3: Session cookie + session = request.cookies.get("session_id") + if session and self.validate_session(session): + return True + + return False + + def validate_bearer(self, token: str) -> bool: + return token in ["jwt-token-1", "jwt-token-2"] + + def validate_api_key(self, key: str) -> bool: + return key in ["api-key-1", "api-key-2"] + + def validate_session(self, session: str) -> bool: + return session in ["session-1", "session-2"] +``` + +## Custom Error Handling + +Override the `__call__` method for custom error responses: + +```python +import inspect +from datetime import datetime + +class CustomErrorGuard(BaseGuard): + security_scheme = APIKeyHeader(name="X-Custom-Key") + + async def __call__(self, request: Request, credentials=None): + try: + result = self.can_activate(request, credentials) + if inspect.isawaitable(result): + result = await result + + if not result: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail={ + "error": "INVALID_API_KEY", + "message": "The provided API key is invalid or expired", + "code": "AUTH_001", + "timestamp": datetime.now().isoformat() + }, + headers={"WWW-Authenticate": "ApiKey"} + ) + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={"error": "AUTHENTICATION_ERROR", "message": str(e)} + ) + + def can_activate(self, request: Request, credentials=None) -> bool: + return credentials == "valid-custom-key" +``` + +## OpenAPI Integration + +When a guard sets the `security_scheme` attribute, the generated OpenAPI schema includes the corresponding security requirement. The docs page will show: + +- 🔒 Lock icon next to protected routes +- "Authorize" button in the top right +- Input fields for tokens/credentials +- Security requirements in the route documentation + +This works with any `fastapi.security` scheme: +- `APIKeyHeader`, `APIKeyQuery`, `APIKeyCookie` +- `HTTPBasic`, `HTTPBearer`, `HTTPDigest` +- `OAuth2PasswordBearer`, `OAuth2AuthorizationCodeBearer` +- `OpenIdConnect` + +## Testing Guards + +Create mock guards for testing: + +```python +class MockAuthGuard(BaseGuard): + security_scheme = APIKeyHeader(name="X-Test-Key") + + def __init__(self, should_pass: bool = True): + self.should_pass = should_pass + + def can_activate(self, request: Request, credentials=None) -> bool: + return self.should_pass + +# In tests +@UseGuards(MockAuthGuard(should_pass=True)) # Allow access +@UseGuards(MockAuthGuard(should_pass=False)) # Deny access +``` + +## Complete Usage Examples + +### Public API with Mixed Security + +```python +@Controller("api/v1") +class APIController: + @Get("/public") + def public_endpoint(self): + return {"message": "No authentication required"} + + @Get("/protected") + @UseGuards(APIKeyGuard) + def protected_endpoint(self): + return {"message": "API key required"} + + @Get("/admin") + @UseGuards(JWTGuard, RoleBasedGuard(["admin"])) + def admin_endpoint(self): + return {"message": "JWT + admin role required"} +``` + +### Enterprise Security Setup + +```python +# Base authentication +@Controller("enterprise") +@UseGuards(OAuth2Guard, RateLimitGuard) +class EnterpriseController: + + @Get("/reports") + @UseGuards(RoleBasedGuard(["analyst", "admin"])) + def get_reports(self): + return {"reports": []} + + @Post("/admin/system") + @UseGuards(RoleBasedGuard(["admin"]), BasicAuthGuard) # Double auth + def admin_action(self): + return {"message": "System action performed"} +``` + +## Best Practices + +1. **Use Security Schemes**: Always define `security_scheme` for standard authentication methods to get OpenAPI documentation +2. **Layer Security**: Combine multiple guards for defense in depth +3. **Async for Database**: Use async guards when validating against databases +4. **Custom Errors**: Implement custom error handling for better UX +5. **Scope-Based Access**: Use OAuth2 scopes for fine-grained permissions +6. **Rate Limiting**: Combine auth guards with rate limiting guards +7. **Testing**: Create mock guards for unit testing +8. **Principle of Least Privilege**: Grant minimum required permissions + +## Guard Types Summary + +| Guard Type | Security Scheme | Use Case | OpenAPI | +|------------|----------------|----------|---------| +| API Key Header | `APIKeyHeader` | Service-to-service | ✅ | +| API Key Query | `APIKeyQuery` | Webhooks, simple APIs | ✅ | +| API Key Cookie | `APIKeyCookie` | Browser sessions | ✅ | +| Basic Auth | `HTTPBasic` | Simple username/password | ✅ | +| Bearer Token | `HTTPBearer` | JWT tokens | ✅ | +| OAuth2 Password | `OAuth2PasswordBearer` | OAuth2 flows | ✅ | +| OAuth2 Scopes | `OAuth2PasswordBearer` | Permission-based access | ✅ | +| Custom Logic | None | Rate limiting, custom rules | ❌ | +| Multi-Auth | Any | Flexible authentication | ✅ | + +PyNest guards provide a powerful, flexible, and standards-compliant way to secure your APIs while maintaining excellent developer experience and automatic documentation generation. + diff --git a/examples/guard_examples.py b/examples/guard_examples.py new file mode 100644 index 0000000..f0f8176 --- /dev/null +++ b/examples/guard_examples.py @@ -0,0 +1,558 @@ +""" +PyNest Guards Examples - Complete Security Implementation Guide + +This file demonstrates various guard implementations using different FastAPI +security schemes, fully compatible with the FastAPI security system. + +Based on FastAPI Security documentation: +https://fastapi.tiangolo.com/tutorial/security/ +""" + +from fastapi import Request, Depends, HTTPException, status +from fastapi.security import ( + APIKeyHeader, APIKeyQuery, APIKeyCookie, + HTTPBasic, HTTPBearer, HTTPDigest, + OAuth2PasswordBearer, OAuth2AuthorizationCodeBearer, + OpenIdConnect +) +from fastapi.security.http import HTTPBasicCredentials, HTTPAuthorizationCredentials +from typing import Optional +import jwt +from datetime import datetime, timedelta + +from nest.core import BaseGuard, UseGuards, Controller, Get, Post + + +# ============================================================================= +# 1. API KEY GUARDS (Header, Query, Cookie) +# ============================================================================= + +class APIKeyHeaderGuard(BaseGuard): + """API Key in HTTP Header - Most common API authentication method.""" + + security_scheme = APIKeyHeader( + name="X-API-Key", + description="API key required for authentication" + ) + + def can_activate(self, request: Request, credentials=None) -> bool: + """Validate API key from header.""" + # credentials contains the X-API-Key header value + valid_api_keys = {"admin-key-123", "user-key-456", "service-key-789"} + return credentials in valid_api_keys + + +class APIKeyQueryGuard(BaseGuard): + """API Key in URL query parameter.""" + + security_scheme = APIKeyQuery( + name="api_key", + description="API key as query parameter (?api_key=your-key)" + ) + + def can_activate(self, request: Request, credentials=None) -> bool: + """Validate API key from query parameter.""" + return credentials == "secret-query-key" + + +class APIKeyCookieGuard(BaseGuard): + """API Key in HTTP Cookie - Useful for browser-based applications.""" + + security_scheme = APIKeyCookie( + name="session_token", + description="Session token stored in cookie" + ) + + def can_activate(self, request: Request, credentials=None) -> bool: + """Validate session token from cookie.""" + # In real application, validate against session store + valid_sessions = {"sess_abc123", "sess_def456", "sess_ghi789"} + return credentials in valid_sessions + + +# ============================================================================= +# 2. HTTP AUTHENTICATION GUARDS +# ============================================================================= + +class HTTPBasicGuard(BaseGuard): + """HTTP Basic Authentication (RFC 7617).""" + + security_scheme = HTTPBasic( + description="Username and password using HTTP Basic authentication" + ) + + def can_activate(self, request: Request, credentials: HTTPBasicCredentials = None) -> bool: + """Validate username and password.""" + if not credentials: + return False + + # In production, hash passwords and use secure comparison + users = { + "admin": "admin123", + "user": "user456", + "guest": "guest789" + } + + expected_password = users.get(credentials.username) + return expected_password == credentials.password + + +class HTTPBearerGuard(BaseGuard): + """HTTP Bearer Token Authentication - Common for JWT tokens.""" + + security_scheme = HTTPBearer( + description="Bearer token (typically JWT)" + ) + + def can_activate(self, request: Request, credentials: HTTPAuthorizationCredentials = None) -> bool: + """Validate Bearer token.""" + if not credentials or credentials.scheme != "Bearer": + return False + + token = credentials.credentials + return self.validate_jwt_token(token) + + def validate_jwt_token(self, token: str) -> bool: + """Validate JWT token (simplified example).""" + try: + # In production, use proper JWT validation with secret key + if token.startswith("eyJ"): # Simple JWT format check + return len(token) > 20 + return False + except Exception: + return False + + +class CustomHTTPBearerGuard(BaseGuard): + """Advanced Bearer token validation with user extraction.""" + + security_scheme = HTTPBearer( + description="JWT Bearer token with user context" + ) + + def can_activate(self, request: Request, credentials: HTTPAuthorizationCredentials = None) -> bool: + """Validate token and attach user to request.""" + if not credentials or credentials.scheme != "Bearer": + return False + + user = self.get_current_user(credentials.credentials) + if user: + # Attach user to request for use in controllers + request.state.current_user = user + return True + return False + + def get_current_user(self, token: str) -> Optional[dict]: + """Extract user information from JWT token.""" + try: + # Simplified JWT decoding (use proper library in production) + if token == "valid-jwt-token": + return { + "id": 1, + "username": "john_doe", + "email": "john@example.com", + "roles": ["user"] + } + elif token == "admin-jwt-token": + return { + "id": 2, + "username": "admin", + "email": "admin@example.com", + "roles": ["admin", "user"] + } + return None + except Exception: + return None + + +# ============================================================================= +# 3. OAUTH2 GUARDS +# ============================================================================= + +class OAuth2PasswordBearerGuard(BaseGuard): + """OAuth2 with Password flow - Most common OAuth2 implementation.""" + + security_scheme = OAuth2PasswordBearer( + tokenUrl="auth/token", + description="OAuth2 password bearer token" + ) + + def can_activate(self, request: Request, credentials=None) -> bool: + """Validate OAuth2 token.""" + if not credentials: + return False + + # Validate token (in production, verify with auth server) + return self.validate_oauth2_token(credentials) + + def validate_oauth2_token(self, token: str) -> bool: + """Validate OAuth2 access token.""" + # In production, validate with OAuth2 server or JWT validation + valid_tokens = { + "oauth2_access_token_123", + "oauth2_access_token_456", + "oauth2_access_token_789" + } + return token in valid_tokens + + +class OAuth2ScopesGuard(BaseGuard): + """OAuth2 with scopes for fine-grained permissions.""" + + security_scheme = OAuth2PasswordBearer( + tokenUrl="auth/token", + scopes={ + "read": "Read access to resources", + "write": "Write access to resources", + "delete": "Delete access to resources", + "admin": "Full administrative access" + }, + description="OAuth2 with permission scopes" + ) + + def __init__(self, required_scopes: list = None): + self.required_scopes = required_scopes or [] + + def can_activate(self, request: Request, credentials=None) -> bool: + """Validate token and check required scopes.""" + if not credentials: + return False + + user_scopes = self.get_token_scopes(credentials) + return all(scope in user_scopes for scope in self.required_scopes) + + def get_token_scopes(self, token: str) -> list: + """Extract scopes from OAuth2 token.""" + # In production, decode JWT or query OAuth2 server + token_scopes = { + "admin_token": ["read", "write", "delete", "admin"], + "user_token": ["read", "write"], + "readonly_token": ["read"] + } + return token_scopes.get(token, []) + + +# ============================================================================= +# 4. CUSTOM AND COMPOSITE GUARDS +# ============================================================================= + +class RoleBasedGuard(BaseGuard): + """Role-based access control guard.""" + + security_scheme = HTTPBearer(description="JWT token with role information") + + def __init__(self, allowed_roles: list): + self.allowed_roles = allowed_roles + + def can_activate(self, request: Request, credentials: HTTPAuthorizationCredentials = None) -> bool: + """Check if user has required role.""" + if not credentials: + return False + + user_roles = self.get_user_roles(credentials.credentials) + return any(role in user_roles for role in self.allowed_roles) + + def get_user_roles(self, token: str) -> list: + """Extract roles from JWT token.""" + # Simplified role extraction + role_mapping = { + "admin_token": ["admin", "user"], + "user_token": ["user"], + "guest_token": ["guest"] + } + return role_mapping.get(token, []) + + +class AsyncDatabaseGuard(BaseGuard): + """Async guard that validates against database.""" + + security_scheme = APIKeyHeader(name="X-Auth-Token") + + async def can_activate(self, request: Request, credentials=None) -> bool: + """Async validation against database.""" + if not credentials: + return False + + # Simulate async database call + user = await self.get_user_by_token(credentials) + return user is not None and user.get("is_active", False) + + async def get_user_by_token(self, token: str) -> Optional[dict]: + """Simulate async database lookup.""" + # In production, use actual database query + import asyncio + await asyncio.sleep(0.1) # Simulate DB delay + + users = { + "db_token_123": {"id": 1, "username": "user1", "is_active": True}, + "db_token_456": {"id": 2, "username": "user2", "is_active": False}, + "db_token_789": {"id": 3, "username": "user3", "is_active": True} + } + return users.get(token) + + +class RateLimitGuard(BaseGuard): + """Rate limiting guard (no security scheme - internal logic only).""" + + # No security_scheme - this guard doesn't appear in OpenAPI + + def __init__(self, max_requests: int = 100, window_minutes: int = 60): + self.max_requests = max_requests + self.window_minutes = window_minutes + self.request_counts = {} + + def can_activate(self, request: Request, credentials=None) -> bool: + """Check rate limit for client IP.""" + client_ip = request.client.host + now = datetime.now() + + # Clean old entries + cutoff = now - timedelta(minutes=self.window_minutes) + self.request_counts = { + ip: times for ip, times in self.request_counts.items() + if any(t > cutoff for t in times) + } + + # Check current IP + if client_ip not in self.request_counts: + self.request_counts[client_ip] = [] + + # Filter recent requests + recent_requests = [ + t for t in self.request_counts[client_ip] + if t > cutoff + ] + + if len(recent_requests) >= self.max_requests: + return False + + # Add current request + recent_requests.append(now) + self.request_counts[client_ip] = recent_requests + return True + + +# ============================================================================= +# 5. USAGE EXAMPLES WITH CONTROLLERS +# ============================================================================= + +@Controller("public") +class PublicController: + """Public endpoints - no authentication required.""" + + @Get("/health") + def health_check(self): + return {"status": "healthy", "timestamp": datetime.now().isoformat()} + + @Get("/info") + def public_info(self): + return {"message": "This is a public endpoint"} + + +@Controller("api/v1/protected") +@UseGuards(APIKeyHeaderGuard, RateLimitGuard) +class ProtectedController: + """All endpoints require API key and are rate limited.""" + + @Get("/data") + def get_protected_data(self): + return {"data": "This is protected data", "auth": "api_key"} + + @Post("/data") + def create_data(self, data: dict): + return {"message": "Data created", "data": data} + + +@Controller("api/v1/user") +class UserController: + """Mixed authentication - different guards per endpoint.""" + + @Get("/profile") + @UseGuards(HTTPBearerGuard) + def get_profile(self, request: Request): + # Access user from request.state if attached by guard + user = getattr(request.state, 'current_user', None) + return {"profile": user or "JWT authenticated user"} + + @Post("/upload") + @UseGuards(OAuth2PasswordBearerGuard) + def upload_file(self): + return {"message": "File uploaded with OAuth2 auth"} + + @Delete("/account") + @UseGuards(HTTPBasicGuard, RoleBasedGuard(["admin"])) + def delete_account(self): + return {"message": "Account deleted - requires basic auth + admin role"} + + +@Controller("api/v1/admin") +@UseGuards(OAuth2ScopesGuard(["admin"])) +class AdminController: + """Admin-only endpoints requiring OAuth2 admin scope.""" + + @Get("/users") + def list_users(self): + return {"users": ["user1", "user2", "user3"]} + + @Post("/system/restart") + @UseGuards(HTTPBasicGuard) # Additional guard for critical operations + def restart_system(self): + return {"message": "System restart initiated"} + + +# ============================================================================= +# 6. ADVANCED GUARD COMBINATIONS +# ============================================================================= + +class MultiAuthGuard(BaseGuard): + """Guard that accepts multiple authentication methods.""" + + # Primary security scheme for OpenAPI documentation + security_scheme = HTTPBearer(description="Bearer token or API key") + + def can_activate(self, request: Request, credentials=None) -> bool: + """Try multiple authentication methods.""" + + # Method 1: Bearer token + if credentials: + if self.validate_bearer_token(credentials.credentials): + return True + + # Method 2: API key in header + api_key = request.headers.get("X-API-Key") + if api_key and self.validate_api_key(api_key): + return True + + # Method 3: Session cookie + session = request.cookies.get("session_id") + if session and self.validate_session(session): + return True + + return False + + def validate_bearer_token(self, token: str) -> bool: + return token in ["valid-jwt-1", "valid-jwt-2"] + + def validate_api_key(self, key: str) -> bool: + return key in ["api-key-1", "api-key-2"] + + def validate_session(self, session: str) -> bool: + return session in ["session-1", "session-2"] + + +@Controller("api/flexible") +class FlexibleAuthController: + """Controller accepting multiple authentication methods.""" + + @Get("/data") + @UseGuards(MultiAuthGuard) + def get_data(self): + return {"message": "Authenticated via Bearer, API key, or session"} + + +# ============================================================================= +# 7. ERROR HANDLING AND CUSTOM RESPONSES +# ============================================================================= + +class CustomErrorGuard(BaseGuard): + """Guard with custom error messages and status codes.""" + + security_scheme = APIKeyHeader(name="X-Custom-Key") + + async def __call__(self, request: Request, credentials=None): + """Override to provide custom error handling.""" + try: + result = self.can_activate(request, credentials) + if inspect.isawaitable(result): + result = await result + + if not result: + # Custom error response + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail={ + "error": "INVALID_API_KEY", + "message": "The provided API key is invalid or expired", + "code": "AUTH_001", + "timestamp": datetime.now().isoformat() + }, + headers={"WWW-Authenticate": "ApiKey"} + ) + except HTTPException: + raise + except Exception as e: + # Handle unexpected errors + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={"error": "AUTHENTICATION_ERROR", "message": str(e)} + ) + + def can_activate(self, request: Request, credentials=None) -> bool: + return credentials == "valid-custom-key" + + +# ============================================================================= +# 8. TESTING HELPERS +# ============================================================================= + +class MockAuthGuard(BaseGuard): + """Mock guard for testing purposes.""" + + security_scheme = APIKeyHeader(name="X-Test-Key") + + def __init__(self, should_pass: bool = True): + self.should_pass = should_pass + + def can_activate(self, request: Request, credentials=None) -> bool: + """Always returns the configured result for testing.""" + return self.should_pass + + +# Usage in tests: +# @UseGuards(MockAuthGuard(should_pass=True)) # Allow access +# @UseGuards(MockAuthGuard(should_pass=False)) # Deny access + + +""" +SUMMARY OF GUARD TYPES AND USAGE: + +1. **API Key Guards**: Simple token-based authentication + - Header: Most common, secure + - Query: Less secure, useful for webhooks + - Cookie: Browser-friendly + +2. **HTTP Authentication Guards**: Standard HTTP auth methods + - Basic: Username/password + - Bearer: Token-based (JWT) + - Digest: More secure than Basic + +3. **OAuth2 Guards**: Industry standard authentication + - Password Bearer: Most common OAuth2 flow + - Scopes: Fine-grained permissions + - Authorization Code: For third-party integration + +4. **Custom Guards**: Application-specific logic + - Role-based access control + - Rate limiting + - Database validation + - Multi-method authentication + +5. **Guard Combinations**: + - Multiple guards per endpoint + - Controller-level + route-level guards + - Different guards for different routes + +**OpenAPI Integration:** +- Guards with security_scheme appear in Swagger UI +- "Authorize" button for interactive testing +- Automatic documentation generation +- Client code generation includes auth + +**Best Practices:** +- Use security schemes for standard auth methods +- Combine multiple guards for layered security +- Implement proper error handling +- Use async guards for database operations +- Create mock guards for testing +- Follow principle of least privilege with scopes +""" \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index 61550a4..b12e051 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -56,6 +56,7 @@ nav: - Modules: modules.md - Controllers: controllers.md - Providers: providers.md + - Guards: guards.md - Dependency Injection: dependency_injection.md - Deployment: - Docker: docker.md diff --git a/nest/core/__init__.py b/nest/core/__init__.py index 50a4cc8..e2dce6e 100644 --- a/nest/core/__init__.py +++ b/nest/core/__init__.py @@ -11,6 +11,7 @@ Post, Put, ) +from nest.core.decorators.guards import BaseGuard, UseGuards from nest.core.pynest_application import PyNestApp from nest.core.pynest_container import PyNestContainer from nest.core.pynest_factory import PyNestFactory diff --git a/nest/core/decorators/controller.py b/nest/core/decorators/controller.py index 8073083..b8590b1 100644 --- a/nest/core/decorators/controller.py +++ b/nest/core/decorators/controller.py @@ -1,10 +1,12 @@ -from typing import Optional, Type +from typing import Optional, Type, List from fastapi.routing import APIRouter +from fastapi import Depends from nest.core.decorators.class_based_view import class_based_view as ClassBasedView from nest.core.decorators.http_method import HTTPMethod from nest.core.decorators.utils import get_instance_variables, parse_dependencies +from nest.core.decorators.guards import BaseGuard def Controller(prefix: Optional[str] = None, tag: Optional[str] = None): @@ -92,7 +94,7 @@ def add_routes(cls: Type, router: APIRouter, route_prefix: str) -> None: if callable(method_function) and hasattr(method_function, "__http_method__"): validate_method_decorator(method_function, method_name) configure_method_route(method_function, route_prefix) - add_route_to_router(router, method_function) + add_route_to_router(router, method_function, cls) def validate_method_decorator(method_function: callable, method_name: str) -> None: @@ -127,7 +129,18 @@ def configure_method_route(method_function: callable, route_prefix: str) -> None method_function.__route_path__ = method_function.__route_path__.rstrip("/") -def add_route_to_router(router: APIRouter, method_function: callable) -> None: +def _collect_guards(cls: Type, method: callable) -> List[BaseGuard]: + guards: List[BaseGuard] = [] + for guard in getattr(cls, "__guards__", []): + guards.append(guard) + for guard in getattr(method, "__guards__", []): + guards.append(guard) + return guards + + +def add_route_to_router( + router: APIRouter, method_function: callable, cls: Type +) -> None: """Add the configured route to the router.""" route_kwargs = { "path": method_function.__route_path__, @@ -139,4 +152,11 @@ def add_route_to_router(router: APIRouter, method_function: callable) -> None: if hasattr(method_function, "status_code"): route_kwargs["status_code"] = method_function.status_code + guards = _collect_guards(cls, method_function) + if guards: + dependencies = route_kwargs.get("dependencies", []) + for guard in guards: + dependencies.append(guard.as_dependency()) + route_kwargs["dependencies"] = dependencies + router.add_api_route(**route_kwargs) diff --git a/nest/core/decorators/guards.py b/nest/core/decorators/guards.py new file mode 100644 index 0000000..a40240d --- /dev/null +++ b/nest/core/decorators/guards.py @@ -0,0 +1,374 @@ +from fastapi import Request, HTTPException, status, Security, Depends +from fastapi.security.base import SecurityBase +from typing import Optional +import inspect + + +class BaseGuard: + """Base class for creating route guards in PyNest. + + Guards provide a way to implement authentication and authorization logic + that can be applied to controllers or individual routes. They are fully + compatible with FastAPI's security system and OpenAPI documentation. + + **Security Scheme Integration:** + + If ``security_scheme`` is set to an instance of ``fastapi.security.SecurityBase``, + the guard will: + - Be injected with credentials from that security scheme + - Appear in the generated OpenAPI schema with appropriate security requirements + - Show an "Authorize" button in Swagger UI + - Allow users to authenticate through the interactive documentation + + **Supported Security Schemes:** + + PyNest guards support all FastAPI security schemes: + + * **API Keys** (``fastapi.security.APIKeyHeader``, ``APIKeyQuery``, ``APIKeyCookie``) + * **HTTP Authentication** (``HTTPBasic``, ``HTTPBearer``, ``HTTPDigest``) + * **OAuth2** (``OAuth2PasswordBearer``, ``OAuth2AuthorizationCodeBearer``) + * **OpenID Connect** (``OpenIdConnect``) + + **Examples:** + + **1. Simple Guard (No Security Scheme):** + + ```python + class SimpleGuard(BaseGuard): + def can_activate(self, request: Request, credentials=None) -> bool: + # Custom logic without OpenAPI documentation + return request.headers.get("X-Custom-Header") == "allowed" + ``` + + **2. API Key Header Guard:** + + ```python + from fastapi.security import APIKeyHeader + + class APIKeyGuard(BaseGuard): + security_scheme = APIKeyHeader( + name="X-API-Key", + description="API Key for authentication" + ) + + def can_activate(self, request: Request, credentials=None) -> bool: + # credentials contains the API key value + return credentials == "your-secret-api-key" + ``` + + **3. Bearer Token Guard:** + + ```python + from fastapi.security import HTTPBearer + + class BearerTokenGuard(BaseGuard): + security_scheme = HTTPBearer(description="Bearer token") + + def can_activate(self, request: Request, credentials=None) -> bool: + # credentials is an HTTPAuthorizationCredentials object + if credentials and credentials.scheme == "Bearer": + return self.validate_jwt_token(credentials.credentials) + return False + + def validate_jwt_token(self, token: str) -> bool: + # Implement JWT validation logic + return token == "valid-jwt-token" + ``` + + **4. Basic Authentication Guard:** + + ```python + from fastapi.security import HTTPBasic + + class BasicAuthGuard(BaseGuard): + security_scheme = HTTPBasic(description="Basic HTTP authentication") + + def can_activate(self, request: Request, credentials=None) -> bool: + # credentials is an HTTPBasicCredentials object + if credentials: + return (credentials.username == "admin" and + credentials.password == "secret") + return False + ``` + + **5. OAuth2 Password Bearer Guard:** + + ```python + from fastapi.security import OAuth2PasswordBearer + + class OAuth2Guard(BaseGuard): + security_scheme = OAuth2PasswordBearer( + tokenUrl="token", + description="OAuth2 with Password and Bearer" + ) + + def can_activate(self, request: Request, credentials=None) -> bool: + # credentials contains the OAuth2 token + return self.validate_oauth2_token(credentials) + + def validate_oauth2_token(self, token: str) -> bool: + # Implement OAuth2 token validation + return token and len(token) > 20 + ``` + + **6. Multi-Scope OAuth2 Guard:** + + ```python + from fastapi.security import OAuth2PasswordBearer + + class AdminGuard(BaseGuard): + security_scheme = OAuth2PasswordBearer( + tokenUrl="token", + scopes={ + "read": "Read access", + "write": "Write access", + "admin": "Admin access" + } + ) + + def can_activate(self, request: Request, credentials=None) -> bool: + # Validate token and check for admin scope + return self.has_admin_scope(credentials) + ``` + + **7. Cookie Authentication Guard:** + + ```python + from fastapi.security import APIKeyCookie + + class CookieGuard(BaseGuard): + security_scheme = APIKeyCookie( + name="session_id", + description="Session cookie" + ) + + def can_activate(self, request: Request, credentials=None) -> bool: + # credentials contains the cookie value + return self.validate_session(credentials) + ``` + + **8. Async Guard Example:** + + ```python + class AsyncGuard(BaseGuard): + security_scheme = APIKeyHeader(name="X-Token") + + async def can_activate(self, request: Request, credentials=None) -> bool: + # Async validation (e.g., database lookup) + user = await self.get_user_by_token(credentials) + return user is not None + + async def get_user_by_token(self, token: str): + # Async database call + pass + ``` + + **Usage with Controllers:** + + ```python + @Controller("users") + @UseGuards(APIKeyGuard, AdminGuard) # Multiple guards + class UserController: + pass + + @Controller("public") + class PublicController: + + @Get("/protected") + @UseGuards(BearerTokenGuard) # Route-level guard + def protected_route(self): + return {"message": "This is protected"} + ``` + + **Error Handling:** + + Guards automatically raise ``HTTPException`` with status 403 (Forbidden) + when ``can_activate`` returns ``False``. You can customize error handling + by overriding the ``__call__`` method. + + **OpenAPI Documentation:** + + When using security schemes, guards automatically: + - Add security requirements to OpenAPI schema + - Display "Authorize" button in Swagger UI + - Show required headers/parameters in API documentation + - Enable interactive authentication testing + """ + + security_scheme: Optional[SecurityBase] = None + + def can_activate(self, request: Request, credentials=None) -> bool: + """Determine if the request should be allowed to proceed. + + **Override this method** with your custom authorization logic. + + Args: + request: The FastAPI Request object containing request information + credentials: Credentials extracted from the security scheme (if any). + Type depends on the security scheme used: + - APIKey schemes: str (the key value) + - HTTPBasic: HTTPBasicCredentials object + - HTTPBearer: HTTPAuthorizationCredentials object + - OAuth2: str (the token) + + Returns: + bool: True if the request should be allowed, False to deny with 403 + + Note: + This method can be async. If it returns an awaitable, it will be + automatically awaited by the guard system. + + Examples: + ```python + def can_activate(self, request: Request, credentials=None) -> bool: + # Simple token check + return credentials == "secret-token" + + async def can_activate(self, request: Request, credentials=None) -> bool: + # Async validation + user = await database.get_user_by_token(credentials) + return user.is_active + ``` + """ + raise NotImplementedError("Subclasses must implement can_activate method") + + async def __call__(self, request: Request, credentials=None): + """Internal method that executes the guard logic. + + This method: + 1. Calls can_activate() with request and credentials + 2. Handles both sync and async can_activate implementations + 3. Raises HTTPException(403) if access is denied + + You typically don't need to override this method unless you want + custom error handling or logging. + """ + result = self.can_activate(request, credentials) + if inspect.isawaitable(result): + result = await result + if not result: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Access denied: insufficient permissions" + ) + + @classmethod + def as_dependency(cls): + """Convert the guard class to a FastAPI dependency function. + + This method is used internally by PyNest to integrate guards with + FastAPI's dependency system. It creates the appropriate dependency + function based on whether a security scheme is configured. + + Returns: + Callable: A dependency function that FastAPI can use + + **Internal Implementation Details:** + + - If no security_scheme: Creates a simple dependency that validates the request + - If security_scheme exists: Creates a dependency with Security parameter for OpenAPI + + The returned dependency will: + - Appear in OpenAPI schema (if security_scheme is set) + - Extract credentials automatically (if security_scheme is set) + - Execute guard logic and raise 403 on failure + """ + if cls.security_scheme is None: + # No security scheme - simple request validation + async def dependency(request: Request): + guard = cls() + await guard(request) + + return Depends(dependency) + + # Security scheme configured - create function with Security parameter + # This allows FastAPI to detect the security requirement for OpenAPI + security_scheme = cls.security_scheme + + async def security_dependency( + request: Request, + credentials=Security(security_scheme) + ): + guard = cls() + await guard(request, credentials) + + return Depends(security_dependency) + + +def UseGuards(*guards): + """Decorator to apply guards to controllers or individual routes. + + Guards provide authentication and authorization for your API endpoints. + This decorator can be applied at the controller level (protecting all routes) + or at individual route methods. + + Args: + *guards: One or more guard classes (not instances) to apply + + **Usage Examples:** + + **Controller-level protection:** + ```python + @Controller("admin") + @UseGuards(AdminGuard, RateLimitGuard) + class AdminController: + # All routes in this controller are protected + pass + ``` + + **Route-level protection:** + ```python + @Controller("users") + class UserController: + + @Get("/public") + def public_endpoint(self): + return {"message": "No auth required"} + + @Get("/private") + @UseGuards(AuthGuard) + def private_endpoint(self): + return {"message": "Auth required"} + + @Post("/admin-only") + @UseGuards(AuthGuard, AdminGuard) # Multiple guards + def admin_endpoint(self): + return {"message": "Admin access required"} + ``` + + **Guard Execution Order:** + + Guards are executed in the order they are specified. If any guard fails, + subsequent guards are not executed and a 403 error is returned. + + **Combining with FastAPI Dependencies:** + + Guards work alongside FastAPI's native dependency system and can be + combined with other dependencies: + + ```python + @Get("/users/{user_id}") + @UseGuards(AuthGuard) + def get_user(user_id: int, db: Session = Depends(get_db)): + # Both guard and dependency are executed + pass + ``` + + **Security in OpenAPI:** + + When guards use security schemes, they automatically appear in: + - OpenAPI/Swagger documentation + - Interactive API testing interface + - Generated client code + + Returns: + Callable: Decorator function that applies guards to the target + """ + def decorator(obj): + # Get existing guards (if any) and append new ones + existing_guards = list(getattr(obj, "__guards__", [])) + existing_guards.extend(guards) + setattr(obj, "__guards__", existing_guards) + return obj + + return decorator diff --git a/tests/test_core/test_decorators/test_guard.py b/tests/test_core/test_decorators/test_guard.py new file mode 100644 index 0000000..ccda088 --- /dev/null +++ b/tests/test_core/test_decorators/test_guard.py @@ -0,0 +1,81 @@ +import inspect + +from fastapi import Request + +from fastapi.security import HTTPBearer + +from nest.core import Controller, Get, UseGuards, BaseGuard + + +class SimpleGuard(BaseGuard): + def __init__(self): + self.called = False + + def can_activate(self, request: Request) -> bool: + self.called = True + return True + + +class BearerGuard(BaseGuard): + security_scheme = HTTPBearer() + + def can_activate(self, request: Request, credentials) -> bool: + return True + + +class JWTGuard(BaseGuard): + security_scheme = HTTPBearer() + + def can_activate(self, request: Request, credentials=None) -> bool: + if credentials and credentials.scheme == "Bearer": + return self.validate_jwt(credentials.credentials) + return False + + +@Controller("/guard") +class GuardController: + @Get("/") + @UseGuards(SimpleGuard) + def root(self): + return {"ok": True} + + +def test_use_guards_sets_attribute(): + assert hasattr(GuardController.root, "__guards__") + assert SimpleGuard in GuardController.root.__guards__ + + +def test_guard_added_to_route_dependencies(): + router = GuardController.get_router() + route = router.routes[0] + deps = route.dependencies + assert len(deps) == 1 + assert callable(deps[0].dependency) + + +def _has_security_requirements(dependant): + """Recursively check if a dependant or its dependencies have security requirements.""" + if dependant.security_requirements: + return True + + for dep in dependant.dependencies: + if _has_security_requirements(dep): + return True + + return False + + +def test_openapi_security_requirement(): + @Controller("/bearer") + class BearerController: + @Get("/") + @UseGuards(BearerGuard) + def root(self): + return {"ok": True} + + router = BearerController.get_router() + route = router.routes[0] + + # Check if security requirements exist anywhere in the dependency tree + assert _has_security_requirements(route.dependant), \ + "Security requirements should be present in the dependency tree for OpenAPI integration"