-
Notifications
You must be signed in to change notification settings - Fork 9
Restrict Vault token exchange to specific hosts; improve auth errors; (Issue #19) #40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,6 @@ | ||
| # project-specific | ||
| tmp/ | ||
| vault-token.dat | ||
|
|
||
| # Byte-compiled / optimized / DLL files | ||
| __pycache__/ | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -1,6 +1,7 @@ | ||||||
| import json | ||||||
| import os | ||||||
| from typing import List | ||||||
| from urllib.parse import urlparse | ||||||
|
|
||||||
| import requests | ||||||
| from SPARQLWrapper import JSON, SPARQLWrapper | ||||||
|
|
@@ -12,6 +13,18 @@ | |||||
| ) | ||||||
|
|
||||||
|
|
||||||
| # Hosts that require Vault token based authentication. Central source of truth. | ||||||
| VAULT_REQUIRED_HOSTS = { | ||||||
| "data.dbpedia.io", | ||||||
| "data.dev.dbpedia.link", | ||||||
| } | ||||||
|
|
||||||
|
|
||||||
| class DownloadAuthError(Exception): | ||||||
| """Raised when an authorization problem occurs during download.""" | ||||||
|
|
||||||
|
|
||||||
|
|
||||||
| def _download_file( | ||||||
| url, | ||||||
| localDir, | ||||||
|
|
@@ -52,16 +65,9 @@ def _download_file( | |||||
| os.makedirs(dirpath, exist_ok=True) # Create the necessary directories | ||||||
| # --- 1. Get redirect URL by requesting HEAD --- | ||||||
| headers = {} | ||||||
| # --- 1a. public databus --- | ||||||
| response = requests.head(url, timeout=30) | ||||||
| # --- 1b. Databus API key required --- | ||||||
| if response.status_code == 401: | ||||||
| # print(f"API key required for {url}") | ||||||
| if not databus_key: | ||||||
| raise ValueError("Databus API key not given for protected download") | ||||||
|
|
||||||
| headers = {"X-API-KEY": databus_key} | ||||||
| response = requests.head(url, headers=headers, timeout=30) | ||||||
| # --- 1a. public databus --- | ||||||
| response = requests.head(url, timeout=30, allow_redirects=False) | ||||||
|
|
||||||
| # Check for redirect and update URL if necessary | ||||||
| if response.headers.get("Location") and response.status_code in [ | ||||||
|
|
@@ -73,6 +79,30 @@ def _download_file( | |||||
| ]: | ||||||
| url = response.headers.get("Location") | ||||||
| print("Redirects url: ", url) | ||||||
| # Re-do HEAD request on redirect URL | ||||||
| response = requests.head(url, timeout=30) | ||||||
|
|
||||||
| # Extract hostname from final URL (after redirect) to check if vault token needed. | ||||||
| # This is the actual download location that may require authentication. | ||||||
| parsed = urlparse(url) | ||||||
| host = parsed.hostname | ||||||
|
|
||||||
| # --- 1b. Handle 401 on HEAD request --- | ||||||
| if response.status_code == 401: | ||||||
| # Check if this is a vault-required host | ||||||
| if host in VAULT_REQUIRED_HOSTS: | ||||||
| # Vault-required host: need vault token | ||||||
| if not vault_token_file: | ||||||
| raise DownloadAuthError( | ||||||
| f"Vault token required for host '{host}', but no token was provided. Please use --vault-token." | ||||||
| ) | ||||||
| # Token provided; will handle in GET request below | ||||||
| else: | ||||||
| # Not a vault host; might need databus API key | ||||||
| if not databus_key: | ||||||
| raise DownloadAuthError("Databus API key not given for protected download") | ||||||
| headers = {"X-API-KEY": databus_key} | ||||||
| response = requests.head(url, headers=headers, timeout=30) | ||||||
|
|
||||||
| # --- 2. Try direct GET to redirected URL --- | ||||||
| headers["Accept-Encoding"] = ( | ||||||
|
|
@@ -81,25 +111,54 @@ def _download_file( | |||||
| response = requests.get( | ||||||
| url, headers=headers, stream=True, allow_redirects=True, timeout=30 | ||||||
| ) | ||||||
| www = response.headers.get( | ||||||
| "WWW-Authenticate", "" | ||||||
| ) # Check if authentication is required | ||||||
| www = response.headers.get("WWW-Authenticate", "") # Check if authentication is required | ||||||
|
|
||||||
| # --- 3. If redirected to authentication 401 Unauthorized, get Vault token and retry --- | ||||||
| # --- 3. Handle authentication responses --- | ||||||
| # 3a. Server requests Bearer auth. Only attempt token exchange for hosts | ||||||
| # we explicitly consider Vault-protected (VAULT_REQUIRED_HOSTS). This avoids | ||||||
| # sending tokens to unrelated hosts and makes auth behavior predictable. | ||||||
| if response.status_code == 401 and "bearer" in www.lower(): | ||||||
| print(f"Authentication required for {url}") | ||||||
| if not (vault_token_file): | ||||||
| raise ValueError("Vault token file not given for protected download") | ||||||
| # If host is not configured for Vault, do not attempt token exchange. | ||||||
| if host not in VAULT_REQUIRED_HOSTS: | ||||||
| raise DownloadAuthError( | ||||||
| "Server requests Bearer authentication but this host is not configured for Vault token exchange." | ||||||
| " Try providing a databus API key with --databus-key or contact your administrator." | ||||||
| ) | ||||||
|
|
||||||
| # Host requires Vault; ensure token file provided. | ||||||
| if not vault_token_file: | ||||||
| raise DownloadAuthError( | ||||||
| f"Vault token required for host '{host}', but no token was provided. Please use --vault-token." | ||||||
| ) | ||||||
|
|
||||||
| # --- 3a. Fetch Vault token --- | ||||||
| # TODO: cache token | ||||||
| # --- 3b. Fetch Vault token and retry --- | ||||||
| # Token exchange is potentially sensitive and should only be performed | ||||||
| # for known hosts. __get_vault_access__ handles reading the refresh | ||||||
| # token and exchanging it; errors are translated to DownloadAuthError | ||||||
| # for user-friendly CLI output. | ||||||
| vault_token = __get_vault_access__(url, vault_token_file, auth_url, client_id) | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Critical: Pass final URL to token exchange function Line 139 passes the The 🔎 Proposed fix- vault_token = __get_vault_access__(url, vault_token_file, auth_url, client_id)
+ vault_token = __get_vault_access__(response.url, vault_token_file, auth_url, client_id)📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||
| headers["Authorization"] = f"Bearer {vault_token}" | ||||||
| headers.pop("Accept-Encoding") | ||||||
| headers.pop("Accept-Encoding", None) | ||||||
|
|
||||||
| # --- 3b. Retry with token --- | ||||||
| # Retry with token | ||||||
| response = requests.get(url, headers=headers, stream=True, timeout=30) | ||||||
|
|
||||||
| # Map common auth failures to friendly messages | ||||||
| if response.status_code == 401: | ||||||
| raise DownloadAuthError("Vault token is invalid or expired. Please generate a new token.") | ||||||
| if response.status_code == 403: | ||||||
| raise DownloadAuthError("Vault token is valid but has insufficient permissions to access this file.") | ||||||
|
|
||||||
| # 3c. Generic forbidden without Bearer challenge | ||||||
| if response.status_code == 403: | ||||||
| raise DownloadAuthError("Access forbidden: your token or API key does not have permission to download this file.") | ||||||
|
|
||||||
| # 3d. Generic unauthorized without Bearer | ||||||
| if response.status_code == 401: | ||||||
| raise DownloadAuthError( | ||||||
| "Unauthorized: access denied. Check your --databus-key or --vault-token settings." | ||||||
| ) | ||||||
|
|
||||||
| try: | ||||||
| response.raise_for_status() # Raise if still failing | ||||||
| except requests.exceptions.HTTPError as e: | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,30 @@ | ||
| import sys | ||
| import types | ||
|
|
||
| # Provide a lightweight fake SPARQLWrapper module for tests when not installed. | ||
| if "SPARQLWrapper" not in sys.modules: | ||
| mod = types.ModuleType("SPARQLWrapper") | ||
| mod.JSON = None | ||
|
|
||
| class DummySPARQL: | ||
| def __init__(self, *args, **kwargs): | ||
| pass | ||
|
|
||
| def setQuery(self, q): | ||
| self._q = q | ||
|
|
||
| def setReturnFormat(self, f): | ||
| self._fmt = f | ||
|
|
||
| def setCustomHttpHeaders(self, h): | ||
| self._headers = h | ||
|
|
||
| def query(self): | ||
| class R: | ||
| def convert(self): | ||
| return {"results": {"bindings": []}} | ||
|
|
||
| return R() | ||
|
|
||
| mod.SPARQLWrapper = DummySPARQL | ||
| sys.modules["SPARQLWrapper"] = mod |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,104 @@ | ||
| from unittest.mock import Mock, patch | ||
|
|
||
| import pytest | ||
|
|
||
| import requests | ||
|
|
||
| import databusclient.api.download as dl | ||
|
|
||
| from databusclient.api.download import VAULT_REQUIRED_HOSTS, DownloadAuthError | ||
|
|
||
|
|
||
| def make_response(status=200, headers=None, content=b""): | ||
| headers = headers or {} | ||
| mock = Mock() | ||
| mock.status_code = status | ||
| mock.headers = headers | ||
| mock.content = content | ||
|
|
||
| def iter_content(chunk_size): | ||
| if content: | ||
| yield content | ||
| else: | ||
| return | ||
|
|
||
| mock.iter_content = lambda chunk: iter(iter_content(chunk)) | ||
|
|
||
| def raise_for_status(): | ||
| if mock.status_code >= 400: | ||
| raise requests.exceptions.HTTPError() | ||
|
|
||
| mock.raise_for_status = raise_for_status | ||
| return mock | ||
|
|
||
|
|
||
| def test_vault_host_no_token_raises(): | ||
| vault_host = next(iter(VAULT_REQUIRED_HOSTS)) | ||
| url = f"https://{vault_host}/some/protected/file.ttl" | ||
|
|
||
| with pytest.raises(DownloadAuthError) as exc: | ||
| dl._download_file(url, localDir='.', vault_token_file=None) | ||
|
|
||
| assert "Vault token required" in str(exc.value) | ||
|
|
||
|
|
||
| def test_non_vault_host_no_token_allows_download(monkeypatch): | ||
| url = "https://example.com/public/file.txt" | ||
|
|
||
| resp_head = make_response(status=200, headers={}) | ||
| resp_get = make_response(status=200, headers={"content-length": "0"}, content=b"") | ||
|
|
||
| with patch("requests.head", return_value=resp_head), patch( | ||
| "requests.get", return_value=resp_get | ||
| ): | ||
| # should not raise | ||
| dl._download_file(url, localDir='.', vault_token_file=None) | ||
|
|
||
|
|
||
| def test_401_after_token_exchange_reports_invalid_token(monkeypatch): | ||
| vault_host = next(iter(VAULT_REQUIRED_HOSTS)) | ||
| url = f"https://{vault_host}/protected/file.ttl" | ||
|
|
||
| # initial head and get -> 401 with Bearer | ||
| resp_head = make_response(status=200, headers={}) | ||
| resp_401 = make_response(status=401, headers={"WWW-Authenticate": "Bearer realm=\"auth\""}) | ||
|
|
||
| # after retry with token -> still 401 | ||
| resp_401_retry = make_response(status=401, headers={}) | ||
|
|
||
| # Mock requests.get side effects: first 401 (challenge), then 401 after token | ||
| get_side_effects = [resp_401, resp_401_retry] | ||
|
|
||
| # Mock token exchange responses | ||
| post_resp_1 = Mock() | ||
| post_resp_1.json.return_value = {"access_token": "ACCESS"} | ||
| post_resp_2 = Mock() | ||
| post_resp_2.json.return_value = {"access_token": "VAULT"} | ||
|
|
||
| with patch("requests.head", return_value=resp_head), patch( | ||
| "requests.get", side_effect=get_side_effects | ||
| ), patch("requests.post", side_effect=[post_resp_1, post_resp_2]): | ||
| # set REFRESH_TOKEN so __get_vault_access__ doesn't try to open a file | ||
| monkeypatch.setenv("REFRESH_TOKEN", "x" * 90) | ||
|
|
||
| with pytest.raises(DownloadAuthError) as exc: | ||
| dl._download_file(url, localDir='.', vault_token_file="/does/not/matter") | ||
|
|
||
| assert "invalid or expired" in str(exc.value) | ||
|
|
||
|
|
||
| def test_403_reports_insufficient_permissions(): | ||
| vault_host = next(iter(VAULT_REQUIRED_HOSTS)) | ||
| url = f"https://{vault_host}/protected/file.ttl" | ||
|
|
||
| resp_head = make_response(status=200, headers={}) | ||
| resp_403 = make_response(status=403, headers={}) | ||
|
|
||
| with patch("requests.head", return_value=resp_head), patch( | ||
| "requests.get", return_value=resp_403 | ||
| ): | ||
| # provide a token path so early check does not block | ||
| with pytest.raises(DownloadAuthError) as exc: | ||
| dl._download_file(url, localDir='.', vault_token_file="/some/token/file") | ||
|
|
||
| assert "permission" in str(exc.value) or "forbidden" in str(exc.value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: Extract hostname from final URL after all redirects
Line 87 parses the
urlvariable, which contains only the first redirect location (from line 80). However, the HEAD request on line 83 uses the defaultallow_redirects=True, meaning it may follow multiple redirect hops. Theresponse.urlattribute contains the final URL after all redirects, which is the actual download location that should be checked for authentication requirements.Using the intermediate redirect URL instead of the final URL will cause authentication checks (lines 93, 122) to evaluate against the wrong host, potentially failing authentication for vault-required hosts or incorrectly attempting vault auth for non-vault hosts.
🔎 Proposed fix
Based on past review comments indicating host should be extracted from the redirect URL, not the original URL.
📝 Committable suggestion
🤖 Prompt for AI Agents