diff --git a/README.md b/README.md index f6b31be..53662a7 100644 --- a/README.md +++ b/README.md @@ -7,8 +7,8 @@ EPR Python is a python client for the Event Provenance Registry server. ## Development ```bash -python3 -m venv ~/.virtualenvs/epr-python -source ~/.virtualenvs/epr-python/bin/activate +python3 -m venv .venv/epr-python +source .venv/epr-python/bin/activate git clone git@github.com:xbcsmith/epr-python.git cd epr-python diff --git a/docs/README.md b/docs/README.md index 090deb8..b03833d 100644 --- a/docs/README.md +++ b/docs/README.md @@ -2,7 +2,9 @@ ## Overview -EPR Python is a python client for the Event Provenance Registry server. +EPR Python is an async python client for the Event Provenance Registry server. The client uses `httpx` for HTTP requests and requires Python 3.8+ with async/await support. + +**Note:** All client operations are async and must be used with `await`. The client should be used as an async context manager to ensure proper resource cleanup. ## Model Example @@ -45,173 +47,190 @@ print(f"Fingerprint: {erf_obj.compute_fingerprint()}") ## Client Create Example -A Client create example using the models in the `epr` package. +An async Client create example using the models in the `epr` package. ```python +import asyncio import time from epr.client import Client from epr.models import Event, EventReceiver, EventReceiverGroup -url = "http://localhost:8042" -headers = {} -client = Client(url, headers=headers) - -# Create an event receiver - -event_receiver_foo = EventReceiver() -event_receiver_foo.name = "foo-receiver-2" -event_receiver_foo.type = "dev.events.foo" -event_receiver_foo.version = "1.0.0" -event_receiver_foo.description = "The Event Receiver for the Foo of Brixton" -event_receiver_foo.schema = "{}" - -event_receiver_foo_res = client.create_event_receiver(params=event_receiver_foo) -event_receiver_foo_id = event_receiver_foo_res["data"]["create_event_receiver"] - - -event_receiver_bar = EventReceiver() -event_receiver_bar.name = "bar-receiver-1" -event_receiver_bar.type = "dev.events.bar" -event_receiver_bar.version = "1.0.0" -event_receiver_bar.description = "The Event Receiver for the Bar of Brixton" -event_receiver_bar.schema = "{}" - -event_receiver_bar_res = client.create_event_receiver(params=event_receiver_bar) -event_receiver_bar_id = event_receiver_bar_res["data"]["create_event_receiver"] - -# Create an event receiver group - -event_receiver_group_foo = EventReceiverGroup() -event_receiver_group_foo.name = "foo-bar-receiver-group-1" -event_receiver_group_foo.type = "dev.events.foo.bar.complete" -event_receiver_group_foo.version = "1.0.0" -event_receiver_group_foo.description = "The Event Receiver Group for the Foo and Bar of Brixton" -event_receiver_group_foo.event_receiver_ids = [event_receiver_foo_id, event_receiver_bar_id] -event_receiver_group_foo.enabled = True - -event_receiver_group_foo_res = client.create_event_receiver_group(params=event_receiver_group_foo) -event_receiver_group_foo_id = event_receiver_group_foo_res["data"]["create_event_receiver_group"] - -# Create an event - -event_foo = Event() -event_foo.name = "foo" -event_foo.version = "1.0.0" -event_foo.release = str(time.time()) -event_foo.platform_id = "x86_64-gnu-linux-9" -event_foo.package = "rpm" -event_foo.description = "The Foo of Brixton" -event_foo.payload = '{"name": "foo"}' -event_foo.success = True -event_foo.event_receiver_id = event_receiver_foo_id - -event_foo_res = client.create_event(params=event_foo) -event_foo_id = event_foo_res["data"]["create_event"] - -event_bar = Event() -event_bar.name = "bar" -event_bar.version = "1.0.0" -event_bar.release = str(time.time()) -event_bar.platform_id = "x86_64-gnu-linux-9" -event_bar.package = "rpm" -event_bar.description = "The Bar of Brixton" -event_bar.payload = '{"name": "bar"}' -event_bar.success = True -event_bar.event_receiver_id = event_receiver_bar_id - -event_bar_res = client.create_event(params=event_bar) -event_bar_id = event_bar_res["data"]["create_event"] - -results = { - "events": [event_foo_id, event_bar_id], - "event_receivers": [event_receiver_foo_id, event_receiver_bar_id], - "event_receiver_groups": [event_receiver_group_foo_id], -} - -print(f"{results}") +async def main(): + url = "http://localhost:8042" + headers = {} + + async with Client(url, headers=headers) as client: + + # Create an event receiver + + event_receiver_foo = EventReceiver() + event_receiver_foo.name = "foo-receiver-2" + event_receiver_foo.type = "dev.events.foo" + event_receiver_foo.version = "1.0.0" + event_receiver_foo.description = "The Event Receiver for the Foo of Brixton" + event_receiver_foo.schema = "{}" + + event_receiver_foo_res = await client.create_event_receiver(params=event_receiver_foo) + event_receiver_foo_id = event_receiver_foo_res["data"]["create_event_receiver"] + + + event_receiver_bar = EventReceiver() + event_receiver_bar.name = "bar-receiver-1" + event_receiver_bar.type = "dev.events.bar" + event_receiver_bar.version = "1.0.0" + event_receiver_bar.description = "The Event Receiver for the Bar of Brixton" + event_receiver_bar.schema = "{}" + + event_receiver_bar_res = await client.create_event_receiver(params=event_receiver_bar) + event_receiver_bar_id = event_receiver_bar_res["data"]["create_event_receiver"] + + # Create an event receiver group + + event_receiver_group_foo = EventReceiverGroup() + event_receiver_group_foo.name = "foo-bar-receiver-group-1" + event_receiver_group_foo.type = "dev.events.foo.bar.complete" + event_receiver_group_foo.version = "1.0.0" + event_receiver_group_foo.description = "The Event Receiver Group for the Foo and Bar of Brixton" + event_receiver_group_foo.event_receiver_ids = [event_receiver_foo_id, event_receiver_bar_id] + event_receiver_group_foo.enabled = True + + event_receiver_group_foo_res = await client.create_event_receiver_group(params=event_receiver_group_foo) + event_receiver_group_foo_id = event_receiver_group_foo_res["data"]["create_event_receiver_group"] + + # Create an event + + event_foo = Event() + event_foo.name = "foo" + event_foo.version = "1.0.0" + event_foo.release = str(time.time()) + event_foo.platform_id = "x86_64-gnu-linux-9" + event_foo.package = "rpm" + event_foo.description = "The Foo of Brixton" + event_foo.payload = '{"name": "foo"}' + event_foo.success = True + event_foo.event_receiver_id = event_receiver_foo_id + + event_foo_res = await client.create_event(params=event_foo) + event_foo_id = event_foo_res["data"]["create_event"] + + event_bar = Event() + event_bar.name = "bar" + event_bar.version = "1.0.0" + event_bar.release = str(time.time()) + event_bar.platform_id = "x86_64-gnu-linux-9" + event_bar.package = "rpm" + event_bar.description = "The Bar of Brixton" + event_bar.payload = '{"name": "bar"}' + event_bar.success = True + event_bar.event_receiver_id = event_receiver_bar_id + + event_bar_res = await client.create_event(params=event_bar) + event_bar_id = event_bar_res["data"]["create_event"] + + results = { + "events": [event_foo_id, event_bar_id], + "event_receivers": [event_receiver_foo_id, event_receiver_bar_id], + "event_receiver_groups": [event_receiver_group_foo_id], + } + + print(f"{results}") + return results + +# Run the async function +if __name__ == "__main__": + asyncio.run(main()) ``` -### TL;DR Client Example +### TL;DR Async Client Example Create an event receiver and event receiver group. Then send events. ```python +import asyncio +import time from epr.client import Client from epr.models import Event, EventReceiver, EventReceiverGroup -url = "http://localhost:8042" -headers = {} -client = Client(url, headers=headers) - -erf = dict( - name="foo-receiver-3", - type="dev.events.foo", - version="1.0.0", - description="The Event Receiver for the Foo of Brixton", - schema="{}", -) - -erf_res = client.create_event_receiver(params=erf) -erf_id = erf_res["data"]["create_event_receiver"] - - -erb = dict( - name="bar-receiver-4", - type="dev.events.bar", - version="1.0.0", - description="The Event Receiver for the Bar of Brixton", - schema="{}", -) - -erb_res = client.create_event_receiver(params=erb) -erb_id = erb_res["data"]["create_event_receiver"] - -erg = dict( - name="foo-bar-receiver-group-2", - type="dev.events.foo.bar.complete", - version="1.0.0", - description="The Event Receiver Group for the Foo and Bar of Brixton", - enabled=True, - event_receiver_ids=[erf_id, erb_id], -) - -erg_res = client.create_event_receiver_group(params=erg) -erg_id = erg_res["data"]["create_event_receiver_group"] - -ef = dict( - name="foo", - version="1.0.0", - release=str(time.time()), - platform_id="x86_64-gnu-linux-9", - package="rpm", - description="The Foo of Brixton", - payload="{}", - success=True, - event_receiver_id=erf_id, -) - -ef_res = client.create_event(params=ef) -ef_id = ef_res["data"]["create_event"] - -eb = dict( - name="bar", - version="1.0.0", - release=str(time.time()), - platform_id="x86_64-gnu-linux-9", - package="rpm", - description="The Bar of Brixton", - payload="{}", - success=True, - event_receiver_id=erb_id, -) - -eb_res = client.create_event(params=eb) -eb_id = eb_res["data"]["create_event"] - -results = {"events": [ef_id, eb_id], "event_receivers": [erf_id, erb_id], "event_receiver_groups": [erg_id]} - -print(f"{results}") +async def main(): + url = "http://localhost:8042" + headers = {} + + async with Client(url, headers=headers) as client: + + erf = dict( + name="foo-receiver-3", + type="dev.events.foo", + version="1.0.0", + description="The Event Receiver for the Foo of Brixton", + schema="{}", + ) + + erf_res = await client.create_event_receiver(params=erf) + erf_id = erf_res["data"]["create_event_receiver"] + + + erb = dict( + name="bar-receiver-4", + type="dev.events.bar", + version="1.0.0", + description="The Event Receiver for the Bar of Brixton", + schema="{}", + ) + + erb_res = await client.create_event_receiver(params=erb) + erb_id = erb_res["data"]["create_event_receiver"] + + erg = dict( + name="foo-bar-receiver-group-2", + type="dev.events.foo.bar.complete", + version="1.0.0", + description="The Event Receiver Group for the Foo and Bar of Brixton", + enabled=True, + event_receiver_ids=[erf_id, erb_id], + ) + + erg_res = await client.create_event_receiver_group(params=erg) + erg_id = erg_res["data"]["create_event_receiver_group"] + + ef = dict( + name="foo", + version="1.0.0", + release=str(time.time()), + platform_id="x86_64-gnu-linux-9", + package="rpm", + description="The Foo of Brixton", + payload="{}", + success=True, + event_receiver_id=erf_id, + ) + + ef_res = await client.create_event(params=ef) + ef_id = ef_res["data"]["create_event"] + + eb = dict( + name="bar", + version="1.0.0", + release=str(time.time()), + platform_id="x86_64-gnu-linux-9", + package="rpm", + description="The Bar of Brixton", + payload="{}", + success=True, + event_receiver_id=erb_id, + ) + + eb_res = await client.create_event(params=eb) + eb_id = eb_res["data"]["create_event"] + + results = {"events": [ef_id, eb_id], "event_receivers": [erf_id, erb_id], "event_receiver_groups": [erg_id]} + + print(f"{results}") + return results + +# Run the async function +if __name__ == "__main__": + asyncio.run(main()) ``` ## Client Search Example @@ -221,74 +240,136 @@ version, and release. We will use the release values from the Events we created above. We will also set the fields we want to return. ```python +import asyncio from epr.client import Client from epr.models import Event, EventReceiver, EventReceiverGroup -url = "http://localhost:8042" -headers = {} -client = Client(url, headers=headers) - -e_fields = [ - "id", - "name", - "version", - "release", - "platform_id", - "package", - "description", - "success", - "event_receiver_id", -] - -efs = Event() -efs.name = "foo" -efs.version = "1.0.0" -efs.release = ef["release"] - -event_results = client.search_events(params=efs.as_dict_query(), fields=e_fields) -print(f"{event_results}") - -ebs = Event() -ebs.name = "bar" -ebs.version = "1.0.0" -ebs.release = eb["release"] - -event_results = client.search_events(params=ebs.as_dict_query(), fields=e_fields) -print(f"{event_results}") +async def search_example(): + url = "http://localhost:8042" + headers = {} + + async with Client(url, headers=headers) as client: + + e_fields = [ + "id", + "name", + "version", + "release", + "platform_id", + "package", + "description", + "success", + "event_receiver_id", + ] + + efs = Event() + efs.name = "foo" + efs.version = "1.0.0" + efs.release = "1234567890" # Use your actual release value + + event_results = await client.search_events(params=efs.as_dict_query(), fields=e_fields) + print(f"{event_results}") + + ebs = Event() + ebs.name = "bar" + ebs.version = "1.0.0" + ebs.release = "1234567891" # Use your actual release value + + event_results = await client.search_events(params=ebs.as_dict_query(), fields=e_fields) + print(f"{event_results}") + +# Run the async function +if __name__ == "__main__": + asyncio.run(search_example()) ``` We will now search for event receivers using the name, version, and type. ```python -er_fields = ["id", "name", "type", "version", "description", "schema", "fingerprint", "created_at"] +import asyncio +from epr.client import Client +from epr.models import EventReceiver + +async def search_event_receivers(): + url = "http://localhost:8042" + headers = {} + + async with Client(url, headers=headers) as client: + er_fields = ["id", "name", "type", "version", "description", "schema", "fingerprint", "created_at"] -erfs = EventReceiver() -erfs.name = "foo-receiver-3" -erfs.version = "1.0.0" -erfs.type = "dev.events.foo" + erfs = EventReceiver() + erfs.name = "foo-receiver-3" + erfs.version = "1.0.0" + erfs.type = "dev.events.foo" -event_receiver_results = client.search_event_receivers(params=erfs.as_dict_query(), fields=er_fields) -print(f"{event_receiver_results}") + event_receiver_results = await client.search_event_receivers(params=erfs.as_dict_query(), fields=er_fields) + print(f"{event_receiver_results}") -erbs = EventReceiver() -erbs.name = "bar-receiver-4" -erbs.version = "1.0.0" -erbs.type = "dev.events.bar" + erbs = EventReceiver() + erbs.name = "bar-receiver-4" + erbs.version = "1.0.0" + erbs.type = "dev.events.bar" -event_receiver_results = client.search_event_receivers(params=erbs.as_dict_query(), fields=er_fields) -print(f"{event_receiver_results}") + event_receiver_results = await client.search_event_receivers(params=erbs.as_dict_query(), fields=er_fields) + print(f"{event_receiver_results}") + +# Run the async function +if __name__ == "__main__": + asyncio.run(search_event_receivers()) ``` Last we will search for event receiver groups using the name, version, and type. ```python -erg_fields = ["id", "name", "type", "version", "description", "enabled", "created_at", "event_receiver_ids"] -ergs = EventReceiverGroup() -ergs.name = "foo-bar-receiver-group-2" -ergs.version = "1.0.0" -ergs.type = "dev.events.foo.bar.complete" - -event_receiver_group_results = client.search_event_receiver_groups(params=ergs.as_dict_query(), fields=erg_fields) -print(f"{event_receiver_group_results}") +import asyncio +from epr.client import Client +from epr.models import EventReceiverGroup + +async def search_event_receiver_groups(): + url = "http://localhost:8042" + headers = {} + + async with Client(url, headers=headers) as client: + erg_fields = ["id", "name", "type", "version", "description", "enabled", "created_at", "event_receiver_ids"] + ergs = EventReceiverGroup() + ergs.name = "foo-bar-receiver-group-2" + ergs.version = "1.0.0" + ergs.type = "dev.events.foo.bar.complete" + + event_receiver_group_results = await client.search_event_receiver_groups(params=ergs.as_dict_query(), fields=erg_fields) + print(f"{event_receiver_group_results}") + +# Run the async function +if __name__ == "__main__": + asyncio.run(search_event_receiver_groups()) +``` + +## Using with Model Context Protocol (MCP) Servers + +The async client is designed to work seamlessly with MCP servers. Here's an example of how to use it in an MCP server context: + +```python +import asyncio +from epr.client import Client +from epr.models import Event + +class EPRMCPServer: + def __init__(self, epr_url: str): + self.epr_url = epr_url + + async def create_event_handler(self, event_data: dict) -> dict: + """MCP tool handler for creating events""" + async with Client(self.epr_url) as client: + event = Event(**event_data) + result = await client.create_event(params=event) + return result + + async def search_events_handler(self, search_params: dict, fields: list = None) -> dict: + """MCP tool handler for searching events""" + async with Client(self.epr_url) as client: + result = await client.search_events(params=search_params, fields=fields) + return result ``` +The async context manager ensures that HTTP connections are properly managed and closed, which is essential for long-running MCP servers. + diff --git a/examples/async_client_example.py b/examples/async_client_example.py new file mode 100644 index 0000000..8903263 --- /dev/null +++ b/examples/async_client_example.py @@ -0,0 +1,193 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# SPDX-FileCopyrightText: © 2024 Brett Smith +# SPDX-License-Identifier: Apache-2.0 + +""" +Example demonstrating how to use the async EPR client. +This is especially useful for MCP servers and other async contexts. +""" + +import asyncio +import json +from typing import Optional + +from epr.client import Client + + +async def search_events_example(): + """Example of searching for events using the async client.""" + + # Use the async context manager - this ensures proper cleanup + async with Client("http://localhost:8042") as client: + # Search for events with specific parameters + search_params = {"name": "test-event", "version": "1.0.0"} + + # Specify which fields to return + fields = ["id", "name", "version", "release", "description", "success"] + + try: + result = await client.search_events(params=search_params, fields=fields) + print("Search Results:") + print(json.dumps(result, indent=2)) + return result + + except Exception as e: + print(f"Error searching for events: {e}") + return None + + +async def create_event_example(): + """Example of creating an event using the async client.""" + + async with Client("http://localhost:8042") as client: + # Event data to create + event_data = { + "name": "example-event", + "version": "1.0.1", + "release": "stable", + "platform_id": "linux-x64", + "package": "my-package", + "description": "An example event created via async client", + "payload": {"key": "value", "number": 42}, + "success": True, + } + + try: + result = await client.create_event(params=event_data) + print("Create Result:") + print(json.dumps(result, indent=2)) + return result + + except Exception as e: + print(f"Error creating event: {e}") + return None + + +async def batch_operations_example(): + """Example of performing multiple operations efficiently.""" + + async with Client("http://localhost:8042") as client: + # Perform multiple operations concurrently + search_task = client.search_events(fields=["id", "name", "version"]) + + receiver_search_task = client.search_event_receivers(fields=["id", "name", "type", "version"]) + + group_search_task = client.search_event_receiver_groups(fields=["id", "name", "type", "enabled"]) + + try: + # Wait for all operations to complete + search_result, receiver_result, group_result = await asyncio.gather( + search_task, receiver_search_task, group_search_task, return_exceptions=True + ) + + results = { + "events": search_result, + "event_receivers": receiver_result, + "event_receiver_groups": group_result, + } + + print("Batch Results:") + print(json.dumps(results, indent=2)) + return results + + except Exception as e: + print(f"Error in batch operations: {e}") + return None + + +class EPRAsyncService: + """ + Example service class showing how to use the async client in a class. + This pattern is useful for MCP servers and other service applications. + """ + + def __init__(self, url: str = "http://localhost:8042", headers: Optional[dict] = None): + self.url = url + self.headers = headers or {} + self._client: Optional[Client] = None + + async def __aenter__(self): + """Async context manager entry.""" + self._client = Client(self.url, headers=self.headers) + await self._client.__aenter__() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit.""" + if self._client: + await self._client.__aexit__(exc_type, exc_val, exc_tb) + + async def get_event_by_name(self, name: str): + """Get an event by name.""" + if not self._client: + raise RuntimeError("Service not initialized. Use 'async with EPRAsyncService()' context manager.") + + search_params = {"name": name} + fields = ["id", "name", "version", "release", "description", "success"] + + result = await self._client.search_events(params=search_params, fields=fields) + events = result.get("data", {}).get("events", []) + return events[0] if events else None + + async def create_event(self, event_data: dict): + """Create a new event.""" + if not self._client: + raise RuntimeError("Service not initialized. Use 'async with EPRAsyncService()' context manager.") + + return await self._client.create_event(params=event_data) + + async def health_check(self): + """Simple health check by attempting to search for events.""" + if not self._client: + raise RuntimeError("Service not initialized. Use 'async with EPRAsyncService()' context manager.") + + try: + await self._client.search_events(fields=["id"]) + return {"status": "healthy", "service": "epr"} + except Exception as e: + return {"status": "unhealthy", "service": "epr", "error": str(e)} + + +async def service_example(): + """Example using the service class.""" + + async with EPRAsyncService() as service: + # Health check + health = await service.health_check() + print("Health Check:") + print(json.dumps(health, indent=2)) + + # Get event by name + event = await service.get_event_by_name("test-event") + if event: + print("Found Event:") + print(json.dumps(event, indent=2)) + else: + print("Event not found") + + +async def main(): + """Main example function.""" + print("EPR Async Client Examples") + print("=" * 50) + + # Note: These examples assume an EPR server is running at localhost:8042 + # In a real scenario, you would handle connection errors appropriately + + print("\n1. Search Events Example:") + await search_events_example() + + print("\n2. Create Event Example:") + await create_event_example() + + print("\n3. Batch Operations Example:") + await batch_operations_example() + + print("\n4. Service Class Example:") + await service_example() + + +if __name__ == "__main__": + # Run the examples + asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml index f978f56..30817c1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ classifiers = [ "Programming Language :: Python :: 3.14", ] dependencies = [ - "urllib3", + "httpx", "jsonpath-ng", ] diff --git a/src/epr/client.py b/src/epr/client.py index e1703e1..5b295f2 100644 --- a/src/epr/client.py +++ b/src/epr/client.py @@ -7,17 +7,15 @@ from typing import Any, Optional from urllib.parse import urljoin -import urllib3 +import httpx from .common import EnhancedJSONEncoder from .models import GraphQLQuery -urllib3.disable_warnings() - logger = logging.getLogger(__name__) -class Client(object): +class Client: def __init__(self, url, headers=None): self.url = url self.api_version = "v1" @@ -51,6 +49,24 @@ def __init__(self, url, headers=None): "create_event_receiver_group": "event_receiver_group", }, } + self._client = None + + async def __aenter__(self): + """Async context manager entry""" + self._client = httpx.AsyncClient(timeout=httpx.Timeout(connect=2.0, read=10.0, write=10.0, pool=10.0)) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit""" + if self._client: + await self._client.aclose() + self._client = None + + async def close(self): + """Manually close the client""" + if self._client: + await self._client.aclose() + self._client = None def _new_graphql_search_query( self, operation: str, params: Optional[dict] = None, fields: Optional[list] = None @@ -108,21 +124,20 @@ def _new_graphql_mutation_query(self, operation: str, params: Optional[dict] = N query = f"""mutation ($obj: {method}){{{operation}({op}: $obj)}}""" return GraphQLQuery(query=query, variables=variables) - def _query(self, query: GraphQLQuery) -> Any: + async def _query(self, query: GraphQLQuery) -> Any: """ Sends a GraphQL query to the server. Args: - query (str): The GraphQL query string. - variables (dict, optional): The variables to be used in the query. Defaults to None. + query (GraphQLQuery): The GraphQL query object containing query and variables. Returns: Any: The response data from the server. """ - response = self._post(self.target, data=query.as_dict()) - return json.loads(response.decode("utf-8")) + response = await self._post(self.target, data=query.as_dict()) + return response.json() - def _post(self, url: str, data: dict) -> bytes: + async def _post(self, url: str, data: dict) -> httpx.Response: """ Sends a POST request to the specified URL with the provided data. @@ -131,15 +146,17 @@ def _post(self, url: str, data: dict) -> bytes: data (dict): The data to be sent in the POST request. Returns: - bytes: The data received in the response to the POST request. + httpx.Response: The response from the server. """ - timeout = urllib3.Timeout(connect=2.0, read=10.0) - http = urllib3.PoolManager(timeout=timeout) - encoded_data = json.dumps(data, cls=EnhancedJSONEncoder).encode("utf-8") - response = http.request("POST", url, body=encoded_data, headers=self.headers) - return response.data + if not self._client: + self._client = httpx.AsyncClient(timeout=httpx.Timeout(connect=2.0, read=10.0, write=10.0, pool=10.0)) + + json_data = json.dumps(data, cls=EnhancedJSONEncoder) + response = await self._client.post(url, content=json_data, headers=self.headers) + response.raise_for_status() + return response - def _search(self, operation: str, params: Optional[dict] = None, fields: Optional[list] = None) -> Any: + async def _search(self, operation: str, params: Optional[dict] = None, fields: Optional[list] = None) -> Any: """ Sends a GraphQL search query to the server. @@ -152,10 +169,9 @@ def _search(self, operation: str, params: Optional[dict] = None, fields: Optiona Any: The response data from the server. """ query = self._new_graphql_search_query(operation, params, fields) + return await self._query(query=query) - return self._query(query=query) - - def _mutation(self, operation: str, params: Optional[dict] = None) -> Any: + async def _mutation(self, operation: str, params: Optional[dict] = None) -> Any: """ Sends a GraphQL mutation query to the server. @@ -170,9 +186,9 @@ def _mutation(self, operation: str, params: Optional[dict] = None) -> Any: It then sends the query to the server using the `query` method and returns the response data. """ query = self._new_graphql_mutation_query(operation, params) - return self._query(query) + return await self._query(query) - def search_events(self, params: Optional[dict] = None, fields: Optional[list] = None) -> Any: + async def search_events(self, params: Optional[dict] = None, fields: Optional[list] = None) -> Any: """ Searches for events based on the provided parameters and fields. @@ -185,9 +201,9 @@ def search_events(self, params: Optional[dict] = None, fields: Optional[list] = This function performs a search for events based on the provided parameters and fields. """ - return self._search("events", params, fields) + return await self._search("events", params, fields) - def search_event_receivers(self, params: Optional[dict] = None, fields: Optional[list] = None) -> Any: + async def search_event_receivers(self, params: Optional[dict] = None, fields: Optional[list] = None) -> Any: """ Search for event receivers based on the given parameters and fields. @@ -200,9 +216,9 @@ def search_event_receivers(self, params: Optional[dict] = None, fields: Optional This function performs a search for event receivers based on the given parameters and fields. """ - return self._search("event_receivers", params, fields) + return await self._search("event_receivers", params, fields) - def search_event_receiver_groups(self, params: Optional[dict] = None, fields: Optional[list] = None) -> Any: + async def search_event_receiver_groups(self, params: Optional[dict] = None, fields: Optional[list] = None) -> Any: """ Search for event receiver groups based on the given parameters and fields. @@ -215,9 +231,9 @@ def search_event_receiver_groups(self, params: Optional[dict] = None, fields: Op This function performs a search for event receiver groups based on the given parameters and fields. """ - return self._search("event_receiver_groups", params, fields) + return await self._search("event_receiver_groups", params, fields) - def create_event(self, params: Optional[dict] = None) -> Any: + async def create_event(self, params: Optional[dict] = None) -> Any: """ Creates an event using the provided parameters. @@ -229,9 +245,9 @@ def create_event(self, params: Optional[dict] = None) -> Any: This function sends a mutation query to create an event using the provided parameters. """ - return self._mutation("create_event", params) + return await self._mutation("create_event", params) - def create_event_receiver(self, params: Optional[dict] = None) -> Any: + async def create_event_receiver(self, params: Optional[dict] = None) -> Any: """ Creates an event receiver using the provided parameters. @@ -243,9 +259,9 @@ def create_event_receiver(self, params: Optional[dict] = None) -> Any: This function sends a mutation query to create an event receiver using the provided parameters. """ - return self._mutation("create_event_receiver", params) + return await self._mutation("create_event_receiver", params) - def create_event_receiver_group(self, params: Optional[dict] = None) -> Any: + async def create_event_receiver_group(self, params: Optional[dict] = None) -> Any: """ Creates an event receiver group using the provided parameters. @@ -257,4 +273,4 @@ def create_event_receiver_group(self, params: Optional[dict] = None) -> Any: This function sends a mutation query to create an event receiver group using the provided parameters. """ - return self._mutation("create_event_receiver_group", params) + return await self._mutation("create_event_receiver_group", params) diff --git a/src/epr/create.py b/src/epr/create.py index 483c50b..f148c29 100644 --- a/src/epr/create.py +++ b/src/epr/create.py @@ -8,7 +8,7 @@ from .config import Config -def create(config: Config): +async def create(config: Config): """Create an event provenance registry object""" url = config.url @@ -16,23 +16,25 @@ def create(config: Config): url = "http://localhost:8042" # headers = {"Authorization": "Bearer " + config.token} headers = {} - client = Client(url, headers=headers) - - events = [] - for e in config.events: - event = client.create_event(params=e.as_dict()) - event_id = event["data"]["create_event"] - events.append(event_id) - event_receivers = [] - for er in config.event_receivers: - event_receiver = client.create_event_receiver(params=er.as_dict()) - event_receiver_id = event_receiver["data"]["create_event_receiver"] - event_receivers.append(event_receiver_id) - event_receiver_groups = [] - for erg in config.event_receiver_groups: - event_receiver_group = client.create_event_receiver_group(params=erg.as_dict()) - event_receiver_group_id = event_receiver_group["data"]["create_event_receiver_group"] - event_receiver_groups.append(event_receiver_group_id) + + async with Client(url, headers=headers) as client: + events = [] + for e in config.events: + event = await client.create_event(params=e.as_dict()) + event_id = event["data"]["create_event"] + events.append(event_id) + + event_receivers = [] + for er in config.event_receivers: + event_receiver = await client.create_event_receiver(params=er.as_dict()) + event_receiver_id = event_receiver["data"]["create_event_receiver"] + event_receivers.append(event_receiver_id) + + event_receiver_groups = [] + for erg in config.event_receiver_groups: + event_receiver_group = await client.create_event_receiver_group(params=erg.as_dict()) + event_receiver_group_id = event_receiver_group["data"]["create_event_receiver_group"] + event_receiver_groups.append(event_receiver_group_id) results = {"events": events, "event_receivers": event_receivers, "event_receiver_groups": event_receiver_groups} stdout = json.dumps(results) diff --git a/src/epr/main.py b/src/epr/main.py index 30deec2..bb7da1a 100644 --- a/src/epr/main.py +++ b/src/epr/main.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import argparse +import asyncio import logging import os import sys @@ -46,10 +47,9 @@ def __init__(self): getattr(self, args.command)() def _handle_fields(self, value): - fields = None ulid_length = 26 # a valid ULID length is 26 characters if value is None: - return fields + return [] if "," in value: fields = [x.strip() for x in value.split(",") if x] elif " " in value: @@ -298,7 +298,7 @@ def create(self): event_receiver_group.event_receiver_ids = self._handle_fields(args["event_receiver_ids"]) event_receiver_group.enabled = True if not args["disable"] else False cfg.event_receiver_groups.append(event_receiver_group) - return create.create(cfg) + return asyncio.run(create.create(cfg)) def search(self): """ @@ -567,7 +567,7 @@ def search(self): event_receiver_group.event_receiver_ids = self._handle_fields(args["event_receiver_ids"]) cfg.event_receiver_groups.append(event_receiver_group) cfg.event_receiver_group_fields = fields - return search.search(cfg) + return asyncio.run(search.search(cfg)) def version(self): """ diff --git a/src/epr/models.py b/src/epr/models.py index 4df487b..634fe71 100644 --- a/src/epr/models.py +++ b/src/epr/models.py @@ -118,4 +118,4 @@ class Message(Model): @dataclass class GraphQLQuery(Model): query: str - variables: str + variables: Dict[str, Any] = field(default_factory=dict) diff --git a/src/epr/search.py b/src/epr/search.py index bc92878..56aaf51 100644 --- a/src/epr/search.py +++ b/src/epr/search.py @@ -8,7 +8,7 @@ from .config import Config -def search(config: Config): +async def search(config: Config): """Search for events and event receivers""" url = config.url @@ -16,42 +16,44 @@ def search(config: Config): url = "http://localhost:8042" # headers = {"Authorization": "Bearer " + config.token} headers = {} - client = Client(url, headers=headers) - - events = [] - for e in config.events: - fields = config.event_fields - if fields is None: - fields = [ - "id", - "name", - "version", - "release", - "platform_id", - "package", - "description", - "success", - "event_receiver_id", - ] - event = client.search_events(params=e.as_dict_query(), fields=fields) - event_result = event["data"]["events"][-1] - events.append(event_result) - event_receivers = [] - for er in config.event_receivers: - fields = config.event_receiver_fields - if fields is None: - fields = ["id", "name", "type", "version", "description", "schema", "fingerprint", "created_at"] - event_receiver = client.search_event_receivers(params=er.as_dict_query(), fields=fields) - event_receiver_result = event_receiver["data"]["event_receivers"][-1] - event_receivers.append(event_receiver_result) - event_receiver_groups = [] - for erg in config.event_receiver_groups: - fields = config.event_receiver_group_fields - if fields is None: - fields = ["id", "name", "type", "version", "description", "enabled", "created_at"] - event_receiver_group = client.search_event_receiver_groups(params=erg.as_dict_query(), fields=fields) - event_receiver_group_result = event_receiver_group["data"]["event_receiver_groups"][-1] - event_receiver_groups.append(event_receiver_group_result) + + async with Client(url, headers=headers) as client: + events = [] + for e in config.events: + fields = config.event_fields + if fields is None: + fields = [ + "id", + "name", + "version", + "release", + "platform_id", + "package", + "description", + "success", + "event_receiver_id", + ] + event = await client.search_events(params=e.as_dict_query(), fields=fields) + event_result = event["data"]["events"][-1] + events.append(event_result) + + event_receivers = [] + for er in config.event_receivers: + fields = config.event_receiver_fields + if fields is None: + fields = ["id", "name", "type", "version", "description", "schema", "fingerprint", "created_at"] + event_receiver = await client.search_event_receivers(params=er.as_dict_query(), fields=fields) + event_receiver_result = event_receiver["data"]["event_receivers"][-1] + event_receivers.append(event_receiver_result) + + event_receiver_groups = [] + for erg in config.event_receiver_groups: + fields = config.event_receiver_group_fields + if fields is None: + fields = ["id", "name", "type", "version", "description", "enabled", "created_at"] + event_receiver_group = await client.search_event_receiver_groups(params=erg.as_dict_query(), fields=fields) + event_receiver_group_result = event_receiver_group["data"]["event_receiver_groups"][-1] + event_receiver_groups.append(event_receiver_group_result) results = {"events": events, "event_receivers": event_receivers, "event_receiver_groups": event_receiver_groups} stdout = json.dumps(results) diff --git a/tests/test_smoke.py b/tests/test_smoke.py index 57b87e9..901e2ec 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -2,6 +2,8 @@ # SPDX-FileCopyrightText: © 2024 Brett Smith # SPDX-License-Identifier: Apache-2.0 +import asyncio + def test_import(): """Validate epr is importable""" @@ -10,3 +12,51 @@ def test_import(): print(epr) assert True + + +def test_async_client_import(): + """Validate async client functionality is importable and works""" + from epr.client import Client + from epr.models import GraphQLQuery + + # Test that we can create a client + client = Client("http://localhost:8042") + assert client is not None + assert client.url == "http://localhost:8042" + + # Test that we can create GraphQL queries with dict variables + query = GraphQLQuery(query="query test { events { id } }", variables={"name": "test", "version": "1.0"}) + assert query.query == "query test { events { id } }" + assert query.variables["name"] == "test" + + print("✓ Async client imports and basic functionality work") + + +def test_async_client_context_manager(): + """Test that async client context manager works""" + + async def async_test(): + from epr.client import Client + + async with Client("http://localhost:8042") as client: + assert client is not None + assert client._client is not None # HTTP client should be created + return True + + # Run the async test + result = asyncio.run(async_test()) + assert result is True + + print("✓ Async client context manager works") + + +def test_async_search_create_imports(): + """Test that async search and create functions are importable""" + from epr.create import create + from epr.search import search + + # These should be async functions now + assert asyncio.iscoroutinefunction(search) + assert asyncio.iscoroutinefunction(create) + + print("✓ Async search and create functions are properly async") diff --git a/tests/unit/test_async_client.py b/tests/unit/test_async_client.py new file mode 100644 index 0000000..e2add4f --- /dev/null +++ b/tests/unit/test_async_client.py @@ -0,0 +1,275 @@ +# -*- coding: utf-8 -*- +# SPDX-FileCopyrightText: © 2024 Brett Smith +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import unittest +from unittest.mock import MagicMock, patch + +import httpx + +from epr.client import Client +from epr.models import GraphQLQuery +from tests import base + + +class AsyncClientTestCase(base.BaseTestCase): + """Test cases for the async Client class.""" + + def setUp(self): + super().setUp() + self.test_url = "http://test-server:8042" + self.test_headers = {"Authorization": "Bearer test-token"} + + def test_client_initialization(self): + """Test client initialization with various parameters.""" + # Test default initialization + client = Client(None) + self.assertEqual(client.url, "http://localhost:8042") + self.assertEqual(client.target, "http://localhost:8042/api/v1/graphql/query") + self.assertEqual(client.headers["Content-Type"], "application/json") + + # Test with custom URL and headers + client = Client(self.test_url, headers=self.test_headers) + self.assertEqual(client.url, self.test_url) + self.assertEqual(client.target, f"{self.test_url}/api/v1/graphql/query") + self.assertIn("Authorization", client.headers) + self.assertEqual(client.headers["Authorization"], "Bearer test-token") + + def test_graphql_search_query_generation(self): + """Test GraphQL search query generation.""" + client = Client(self.test_url) + + # Test basic search query + query = client._new_graphql_search_query("events", {"name": "test"}, ["id", "name"]) + self.assertIsInstance(query, GraphQLQuery) + self.assertIn("query", query.query) + self.assertIn("FindEventInput!", query.query) + self.assertIn("events", query.query) + self.assertEqual(query.variables["obj"]["name"], "test") + + # Test with no parameters + query = client._new_graphql_search_query("events", None, ["id"]) + self.assertEqual(query.variables["obj"], None) + + # Test with no fields (should default to "id") + query = client._new_graphql_search_query("events", {"name": "test"}, None) + self.assertIn("id", query.query) + + def test_graphql_mutation_query_generation(self): + """Test GraphQL mutation query generation.""" + client = Client(self.test_url) + + # Test mutation query + params = {"name": "test-event", "version": "1.0.0"} + query = client._new_graphql_mutation_query("create_event", params) + + self.assertIsInstance(query, GraphQLQuery) + self.assertIn("mutation", query.query) + self.assertIn("CreateEventInput!", query.query) + self.assertIn("create_event", query.query) + self.assertEqual(query.variables["obj"], params) + + async def test_async_context_manager(self): + """Test async context manager functionality.""" + client = Client(self.test_url) + + # Test that client starts with no HTTP client + self.assertIsNone(client._client) + + async with client as c: + # Should be the same client instance + self.assertIs(c, client) + # Should have an HTTP client now + self.assertIsNotNone(client._client) + self.assertIsInstance(client._client, httpx.AsyncClient) + + # Client should be closed after context exit + self.assertIsNone(client._client) + + async def test_manual_client_management(self): + """Test manual client lifecycle management.""" + client = Client(self.test_url) + + # Manually enter + result = await client.__aenter__() + self.assertIs(result, client) + self.assertIsNotNone(client._client) + + # Manually exit + await client.__aexit__(None, None, None) + self.assertIsNone(client._client) + + @patch("httpx.AsyncClient.post") + async def test_post_method(self, mock_post): + """Test the async _post method.""" + # Setup mock response + mock_response = MagicMock() + mock_response.json.return_value = {"data": {"test": "result"}} + mock_post.return_value = mock_response + + async with Client(self.test_url) as client: + data = {"query": "test query", "variables": {}} + response = await client._post(client.target, data) + + # Verify the response + self.assertEqual(response, mock_response) + + # Verify the post was called correctly + mock_post.assert_called_once() + call_args = mock_post.call_args + + # Check URL + self.assertEqual(call_args[1]["url"] if "url" in call_args[1] else call_args[0][0], client.target) + + # Check headers + self.assertEqual(call_args[1]["headers"], client.headers) + + # Check that response.raise_for_status was called + mock_response.raise_for_status.assert_called_once() + + @patch("httpx.AsyncClient.post") + async def test_query_method(self, mock_post): + """Test the async _query method.""" + # Setup mock response + mock_response = MagicMock() + mock_response.json.return_value = {"data": {"events": [{"id": "test-id"}]}} + mock_post.return_value = mock_response + + async with Client(self.test_url) as client: + query = GraphQLQuery(query="query test { events { id } }", variables={"obj": {"name": "test"}}) + + result = await client._query(query) + + # Should return the JSON response + self.assertEqual(result, {"data": {"events": [{"id": "test-id"}]}}) + + # Verify the underlying post call + mock_post.assert_called_once() + + @patch("httpx.AsyncClient.post") + async def test_search_events(self, mock_post): + """Test async search_events method.""" + mock_response = MagicMock() + mock_response.json.return_value = { + "data": {"events": [{"id": "test-id", "name": "test-event", "version": "1.0.0"}]} + } + mock_post.return_value = mock_response + + async with Client(self.test_url) as client: + result = await client.search_events(params={"name": "test-event"}, fields=["id", "name", "version"]) + + self.assertEqual(result["data"]["events"][0]["name"], "test-event") + mock_post.assert_called_once() + + @patch("httpx.AsyncClient.post") + async def test_create_event(self, mock_post): + """Test async create_event method.""" + mock_response = MagicMock() + mock_response.json.return_value = {"data": {"create_event": "new-event-id"}} + mock_post.return_value = mock_response + + async with Client(self.test_url) as client: + event_data = {"name": "new-event", "version": "1.0.0", "description": "Test event"} + + result = await client.create_event(params=event_data) + + self.assertEqual(result["data"]["create_event"], "new-event-id") + mock_post.assert_called_once() + + @patch("httpx.AsyncClient.post") + async def test_batch_operations(self, mock_post): + """Test concurrent batch operations.""" + # Setup different responses for different calls + responses = [MagicMock(), MagicMock(), MagicMock()] + + responses[0].json.return_value = {"data": {"events": [{"id": "event-1"}]}} + responses[1].json.return_value = {"data": {"event_receivers": [{"id": "receiver-1"}]}} + responses[2].json.return_value = {"data": {"event_receiver_groups": [{"id": "group-1"}]}} + + mock_post.side_effect = responses + + async with Client(self.test_url) as client: + # Run multiple operations concurrently + results = await asyncio.gather( + client.search_events(fields=["id"]), + client.search_event_receivers(fields=["id"]), + client.search_event_receiver_groups(fields=["id"]), + ) + + # Verify all operations completed + self.assertEqual(len(results), 3) + self.assertEqual(results[0]["data"]["events"][0]["id"], "event-1") + self.assertEqual(results[1]["data"]["event_receivers"][0]["id"], "receiver-1") + self.assertEqual(results[2]["data"]["event_receiver_groups"][0]["id"], "group-1") + + # Verify all HTTP calls were made + self.assertEqual(mock_post.call_count, 3) + + @patch("httpx.AsyncClient.post") + async def test_http_error_handling(self, mock_post): + """Test HTTP error handling.""" + # Setup mock to raise an HTTP error + mock_post.side_effect = httpx.HTTPStatusError("404 Not Found", request=MagicMock(), response=MagicMock()) + + async with Client(self.test_url) as client: + with self.assertRaises(httpx.HTTPStatusError): + await client.search_events() + + async def test_client_without_context_manager(self): + """Test that client works without context manager by auto-creating HTTP client.""" + client = Client(self.test_url) + + with patch("httpx.AsyncClient.post") as mock_post: + mock_response = MagicMock() + mock_response.json.return_value = {"data": {"events": []}} + mock_post.return_value = mock_response + + # This should auto-create the HTTP client + await client.search_events() + + # Verify HTTP client was created + self.assertIsNotNone(client._client) + mock_post.assert_called_once() + + # Clean up + await client.close() + + def test_sync_utility_methods(self): + """Test synchronous utility methods that don't require async.""" + client = Client(self.test_url) + + # Test operation mapping + self.assertIn("events", client._operation_map["search"]) + self.assertIn("create_event", client._operation_map["mutation"]) + + # Test that all expected operations are mapped + expected_search_ops = ["events", "event_receivers", "event_receiver_groups"] + expected_mutation_ops = ["create_event", "create_event_receiver", "create_event_receiver_group"] + + for op in expected_search_ops: + self.assertIn(op, client._operation_map["search"]) + + for op in expected_mutation_ops: + self.assertIn(op, client._operation_map["mutation"]) + + +# Run async tests using asyncio +def run_async_test(coro): + """Helper to run async test methods.""" + + def wrapper(self): + return asyncio.run(coro(self)) + + return wrapper + + +# Apply the async wrapper to all async test methods +for attr_name in dir(AsyncClientTestCase): + attr = getattr(AsyncClientTestCase, attr_name) + if callable(attr) and asyncio.iscoroutinefunction(attr) and attr_name.startswith("test_"): + setattr(AsyncClientTestCase, attr_name, run_async_test(attr)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/test_async_main.py b/tests/unit/test_async_main.py new file mode 100644 index 0000000..efae905 --- /dev/null +++ b/tests/unit/test_async_main.py @@ -0,0 +1,230 @@ +# -*- coding: utf-8 -*- +# SPDX-FileCopyrightText: © 2024 Brett Smith +# SPDX-License-Identifier: Apache-2.0 + +import sys +import unittest +from io import StringIO +from unittest.mock import AsyncMock, patch + +from epr import main +from tests import base + + +class AsyncMainTestCase(base.BaseTestCase): + """Test cases for main module with async functionality.""" + + def setUp(self): + super().setUp() + # Capture stdout for testing CLI output + self.old_stdout = sys.stdout + self.captured_output = StringIO() + sys.stdout = self.captured_output + + # Store original sys.argv + self.original_argv = sys.argv + + def tearDown(self): + super().tearDown() + # Restore stdout + sys.stdout = self.old_stdout + # Restore sys.argv + sys.argv = self.original_argv + + @patch("epr.create.create") + def test_cli_create_calls_async_function(self, mock_create): + """Test that CLI create command properly calls async create function.""" + # Setup mock to return a result + mock_create.return_value = {"events": ["event-id-1"], "event_receivers": [], "event_receiver_groups": []} + + # Simulate CLI arguments for create command + test_args = [ + "eprcli", + "create", + "event", + "--name", + "test-event", + "--version", + "1.0.0", + "--release", + "stable", + "--platform-id", + "linux-x64", + "--package", + "test-package", + "--description", + "Test event", + "--success", + ] + + with patch.object(sys, "argv", test_args): + try: + # This should work without raising an exception + cmd = main.CmdLine() + self.assertIsNotNone(cmd) + # The create method should have been called with asyncio.run() + mock_create.assert_called_once() + except SystemExit: + # CLI may exit after completion, that's expected + pass + + @patch("epr.search.search") + def test_cli_search_calls_async_function(self, mock_search): + """Test that CLI search command properly calls async search function.""" + # Setup mock to return a result + mock_search.return_value = { + "events": [{"id": "event-1", "name": "test-event"}], + "event_receivers": [], + "event_receiver_groups": [], + } + + # Simulate CLI arguments for search command + test_args = ["eprcli", "search", "event", "--name", "test-event"] + + with patch.object(sys, "argv", test_args): + try: + # This should work without raising an exception + cmd = main.CmdLine() + self.assertIsNotNone(cmd) + # The search method should have been called with asyncio.run() + mock_search.assert_called_once() + except SystemExit: + # CLI may exit after completion, that's expected + pass + + def test_cli_help_still_works(self): + """Test that CLI help functionality still works after async changes.""" + test_args = ["eprcli", "--help"] + + with patch.object(sys, "argv", test_args): + with self.assertRaises(SystemExit) as context: + main.CmdLine() + + # Help should exit with code 0 + self.assertEqual(context.exception.code, 0) + + @patch("epr.create.create") + def test_create_command_error_handling(self, mock_create): + """Test error handling in create command with async function.""" + # Setup mock to raise an exception + mock_create.side_effect = Exception("Async operation failed") + + test_args = [ + "eprcli", + "create", + "event", + "--name", + "test-event", + "--version", + "1.0.0", + "--release", + "stable", + "--platform-id", + "linux-x64", + "--package", + "test-package", + "--description", + "Test event", + ] + + with patch.object(sys, "argv", test_args): + # The exception should propagate (or be handled gracefully by CLI) + with self.assertRaises((Exception, SystemExit)): + main.CmdLine() + + @patch("epr.search.search") + def test_search_command_error_handling(self, mock_search): + """Test error handling in search command with async function.""" + # Setup mock to raise an exception + mock_search.side_effect = Exception("Async search failed") + + test_args = ["eprcli", "search", "event", "--name", "test-event"] + + with patch.object(sys, "argv", test_args): + # The exception should propagate (or be handled gracefully by CLI) + with self.assertRaises((Exception, SystemExit)): + main.CmdLine() + + def test_field_handling_utility(self): + """Test the _handle_fields utility method works correctly.""" + cmd = main.CmdLine.__new__(main.CmdLine) # Create instance without __init__ + + # Test comma-separated values + result = cmd._handle_fields("field1,field2,field3") + self.assertEqual(result, ["field1", "field2", "field3"]) + + # Test space-separated values + result = cmd._handle_fields("field1 field2 field3") + self.assertEqual(result, ["field1", "field2", "field3"]) + + # Test single ULID-like value (26 characters) + ulid = "01ARZ3NDEKTSV4RRFFQ69G5FAV" + result = cmd._handle_fields(ulid) + self.assertEqual(result, [ulid]) + + # Test None input + result = cmd._handle_fields(None) + self.assertEqual(result, []) + + # Test invalid input should raise ValueError + with self.assertRaises(ValueError): + cmd._handle_fields("invalid-input") + + def test_import_compatibility(self): + """Test that all imports still work after async refactoring.""" + # These imports should work without errors + from epr.client import Client + from epr.config import Config + from epr.models import Event + + # Basic smoke test - creating instances should work + client = Client("http://localhost:8042") + self.assertIsNotNone(client) + + config = Config(url="http://localhost:8042", token="test-token") + self.assertIsNotNone(config) + + event = Event() + self.assertIsNotNone(event) + + @patch("epr.create.create", new_callable=AsyncMock) + @patch("asyncio.run") + def test_asyncio_run_usage(self, mock_asyncio_run, mock_create): + """Test that asyncio.run is properly used for async function calls.""" + mock_asyncio_run.return_value = {"events": [], "event_receivers": [], "event_receiver_groups": []} + mock_create.return_value = {"events": [], "event_receivers": [], "event_receiver_groups": []} + + test_args = [ + "eprcli", + "create", + "event", + "--name", + "test-event", + "--version", + "1.0.0", + "--release", + "stable", + "--platform-id", + "linux-x64", + "--package", + "test-package", + "--description", + "Test event", + "--payload", + '{"key": "value"}', + "--event-receiver-id", + "test-receiver-id", + ] + + with patch.object(sys, "argv", test_args): + try: + main.CmdLine() + except SystemExit: + pass + + # asyncio.run should have been called + mock_asyncio_run.assert_called() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/test_async_search_create.py b/tests/unit/test_async_search_create.py new file mode 100644 index 0000000..14059ea --- /dev/null +++ b/tests/unit/test_async_search_create.py @@ -0,0 +1,278 @@ +# -*- coding: utf-8 -*- +# SPDX-FileCopyrightText: © 2024 Brett Smith +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import unittest +from unittest.mock import AsyncMock, patch + +from epr import create, search +from epr.config import Config +from epr.models import Event, EventReceiver, EventReceiverGroup +from tests import base + + +class AsyncSearchCreateTestCase(base.BaseTestCase): + """Test cases for the async search and create functions.""" + + def setUp(self): + super().setUp() + self.test_url = "http://test-server:8042" + + # Create test config + self.config = Config(url=self.test_url, token="test-token") + + # Create test data + self.test_event = Event( + name="test-event", + version="1.0.0", + release="stable", + platform_id="linux-x64", + package="test-package", + description="Test event", + ) + + self.test_receiver = EventReceiver( + name="test-receiver", + type="webhook", + version="1.0.0", + description="Test receiver", + schema={"type": "object"}, + ) + + self.test_group = EventReceiverGroup( + name="test-group", + type="group", + version="1.0.0", + description="Test group", + event_receiver_ids=["receiver-1", "receiver-2"], + enabled=True, + ) + + @patch("epr.search.Client") + async def test_search_function(self, mock_client_class): + """Test the async search function.""" + # Setup mock client and responses + mock_client = AsyncMock() + mock_client_class.return_value.__aenter__.return_value = mock_client + + # Mock search responses + mock_client.search_events.return_value = {"data": {"events": [{"id": "event-1", "name": "test-event"}]}} + mock_client.search_event_receivers.return_value = { + "data": {"event_receivers": [{"id": "receiver-1", "name": "test-receiver"}]} + } + mock_client.search_event_receiver_groups.return_value = { + "data": {"event_receiver_groups": [{"id": "group-1", "name": "test-group"}]} + } + + # Add test data to config + self.config.events = [self.test_event] + self.config.event_receivers = [self.test_receiver] + self.config.event_receiver_groups = [self.test_group] + + # Run search function + result = await search.search(self.config) + + # Verify results + self.assertIn("events", result) + self.assertIn("event_receivers", result) + self.assertIn("event_receiver_groups", result) + + # Verify client was used correctly + mock_client.search_events.assert_called_once() + mock_client.search_event_receivers.assert_called_once() + mock_client.search_event_receiver_groups.assert_called_once() + + # Verify client context manager was used + mock_client_class.return_value.__aenter__.assert_called_once() + mock_client_class.return_value.__aexit__.assert_called_once() + + @patch("epr.create.Client") + async def test_create_function(self, mock_client_class): + """Test the async create function.""" + # Setup mock client and responses + mock_client = AsyncMock() + mock_client_class.return_value.__aenter__.return_value = mock_client + + # Mock create responses + mock_client.create_event.return_value = {"data": {"create_event": "new-event-id"}} + mock_client.create_event_receiver.return_value = {"data": {"create_event_receiver": "new-receiver-id"}} + mock_client.create_event_receiver_group.return_value = {"data": {"create_event_receiver_group": "new-group-id"}} + + # Add test data to config + self.config.events = [self.test_event] + self.config.event_receivers = [self.test_receiver] + self.config.event_receiver_groups = [self.test_group] + + # Run create function + result = await create.create(self.config) + + # Verify results + self.assertIn("events", result) + self.assertIn("event_receivers", result) + self.assertIn("event_receiver_groups", result) + + self.assertEqual(result["events"], ["new-event-id"]) + self.assertEqual(result["event_receivers"], ["new-receiver-id"]) + self.assertEqual(result["event_receiver_groups"], ["new-group-id"]) + + # Verify client methods were called + mock_client.create_event.assert_called_once() + mock_client.create_event_receiver.assert_called_once() + mock_client.create_event_receiver_group.assert_called_once() + + @patch("epr.search.Client") + async def test_search_with_custom_fields(self, mock_client_class): + """Test search function with custom field specifications.""" + mock_client = AsyncMock() + mock_client_class.return_value.__aenter__.return_value = mock_client + + mock_client.search_events.return_value = {"data": {"events": [{"id": "event-1"}]}} + + # Set custom fields in config + self.config.events = [self.test_event] + self.config.event_fields = ["id", "name", "custom_field"] + + await search.search(self.config) + + # Verify custom fields were passed + call_args = mock_client.search_events.call_args + self.assertIn("fields", call_args.kwargs) + self.assertEqual(call_args.kwargs["fields"], ["id", "name", "custom_field"]) + + @patch("epr.search.Client") + async def test_search_empty_config(self, mock_client_class): + """Test search function with empty config.""" + mock_client = AsyncMock() + mock_client_class.return_value.__aenter__.return_value = mock_client + + # Empty config - no events, receivers, or groups + empty_config = Config(url=self.test_url, token="test-token") + + result = await search.search(empty_config) + + # Should return empty results + self.assertEqual(result["events"], []) + self.assertEqual(result["event_receivers"], []) + self.assertEqual(result["event_receiver_groups"], []) + + # Client methods should not be called + mock_client.search_events.assert_not_called() + mock_client.search_event_receivers.assert_not_called() + mock_client.search_event_receiver_groups.assert_not_called() + + @patch("epr.create.Client") + async def test_create_empty_config(self, mock_client_class): + """Test create function with empty config.""" + mock_client = AsyncMock() + mock_client_class.return_value.__aenter__.return_value = mock_client + + # Empty config + empty_config = Config(url=self.test_url, token="test-token") + + result = await create.create(empty_config) + + # Should return empty results + self.assertEqual(result["events"], []) + self.assertEqual(result["event_receivers"], []) + self.assertEqual(result["event_receiver_groups"], []) + + # Client methods should not be called + mock_client.create_event.assert_not_called() + mock_client.create_event_receiver.assert_not_called() + mock_client.create_event_receiver_group.assert_not_called() + + @patch("epr.search.Client") + async def test_search_error_handling(self, mock_client_class): + """Test error handling in search function.""" + mock_client = AsyncMock() + mock_client_class.return_value.__aenter__.return_value = mock_client + + # Setup client to raise an exception + mock_client.search_events.side_effect = Exception("Network error") + + self.config.events = [self.test_event] + + # The search function should propagate the exception + with self.assertRaises(Exception) as context: + await search.search(self.config) + + self.assertIn("Network error", str(context.exception)) + + @patch("epr.create.Client") + async def test_create_error_handling(self, mock_client_class): + """Test error handling in create function.""" + mock_client = AsyncMock() + mock_client_class.return_value.__aenter__.return_value = mock_client + + # Setup client to raise an exception + mock_client.create_event.side_effect = Exception("Server error") + + self.config.events = [self.test_event] + + # The create function should propagate the exception + with self.assertRaises(Exception) as context: + await create.create(self.config) + + self.assertIn("Server error", str(context.exception)) + + @patch("epr.create.Client") + async def test_multiple_items_handling(self, mock_client_class): + """Test handling of multiple items in config.""" + mock_client = AsyncMock() + mock_client_class.return_value.__aenter__.return_value = mock_client + + # Setup responses for multiple items + mock_client.create_event.side_effect = [ + {"data": {"create_event": "event-id-1"}}, + {"data": {"create_event": "event-id-2"}}, + ] + + # Add multiple events to config + event2 = Event(name="event-2", version="2.0.0") + self.config.events = [self.test_event, event2] + + result = await create.create(self.config) + + # Should have created both events + self.assertEqual(len(result["events"]), 2) + self.assertEqual(result["events"], ["event-id-1", "event-id-2"]) + + # Client method should be called twice + self.assertEqual(mock_client.create_event.call_count, 2) + + def test_data_serialization(self): + """Test that model data is properly serialized for API calls.""" + # Test as_dict method + event_dict = self.test_event.as_dict() + self.assertIsInstance(event_dict, dict) + self.assertEqual(event_dict["name"], "test-event") + self.assertEqual(event_dict["version"], "1.0.0") + + # Test as_dict_query method (filters out empty values) + event_query = self.test_event.as_dict_query() + self.assertIsInstance(event_query, dict) + # Should only include non-empty values + self.assertNotIn("id", event_query) # id is empty by default + self.assertIn("name", event_query) + + +# Run async tests using asyncio +def run_async_test(coro): + """Helper to run async test methods.""" + + def wrapper(self): + return asyncio.run(coro(self)) + + return wrapper + + +# Apply the async wrapper to all async test methods +for attr_name in dir(AsyncSearchCreateTestCase): + attr = getattr(AsyncSearchCreateTestCase, attr_name) + if callable(attr) and asyncio.iscoroutinefunction(attr) and attr_name.startswith("test_"): + setattr(AsyncSearchCreateTestCase, attr_name, run_async_test(attr)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/test_errors.py b/tests/unit/test_errors.py index 67a5d8a..613e7a3 100644 --- a/tests/unit/test_errors.py +++ b/tests/unit/test_errors.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: © 2024 Brett Smith # SPDX-License-Identifier: Apache-2.0 -import mock +from unittest import mock from epr import errors from tests import base @@ -54,7 +54,7 @@ def fake_exception(): frame = FakeFrame(code, {}) tb = FakeTraceback([frame], [1]) exc = FakeException("foo").with_traceback(tb) - return FakeException, exc, tb + return FakeException, exc, exc._fake_traceback class FakeCode(object): @@ -86,17 +86,9 @@ def tb_next(self): class FakeException(Exception): def __init__(self, *args, **kwargs): - self._tb = None super(FakeException, self).__init__(*args, **kwargs) - - @property - def __traceback__(self): - return self._tb - - @__traceback__.setter - def __traceback__(self, value): - self._tb = value + self._fake_traceback = None def with_traceback(self, value): - self._tb = value + self._fake_traceback = value return self diff --git a/tests/unit/test_models.py b/tests/unit/test_models.py new file mode 100644 index 0000000..ae38be8 --- /dev/null +++ b/tests/unit/test_models.py @@ -0,0 +1,148 @@ +# -*- coding: utf-8 -*- +# SPDX-FileCopyrightText: © 2024 Brett Smith +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +from epr.models import Event, EventReceiver, EventReceiverGroup, GraphQLQuery +from tests import base + + +class ModelsTestCase(base.BaseTestCase): + """Test cases for updated model classes.""" + + def setUp(self): + super().setUp() + + def test_graphql_query_model(self): + """Test the updated GraphQLQuery model with dict variables.""" + # Test basic creation + query = GraphQLQuery(query="query test { events { id } }", variables={"name": "test", "version": "1.0"}) + + self.assertEqual(query.query, "query test { events { id } }") + self.assertIsInstance(query.variables, dict) + self.assertEqual(query.variables["name"], "test") + self.assertEqual(query.variables["version"], "1.0") + + def test_graphql_query_as_dict(self): + """Test GraphQLQuery as_dict method works correctly.""" + query = GraphQLQuery( + query="query test { events { id } }", variables={"name": "test", "nested": {"key": "value"}} + ) + + query_dict = query.as_dict() + + self.assertIsInstance(query_dict, dict) + self.assertIn("query", query_dict) + self.assertIn("variables", query_dict) + self.assertEqual(query_dict["query"], "query test { events { id } }") + self.assertEqual(query_dict["variables"]["name"], "test") + self.assertEqual(query_dict["variables"]["nested"]["key"], "value") + + def test_graphql_query_default_variables(self): + """Test GraphQLQuery with default empty variables.""" + query = GraphQLQuery(query="query test { events { id } }") + + self.assertEqual(query.query, "query test { events { id } }") + self.assertIsInstance(query.variables, dict) + self.assertEqual(len(query.variables), 0) + + def test_event_model_unchanged(self): + """Test that Event model still works as expected.""" + event = Event(name="test-event", version="1.0.0", description="Test event") + + self.assertEqual(event.name, "test-event") + self.assertEqual(event.version, "1.0.0") + self.assertEqual(event.description, "Test event") + + # Test as_dict + event_dict = event.as_dict() + self.assertIn("name", event_dict) + self.assertEqual(event_dict["name"], "test-event") + + # Test as_dict_query (filters empty values) + event_query = event.as_dict_query() + self.assertIn("name", event_query) + self.assertNotIn("id", event_query) # Should be filtered out if empty + + def test_event_receiver_model_unchanged(self): + """Test that EventReceiver model still works as expected.""" + receiver = EventReceiver(name="test-receiver", type="webhook", version="1.0.0", description="Test receiver") + + self.assertEqual(receiver.name, "test-receiver") + self.assertEqual(receiver.type, "webhook") + + # Test serialization + receiver_dict = receiver.as_dict() + self.assertIn("name", receiver_dict) + self.assertIn("type", receiver_dict) + + def test_event_receiver_group_model_unchanged(self): + """Test that EventReceiverGroup model still works as expected.""" + group = EventReceiverGroup( + name="test-group", + type="group", + version="1.0.0", + event_receiver_ids=["receiver-1", "receiver-2"], + enabled=True, + ) + + self.assertEqual(group.name, "test-group") + self.assertEqual(group.enabled, True) + self.assertEqual(len(group.event_receiver_ids), 2) + + # Test serialization + group_dict = group.as_dict() + self.assertIn("name", group_dict) + self.assertIn("enabled", group_dict) + self.assertIn("event_receiver_ids", group_dict) + + def test_model_inheritance(self): + """Test that all models properly inherit from Model base class.""" + event = Event() + receiver = EventReceiver() + group = EventReceiverGroup() + query = GraphQLQuery(query="test") + + # All should have as_dict method + self.assertTrue(hasattr(event, "as_dict")) + self.assertTrue(hasattr(receiver, "as_dict")) + self.assertTrue(hasattr(group, "as_dict")) + self.assertTrue(hasattr(query, "as_dict")) + + # All should have as_dict_query method + self.assertTrue(hasattr(event, "as_dict_query")) + self.assertTrue(hasattr(receiver, "as_dict_query")) + self.assertTrue(hasattr(group, "as_dict_query")) + self.assertTrue(hasattr(query, "as_dict_query")) + + def test_complex_variables_serialization(self): + """Test that complex variables in GraphQLQuery serialize correctly.""" + complex_vars = { + "string": "test", + "number": 42, + "boolean": True, + "list": ["a", "b", "c"], + "nested_dict": {"inner": "value", "inner_list": [1, 2, 3]}, + "null_value": None, + } + + query = GraphQLQuery( + query="mutation complex($input: ComplexInput!) { create(input: $input) }", variables=complex_vars + ) + + # Test that all variable types are preserved + self.assertEqual(query.variables["string"], "test") + self.assertEqual(query.variables["number"], 42) + self.assertEqual(query.variables["boolean"], True) + self.assertEqual(query.variables["list"], ["a", "b", "c"]) + self.assertEqual(query.variables["nested_dict"]["inner"], "value") + self.assertIsNone(query.variables["null_value"]) + + # Test serialization + query_dict = query.as_dict() + self.assertEqual(query_dict["variables"]["nested_dict"]["inner_list"], [1, 2, 3]) + + +if __name__ == "__main__": + unittest.main()