From 071985853d9913722917fd40b0598fcc93883cb7 Mon Sep 17 00:00:00 2001 From: Ivan Webber Date: Fri, 12 Dec 2025 16:43:42 -0800 Subject: [PATCH 01/11] Add async support and unlock scope --- pytest_postgresql/factories/client.py | 77 +++++++++++- pytest_postgresql/factories/noprocess.py | 5 +- pytest_postgresql/factories/process.py | 5 +- pytest_postgresql/janitor.py | 146 ++++++++++++++++++++++- pytest_postgresql/loader.py | 25 +++- pytest_postgresql/retry.py | 31 ++++- 6 files changed, 281 insertions(+), 8 deletions(-) diff --git a/pytest_postgresql/factories/client.py b/pytest_postgresql/factories/client.py index 76e7afce..eb636664 100644 --- a/pytest_postgresql/factories/client.py +++ b/pytest_postgresql/factories/client.py @@ -23,6 +23,7 @@ import pytest from psycopg import Connection from pytest import FixtureRequest +from _pytest.scope import _ScopeName from pytest_postgresql.config import get_config from pytest_postgresql.executor import PostgreSQLExecutor @@ -34,6 +35,7 @@ def postgresql( process_fixture_name: str, dbname: str | None = None, isolation_level: "psycopg.IsolationLevel | None" = None, + scope: _ScopeName="function" ) -> Callable[[FixtureRequest], Iterator[Connection]]: """Return connection fixture factory for PostgreSQL. @@ -41,10 +43,11 @@ def postgresql( :param dbname: database name :param isolation_level: optional postgresql isolation level defaults to server's default + :param scope: fixture scope; by default "function" which is recommended. :returns: function which makes a connection to postgresql """ - @pytest.fixture + @pytest.fixture(scope=scope) def postgresql_factory(request: FixtureRequest) -> Iterator[Connection]: """Fixture factory for PostgreSQL. @@ -85,3 +88,75 @@ def postgresql_factory(request: FixtureRequest) -> Iterator[Connection]: db_connection.close() return postgresql_factory + + + +def postgresql_async( + process_fixture_name: str, + dbname: str | None = None, + isolation_level: "psycopg.IsolationLevel | None" = None, + scope: _ScopeName="function" +) -> Callable[[FixtureRequest], Iterator[Connection]]: + """Return async connection fixture factory for PostgreSQL. + + :param process_fixture_name: name of the process fixture + :param dbname: database name + :param isolation_level: optional postgresql isolation level + defaults to server's default + :param scope: fixture scope; by default "function" which is recommended. + :returns: function which makes a connection to postgresql + """ + + import pytest_asyncio + from psycopg import AsyncConnection + + from pytest_postgresql.janitor import AsyncDatabaseJanitor + + @pytest_asyncio.fixture(scope=scope) + async def postgresql_factory(request: FixtureRequest) -> Iterator[AsyncConnection]: + """ + Async fixture factory for PostgreSQL. + + :param request: fixture request object + :returns: postgresql client + """ + proc_fixture: PostgreSQLExecutor | NoopExecutor = request.getfixturevalue(process_fixture_name) + config = get_config(request) + + pg_host = proc_fixture.host + pg_port = proc_fixture.port + pg_user = proc_fixture.user + pg_password = proc_fixture.password + pg_options = proc_fixture.options + pg_db = dbname or proc_fixture.dbname + janitor = DatabaseJanitor( + user=pg_user, + host=pg_host, + port=pg_port, + dbname=pg_db, + template_dbname=proc_fixture.template_dbname, + version=proc_fixture.version, + password=pg_password, + isolation_level=isolation_level, + ) + if config["drop_test_database"]: + janitor.drop() + with AsyncDatabaseJanitor( + pg_user, pg_host, pg_port, pg_db, proc_fixture.version, pg_password, isolation_level + ) as janitor: + # Line modified here + db_connection: AsyncConnection = await AsyncConnection.connect( + dbname=pg_db, + user=pg_user, + password=pg_password, + host=pg_host, + port=pg_port, + options=pg_options, + ) + for load_element in pg_load: + janitor.load(load_element) + yield db_connection + # And here + await db_connection.close() + + return postgresql_factory \ No newline at end of file diff --git a/pytest_postgresql/factories/noprocess.py b/pytest_postgresql/factories/noprocess.py index fd3b96ba..8038f700 100644 --- a/pytest_postgresql/factories/noprocess.py +++ b/pytest_postgresql/factories/noprocess.py @@ -23,6 +23,7 @@ import pytest from pytest import FixtureRequest +from _pytest.scope import _ScopeName from pytest_postgresql.config import get_config from pytest_postgresql.executor_noop import NoopExecutor @@ -45,6 +46,7 @@ def postgresql_noproc( dbname: str | None = None, options: str = "", load: list[Callable | str | Path] | None = None, + scope: _ScopeName="session" ) -> Callable[[FixtureRequest], Iterator[NoopExecutor]]: """Postgresql noprocess factory. @@ -55,10 +57,11 @@ def postgresql_noproc( :param dbname: postgresql database name :param options: Postgresql connection options :param load: List of functions used to initialize database's template. + :param scope: fixture scope; by default "session" which is recommended. :returns: function which makes a postgresql process """ - @pytest.fixture(scope="session") + @pytest.fixture(scope=scope) def postgresql_noproc_fixture(request: FixtureRequest) -> Iterator[NoopExecutor]: """Noop Process fixture for PostgreSQL. diff --git a/pytest_postgresql/factories/process.py b/pytest_postgresql/factories/process.py index 0a663fc0..1fdbada8 100644 --- a/pytest_postgresql/factories/process.py +++ b/pytest_postgresql/factories/process.py @@ -27,6 +27,7 @@ import pytest from port_for import PortForException, get_port from pytest import FixtureRequest, TempPathFactory +from _pytest.scope import _ScopeName from pytest_postgresql.config import PostgresqlConfigDict, get_config from pytest_postgresql.exceptions import ExecutableMissingException @@ -81,6 +82,7 @@ def postgresql_proc( unixsocketdir: str | None = None, postgres_options: str | None = None, load: list[Callable | str | Path] | None = None, + scope: _ScopeName="session" ) -> Callable[[FixtureRequest, TempPathFactory], Iterator[PostgreSQLExecutor]]: """Postgresql process factory. @@ -101,10 +103,11 @@ def postgresql_proc( :param unixsocketdir: directory to create postgresql's unixsockets :param postgres_options: Postgres executable options for use by pg_ctl :param load: List of functions used to initialize database's template. + :param scope: fixture scope; by default "session" which is recommended. :returns: function which makes a postgresql process """ - @pytest.fixture(scope="session") + @pytest.fixture(scope=scope) def postgresql_proc_fixture( request: FixtureRequest, tmp_path_factory: TempPathFactory ) -> Iterator[PostgreSQLExecutor]: diff --git a/pytest_postgresql/janitor.py b/pytest_postgresql/janitor.py index 247ade22..fe631c51 100644 --- a/pytest_postgresql/janitor.py +++ b/pytest_postgresql/janitor.py @@ -1,6 +1,7 @@ """Database Janitor.""" -from contextlib import contextmanager +import inspect +from contextlib import contextmanager, asynccontextmanager from pathlib import Path from types import TracebackType from typing import Callable, Iterator, Type, TypeVar @@ -9,8 +10,8 @@ from packaging.version import parse from psycopg import Connection, Cursor -from pytest_postgresql.loader import build_loader -from pytest_postgresql.retry import retry +from pytest_postgresql.loader import build_loader, build_loader_async +from pytest_postgresql.retry import retry, retry_async Version = type(parse("1")) @@ -163,3 +164,142 @@ def __exit__( ) -> None: """Exit from Database janitor context cleaning after itself.""" self.drop() + + +class AsyncDatabaseJanitor: + """Manage database state for specific tasks.""" + + def __init__( + self, + *, + user: str, + host: str, + port: str | int, + version: str | float | Version, # type: ignore[valid-type] + dbname: str | None = None, + template_dbname: str | None = None, + password: str | None = None, + isolation_level: "psycopg.IsolationLevel | None" = None, + connection_timeout: int = 60, + ) -> None: + """Initialize janitor. + + :param user: postgresql username + :param host: postgresql host + :param port: postgresql port + :param dbname: database name + :param dbname: template database name + :param version: postgresql version number + :param password: optional postgresql password + :param isolation_level: optional postgresql isolation level + defaults to server's default + :param connection_timeout: how long to retry connection before + raising a TimeoutError + """ + self.user = user + self.password = password + self.host = host + self.port = port + # At least one of the dbname or template_dbname has to be filled. + assert any([dbname, template_dbname]) + self.dbname = dbname + self.template_dbname = template_dbname + self._connection_timeout = connection_timeout + self.isolation_level = isolation_level + if not isinstance(version, Version): + self.version = parse(str(version)) + else: + self.version = version + + async def init(self) -> None: + """Create database in postgresql.""" + async with self.cursor() as cur: + if self.is_template(): + await cur.execute(f'CREATE DATABASE "{self.template_dbname}" WITH is_template = true;') + elif self.template_dbname is None: + await cur.execute(f'CREATE DATABASE "{self.dbname}";') + else: + # And make sure no-one is left connected to the template database. + # Otherwise, Creating database from template will fail + await self._terminate_connection(cur, self.template_dbname) + await cur.execute(f'CREATE DATABASE "{self.dbname}" TEMPLATE "{self.template_dbname}";') + + def is_template(self) -> bool: + """Determine whether the DatabaseJanitor maintains template or database.""" + return self.dbname is None + + async def drop(self) -> None: + """Drop database in postgresql (async).""" + db_to_drop = self.template_dbname if self.is_template() else self.dbname + assert db_to_drop + async with self.cursor() as cur: + await self._dont_datallowconn(cur, db_to_drop) + await self._terminate_connection(cur, db_to_drop) + if self.is_template(): + await cur.execute(f'ALTER DATABASE "{db_to_drop}" with is_template false;') + await cur.execute(f'DROP DATABASE IF EXISTS "{db_to_drop}";') + + @staticmethod + async def _dont_datallowconn(cur, dbname: str) -> None: + await cur.execute(f'ALTER DATABASE "{dbname}" with allow_connections false;') + + @staticmethod + async def _terminate_connection(cur, dbname: str) -> None: + await cur.execute( + "SELECT pg_terminate_backend(pg_stat_activity.pid)" + "FROM pg_stat_activity " + "WHERE pg_stat_activity.datname = %s;", + (dbname,), + ) + + async def load(self, load: Callable | str | Path) -> None: + """Load data into a database (async). + + Expects: + + * a Path to sql file, that'll be loaded + * an import path to import callable + * a callable that expects: host, port, user, dbname and password arguments. + + """ + db_to_load = self.template_dbname if self.is_template() else self.dbname + _loader = build_loader_async(load) + cor = _loader( + host=self.host, + port=self.port, + user=self.user, + dbname=db_to_load, + password=self.password, + ) + if inspect.isawaitable(cor): + await cor + + @asynccontextmanager + async def cursor(self, dbname: str = "postgres"): + """Async context manager for postgresql cursor.""" + + async def connect() -> psycopg.AsyncConnection: + return await psycopg.AsyncConnection.connect( + dbname=dbname, + user=self.user, + password=self.password, + host=self.host, + port=self.port, + ) + + conn = await retry_async(connect, timeout=self._connection_timeout, possible_exception=psycopg.OperationalError) + conn.isolation_level = self.isolation_level + # We must not run a transaction since we create a database. + conn.autocommit = True + async with conn.cursor() as cur: + try: + yield cur + finally: + await conn.close() + + async def __aenter__(self): + await self.init() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.drop() diff --git a/pytest_postgresql/loader.py b/pytest_postgresql/loader.py index c9b28cbd..04418c3c 100644 --- a/pytest_postgresql/loader.py +++ b/pytest_postgresql/loader.py @@ -1,5 +1,6 @@ """Loader helper functions.""" +import importlib import re from functools import partial from pathlib import Path @@ -16,7 +17,7 @@ def build_loader(load: Callable | str | Path) -> Callable: loader_parts = re.split("[.:]", load, maxsplit=2) import_path = ".".join(loader_parts[:-1]) loader_name = loader_parts[-1] - _temp_import = __import__(import_path, globals(), locals(), fromlist=[loader_name]) + _temp_import = importlib.import_module(import_path, globals(), locals(), fromlist=[loader_name]) _loader: Callable = getattr(_temp_import, loader_name) return _loader else: @@ -30,3 +31,25 @@ def sql(sql_filename: Path, **kwargs: Any) -> None: with db_connection.cursor() as cur: cur.execute(_fd.read()) db_connection.commit() + +def build_loader_async(load: Callable | str | Path) -> Callable: + """Build a loader callable.""" + if isinstance(load, Path): + return partial(sql_async, load) + elif isinstance(load, str): + loader_parts = re.split("[.:]", load, maxsplit=2) + import_path = ".".join(loader_parts[:-1]) + loader_name = loader_parts[-1] + _temp_import = importlib.import_module(import_path, globals(), locals(), fromlist=[loader_name]) + _loader: Callable = getattr(_temp_import, loader_name) + return _loader + else: + return load + +async def sql_async(sql_filename: Path, **kwargs: Any) -> None: + """Async database loader for sql files.""" + async with await psycopg.AsyncConnection.connect(**kwargs) as db_connection: + async with await db_connection.cursor() as cur: + with open(sql_filename, "r") as _fd: + await cur.execute(_fd.read()) + await db_connection.commit() diff --git a/pytest_postgresql/retry.py b/pytest_postgresql/retry.py index ea25fa2e..98c00daa 100644 --- a/pytest_postgresql/retry.py +++ b/pytest_postgresql/retry.py @@ -1,9 +1,10 @@ """Small retry callable in case of specific error occurred.""" +import asyncio import datetime import sys from time import sleep -from typing import Callable, Type, TypeVar +from typing import Awaitable, Callable, Type, TypeVar T = TypeVar("T") @@ -36,6 +37,34 @@ def retry( sleep(1) +async def retry_async( + func: Callable[[], Awaitable[T]], + timeout: int = 60, + possible_exception: Type[Exception] = Exception, +) -> T: + """Attempt to retry the async function for timeout time. + + Most often used for connecting to postgresql database as, + especially on macos on github-actions, first few tries fails + with this message: + + ... :: + FATAL: the database system is starting up + """ + time: datetime.datetime = get_current_datetime() + timeout_diff: datetime.timedelta = datetime.timedelta(seconds=timeout) + i = 0 + while True: + i += 1 + try: + res = await func() + return res + except possible_exception as e: + if time + timeout_diff < get_current_datetime(): + raise TimeoutError(f"Failed after {i} attempts") from e + await asyncio.sleep(1) + + def get_current_datetime() -> datetime.datetime: """Get the current datetime.""" # To ensure the current datetime retrieval is adjusted with the latest From db2fcdb7a105c59dc9ebb37cd882650cee36f51a Mon Sep 17 00:00:00 2001 From: Ivan Webber Date: Fri, 12 Dec 2025 17:27:15 -0800 Subject: [PATCH 02/11] Add some async tests based on sync tests --- pyproject.toml | 5 +++ pytest_postgresql/factories/__init__.py | 2 +- tests/conftest.py | 2 + tests/docker/test_noproc_docker.py | 30 ++++++++++++- tests/test_janitor.py | 56 ++++++++++++++++++++++++- tests/test_loader.py | 21 +++++++++- tests/test_postgresql.py | 54 +++++++++++++++++++++++- tests/test_template_database.py | 14 ++++++- 8 files changed, 178 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 28ffad42..cd7b97be 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,11 @@ dependencies = [ ] requires-python = ">= 3.10" +[project.optional-dependencies] +async = [ + "pytest-asyncio" +] + [project.urls] "Source" = "https://github.com/dbfixtures/pytest-postgresql" "Bug Tracker" = "https://github.com/dbfixtures/pytest-postgresql/issues" diff --git a/pytest_postgresql/factories/__init__.py b/pytest_postgresql/factories/__init__.py index d6bd2f64..15e84490 100644 --- a/pytest_postgresql/factories/__init__.py +++ b/pytest_postgresql/factories/__init__.py @@ -17,7 +17,7 @@ # along with pytest-postgresql. If not, see . """Fixture factories for postgresql fixtures.""" -from pytest_postgresql.factories.client import postgresql +from pytest_postgresql.factories.client import postgresql, postgresql_async from pytest_postgresql.factories.noprocess import postgresql_noproc from pytest_postgresql.factories.process import PortType, postgresql_proc diff --git a/tests/conftest.py b/tests/conftest.py index 784b8905..483437af 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,3 +17,5 @@ postgresql_proc2 = factories.postgresql_proc(port=None, load=[TEST_SQL_FILE, TEST_SQL_FILE2]) postgresql2 = factories.postgresql("postgresql_proc2", dbname="test-db") postgresql_load_1 = factories.postgresql("postgresql_proc2") +postgresql2_async = factories.postgresql_async("postgresql_proc2", dbname="test-db") +postgresql_load_1_async = factories.postgresql_async("postgresql_proc2") diff --git a/tests/docker/test_noproc_docker.py b/tests/docker/test_noproc_docker.py index ae25307a..0a73e0e7 100644 --- a/tests/docker/test_noproc_docker.py +++ b/tests/docker/test_noproc_docker.py @@ -3,7 +3,7 @@ import pathlib import pytest -from psycopg import Connection +from psycopg import Connection, AsyncConnection import pytest_postgresql.factories.client import pytest_postgresql.factories.noprocess @@ -14,12 +14,17 @@ ) postgres_with_schema = pytest_postgresql.factories.client.postgresql("postgresql_my_proc") +async_postgres_with_schema = pytest_postgresql.factories.client.postgresql_async("postgresql_my_proc") + postgresql_my_proc_template = pytest_postgresql.factories.noprocess.postgresql_noproc( dbname="stories_templated", load=[load_database] ) postgres_with_template = pytest_postgresql.factories.client.postgresql( "postgresql_my_proc_template", dbname="stories_templated" ) +async_postgres_with_template = pytest_postgresql.factories.client.postgresql_async( + "postgresql_my_proc_template", dbname="stories_templated" +) def test_postgres_docker_load(postgres_with_schema: Connection) -> None: @@ -32,6 +37,14 @@ def test_postgres_docker_load(postgres_with_schema: Connection) -> None: print(cur.fetchall()) +@pytest.mark.asyncio +async def test_postgres_docker_load_async(async_postgres_with_schema: AsyncConnection) -> None: + """Async check main postgres fixture.""" + async with async_postgres_with_schema.cursor() as cur: + await cur.execute("select * from public.tokens") + print(await cur.fetchall()) + + @pytest.mark.parametrize("_", range(5)) def test_template_database(postgres_with_template: Connection, _: int) -> None: """Check that the database structure gets recreated out of a template.""" @@ -43,3 +56,18 @@ def test_template_database(postgres_with_template: Connection, _: int) -> None: cur.execute("SELECT * FROM stories") res = cur.fetchall() assert len(res) == 0 + + +# Async version of test_template_database +@pytest.mark.asyncio +@pytest.mark.parametrize("_", range(5)) +async def test_template_database_async(async_postgres_with_template, _: int) -> None: + """Async check that the database structure gets recreated out of a template.""" + async with async_postgres_with_template.cursor() as cur: + await cur.execute("SELECT * FROM stories") + rows = await cur.fetchall() + assert len(rows) == 4 + await cur.execute("TRUNCATE stories") + await cur.execute("SELECT * FROM stories") + rows = await cur.fetchall() + assert len(rows) == 0 diff --git a/tests/test_janitor.py b/tests/test_janitor.py index fd1fca2a..90847063 100644 --- a/tests/test_janitor.py +++ b/tests/test_janitor.py @@ -1,5 +1,6 @@ """Database Janitor tests.""" +import asyncio import sys from typing import Any from unittest.mock import MagicMock, patch @@ -7,7 +8,7 @@ import pytest from packaging.version import parse -from pytest_postgresql.janitor import DatabaseJanitor +from pytest_postgresql.janitor import DatabaseJanitor, AsyncDatabaseJanitor VERSION = parse("10") @@ -18,6 +19,13 @@ def test_version_cast(version: Any) -> None: janitor = DatabaseJanitor(user="user", host="host", port="1234", dbname="database_name", version=version) assert janitor.version == VERSION +@pytest.mark.parametrize("version", (VERSION, 10, "10")) +@pytest.mark.asyncio +async def test_version_cast_async(version: Any) -> None: + """Async test that version is cast to Version object.""" + janitor = AsyncDatabaseJanitor(user="user", host="host", port="1234", dbname="database_name", version=version) + assert janitor.version == VERSION + @patch("pytest_postgresql.janitor.psycopg.connect") def test_cursor_selects_postgres_database(connect_mock: MagicMock) -> None: @@ -27,6 +35,15 @@ def test_cursor_selects_postgres_database(connect_mock: MagicMock) -> None: connect_mock.assert_called_once_with(dbname="postgres", user="user", password=None, host="host", port="1234") +@patch("pytest_postgresql.janitor.psycopg.AsyncConnection.connect") +@pytest.mark.asyncio +async def test_cursor_selects_postgres_database_async(connect_mock: MagicMock) -> None: + """Async test that the cursor requests the postgres database.""" + janitor = AsyncDatabaseJanitor(user="user", host="host", port="1234", dbname="database_name", version=10) + async with janitor.cursor(): + connect_mock.assert_called_once_with(dbname="postgres", user="user", password=None, host="host", port="1234") + + @patch("pytest_postgresql.janitor.psycopg.connect") def test_cursor_connects_with_password(connect_mock: MagicMock) -> None: """Test that the cursor requests the postgres database.""" @@ -44,6 +61,24 @@ def test_cursor_connects_with_password(connect_mock: MagicMock) -> None: ) +@patch("pytest_postgresql.janitor.psycopg.AsyncConnection.connect") +@pytest.mark.asyncio +async def test_cursor_connects_with_password_async(connect_mock: MagicMock) -> None: + """Async test that the cursor requests the postgres database with password.""" + janitor = AsyncDatabaseJanitor( + user="user", + host="host", + port="1234", + dbname="database_name", + version=10, + password="some_password", + ) + async with janitor.cursor(): + connect_mock.assert_called_once_with( + dbname="postgres", user="user", password="some_password", host="host", port="1234" + ) + + @pytest.mark.skipif(sys.version_info < (3, 8), reason="Unittest call_args.kwargs was introduced since python 3.8") @pytest.mark.parametrize("load_database", ("tests.loader.load_database", "tests.loader:load_database")) @patch("pytest_postgresql.janitor.psycopg.connect") @@ -63,3 +98,22 @@ def test_janitor_populate(connect_mock: MagicMock, load_database: str) -> None: janitor.load(load_database) assert connect_mock.called assert connect_mock.call_args.kwargs == call_kwargs + + +@pytest.mark.skipif(sys.version_info < (3, 8), reason="Unittest call_args.kwargs was introduced since python 3.8") +@pytest.mark.parametrize("load_database", ("tests.loader.load_database", "tests.loader:load_database")) +@patch("pytest_postgresql.janitor.psycopg.AsyncConnection.connect") +@pytest.mark.asyncio +async def test_janitor_populate_async(connect_mock: MagicMock, load_database: str) -> None: + """Async test that the cursor requests the postgres database and populates.""" + call_kwargs = { + "host": "host", + "port": "1234", + "user": "user", + "dbname": "database_name", + "password": "some_password", + } + janitor = AsyncDatabaseJanitor(version=10, **call_kwargs) # type: ignore[arg-type] + await janitor.load(load_database) + assert connect_mock.called + assert connect_mock.call_args.kwargs == call_kwargs diff --git a/tests/test_loader.py b/tests/test_loader.py index c03f8a55..0c41e8b2 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -2,7 +2,9 @@ from pathlib import Path -from pytest_postgresql.loader import build_loader, sql +import pytest + +from pytest_postgresql.loader import build_loader, sql, build_loader_async, sql_async from tests.loader import load_database @@ -11,6 +13,15 @@ def test_loader_callables() -> None: assert load_database == build_loader(load_database) assert load_database == build_loader("tests.loader:load_database") +@pytest.mark.asyncio +async def test_loader_callables_async() -> None: + """Async test handling callables in build_loader_async.""" + assert load_database == build_loader_async(load_database) + assert load_database == build_loader_async("tests.loader:load_database") + + async def afun(*args, **kwargs): + return 0 + assert afun == build_loader_async(afun) def test_loader_sql() -> None: """Test returning partial running sql for the sql file path.""" @@ -18,3 +29,11 @@ def test_loader_sql() -> None: loader_func = build_loader(sql_path) assert loader_func.args == (sql_path,) # type: ignore assert loader_func.func == sql # type: ignore + +@pytest.mark.asyncio +async def test_loader_sql_async() -> None: + """Async test returning partial running sql_async for the sql file path.""" + sql_path = Path("test_sql/eidastats.sql") + loader_func = build_loader_async(sql_path) + assert loader_func.args == (sql_path,) # type: ignore + assert loader_func.func == sql_async # type: ignore diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py index dec3ba8c..52990ea2 100644 --- a/tests/test_postgresql.py +++ b/tests/test_postgresql.py @@ -3,7 +3,7 @@ import decimal import pytest -from psycopg import Connection +from psycopg import Connection, AsyncConnection from psycopg.pq import ConnStatus from pytest_postgresql.executor import PostgreSQLExecutor @@ -72,3 +72,55 @@ def check_if_one_connection() -> None: assert len(existing_connections) == 1, f"there is always only one connection, {existing_connections}" retry(check_if_one_connection, timeout=120, possible_exception=AssertionError) + + +@pytest.mark.asyncio +async def test_main_postgres_async(postgresql_async: AsyncConnection) -> None: + """Async check main postgresql fixture.""" + async with postgresql_async.cursor() as cur: + await cur.execute(MAKE_Q) + await postgresql_async.commit() + + +@pytest.mark.asyncio +async def test_two_postgreses_async(postgresql_async: AsyncConnection, postgresql2_async: AsyncConnection) -> None: + """Async check two postgresql fixtures on one test (async).""" + async with postgresql_async.cursor() as cur: + await cur.execute(MAKE_Q) + await postgresql_async.commit() + + async with postgresql2_async.cursor() as cur: + await cur.execute(MAKE_Q) + await postgresql2_async.commit() + + +@pytest.mark.asyncio +async def test_postgres_load_two_files_async(postgresql_load_1_async: AsyncConnection) -> None: + """Async check postgresql fixture can load two files.""" + async with postgresql_load_1_async.cursor() as cur: + await cur.execute(SELECT_Q) + results = await cur.fetchall() + assert len(results) == 2 + + +@pytest.mark.asyncio +async def test_rand_postgres_port_async(postgresql2_async: AsyncConnection) -> None: + """Async check if postgres fixture can be started on random port.""" + assert postgresql2_async.info.status == ConnStatus.OK + + +@pytest.mark.asyncio +@pytest.mark.parametrize("_", range(2)) +async def test_postgres_terminate_connection_async(postgresql2_async: AsyncConnection, _: int) -> None: + """Async test that connections are terminated between tests. + + And check that only one exists at a time. + """ + async with postgresql2_async.cursor() as cur: + + async def check_if_one_connection() -> None: + await cur.execute("SELECT * FROM pg_stat_activity WHERE backend_type = 'client backend';") + existing_connections = await cur.fetchall() + assert len(existing_connections) == 1, f"there is always only one connection, {existing_connections}" + + await retry(check_if_one_connection, timeout=120, possible_exception=AssertionError) \ No newline at end of file diff --git a/tests/test_template_database.py b/tests/test_template_database.py index 64631779..bf4963a7 100644 --- a/tests/test_template_database.py +++ b/tests/test_template_database.py @@ -1,7 +1,7 @@ """Template database tests.""" import pytest -from psycopg import Connection +from psycopg import Connection, AsyncConnection from pytest_postgresql.factories import postgresql, postgresql_proc from tests.loader import load_database @@ -30,3 +30,15 @@ def test_template_database(postgresql_template: Connection, _: int) -> None: cur.execute("SELECT * FROM stories") res = cur.fetchall() assert len(res) == 0 + +@pytest.mark.parametrize("_", range(5)) +def test_template_database(async_postgresql_template: AsyncConnection, _: int) -> None: + """Check that the database structure gets recreated out of a template.""" + async with postgresql_template.cursor() as cur: + await cur.execute("SELECT * FROM stories") + res = cur.fetchall() + assert len(res) == 4 + await cur.execute("TRUNCATE stories") + await cur.execute("SELECT * FROM stories") + res = await cur.fetchall() + assert len(res) == 0 From de78d6c5af1f4b4a58818000a79aefb2c1b8f017 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 13 Dec 2025 01:31:03 +0000 Subject: [PATCH 03/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytest_postgresql/factories/client.py | 13 +++++-------- pytest_postgresql/factories/noprocess.py | 4 ++-- pytest_postgresql/factories/process.py | 4 ++-- pytest_postgresql/janitor.py | 4 ++-- pytest_postgresql/loader.py | 2 ++ tests/docker/test_noproc_docker.py | 2 +- tests/test_janitor.py | 4 ++-- tests/test_loader.py | 6 +++++- tests/test_postgresql.py | 4 ++-- tests/test_template_database.py | 3 ++- 10 files changed, 25 insertions(+), 21 deletions(-) diff --git a/pytest_postgresql/factories/client.py b/pytest_postgresql/factories/client.py index eb636664..88e07132 100644 --- a/pytest_postgresql/factories/client.py +++ b/pytest_postgresql/factories/client.py @@ -21,9 +21,9 @@ import psycopg import pytest +from _pytest.scope import _ScopeName from psycopg import Connection from pytest import FixtureRequest -from _pytest.scope import _ScopeName from pytest_postgresql.config import get_config from pytest_postgresql.executor import PostgreSQLExecutor @@ -35,7 +35,7 @@ def postgresql( process_fixture_name: str, dbname: str | None = None, isolation_level: "psycopg.IsolationLevel | None" = None, - scope: _ScopeName="function" + scope: _ScopeName = "function", ) -> Callable[[FixtureRequest], Iterator[Connection]]: """Return connection fixture factory for PostgreSQL. @@ -90,12 +90,11 @@ def postgresql_factory(request: FixtureRequest) -> Iterator[Connection]: return postgresql_factory - def postgresql_async( process_fixture_name: str, dbname: str | None = None, isolation_level: "psycopg.IsolationLevel | None" = None, - scope: _ScopeName="function" + scope: _ScopeName = "function", ) -> Callable[[FixtureRequest], Iterator[Connection]]: """Return async connection fixture factory for PostgreSQL. @@ -106,7 +105,6 @@ def postgresql_async( :param scope: fixture scope; by default "function" which is recommended. :returns: function which makes a connection to postgresql """ - import pytest_asyncio from psycopg import AsyncConnection @@ -114,8 +112,7 @@ def postgresql_async( @pytest_asyncio.fixture(scope=scope) async def postgresql_factory(request: FixtureRequest) -> Iterator[AsyncConnection]: - """ - Async fixture factory for PostgreSQL. + """Async fixture factory for PostgreSQL. :param request: fixture request object :returns: postgresql client @@ -159,4 +156,4 @@ async def postgresql_factory(request: FixtureRequest) -> Iterator[AsyncConnectio # And here await db_connection.close() - return postgresql_factory \ No newline at end of file + return postgresql_factory diff --git a/pytest_postgresql/factories/noprocess.py b/pytest_postgresql/factories/noprocess.py index 8038f700..94726710 100644 --- a/pytest_postgresql/factories/noprocess.py +++ b/pytest_postgresql/factories/noprocess.py @@ -22,8 +22,8 @@ from typing import Callable, Iterator import pytest -from pytest import FixtureRequest from _pytest.scope import _ScopeName +from pytest import FixtureRequest from pytest_postgresql.config import get_config from pytest_postgresql.executor_noop import NoopExecutor @@ -46,7 +46,7 @@ def postgresql_noproc( dbname: str | None = None, options: str = "", load: list[Callable | str | Path] | None = None, - scope: _ScopeName="session" + scope: _ScopeName = "session", ) -> Callable[[FixtureRequest], Iterator[NoopExecutor]]: """Postgresql noprocess factory. diff --git a/pytest_postgresql/factories/process.py b/pytest_postgresql/factories/process.py index 1fdbada8..f40d9cb1 100644 --- a/pytest_postgresql/factories/process.py +++ b/pytest_postgresql/factories/process.py @@ -25,9 +25,9 @@ import port_for import pytest +from _pytest.scope import _ScopeName from port_for import PortForException, get_port from pytest import FixtureRequest, TempPathFactory -from _pytest.scope import _ScopeName from pytest_postgresql.config import PostgresqlConfigDict, get_config from pytest_postgresql.exceptions import ExecutableMissingException @@ -82,7 +82,7 @@ def postgresql_proc( unixsocketdir: str | None = None, postgres_options: str | None = None, load: list[Callable | str | Path] | None = None, - scope: _ScopeName="session" + scope: _ScopeName = "session", ) -> Callable[[FixtureRequest, TempPathFactory], Iterator[PostgreSQLExecutor]]: """Postgresql process factory. diff --git a/pytest_postgresql/janitor.py b/pytest_postgresql/janitor.py index fe631c51..7671ba6f 100644 --- a/pytest_postgresql/janitor.py +++ b/pytest_postgresql/janitor.py @@ -1,7 +1,7 @@ """Database Janitor.""" import inspect -from contextlib import contextmanager, asynccontextmanager +from contextlib import asynccontextmanager, contextmanager from pathlib import Path from types import TracebackType from typing import Callable, Iterator, Type, TypeVar @@ -279,7 +279,7 @@ async def cursor(self, dbname: str = "postgres"): """Async context manager for postgresql cursor.""" async def connect() -> psycopg.AsyncConnection: - return await psycopg.AsyncConnection.connect( + return await psycopg.AsyncConnection.connect( dbname=dbname, user=self.user, password=self.password, diff --git a/pytest_postgresql/loader.py b/pytest_postgresql/loader.py index 04418c3c..e1193cb6 100644 --- a/pytest_postgresql/loader.py +++ b/pytest_postgresql/loader.py @@ -32,6 +32,7 @@ def sql(sql_filename: Path, **kwargs: Any) -> None: cur.execute(_fd.read()) db_connection.commit() + def build_loader_async(load: Callable | str | Path) -> Callable: """Build a loader callable.""" if isinstance(load, Path): @@ -46,6 +47,7 @@ def build_loader_async(load: Callable | str | Path) -> Callable: else: return load + async def sql_async(sql_filename: Path, **kwargs: Any) -> None: """Async database loader for sql files.""" async with await psycopg.AsyncConnection.connect(**kwargs) as db_connection: diff --git a/tests/docker/test_noproc_docker.py b/tests/docker/test_noproc_docker.py index 0a73e0e7..b06dbc58 100644 --- a/tests/docker/test_noproc_docker.py +++ b/tests/docker/test_noproc_docker.py @@ -3,7 +3,7 @@ import pathlib import pytest -from psycopg import Connection, AsyncConnection +from psycopg import AsyncConnection, Connection import pytest_postgresql.factories.client import pytest_postgresql.factories.noprocess diff --git a/tests/test_janitor.py b/tests/test_janitor.py index 90847063..02240cb6 100644 --- a/tests/test_janitor.py +++ b/tests/test_janitor.py @@ -1,6 +1,5 @@ """Database Janitor tests.""" -import asyncio import sys from typing import Any from unittest.mock import MagicMock, patch @@ -8,7 +7,7 @@ import pytest from packaging.version import parse -from pytest_postgresql.janitor import DatabaseJanitor, AsyncDatabaseJanitor +from pytest_postgresql.janitor import AsyncDatabaseJanitor, DatabaseJanitor VERSION = parse("10") @@ -19,6 +18,7 @@ def test_version_cast(version: Any) -> None: janitor = DatabaseJanitor(user="user", host="host", port="1234", dbname="database_name", version=version) assert janitor.version == VERSION + @pytest.mark.parametrize("version", (VERSION, 10, "10")) @pytest.mark.asyncio async def test_version_cast_async(version: Any) -> None: diff --git a/tests/test_loader.py b/tests/test_loader.py index 0c41e8b2..a4319a6b 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -4,7 +4,7 @@ import pytest -from pytest_postgresql.loader import build_loader, sql, build_loader_async, sql_async +from pytest_postgresql.loader import build_loader, build_loader_async, sql, sql_async from tests.loader import load_database @@ -13,6 +13,7 @@ def test_loader_callables() -> None: assert load_database == build_loader(load_database) assert load_database == build_loader("tests.loader:load_database") + @pytest.mark.asyncio async def test_loader_callables_async() -> None: """Async test handling callables in build_loader_async.""" @@ -21,8 +22,10 @@ async def test_loader_callables_async() -> None: async def afun(*args, **kwargs): return 0 + assert afun == build_loader_async(afun) + def test_loader_sql() -> None: """Test returning partial running sql for the sql file path.""" sql_path = Path("test_sql/eidastats.sql") @@ -30,6 +33,7 @@ def test_loader_sql() -> None: assert loader_func.args == (sql_path,) # type: ignore assert loader_func.func == sql # type: ignore + @pytest.mark.asyncio async def test_loader_sql_async() -> None: """Async test returning partial running sql_async for the sql file path.""" diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py index 52990ea2..f752d1fe 100644 --- a/tests/test_postgresql.py +++ b/tests/test_postgresql.py @@ -3,7 +3,7 @@ import decimal import pytest -from psycopg import Connection, AsyncConnection +from psycopg import AsyncConnection, Connection from psycopg.pq import ConnStatus from pytest_postgresql.executor import PostgreSQLExecutor @@ -123,4 +123,4 @@ async def check_if_one_connection() -> None: existing_connections = await cur.fetchall() assert len(existing_connections) == 1, f"there is always only one connection, {existing_connections}" - await retry(check_if_one_connection, timeout=120, possible_exception=AssertionError) \ No newline at end of file + await retry(check_if_one_connection, timeout=120, possible_exception=AssertionError) diff --git a/tests/test_template_database.py b/tests/test_template_database.py index bf4963a7..30568eaa 100644 --- a/tests/test_template_database.py +++ b/tests/test_template_database.py @@ -1,7 +1,7 @@ """Template database tests.""" import pytest -from psycopg import Connection, AsyncConnection +from psycopg import AsyncConnection, Connection from pytest_postgresql.factories import postgresql, postgresql_proc from tests.loader import load_database @@ -31,6 +31,7 @@ def test_template_database(postgresql_template: Connection, _: int) -> None: res = cur.fetchall() assert len(res) == 0 + @pytest.mark.parametrize("_", range(5)) def test_template_database(async_postgresql_template: AsyncConnection, _: int) -> None: """Check that the database structure gets recreated out of a template.""" From 614a3c938726acea3d65824acae35878dda8945c Mon Sep 17 00:00:00 2001 From: Ivan Webber Date: Mon, 15 Dec 2025 12:02:21 -0800 Subject: [PATCH 04/11] Fix up code to address feedback --- pyproject.toml | 3 +- pytest_postgresql/factories/client.py | 36 +++++++++--------------- pytest_postgresql/factories/noprocess.py | 4 +-- pytest_postgresql/factories/process.py | 4 +-- pytest_postgresql/janitor.py | 12 ++++---- pytest_postgresql/loader.py | 13 +++++---- pytest_postgresql/types.py | 3 ++ tests/test_postgresql.py | 4 +-- tests/test_template_database.py | 15 +++++++--- 9 files changed, 49 insertions(+), 45 deletions(-) create mode 100644 pytest_postgresql/types.py diff --git a/pyproject.toml b/pyproject.toml index cd7b97be..905c1d6a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,8 @@ requires-python = ">= 3.10" [project.optional-dependencies] async = [ - "pytest-asyncio" + "pytest-asyncio", + "aiofiles" ] [project.urls] diff --git a/pytest_postgresql/factories/client.py b/pytest_postgresql/factories/client.py index eb636664..c9572edd 100644 --- a/pytest_postgresql/factories/client.py +++ b/pytest_postgresql/factories/client.py @@ -17,25 +17,24 @@ # along with pytest-postgresql. If not, see . """Fixture factory for postgresql client.""" -from typing import Callable, Iterator +from typing import AsyncIterator, Callable, Iterator import psycopg import pytest -from psycopg import Connection +from psycopg import AsyncConnection, Connection from pytest import FixtureRequest -from _pytest.scope import _ScopeName from pytest_postgresql.config import get_config from pytest_postgresql.executor import PostgreSQLExecutor from pytest_postgresql.executor_noop import NoopExecutor -from pytest_postgresql.janitor import DatabaseJanitor - +from pytest_postgresql.types import FixtureScopeT +from pytest_postgresql.janitor import AsyncDatabaseJanitor, DatabaseJanitor def postgresql( process_fixture_name: str, dbname: str | None = None, isolation_level: "psycopg.IsolationLevel | None" = None, - scope: _ScopeName="function" + scope: FixtureScopeT="function" ) -> Callable[[FixtureRequest], Iterator[Connection]]: """Return connection fixture factory for PostgreSQL. @@ -49,7 +48,7 @@ def postgresql( @pytest.fixture(scope=scope) def postgresql_factory(request: FixtureRequest) -> Iterator[Connection]: - """Fixture factory for PostgreSQL. + """Fixture connection factory for PostgreSQL. :param request: fixture request object :returns: postgresql client @@ -95,8 +94,8 @@ def postgresql_async( process_fixture_name: str, dbname: str | None = None, isolation_level: "psycopg.IsolationLevel | None" = None, - scope: _ScopeName="function" -) -> Callable[[FixtureRequest], Iterator[Connection]]: + scope: FixtureScopeT="function" +) -> Callable[[FixtureRequest], AsyncIterator[AsyncConnection]]: """Return async connection fixture factory for PostgreSQL. :param process_fixture_name: name of the process fixture @@ -108,14 +107,11 @@ def postgresql_async( """ import pytest_asyncio - from psycopg import AsyncConnection - - from pytest_postgresql.janitor import AsyncDatabaseJanitor @pytest_asyncio.fixture(scope=scope) - async def postgresql_factory(request: FixtureRequest) -> Iterator[AsyncConnection]: + async def postgresql_factory(request: FixtureRequest) -> AsyncIterator[AsyncConnection]: """ - Async fixture factory for PostgreSQL. + Async connection fixture factory for PostgreSQL. :param request: fixture request object :returns: postgresql client @@ -129,7 +125,7 @@ async def postgresql_factory(request: FixtureRequest) -> Iterator[AsyncConnectio pg_password = proc_fixture.password pg_options = proc_fixture.options pg_db = dbname or proc_fixture.dbname - janitor = DatabaseJanitor( + janitor = AsyncDatabaseJanitor( user=pg_user, host=pg_host, port=pg_port, @@ -140,11 +136,8 @@ async def postgresql_factory(request: FixtureRequest) -> Iterator[AsyncConnectio isolation_level=isolation_level, ) if config["drop_test_database"]: - janitor.drop() - with AsyncDatabaseJanitor( - pg_user, pg_host, pg_port, pg_db, proc_fixture.version, pg_password, isolation_level - ) as janitor: - # Line modified here + await janitor.drop() + async with janitor: db_connection: AsyncConnection = await AsyncConnection.connect( dbname=pg_db, user=pg_user, @@ -153,10 +146,7 @@ async def postgresql_factory(request: FixtureRequest) -> Iterator[AsyncConnectio port=pg_port, options=pg_options, ) - for load_element in pg_load: - janitor.load(load_element) yield db_connection - # And here await db_connection.close() return postgresql_factory \ No newline at end of file diff --git a/pytest_postgresql/factories/noprocess.py b/pytest_postgresql/factories/noprocess.py index 8038f700..ba285e4b 100644 --- a/pytest_postgresql/factories/noprocess.py +++ b/pytest_postgresql/factories/noprocess.py @@ -23,10 +23,10 @@ import pytest from pytest import FixtureRequest -from _pytest.scope import _ScopeName from pytest_postgresql.config import get_config from pytest_postgresql.executor_noop import NoopExecutor +from pytest_postgresql.types import FixtureScopeT from pytest_postgresql.janitor import DatabaseJanitor @@ -46,7 +46,7 @@ def postgresql_noproc( dbname: str | None = None, options: str = "", load: list[Callable | str | Path] | None = None, - scope: _ScopeName="session" + scope: FixtureScopeT="session" ) -> Callable[[FixtureRequest], Iterator[NoopExecutor]]: """Postgresql noprocess factory. diff --git a/pytest_postgresql/factories/process.py b/pytest_postgresql/factories/process.py index 1fdbada8..4091e505 100644 --- a/pytest_postgresql/factories/process.py +++ b/pytest_postgresql/factories/process.py @@ -27,11 +27,11 @@ import pytest from port_for import PortForException, get_port from pytest import FixtureRequest, TempPathFactory -from _pytest.scope import _ScopeName from pytest_postgresql.config import PostgresqlConfigDict, get_config from pytest_postgresql.exceptions import ExecutableMissingException from pytest_postgresql.executor import PostgreSQLExecutor +from pytest_postgresql.types import FixtureScopeT from pytest_postgresql.janitor import DatabaseJanitor PortType = port_for.PortType # mypy requires explicit export @@ -82,7 +82,7 @@ def postgresql_proc( unixsocketdir: str | None = None, postgres_options: str | None = None, load: list[Callable | str | Path] | None = None, - scope: _ScopeName="session" + scope: FixtureScopeT="session" ) -> Callable[[FixtureRequest, TempPathFactory], Iterator[PostgreSQLExecutor]]: """Postgresql process factory. diff --git a/pytest_postgresql/janitor.py b/pytest_postgresql/janitor.py index fe631c51..39ef1ab2 100644 --- a/pytest_postgresql/janitor.py +++ b/pytest_postgresql/janitor.py @@ -53,8 +53,8 @@ def __init__( self.password = password self.host = host self.port = port - # At least one of the dbname or template_dbname has to be filled. - assert any([dbname, template_dbname]) + if not (dbname or template_dbname): + raise ValueError("At least one of the dbname or template_dbname has to be filled.") self.dbname = dbname self.template_dbname = template_dbname self._connection_timeout = connection_timeout @@ -200,8 +200,8 @@ def __init__( self.password = password self.host = host self.port = port - # At least one of the dbname or template_dbname has to be filled. - assert any([dbname, template_dbname]) + if not (dbname or template_dbname): + raise ValueError("At least one of the dbname or template_dbname has to be filled.") self.dbname = dbname self.template_dbname = template_dbname self._connection_timeout = connection_timeout @@ -288,9 +288,9 @@ async def connect() -> psycopg.AsyncConnection: ) conn = await retry_async(connect, timeout=self._connection_timeout, possible_exception=psycopg.OperationalError) - conn.isolation_level = self.isolation_level + await conn.set_isolation_level(self.isolation_level) + await conn.set_autocommit(True) # We must not run a transaction since we create a database. - conn.autocommit = True async with conn.cursor() as cur: try: yield cur diff --git a/pytest_postgresql/loader.py b/pytest_postgresql/loader.py index 04418c3c..83b4799f 100644 --- a/pytest_postgresql/loader.py +++ b/pytest_postgresql/loader.py @@ -17,7 +17,7 @@ def build_loader(load: Callable | str | Path) -> Callable: loader_parts = re.split("[.:]", load, maxsplit=2) import_path = ".".join(loader_parts[:-1]) loader_name = loader_parts[-1] - _temp_import = importlib.import_module(import_path, globals(), locals(), fromlist=[loader_name]) + _temp_import = importlib.import_module(import_path, loader_name) _loader: Callable = getattr(_temp_import, loader_name) return _loader else: @@ -40,7 +40,7 @@ def build_loader_async(load: Callable | str | Path) -> Callable: loader_parts = re.split("[.:]", load, maxsplit=2) import_path = ".".join(loader_parts[:-1]) loader_name = loader_parts[-1] - _temp_import = importlib.import_module(import_path, globals(), locals(), fromlist=[loader_name]) + _temp_import = importlib.import_module(import_path, loader_name) _loader: Callable = getattr(_temp_import, loader_name) return _loader else: @@ -48,8 +48,11 @@ def build_loader_async(load: Callable | str | Path) -> Callable: async def sql_async(sql_filename: Path, **kwargs: Any) -> None: """Async database loader for sql files.""" + + import aiofiles + async with await psycopg.AsyncConnection.connect(**kwargs) as db_connection: - async with await db_connection.cursor() as cur: - with open(sql_filename, "r") as _fd: - await cur.execute(_fd.read()) + async with db_connection.cursor() as cur: + async with aiofiles.open(sql_filename, "r") as _fd: + await cur.execute(await _fd.read()) await db_connection.commit() diff --git a/pytest_postgresql/types.py b/pytest_postgresql/types.py new file mode 100644 index 00000000..b116b6ee --- /dev/null +++ b/pytest_postgresql/types.py @@ -0,0 +1,3 @@ +from typing import Literal + +FixtureScopeT = Literal["session", "package", "module", "class", "function"] \ No newline at end of file diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py index 52990ea2..cc70e19c 100644 --- a/tests/test_postgresql.py +++ b/tests/test_postgresql.py @@ -7,7 +7,7 @@ from psycopg.pq import ConnStatus from pytest_postgresql.executor import PostgreSQLExecutor -from pytest_postgresql.retry import retry +from pytest_postgresql.retry import retry, retry_async from tests.conftest import POSTGRESQL_VERSION MAKE_Q = "CREATE TABLE test (id serial PRIMARY KEY, num integer, data varchar);" @@ -123,4 +123,4 @@ async def check_if_one_connection() -> None: existing_connections = await cur.fetchall() assert len(existing_connections) == 1, f"there is always only one connection, {existing_connections}" - await retry(check_if_one_connection, timeout=120, possible_exception=AssertionError) \ No newline at end of file + await retry_async(check_if_one_connection, timeout=120, possible_exception=AssertionError) \ No newline at end of file diff --git a/tests/test_template_database.py b/tests/test_template_database.py index bf4963a7..81d113d9 100644 --- a/tests/test_template_database.py +++ b/tests/test_template_database.py @@ -3,7 +3,7 @@ import pytest from psycopg import Connection, AsyncConnection -from pytest_postgresql.factories import postgresql, postgresql_proc +from pytest_postgresql.factories import postgresql, postgresql_async, postgresql_proc from tests.loader import load_database postgresql_proc_with_template = postgresql_proc( @@ -17,6 +17,11 @@ dbname="stories_templated", ) +async_postgresql_template = postgresql_async( + "postgresql_proc_with_template", + dbname="stories_templated", +) + @pytest.mark.xdist_group(name="template_database") @pytest.mark.parametrize("_", range(5)) @@ -31,12 +36,14 @@ def test_template_database(postgresql_template: Connection, _: int) -> None: res = cur.fetchall() assert len(res) == 0 + +@pytest.mark.asyncio @pytest.mark.parametrize("_", range(5)) -def test_template_database(async_postgresql_template: AsyncConnection, _: int) -> None: +async def test_template_database_async(async_postgresql_template: AsyncConnection, _: int) -> None: """Check that the database structure gets recreated out of a template.""" - async with postgresql_template.cursor() as cur: + async with async_postgresql_template.cursor() as cur: await cur.execute("SELECT * FROM stories") - res = cur.fetchall() + res = await cur.fetchall() assert len(res) == 4 await cur.execute("TRUNCATE stories") await cur.execute("SELECT * FROM stories") From fb3ce5c197940eab239f8eb6cd4454cdad52b29f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 15 Dec 2025 20:09:55 +0000 Subject: [PATCH 05/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytest_postgresql/factories/client.py | 10 +++++----- pytest_postgresql/factories/noprocess.py | 4 ++-- pytest_postgresql/factories/process.py | 4 ++-- pytest_postgresql/loader.py | 1 - pytest_postgresql/types.py | 2 +- 5 files changed, 10 insertions(+), 11 deletions(-) diff --git a/pytest_postgresql/factories/client.py b/pytest_postgresql/factories/client.py index cd41ef30..3e015bc9 100644 --- a/pytest_postgresql/factories/client.py +++ b/pytest_postgresql/factories/client.py @@ -27,14 +27,15 @@ from pytest_postgresql.config import get_config from pytest_postgresql.executor import PostgreSQLExecutor from pytest_postgresql.executor_noop import NoopExecutor -from pytest_postgresql.types import FixtureScopeT from pytest_postgresql.janitor import AsyncDatabaseJanitor, DatabaseJanitor +from pytest_postgresql.types import FixtureScopeT + def postgresql( process_fixture_name: str, dbname: str | None = None, isolation_level: "psycopg.IsolationLevel | None" = None, - scope: FixtureScopeT="function" + scope: FixtureScopeT = "function", ) -> Callable[[FixtureRequest], Iterator[Connection]]: """Return connection fixture factory for PostgreSQL. @@ -93,7 +94,7 @@ def postgresql_async( process_fixture_name: str, dbname: str | None = None, isolation_level: "psycopg.IsolationLevel | None" = None, - scope: FixtureScopeT="function" + scope: FixtureScopeT = "function", ) -> Callable[[FixtureRequest], AsyncIterator[AsyncConnection]]: """Return async connection fixture factory for PostgreSQL. @@ -108,8 +109,7 @@ def postgresql_async( @pytest_asyncio.fixture(scope=scope) async def postgresql_factory(request: FixtureRequest) -> AsyncIterator[AsyncConnection]: - """ - Async connection fixture factory for PostgreSQL. + """Async connection fixture factory for PostgreSQL. :param request: fixture request object :returns: postgresql client diff --git a/pytest_postgresql/factories/noprocess.py b/pytest_postgresql/factories/noprocess.py index ba285e4b..3e5e9fc4 100644 --- a/pytest_postgresql/factories/noprocess.py +++ b/pytest_postgresql/factories/noprocess.py @@ -26,8 +26,8 @@ from pytest_postgresql.config import get_config from pytest_postgresql.executor_noop import NoopExecutor -from pytest_postgresql.types import FixtureScopeT from pytest_postgresql.janitor import DatabaseJanitor +from pytest_postgresql.types import FixtureScopeT def xdistify_dbname(dbname: str) -> str: @@ -46,7 +46,7 @@ def postgresql_noproc( dbname: str | None = None, options: str = "", load: list[Callable | str | Path] | None = None, - scope: FixtureScopeT="session" + scope: FixtureScopeT = "session", ) -> Callable[[FixtureRequest], Iterator[NoopExecutor]]: """Postgresql noprocess factory. diff --git a/pytest_postgresql/factories/process.py b/pytest_postgresql/factories/process.py index 4091e505..25b494ee 100644 --- a/pytest_postgresql/factories/process.py +++ b/pytest_postgresql/factories/process.py @@ -31,8 +31,8 @@ from pytest_postgresql.config import PostgresqlConfigDict, get_config from pytest_postgresql.exceptions import ExecutableMissingException from pytest_postgresql.executor import PostgreSQLExecutor -from pytest_postgresql.types import FixtureScopeT from pytest_postgresql.janitor import DatabaseJanitor +from pytest_postgresql.types import FixtureScopeT PortType = port_for.PortType # mypy requires explicit export @@ -82,7 +82,7 @@ def postgresql_proc( unixsocketdir: str | None = None, postgres_options: str | None = None, load: list[Callable | str | Path] | None = None, - scope: FixtureScopeT="session" + scope: FixtureScopeT = "session", ) -> Callable[[FixtureRequest, TempPathFactory], Iterator[PostgreSQLExecutor]]: """Postgresql process factory. diff --git a/pytest_postgresql/loader.py b/pytest_postgresql/loader.py index 4565c37e..30665e2b 100644 --- a/pytest_postgresql/loader.py +++ b/pytest_postgresql/loader.py @@ -50,7 +50,6 @@ def build_loader_async(load: Callable | str | Path) -> Callable: async def sql_async(sql_filename: Path, **kwargs: Any) -> None: """Async database loader for sql files.""" - import aiofiles async with await psycopg.AsyncConnection.connect(**kwargs) as db_connection: diff --git a/pytest_postgresql/types.py b/pytest_postgresql/types.py index b116b6ee..508bf3ad 100644 --- a/pytest_postgresql/types.py +++ b/pytest_postgresql/types.py @@ -1,3 +1,3 @@ from typing import Literal -FixtureScopeT = Literal["session", "package", "module", "class", "function"] \ No newline at end of file +FixtureScopeT = Literal["session", "package", "module", "class", "function"] From d939c07381fd7c6cc5c637ced2a3a4f1e036d0de Mon Sep 17 00:00:00 2001 From: Ivan Webber Date: Mon, 15 Dec 2025 12:17:37 -0800 Subject: [PATCH 06/11] Address a little more feedback --- pytest_postgresql/janitor.py | 16 ++++++++-------- tests/test_loader.py | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/pytest_postgresql/janitor.py b/pytest_postgresql/janitor.py index 9938b5b0..6466fe50 100644 --- a/pytest_postgresql/janitor.py +++ b/pytest_postgresql/janitor.py @@ -288,14 +288,14 @@ async def connect() -> psycopg.AsyncConnection: ) conn = await retry_async(connect, timeout=self._connection_timeout, possible_exception=psycopg.OperationalError) - await conn.set_isolation_level(self.isolation_level) - await conn.set_autocommit(True) - # We must not run a transaction since we create a database. - async with conn.cursor() as cur: - try: - yield cur - finally: - await conn.close() + try: + await conn.set_isolation_level(self.isolation_level) + await conn.set_autocommit(True) + # We must not run a transaction since we create a database. + async with conn.cursor() as cur: + yield cur + finally: + await conn.close() async def __aenter__(self): await self.init() diff --git a/tests/test_loader.py b/tests/test_loader.py index a4319a6b..9e72f4c4 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -20,7 +20,7 @@ async def test_loader_callables_async() -> None: assert load_database == build_loader_async(load_database) assert load_database == build_loader_async("tests.loader:load_database") - async def afun(*args, **kwargs): + async def afun(*_args, **_kwargs): return 0 assert afun == build_loader_async(afun) From 7d837636b916b81805b952bd6414a4d944c6dc2b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 15 Dec 2025 20:17:47 +0000 Subject: [PATCH 07/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytest_postgresql/janitor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytest_postgresql/janitor.py b/pytest_postgresql/janitor.py index 6466fe50..21d2b86c 100644 --- a/pytest_postgresql/janitor.py +++ b/pytest_postgresql/janitor.py @@ -293,7 +293,7 @@ async def connect() -> psycopg.AsyncConnection: await conn.set_autocommit(True) # We must not run a transaction since we create a database. async with conn.cursor() as cur: - yield cur + yield cur finally: await conn.close() From 1ee36daaa8661dfb8702f6de1934bd458334476c Mon Sep 17 00:00:00 2001 From: Ivan Webber Date: Mon, 15 Dec 2025 12:20:25 -0800 Subject: [PATCH 08/11] Fix importlib import_module call --- pytest_postgresql/loader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytest_postgresql/loader.py b/pytest_postgresql/loader.py index 30665e2b..36f37055 100644 --- a/pytest_postgresql/loader.py +++ b/pytest_postgresql/loader.py @@ -17,7 +17,7 @@ def build_loader(load: Callable | str | Path) -> Callable: loader_parts = re.split("[.:]", load, maxsplit=2) import_path = ".".join(loader_parts[:-1]) loader_name = loader_parts[-1] - _temp_import = importlib.import_module(import_path, loader_name) + _temp_import = importlib.import_module(import_path) _loader: Callable = getattr(_temp_import, loader_name) return _loader else: @@ -41,7 +41,7 @@ def build_loader_async(load: Callable | str | Path) -> Callable: loader_parts = re.split("[.:]", load, maxsplit=2) import_path = ".".join(loader_parts[:-1]) loader_name = loader_parts[-1] - _temp_import = importlib.import_module(import_path, loader_name) + _temp_import = importlib.import_module(import_path) _loader: Callable = getattr(_temp_import, loader_name) return _loader else: From f3546d746108cf9d046a2b54e0b2468da6ae701a Mon Sep 17 00:00:00 2001 From: Ivan Webber Date: Mon, 15 Dec 2025 12:25:12 -0800 Subject: [PATCH 09/11] Add version check for test --- tests/test_postgresql.py | 4 ++++ tests/test_template_database.py | 1 + 2 files changed, 5 insertions(+) diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py index 88590882..f4396caf 100644 --- a/tests/test_postgresql.py +++ b/tests/test_postgresql.py @@ -109,6 +109,10 @@ async def test_rand_postgres_port_async(postgresql2_async: AsyncConnection) -> N assert postgresql2_async.info.status == ConnStatus.OK +@pytest.mark.skipif( + decimal.Decimal(POSTGRESQL_VERSION) < 10, + reason="Test query not supported in those postgresql versions, and soon will not be supported.", +) @pytest.mark.asyncio @pytest.mark.parametrize("_", range(2)) async def test_postgres_terminate_connection_async(postgresql2_async: AsyncConnection, _: int) -> None: diff --git a/tests/test_template_database.py b/tests/test_template_database.py index 05c71f83..f2882159 100644 --- a/tests/test_template_database.py +++ b/tests/test_template_database.py @@ -38,6 +38,7 @@ def test_template_database(postgresql_template: Connection, _: int) -> None: @pytest.mark.asyncio +@pytest.mark.xdist_group(name="template_database_async") @pytest.mark.parametrize("_", range(5)) async def test_template_database_async(async_postgresql_template: AsyncConnection, _: int) -> None: """Check that the database structure gets recreated out of a template.""" From c1a2adeb95848321de758ae7a0083510b03313db Mon Sep 17 00:00:00 2001 From: Ivan Webber Date: Mon, 15 Dec 2025 12:30:06 -0800 Subject: [PATCH 10/11] Address ci fails --- pytest_postgresql/factories/__init__.py | 2 +- pytest_postgresql/janitor.py | 3 +++ pytest_postgresql/types.py | 2 ++ 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/pytest_postgresql/factories/__init__.py b/pytest_postgresql/factories/__init__.py index 15e84490..002304cb 100644 --- a/pytest_postgresql/factories/__init__.py +++ b/pytest_postgresql/factories/__init__.py @@ -21,4 +21,4 @@ from pytest_postgresql.factories.noprocess import postgresql_noproc from pytest_postgresql.factories.process import PortType, postgresql_proc -__all__ = ("postgresql_proc", "postgresql_noproc", "postgresql", "PortType") +__all__ = ("postgresql_proc", "postgresql_noproc", "postgresql", "postgresql_async", "PortType") diff --git a/pytest_postgresql/janitor.py b/pytest_postgresql/janitor.py index 21d2b86c..357076a2 100644 --- a/pytest_postgresql/janitor.py +++ b/pytest_postgresql/janitor.py @@ -279,6 +279,7 @@ async def cursor(self, dbname: str = "postgres"): """Async context manager for postgresql cursor.""" async def connect() -> psycopg.AsyncConnection: + """Return postgresql async cursor.""" return await psycopg.AsyncConnection.connect( dbname=dbname, user=self.user, @@ -298,8 +299,10 @@ async def connect() -> psycopg.AsyncConnection: await conn.close() async def __aenter__(self): + """Initialize Database Janitor.""" await self.init() return self async def __aexit__(self, exc_type, exc_val, exc_tb): + """Exit from Database janitor context cleaning after itself.""" await self.drop() diff --git a/pytest_postgresql/types.py b/pytest_postgresql/types.py index 508bf3ad..c4f716ad 100644 --- a/pytest_postgresql/types.py +++ b/pytest_postgresql/types.py @@ -1,3 +1,5 @@ +"""Pytest Postgresql Types""" + from typing import Literal FixtureScopeT = Literal["session", "package", "module", "class", "function"] From d43696965156cc3dbe9dd1176b7f56ff040b91d2 Mon Sep 17 00:00:00 2001 From: Ivan Webber Date: Mon, 15 Dec 2025 12:32:57 -0800 Subject: [PATCH 11/11] Docstring format check fixes --- pytest_postgresql/janitor.py | 4 ++-- pytest_postgresql/types.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytest_postgresql/janitor.py b/pytest_postgresql/janitor.py index 357076a2..6dc3d2f6 100644 --- a/pytest_postgresql/janitor.py +++ b/pytest_postgresql/janitor.py @@ -41,7 +41,7 @@ def __init__( :param host: postgresql host :param port: postgresql port :param dbname: database name - :param dbname: template database name + :param template_dbname: template database name :param version: postgresql version number :param password: optional postgresql password :param isolation_level: optional postgresql isolation level @@ -188,7 +188,7 @@ def __init__( :param host: postgresql host :param port: postgresql port :param dbname: database name - :param dbname: template database name + :param template_dbname: template database name :param version: postgresql version number :param password: optional postgresql password :param isolation_level: optional postgresql isolation level diff --git a/pytest_postgresql/types.py b/pytest_postgresql/types.py index c4f716ad..c5447d0f 100644 --- a/pytest_postgresql/types.py +++ b/pytest_postgresql/types.py @@ -1,4 +1,4 @@ -"""Pytest Postgresql Types""" +"""Pytest Postgresql Types.""" from typing import Literal