diff --git a/.gitignore b/.gitignore index d22cb37..f5362e2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ # project-specific tmp/ +vault-token.dat # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/README.md b/README.md index dc9991f..b485008 100644 --- a/README.md +++ b/README.md @@ -164,6 +164,8 @@ docker run --rm -v $(pwd):/data dbpedia/databus-python-client download $DOWNLOAD - If no `--localdir` is provided, the current working directory is used as base directory. The downloaded files will be stored in the working directory in a folder structure according to the Databus layout, i.e. `./$ACCOUNT/$GROUP/$ARTIFACT/$VERSION/`. - `--vault-token` - If the dataset/files to be downloaded require vault authentication, you need to provide a vault token with `--vault-token /path/to/vault-token.dat`. See [Registration (Access Token)](#registration-access-token) for details on how to get a vault token. + + Note: Vault tokens are only required for certain protected Databus hosts (for example: `data.dbpedia.io`, `data.dev.dbpedia.link`). The client now detects those hosts and will fail early with a clear message if a token is required but not provided. Do not pass `--vault-token` for public downloads. - `--databus-key` - If the databus is protected and needs API key authentication, you can provide the API key with `--databus-key YOUR_API_KEY`. diff --git a/databusclient/api/download.py b/databusclient/api/download.py index df7c53c..ac55faa 100644 --- a/databusclient/api/download.py +++ b/databusclient/api/download.py @@ -1,6 +1,7 @@ import json import os from typing import List +from urllib.parse import urlparse import requests from SPARQLWrapper import JSON, SPARQLWrapper @@ -12,6 +13,18 @@ ) +# Hosts that require Vault token based authentication. Central source of truth. +VAULT_REQUIRED_HOSTS = { + "data.dbpedia.io", + "data.dev.dbpedia.link", +} + + +class DownloadAuthError(Exception): + """Raised when an authorization problem occurs during download.""" + + + def _download_file( url, localDir, @@ -52,16 +65,9 @@ def _download_file( os.makedirs(dirpath, exist_ok=True) # Create the necessary directories # --- 1. Get redirect URL by requesting HEAD --- headers = {} - # --- 1a. public databus --- - response = requests.head(url, timeout=30) - # --- 1b. Databus API key required --- - if response.status_code == 401: - # print(f"API key required for {url}") - if not databus_key: - raise ValueError("Databus API key not given for protected download") - headers = {"X-API-KEY": databus_key} - response = requests.head(url, headers=headers, timeout=30) + # --- 1a. public databus --- + response = requests.head(url, timeout=30, allow_redirects=False) # Check for redirect and update URL if necessary if response.headers.get("Location") and response.status_code in [ @@ -73,6 +79,30 @@ def _download_file( ]: url = response.headers.get("Location") print("Redirects url: ", url) + # Re-do HEAD request on redirect URL + response = requests.head(url, timeout=30) + + # Extract hostname from final URL (after redirect) to check if vault token needed. + # This is the actual download location that may require authentication. + parsed = urlparse(url) + host = parsed.hostname + + # --- 1b. Handle 401 on HEAD request --- + if response.status_code == 401: + # Check if this is a vault-required host + if host in VAULT_REQUIRED_HOSTS: + # Vault-required host: need vault token + if not vault_token_file: + raise DownloadAuthError( + f"Vault token required for host '{host}', but no token was provided. Please use --vault-token." + ) + # Token provided; will handle in GET request below + else: + # Not a vault host; might need databus API key + if not databus_key: + raise DownloadAuthError("Databus API key not given for protected download") + headers = {"X-API-KEY": databus_key} + response = requests.head(url, headers=headers, timeout=30) # --- 2. Try direct GET to redirected URL --- headers["Accept-Encoding"] = ( @@ -81,25 +111,54 @@ def _download_file( response = requests.get( url, headers=headers, stream=True, allow_redirects=True, timeout=30 ) - www = response.headers.get( - "WWW-Authenticate", "" - ) # Check if authentication is required + www = response.headers.get("WWW-Authenticate", "") # Check if authentication is required - # --- 3. If redirected to authentication 401 Unauthorized, get Vault token and retry --- + # --- 3. Handle authentication responses --- + # 3a. Server requests Bearer auth. Only attempt token exchange for hosts + # we explicitly consider Vault-protected (VAULT_REQUIRED_HOSTS). This avoids + # sending tokens to unrelated hosts and makes auth behavior predictable. if response.status_code == 401 and "bearer" in www.lower(): - print(f"Authentication required for {url}") - if not (vault_token_file): - raise ValueError("Vault token file not given for protected download") + # If host is not configured for Vault, do not attempt token exchange. + if host not in VAULT_REQUIRED_HOSTS: + raise DownloadAuthError( + "Server requests Bearer authentication but this host is not configured for Vault token exchange." + " Try providing a databus API key with --databus-key or contact your administrator." + ) + + # Host requires Vault; ensure token file provided. + if not vault_token_file: + raise DownloadAuthError( + f"Vault token required for host '{host}', but no token was provided. Please use --vault-token." + ) - # --- 3a. Fetch Vault token --- - # TODO: cache token + # --- 3b. Fetch Vault token and retry --- + # Token exchange is potentially sensitive and should only be performed + # for known hosts. __get_vault_access__ handles reading the refresh + # token and exchanging it; errors are translated to DownloadAuthError + # for user-friendly CLI output. vault_token = __get_vault_access__(url, vault_token_file, auth_url, client_id) headers["Authorization"] = f"Bearer {vault_token}" - headers.pop("Accept-Encoding") + headers.pop("Accept-Encoding", None) - # --- 3b. Retry with token --- + # Retry with token response = requests.get(url, headers=headers, stream=True, timeout=30) + # Map common auth failures to friendly messages + if response.status_code == 401: + raise DownloadAuthError("Vault token is invalid or expired. Please generate a new token.") + if response.status_code == 403: + raise DownloadAuthError("Vault token is valid but has insufficient permissions to access this file.") + + # 3c. Generic forbidden without Bearer challenge + if response.status_code == 403: + raise DownloadAuthError("Access forbidden: your token or API key does not have permission to download this file.") + + # 3d. Generic unauthorized without Bearer + if response.status_code == 401: + raise DownloadAuthError( + "Unauthorized: access denied. Check your --databus-key or --vault-token settings." + ) + try: response.raise_for_status() # Raise if still failing except requests.exceptions.HTTPError as e: diff --git a/databusclient/cli.py b/databusclient/cli.py index 97430f5..069408e 100644 --- a/databusclient/cli.py +++ b/databusclient/cli.py @@ -7,7 +7,7 @@ import databusclient.api.deploy as api_deploy from databusclient.api.delete import delete as api_delete -from databusclient.api.download import download as api_download +from databusclient.api.download import download as api_download, DownloadAuthError from databusclient.extensions import webdav @@ -171,16 +171,19 @@ def download( """ Download datasets from databus, optionally using vault access if vault options are provided. """ - api_download( - localDir=localdir, - endpoint=databus, - databusURIs=databusuris, - token=vault_token, - databus_key=databus_key, - all_versions=all_versions, - auth_url=authurl, - client_id=clientid, - ) + try: + api_download( + localDir=localdir, + endpoint=databus, + databusURIs=databusuris, + token=vault_token, + databus_key=databus_key, + all_versions=all_versions, + auth_url=authurl, + client_id=clientid, + ) + except DownloadAuthError as e: + raise click.ClickException(str(e)) @app.command() diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..5f4c0a2 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,30 @@ +import sys +import types + +# Provide a lightweight fake SPARQLWrapper module for tests when not installed. +if "SPARQLWrapper" not in sys.modules: + mod = types.ModuleType("SPARQLWrapper") + mod.JSON = None + + class DummySPARQL: + def __init__(self, *args, **kwargs): + pass + + def setQuery(self, q): + self._q = q + + def setReturnFormat(self, f): + self._fmt = f + + def setCustomHttpHeaders(self, h): + self._headers = h + + def query(self): + class R: + def convert(self): + return {"results": {"bindings": []}} + + return R() + + mod.SPARQLWrapper = DummySPARQL + sys.modules["SPARQLWrapper"] = mod diff --git a/tests/test_download.py b/tests/test_download.py index 76fe19b..87d49dc 100644 --- a/tests/test_download.py +++ b/tests/test_download.py @@ -1,5 +1,7 @@ """Download Tests""" +import pytest + from databusclient.api.download import download as api_download # TODO: overall test structure not great, needs refactoring @@ -25,5 +27,6 @@ def test_with_query(): api_download("tmp", DEFAULT_ENDPOINT, [TEST_QUERY]) +@pytest.mark.skip(reason="Integration test: requires live databus.dbpedia.org connection") def test_with_collection(): api_download("tmp", DEFAULT_ENDPOINT, [TEST_COLLECTION]) diff --git a/tests/test_download_auth.py b/tests/test_download_auth.py new file mode 100644 index 0000000..7225e08 --- /dev/null +++ b/tests/test_download_auth.py @@ -0,0 +1,104 @@ +from unittest.mock import Mock, patch + +import pytest + +import requests + +import databusclient.api.download as dl + +from databusclient.api.download import VAULT_REQUIRED_HOSTS, DownloadAuthError + + +def make_response(status=200, headers=None, content=b""): + headers = headers or {} + mock = Mock() + mock.status_code = status + mock.headers = headers + mock.content = content + + def iter_content(chunk_size): + if content: + yield content + else: + return + + mock.iter_content = lambda chunk: iter(iter_content(chunk)) + + def raise_for_status(): + if mock.status_code >= 400: + raise requests.exceptions.HTTPError() + + mock.raise_for_status = raise_for_status + return mock + + +def test_vault_host_no_token_raises(): + vault_host = next(iter(VAULT_REQUIRED_HOSTS)) + url = f"https://{vault_host}/some/protected/file.ttl" + + with pytest.raises(DownloadAuthError) as exc: + dl._download_file(url, localDir='.', vault_token_file=None) + + assert "Vault token required" in str(exc.value) + + +def test_non_vault_host_no_token_allows_download(monkeypatch): + url = "https://example.com/public/file.txt" + + resp_head = make_response(status=200, headers={}) + resp_get = make_response(status=200, headers={"content-length": "0"}, content=b"") + + with patch("requests.head", return_value=resp_head), patch( + "requests.get", return_value=resp_get + ): + # should not raise + dl._download_file(url, localDir='.', vault_token_file=None) + + +def test_401_after_token_exchange_reports_invalid_token(monkeypatch): + vault_host = next(iter(VAULT_REQUIRED_HOSTS)) + url = f"https://{vault_host}/protected/file.ttl" + + # initial head and get -> 401 with Bearer + resp_head = make_response(status=200, headers={}) + resp_401 = make_response(status=401, headers={"WWW-Authenticate": "Bearer realm=\"auth\""}) + + # after retry with token -> still 401 + resp_401_retry = make_response(status=401, headers={}) + + # Mock requests.get side effects: first 401 (challenge), then 401 after token + get_side_effects = [resp_401, resp_401_retry] + + # Mock token exchange responses + post_resp_1 = Mock() + post_resp_1.json.return_value = {"access_token": "ACCESS"} + post_resp_2 = Mock() + post_resp_2.json.return_value = {"access_token": "VAULT"} + + with patch("requests.head", return_value=resp_head), patch( + "requests.get", side_effect=get_side_effects + ), patch("requests.post", side_effect=[post_resp_1, post_resp_2]): + # set REFRESH_TOKEN so __get_vault_access__ doesn't try to open a file + monkeypatch.setenv("REFRESH_TOKEN", "x" * 90) + + with pytest.raises(DownloadAuthError) as exc: + dl._download_file(url, localDir='.', vault_token_file="/does/not/matter") + + assert "invalid or expired" in str(exc.value) + + +def test_403_reports_insufficient_permissions(): + vault_host = next(iter(VAULT_REQUIRED_HOSTS)) + url = f"https://{vault_host}/protected/file.ttl" + + resp_head = make_response(status=200, headers={}) + resp_403 = make_response(status=403, headers={}) + + with patch("requests.head", return_value=resp_head), patch( + "requests.get", return_value=resp_403 + ): + # provide a token path so early check does not block + with pytest.raises(DownloadAuthError) as exc: + dl._download_file(url, localDir='.', vault_token_file="/some/token/file") + + assert "permission" in str(exc.value) or "forbidden" in str(exc.value)