diff --git a/pyproject.toml b/pyproject.toml
index 28ffad42..905c1d6a 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -36,6 +36,12 @@ dependencies = [
]
requires-python = ">= 3.10"
+[project.optional-dependencies]
+async = [
+ "pytest-asyncio",
+ "aiofiles"
+]
+
[project.urls]
"Source" = "https://github.com/dbfixtures/pytest-postgresql"
"Bug Tracker" = "https://github.com/dbfixtures/pytest-postgresql/issues"
diff --git a/pytest_postgresql/factories/__init__.py b/pytest_postgresql/factories/__init__.py
index d6bd2f64..002304cb 100644
--- a/pytest_postgresql/factories/__init__.py
+++ b/pytest_postgresql/factories/__init__.py
@@ -17,8 +17,8 @@
# along with pytest-postgresql. If not, see .
"""Fixture factories for postgresql fixtures."""
-from pytest_postgresql.factories.client import postgresql
+from pytest_postgresql.factories.client import postgresql, postgresql_async
from pytest_postgresql.factories.noprocess import postgresql_noproc
from pytest_postgresql.factories.process import PortType, postgresql_proc
-__all__ = ("postgresql_proc", "postgresql_noproc", "postgresql", "PortType")
+__all__ = ("postgresql_proc", "postgresql_noproc", "postgresql", "postgresql_async", "PortType")
diff --git a/pytest_postgresql/factories/client.py b/pytest_postgresql/factories/client.py
index 76e7afce..3e015bc9 100644
--- a/pytest_postgresql/factories/client.py
+++ b/pytest_postgresql/factories/client.py
@@ -17,23 +17,25 @@
# along with pytest-postgresql. If not, see .
"""Fixture factory for postgresql client."""
-from typing import Callable, Iterator
+from typing import AsyncIterator, Callable, Iterator
import psycopg
import pytest
-from psycopg import Connection
+from psycopg import AsyncConnection, Connection
from pytest import FixtureRequest
from pytest_postgresql.config import get_config
from pytest_postgresql.executor import PostgreSQLExecutor
from pytest_postgresql.executor_noop import NoopExecutor
-from pytest_postgresql.janitor import DatabaseJanitor
+from pytest_postgresql.janitor import AsyncDatabaseJanitor, DatabaseJanitor
+from pytest_postgresql.types import FixtureScopeT
def postgresql(
process_fixture_name: str,
dbname: str | None = None,
isolation_level: "psycopg.IsolationLevel | None" = None,
+ scope: FixtureScopeT = "function",
) -> Callable[[FixtureRequest], Iterator[Connection]]:
"""Return connection fixture factory for PostgreSQL.
@@ -41,12 +43,13 @@ def postgresql(
:param dbname: database name
:param isolation_level: optional postgresql isolation level
defaults to server's default
+ :param scope: fixture scope; by default "function" which is recommended.
:returns: function which makes a connection to postgresql
"""
- @pytest.fixture
+ @pytest.fixture(scope=scope)
def postgresql_factory(request: FixtureRequest) -> Iterator[Connection]:
- """Fixture factory for PostgreSQL.
+ """Fixture connection factory for PostgreSQL.
:param request: fixture request object
:returns: postgresql client
@@ -85,3 +88,63 @@ def postgresql_factory(request: FixtureRequest) -> Iterator[Connection]:
db_connection.close()
return postgresql_factory
+
+
+def postgresql_async(
+ process_fixture_name: str,
+ dbname: str | None = None,
+ isolation_level: "psycopg.IsolationLevel | None" = None,
+ scope: FixtureScopeT = "function",
+) -> Callable[[FixtureRequest], AsyncIterator[AsyncConnection]]:
+ """Return async connection fixture factory for PostgreSQL.
+
+ :param process_fixture_name: name of the process fixture
+ :param dbname: database name
+ :param isolation_level: optional postgresql isolation level
+ defaults to server's default
+ :param scope: fixture scope; by default "function" which is recommended.
+ :returns: function which makes a connection to postgresql
+ """
+ import pytest_asyncio
+
+ @pytest_asyncio.fixture(scope=scope)
+ async def postgresql_factory(request: FixtureRequest) -> AsyncIterator[AsyncConnection]:
+ """Async connection fixture factory for PostgreSQL.
+
+ :param request: fixture request object
+ :returns: postgresql client
+ """
+ proc_fixture: PostgreSQLExecutor | NoopExecutor = request.getfixturevalue(process_fixture_name)
+ config = get_config(request)
+
+ pg_host = proc_fixture.host
+ pg_port = proc_fixture.port
+ pg_user = proc_fixture.user
+ pg_password = proc_fixture.password
+ pg_options = proc_fixture.options
+ pg_db = dbname or proc_fixture.dbname
+ janitor = AsyncDatabaseJanitor(
+ user=pg_user,
+ host=pg_host,
+ port=pg_port,
+ dbname=pg_db,
+ template_dbname=proc_fixture.template_dbname,
+ version=proc_fixture.version,
+ password=pg_password,
+ isolation_level=isolation_level,
+ )
+ if config["drop_test_database"]:
+ await janitor.drop()
+ async with janitor:
+ db_connection: AsyncConnection = await AsyncConnection.connect(
+ dbname=pg_db,
+ user=pg_user,
+ password=pg_password,
+ host=pg_host,
+ port=pg_port,
+ options=pg_options,
+ )
+ yield db_connection
+ await db_connection.close()
+
+ return postgresql_factory
diff --git a/pytest_postgresql/factories/noprocess.py b/pytest_postgresql/factories/noprocess.py
index fd3b96ba..3e5e9fc4 100644
--- a/pytest_postgresql/factories/noprocess.py
+++ b/pytest_postgresql/factories/noprocess.py
@@ -27,6 +27,7 @@
from pytest_postgresql.config import get_config
from pytest_postgresql.executor_noop import NoopExecutor
from pytest_postgresql.janitor import DatabaseJanitor
+from pytest_postgresql.types import FixtureScopeT
def xdistify_dbname(dbname: str) -> str:
@@ -45,6 +46,7 @@ def postgresql_noproc(
dbname: str | None = None,
options: str = "",
load: list[Callable | str | Path] | None = None,
+ scope: FixtureScopeT = "session",
) -> Callable[[FixtureRequest], Iterator[NoopExecutor]]:
"""Postgresql noprocess factory.
@@ -55,10 +57,11 @@ def postgresql_noproc(
:param dbname: postgresql database name
:param options: Postgresql connection options
:param load: List of functions used to initialize database's template.
+ :param scope: fixture scope; by default "session" which is recommended.
:returns: function which makes a postgresql process
"""
- @pytest.fixture(scope="session")
+ @pytest.fixture(scope=scope)
def postgresql_noproc_fixture(request: FixtureRequest) -> Iterator[NoopExecutor]:
"""Noop Process fixture for PostgreSQL.
diff --git a/pytest_postgresql/factories/process.py b/pytest_postgresql/factories/process.py
index 0a663fc0..25b494ee 100644
--- a/pytest_postgresql/factories/process.py
+++ b/pytest_postgresql/factories/process.py
@@ -32,6 +32,7 @@
from pytest_postgresql.exceptions import ExecutableMissingException
from pytest_postgresql.executor import PostgreSQLExecutor
from pytest_postgresql.janitor import DatabaseJanitor
+from pytest_postgresql.types import FixtureScopeT
PortType = port_for.PortType # mypy requires explicit export
@@ -81,6 +82,7 @@ def postgresql_proc(
unixsocketdir: str | None = None,
postgres_options: str | None = None,
load: list[Callable | str | Path] | None = None,
+ scope: FixtureScopeT = "session",
) -> Callable[[FixtureRequest, TempPathFactory], Iterator[PostgreSQLExecutor]]:
"""Postgresql process factory.
@@ -101,10 +103,11 @@ def postgresql_proc(
:param unixsocketdir: directory to create postgresql's unixsockets
:param postgres_options: Postgres executable options for use by pg_ctl
:param load: List of functions used to initialize database's template.
+ :param scope: fixture scope; by default "session" which is recommended.
:returns: function which makes a postgresql process
"""
- @pytest.fixture(scope="session")
+ @pytest.fixture(scope=scope)
def postgresql_proc_fixture(
request: FixtureRequest, tmp_path_factory: TempPathFactory
) -> Iterator[PostgreSQLExecutor]:
diff --git a/pytest_postgresql/janitor.py b/pytest_postgresql/janitor.py
index 247ade22..6dc3d2f6 100644
--- a/pytest_postgresql/janitor.py
+++ b/pytest_postgresql/janitor.py
@@ -1,6 +1,7 @@
"""Database Janitor."""
-from contextlib import contextmanager
+import inspect
+from contextlib import asynccontextmanager, contextmanager
from pathlib import Path
from types import TracebackType
from typing import Callable, Iterator, Type, TypeVar
@@ -9,8 +10,8 @@
from packaging.version import parse
from psycopg import Connection, Cursor
-from pytest_postgresql.loader import build_loader
-from pytest_postgresql.retry import retry
+from pytest_postgresql.loader import build_loader, build_loader_async
+from pytest_postgresql.retry import retry, retry_async
Version = type(parse("1"))
@@ -40,7 +41,7 @@ def __init__(
:param host: postgresql host
:param port: postgresql port
:param dbname: database name
- :param dbname: template database name
+ :param template_dbname: template database name
:param version: postgresql version number
:param password: optional postgresql password
:param isolation_level: optional postgresql isolation level
@@ -52,8 +53,8 @@ def __init__(
self.password = password
self.host = host
self.port = port
- # At least one of the dbname or template_dbname has to be filled.
- assert any([dbname, template_dbname])
+ if not (dbname or template_dbname):
+ raise ValueError("At least one of the dbname or template_dbname has to be filled.")
self.dbname = dbname
self.template_dbname = template_dbname
self._connection_timeout = connection_timeout
@@ -163,3 +164,145 @@ def __exit__(
) -> None:
"""Exit from Database janitor context cleaning after itself."""
self.drop()
+
+
+class AsyncDatabaseJanitor:
+ """Manage database state for specific tasks."""
+
+ def __init__(
+ self,
+ *,
+ user: str,
+ host: str,
+ port: str | int,
+ version: str | float | Version, # type: ignore[valid-type]
+ dbname: str | None = None,
+ template_dbname: str | None = None,
+ password: str | None = None,
+ isolation_level: "psycopg.IsolationLevel | None" = None,
+ connection_timeout: int = 60,
+ ) -> None:
+ """Initialize janitor.
+
+ :param user: postgresql username
+ :param host: postgresql host
+ :param port: postgresql port
+ :param dbname: database name
+ :param template_dbname: template database name
+ :param version: postgresql version number
+ :param password: optional postgresql password
+ :param isolation_level: optional postgresql isolation level
+ defaults to server's default
+ :param connection_timeout: how long to retry connection before
+ raising a TimeoutError
+ """
+ self.user = user
+ self.password = password
+ self.host = host
+ self.port = port
+ if not (dbname or template_dbname):
+ raise ValueError("At least one of the dbname or template_dbname has to be filled.")
+ self.dbname = dbname
+ self.template_dbname = template_dbname
+ self._connection_timeout = connection_timeout
+ self.isolation_level = isolation_level
+ if not isinstance(version, Version):
+ self.version = parse(str(version))
+ else:
+ self.version = version
+
+ async def init(self) -> None:
+ """Create database in postgresql."""
+ async with self.cursor() as cur:
+ if self.is_template():
+ await cur.execute(f'CREATE DATABASE "{self.template_dbname}" WITH is_template = true;')
+ elif self.template_dbname is None:
+ await cur.execute(f'CREATE DATABASE "{self.dbname}";')
+ else:
+ # And make sure no-one is left connected to the template database.
+ # Otherwise, Creating database from template will fail
+ await self._terminate_connection(cur, self.template_dbname)
+ await cur.execute(f'CREATE DATABASE "{self.dbname}" TEMPLATE "{self.template_dbname}";')
+
+ def is_template(self) -> bool:
+ """Determine whether the DatabaseJanitor maintains template or database."""
+ return self.dbname is None
+
+ async def drop(self) -> None:
+ """Drop database in postgresql (async)."""
+ db_to_drop = self.template_dbname if self.is_template() else self.dbname
+ assert db_to_drop
+ async with self.cursor() as cur:
+ await self._dont_datallowconn(cur, db_to_drop)
+ await self._terminate_connection(cur, db_to_drop)
+ if self.is_template():
+ await cur.execute(f'ALTER DATABASE "{db_to_drop}" with is_template false;')
+ await cur.execute(f'DROP DATABASE IF EXISTS "{db_to_drop}";')
+
+ @staticmethod
+ async def _dont_datallowconn(cur, dbname: str) -> None:
+ await cur.execute(f'ALTER DATABASE "{dbname}" with allow_connections false;')
+
+ @staticmethod
+ async def _terminate_connection(cur, dbname: str) -> None:
+ await cur.execute(
+ "SELECT pg_terminate_backend(pg_stat_activity.pid)"
+ "FROM pg_stat_activity "
+ "WHERE pg_stat_activity.datname = %s;",
+ (dbname,),
+ )
+
+ async def load(self, load: Callable | str | Path) -> None:
+ """Load data into a database (async).
+
+ Expects:
+
+ * a Path to sql file, that'll be loaded
+ * an import path to import callable
+ * a callable that expects: host, port, user, dbname and password arguments.
+
+ """
+ db_to_load = self.template_dbname if self.is_template() else self.dbname
+ _loader = build_loader_async(load)
+ cor = _loader(
+ host=self.host,
+ port=self.port,
+ user=self.user,
+ dbname=db_to_load,
+ password=self.password,
+ )
+ if inspect.isawaitable(cor):
+ await cor
+
+ @asynccontextmanager
+ async def cursor(self, dbname: str = "postgres"):
+ """Async context manager for postgresql cursor."""
+
+ async def connect() -> psycopg.AsyncConnection:
+ """Return postgresql async cursor."""
+ return await psycopg.AsyncConnection.connect(
+ dbname=dbname,
+ user=self.user,
+ password=self.password,
+ host=self.host,
+ port=self.port,
+ )
+
+ conn = await retry_async(connect, timeout=self._connection_timeout, possible_exception=psycopg.OperationalError)
+ try:
+ await conn.set_isolation_level(self.isolation_level)
+ await conn.set_autocommit(True)
+ # We must not run a transaction since we create a database.
+ async with conn.cursor() as cur:
+ yield cur
+ finally:
+ await conn.close()
+
+ async def __aenter__(self):
+ """Initialize Database Janitor."""
+ await self.init()
+ return self
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ """Exit from Database janitor context cleaning after itself."""
+ await self.drop()
diff --git a/pytest_postgresql/loader.py b/pytest_postgresql/loader.py
index c9b28cbd..36f37055 100644
--- a/pytest_postgresql/loader.py
+++ b/pytest_postgresql/loader.py
@@ -1,5 +1,6 @@
"""Loader helper functions."""
+import importlib
import re
from functools import partial
from pathlib import Path
@@ -16,7 +17,7 @@ def build_loader(load: Callable | str | Path) -> Callable:
loader_parts = re.split("[.:]", load, maxsplit=2)
import_path = ".".join(loader_parts[:-1])
loader_name = loader_parts[-1]
- _temp_import = __import__(import_path, globals(), locals(), fromlist=[loader_name])
+ _temp_import = importlib.import_module(import_path)
_loader: Callable = getattr(_temp_import, loader_name)
return _loader
else:
@@ -30,3 +31,29 @@ def sql(sql_filename: Path, **kwargs: Any) -> None:
with db_connection.cursor() as cur:
cur.execute(_fd.read())
db_connection.commit()
+
+
+def build_loader_async(load: Callable | str | Path) -> Callable:
+ """Build a loader callable."""
+ if isinstance(load, Path):
+ return partial(sql_async, load)
+ elif isinstance(load, str):
+ loader_parts = re.split("[.:]", load, maxsplit=2)
+ import_path = ".".join(loader_parts[:-1])
+ loader_name = loader_parts[-1]
+ _temp_import = importlib.import_module(import_path)
+ _loader: Callable = getattr(_temp_import, loader_name)
+ return _loader
+ else:
+ return load
+
+
+async def sql_async(sql_filename: Path, **kwargs: Any) -> None:
+ """Async database loader for sql files."""
+ import aiofiles
+
+ async with await psycopg.AsyncConnection.connect(**kwargs) as db_connection:
+ async with db_connection.cursor() as cur:
+ async with aiofiles.open(sql_filename, "r") as _fd:
+ await cur.execute(await _fd.read())
+ await db_connection.commit()
diff --git a/pytest_postgresql/retry.py b/pytest_postgresql/retry.py
index ea25fa2e..98c00daa 100644
--- a/pytest_postgresql/retry.py
+++ b/pytest_postgresql/retry.py
@@ -1,9 +1,10 @@
"""Small retry callable in case of specific error occurred."""
+import asyncio
import datetime
import sys
from time import sleep
-from typing import Callable, Type, TypeVar
+from typing import Awaitable, Callable, Type, TypeVar
T = TypeVar("T")
@@ -36,6 +37,34 @@ def retry(
sleep(1)
+async def retry_async(
+ func: Callable[[], Awaitable[T]],
+ timeout: int = 60,
+ possible_exception: Type[Exception] = Exception,
+) -> T:
+ """Attempt to retry the async function for timeout time.
+
+ Most often used for connecting to postgresql database as,
+ especially on macos on github-actions, first few tries fails
+ with this message:
+
+ ... ::
+ FATAL: the database system is starting up
+ """
+ time: datetime.datetime = get_current_datetime()
+ timeout_diff: datetime.timedelta = datetime.timedelta(seconds=timeout)
+ i = 0
+ while True:
+ i += 1
+ try:
+ res = await func()
+ return res
+ except possible_exception as e:
+ if time + timeout_diff < get_current_datetime():
+ raise TimeoutError(f"Failed after {i} attempts") from e
+ await asyncio.sleep(1)
+
+
def get_current_datetime() -> datetime.datetime:
"""Get the current datetime."""
# To ensure the current datetime retrieval is adjusted with the latest
diff --git a/pytest_postgresql/types.py b/pytest_postgresql/types.py
new file mode 100644
index 00000000..c5447d0f
--- /dev/null
+++ b/pytest_postgresql/types.py
@@ -0,0 +1,5 @@
+"""Pytest Postgresql Types."""
+
+from typing import Literal
+
+FixtureScopeT = Literal["session", "package", "module", "class", "function"]
diff --git a/tests/conftest.py b/tests/conftest.py
index 784b8905..483437af 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -17,3 +17,5 @@
postgresql_proc2 = factories.postgresql_proc(port=None, load=[TEST_SQL_FILE, TEST_SQL_FILE2])
postgresql2 = factories.postgresql("postgresql_proc2", dbname="test-db")
postgresql_load_1 = factories.postgresql("postgresql_proc2")
+postgresql2_async = factories.postgresql_async("postgresql_proc2", dbname="test-db")
+postgresql_load_1_async = factories.postgresql_async("postgresql_proc2")
diff --git a/tests/docker/test_noproc_docker.py b/tests/docker/test_noproc_docker.py
index ae25307a..b06dbc58 100644
--- a/tests/docker/test_noproc_docker.py
+++ b/tests/docker/test_noproc_docker.py
@@ -3,7 +3,7 @@
import pathlib
import pytest
-from psycopg import Connection
+from psycopg import AsyncConnection, Connection
import pytest_postgresql.factories.client
import pytest_postgresql.factories.noprocess
@@ -14,12 +14,17 @@
)
postgres_with_schema = pytest_postgresql.factories.client.postgresql("postgresql_my_proc")
+async_postgres_with_schema = pytest_postgresql.factories.client.postgresql_async("postgresql_my_proc")
+
postgresql_my_proc_template = pytest_postgresql.factories.noprocess.postgresql_noproc(
dbname="stories_templated", load=[load_database]
)
postgres_with_template = pytest_postgresql.factories.client.postgresql(
"postgresql_my_proc_template", dbname="stories_templated"
)
+async_postgres_with_template = pytest_postgresql.factories.client.postgresql_async(
+ "postgresql_my_proc_template", dbname="stories_templated"
+)
def test_postgres_docker_load(postgres_with_schema: Connection) -> None:
@@ -32,6 +37,14 @@ def test_postgres_docker_load(postgres_with_schema: Connection) -> None:
print(cur.fetchall())
+@pytest.mark.asyncio
+async def test_postgres_docker_load_async(async_postgres_with_schema: AsyncConnection) -> None:
+ """Async check main postgres fixture."""
+ async with async_postgres_with_schema.cursor() as cur:
+ await cur.execute("select * from public.tokens")
+ print(await cur.fetchall())
+
+
@pytest.mark.parametrize("_", range(5))
def test_template_database(postgres_with_template: Connection, _: int) -> None:
"""Check that the database structure gets recreated out of a template."""
@@ -43,3 +56,18 @@ def test_template_database(postgres_with_template: Connection, _: int) -> None:
cur.execute("SELECT * FROM stories")
res = cur.fetchall()
assert len(res) == 0
+
+
+# Async version of test_template_database
+@pytest.mark.asyncio
+@pytest.mark.parametrize("_", range(5))
+async def test_template_database_async(async_postgres_with_template, _: int) -> None:
+ """Async check that the database structure gets recreated out of a template."""
+ async with async_postgres_with_template.cursor() as cur:
+ await cur.execute("SELECT * FROM stories")
+ rows = await cur.fetchall()
+ assert len(rows) == 4
+ await cur.execute("TRUNCATE stories")
+ await cur.execute("SELECT * FROM stories")
+ rows = await cur.fetchall()
+ assert len(rows) == 0
diff --git a/tests/test_janitor.py b/tests/test_janitor.py
index fd1fca2a..02240cb6 100644
--- a/tests/test_janitor.py
+++ b/tests/test_janitor.py
@@ -7,7 +7,7 @@
import pytest
from packaging.version import parse
-from pytest_postgresql.janitor import DatabaseJanitor
+from pytest_postgresql.janitor import AsyncDatabaseJanitor, DatabaseJanitor
VERSION = parse("10")
@@ -19,6 +19,14 @@ def test_version_cast(version: Any) -> None:
assert janitor.version == VERSION
+@pytest.mark.parametrize("version", (VERSION, 10, "10"))
+@pytest.mark.asyncio
+async def test_version_cast_async(version: Any) -> None:
+ """Async test that version is cast to Version object."""
+ janitor = AsyncDatabaseJanitor(user="user", host="host", port="1234", dbname="database_name", version=version)
+ assert janitor.version == VERSION
+
+
@patch("pytest_postgresql.janitor.psycopg.connect")
def test_cursor_selects_postgres_database(connect_mock: MagicMock) -> None:
"""Test that the cursor requests the postgres database."""
@@ -27,6 +35,15 @@ def test_cursor_selects_postgres_database(connect_mock: MagicMock) -> None:
connect_mock.assert_called_once_with(dbname="postgres", user="user", password=None, host="host", port="1234")
+@patch("pytest_postgresql.janitor.psycopg.AsyncConnection.connect")
+@pytest.mark.asyncio
+async def test_cursor_selects_postgres_database_async(connect_mock: MagicMock) -> None:
+ """Async test that the cursor requests the postgres database."""
+ janitor = AsyncDatabaseJanitor(user="user", host="host", port="1234", dbname="database_name", version=10)
+ async with janitor.cursor():
+ connect_mock.assert_called_once_with(dbname="postgres", user="user", password=None, host="host", port="1234")
+
+
@patch("pytest_postgresql.janitor.psycopg.connect")
def test_cursor_connects_with_password(connect_mock: MagicMock) -> None:
"""Test that the cursor requests the postgres database."""
@@ -44,6 +61,24 @@ def test_cursor_connects_with_password(connect_mock: MagicMock) -> None:
)
+@patch("pytest_postgresql.janitor.psycopg.AsyncConnection.connect")
+@pytest.mark.asyncio
+async def test_cursor_connects_with_password_async(connect_mock: MagicMock) -> None:
+ """Async test that the cursor requests the postgres database with password."""
+ janitor = AsyncDatabaseJanitor(
+ user="user",
+ host="host",
+ port="1234",
+ dbname="database_name",
+ version=10,
+ password="some_password",
+ )
+ async with janitor.cursor():
+ connect_mock.assert_called_once_with(
+ dbname="postgres", user="user", password="some_password", host="host", port="1234"
+ )
+
+
@pytest.mark.skipif(sys.version_info < (3, 8), reason="Unittest call_args.kwargs was introduced since python 3.8")
@pytest.mark.parametrize("load_database", ("tests.loader.load_database", "tests.loader:load_database"))
@patch("pytest_postgresql.janitor.psycopg.connect")
@@ -63,3 +98,22 @@ def test_janitor_populate(connect_mock: MagicMock, load_database: str) -> None:
janitor.load(load_database)
assert connect_mock.called
assert connect_mock.call_args.kwargs == call_kwargs
+
+
+@pytest.mark.skipif(sys.version_info < (3, 8), reason="Unittest call_args.kwargs was introduced since python 3.8")
+@pytest.mark.parametrize("load_database", ("tests.loader.load_database", "tests.loader:load_database"))
+@patch("pytest_postgresql.janitor.psycopg.AsyncConnection.connect")
+@pytest.mark.asyncio
+async def test_janitor_populate_async(connect_mock: MagicMock, load_database: str) -> None:
+ """Async test that the cursor requests the postgres database and populates."""
+ call_kwargs = {
+ "host": "host",
+ "port": "1234",
+ "user": "user",
+ "dbname": "database_name",
+ "password": "some_password",
+ }
+ janitor = AsyncDatabaseJanitor(version=10, **call_kwargs) # type: ignore[arg-type]
+ await janitor.load(load_database)
+ assert connect_mock.called
+ assert connect_mock.call_args.kwargs == call_kwargs
diff --git a/tests/test_loader.py b/tests/test_loader.py
index c03f8a55..9e72f4c4 100644
--- a/tests/test_loader.py
+++ b/tests/test_loader.py
@@ -2,7 +2,9 @@
from pathlib import Path
-from pytest_postgresql.loader import build_loader, sql
+import pytest
+
+from pytest_postgresql.loader import build_loader, build_loader_async, sql, sql_async
from tests.loader import load_database
@@ -12,9 +14,30 @@ def test_loader_callables() -> None:
assert load_database == build_loader("tests.loader:load_database")
+@pytest.mark.asyncio
+async def test_loader_callables_async() -> None:
+ """Async test handling callables in build_loader_async."""
+ assert load_database == build_loader_async(load_database)
+ assert load_database == build_loader_async("tests.loader:load_database")
+
+ async def afun(*_args, **_kwargs):
+ return 0
+
+ assert afun == build_loader_async(afun)
+
+
def test_loader_sql() -> None:
"""Test returning partial running sql for the sql file path."""
sql_path = Path("test_sql/eidastats.sql")
loader_func = build_loader(sql_path)
assert loader_func.args == (sql_path,) # type: ignore
assert loader_func.func == sql # type: ignore
+
+
+@pytest.mark.asyncio
+async def test_loader_sql_async() -> None:
+ """Async test returning partial running sql_async for the sql file path."""
+ sql_path = Path("test_sql/eidastats.sql")
+ loader_func = build_loader_async(sql_path)
+ assert loader_func.args == (sql_path,) # type: ignore
+ assert loader_func.func == sql_async # type: ignore
diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py
index dec3ba8c..f4396caf 100644
--- a/tests/test_postgresql.py
+++ b/tests/test_postgresql.py
@@ -3,11 +3,11 @@
import decimal
import pytest
-from psycopg import Connection
+from psycopg import AsyncConnection, Connection
from psycopg.pq import ConnStatus
from pytest_postgresql.executor import PostgreSQLExecutor
-from pytest_postgresql.retry import retry
+from pytest_postgresql.retry import retry, retry_async
from tests.conftest import POSTGRESQL_VERSION
MAKE_Q = "CREATE TABLE test (id serial PRIMARY KEY, num integer, data varchar);"
@@ -72,3 +72,59 @@ def check_if_one_connection() -> None:
assert len(existing_connections) == 1, f"there is always only one connection, {existing_connections}"
retry(check_if_one_connection, timeout=120, possible_exception=AssertionError)
+
+
+@pytest.mark.asyncio
+async def test_main_postgres_async(postgresql_async: AsyncConnection) -> None:
+ """Async check main postgresql fixture."""
+ async with postgresql_async.cursor() as cur:
+ await cur.execute(MAKE_Q)
+ await postgresql_async.commit()
+
+
+@pytest.mark.asyncio
+async def test_two_postgreses_async(postgresql_async: AsyncConnection, postgresql2_async: AsyncConnection) -> None:
+ """Async check two postgresql fixtures on one test (async)."""
+ async with postgresql_async.cursor() as cur:
+ await cur.execute(MAKE_Q)
+ await postgresql_async.commit()
+
+ async with postgresql2_async.cursor() as cur:
+ await cur.execute(MAKE_Q)
+ await postgresql2_async.commit()
+
+
+@pytest.mark.asyncio
+async def test_postgres_load_two_files_async(postgresql_load_1_async: AsyncConnection) -> None:
+ """Async check postgresql fixture can load two files."""
+ async with postgresql_load_1_async.cursor() as cur:
+ await cur.execute(SELECT_Q)
+ results = await cur.fetchall()
+ assert len(results) == 2
+
+
+@pytest.mark.asyncio
+async def test_rand_postgres_port_async(postgresql2_async: AsyncConnection) -> None:
+ """Async check if postgres fixture can be started on random port."""
+ assert postgresql2_async.info.status == ConnStatus.OK
+
+
+@pytest.mark.skipif(
+ decimal.Decimal(POSTGRESQL_VERSION) < 10,
+ reason="Test query not supported in those postgresql versions, and soon will not be supported.",
+)
+@pytest.mark.asyncio
+@pytest.mark.parametrize("_", range(2))
+async def test_postgres_terminate_connection_async(postgresql2_async: AsyncConnection, _: int) -> None:
+ """Async test that connections are terminated between tests.
+
+ And check that only one exists at a time.
+ """
+ async with postgresql2_async.cursor() as cur:
+
+ async def check_if_one_connection() -> None:
+ await cur.execute("SELECT * FROM pg_stat_activity WHERE backend_type = 'client backend';")
+ existing_connections = await cur.fetchall()
+ assert len(existing_connections) == 1, f"there is always only one connection, {existing_connections}"
+
+ await retry_async(check_if_one_connection, timeout=120, possible_exception=AssertionError)
diff --git a/tests/test_template_database.py b/tests/test_template_database.py
index 64631779..f2882159 100644
--- a/tests/test_template_database.py
+++ b/tests/test_template_database.py
@@ -1,9 +1,9 @@
"""Template database tests."""
import pytest
-from psycopg import Connection
+from psycopg import AsyncConnection, Connection
-from pytest_postgresql.factories import postgresql, postgresql_proc
+from pytest_postgresql.factories import postgresql, postgresql_async, postgresql_proc
from tests.loader import load_database
postgresql_proc_with_template = postgresql_proc(
@@ -17,6 +17,11 @@
dbname="stories_templated",
)
+async_postgresql_template = postgresql_async(
+ "postgresql_proc_with_template",
+ dbname="stories_templated",
+)
+
@pytest.mark.xdist_group(name="template_database")
@pytest.mark.parametrize("_", range(5))
@@ -30,3 +35,18 @@ def test_template_database(postgresql_template: Connection, _: int) -> None:
cur.execute("SELECT * FROM stories")
res = cur.fetchall()
assert len(res) == 0
+
+
+@pytest.mark.asyncio
+@pytest.mark.xdist_group(name="template_database_async")
+@pytest.mark.parametrize("_", range(5))
+async def test_template_database_async(async_postgresql_template: AsyncConnection, _: int) -> None:
+ """Check that the database structure gets recreated out of a template."""
+ async with async_postgresql_template.cursor() as cur:
+ await cur.execute("SELECT * FROM stories")
+ res = await cur.fetchall()
+ assert len(res) == 4
+ await cur.execute("TRUNCATE stories")
+ await cur.execute("SELECT * FROM stories")
+ res = await cur.fetchall()
+ assert len(res) == 0