diff --git a/python_anvil/api.py b/python_anvil/api.py index 54ef847f..0e69c1f6 100644 --- a/python_anvil/api.py +++ b/python_anvil/api.py @@ -15,7 +15,7 @@ ForgeSubmitPayload, GeneratePDFPayload, ) -from .api_resources.requests import PlainRequest, RestRequest +from .api_resources.requests import FullyQualifiedRequest, PlainRequest, RestRequest from .http import GQLClient, HTTPClient @@ -117,6 +117,10 @@ def request_rest(self, options: Optional[dict] = None): api = RestRequest(self.client, options=options) return api + def request_fully_qualified(self, options: Optional[dict] = None): + api = FullyQualifiedRequest(self.client, options=options) + return api + def fill_pdf( self, template_id: str, payload: Union[dict, AnyStr, FillPDFPayload], **kwargs ): diff --git a/python_anvil/api_resources/requests.py b/python_anvil/api_resources/requests.py index 296f980d..395eaa4a 100644 --- a/python_anvil/api_resources/requests.py +++ b/python_anvil/api_resources/requests.py @@ -1,5 +1,6 @@ from typing import Any, Dict +from python_anvil.constants import VALID_HOSTS from python_anvil.http import HTTPClient @@ -161,3 +162,22 @@ class PlainRequest(BaseAnvilHttpRequest): def get_url(self): return f"{self.API_HOST}/{self.API_BASE}" + + +class FullyQualifiedRequest(BaseAnvilHttpRequest): + """A request class that validates URLs point to Anvil domains.""" + + def get_url(self): + return "" # Not used since we expect full URLs + + def _validate_url(self, url): + if not any(url.startswith(host) for host in VALID_HOSTS): + raise ValueError(f"URL must start with one of: {', '.join(VALID_HOSTS)}") + + def get(self, url, params=None, **kwargs): + self._validate_url(url) + return super().get(url, params, **kwargs) + + def post(self, url, data=None, **kwargs): + self._validate_url(url) + return super().post(url, data, **kwargs) diff --git a/python_anvil/constants.py b/python_anvil/constants.py index 981eb74a..92e68d32 100644 --- a/python_anvil/constants.py +++ b/python_anvil/constants.py @@ -1,7 +1,14 @@ """Basic constants used in the library.""" GRAPHQL_ENDPOINT: str = "https://graphql.useanvil.com" -REST_ENDPOINT = "https://app.useanvil.com/api/v1/" +REST_ENDPOINT = "https://app.useanvil.com/api/v1" +ANVIL_HOST = "https://app.useanvil.com" + +VALID_HOSTS = [ + ANVIL_HOST, + REST_ENDPOINT, + GRAPHQL_ENDPOINT, +] RETRIES_LIMIT = 5 REQUESTS_LIMIT = { diff --git a/python_anvil/tests/test_api.py b/python_anvil/tests/test_api.py index 7d1cdadc..cdbe51bd 100644 --- a/python_anvil/tests/test_api.py +++ b/python_anvil/tests/test_api.py @@ -9,6 +9,7 @@ CreateEtchPacketPayload, ForgeSubmitPayload, ) +from python_anvil.constants import VALID_HOSTS from ..api_resources.payload import FillPDFPayload from . import payloads @@ -396,3 +397,61 @@ def test_minimum_valid_data_submission(m_request_post, anvil): anvil.forge_submit(payload=payload) assert m_request_post.call_count == 1 assert _expected_data in m_request_post.call_args + + def describe_rest_request_absolute_url_behavior(): + @pytest.mark.parametrize( + "url, should_raise", + [ + ("some/relative/path", True), + ("https://external.example.com/full/path/file.pdf", True), + *[(host + "/some-endpoint", False) for host in VALID_HOSTS], + ], + ) + @mock.patch("python_anvil.api_resources.requests.AnvilRequest._request") + def test_get_behavior(mock_request, anvil, url, should_raise): + mock_request.return_value = (b"fake_content", 200, {}) + rest_client = anvil.request_fully_qualified() + + if should_raise: + with pytest.raises( + ValueError, + match="URL must start with one of: https://app.useanvil.com", + ): + rest_client.get(url) + else: + rest_client.get(url) + mock_request.assert_called_once_with( + "GET", + url, + params=None, + retry=True, + ) + + @pytest.mark.parametrize( + "url, should_raise", + [ + ("some/relative/path", True), + ("https://external.example.com/full/path/file.pdf", True), + *[(host + "/some-endpoint", False) for host in VALID_HOSTS], + ], + ) + @mock.patch("python_anvil.api_resources.requests.AnvilRequest._request") + def test_post_behavior(mock_request, anvil, url, should_raise): + mock_request.return_value = (b"fake_content", 200, {}) + rest_client = anvil.request_fully_qualified() + + if should_raise: + with pytest.raises( + ValueError, + match="URL must start with one of: https://app.useanvil.com", + ): + rest_client.post(url, data={}) + else: + rest_client.post(url, data={}) + mock_request.assert_called_once_with( + "POST", + url, + json={}, + retry=True, + params=None, + )