Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ dependencies = [
]
requires-python = ">= 3.10"

[project.optional-dependencies]
async = [
"pytest-asyncio",
"aiofiles"
]

[project.urls]
"Source" = "https://github.com/dbfixtures/pytest-postgresql"
"Bug Tracker" = "https://github.com/dbfixtures/pytest-postgresql/issues"
Expand Down
4 changes: 2 additions & 2 deletions pytest_postgresql/factories/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
# along with pytest-postgresql. If not, see <http://www.gnu.org/licenses/>.
"""Fixture factories for postgresql fixtures."""

from pytest_postgresql.factories.client import postgresql
from pytest_postgresql.factories.client import postgresql, postgresql_async
from pytest_postgresql.factories.noprocess import postgresql_noproc
from pytest_postgresql.factories.process import PortType, postgresql_proc

__all__ = ("postgresql_proc", "postgresql_noproc", "postgresql", "PortType")
__all__ = ("postgresql_proc", "postgresql_noproc", "postgresql", "postgresql_async", "PortType")
73 changes: 68 additions & 5 deletions pytest_postgresql/factories/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,36 +17,39 @@
# along with pytest-postgresql. If not, see <http://www.gnu.org/licenses/>.
"""Fixture factory for postgresql client."""

from typing import Callable, Iterator
from typing import AsyncIterator, Callable, Iterator

import psycopg
import pytest
from psycopg import Connection
from psycopg import AsyncConnection, Connection
from pytest import FixtureRequest

from pytest_postgresql.config import get_config
from pytest_postgresql.executor import PostgreSQLExecutor
from pytest_postgresql.executor_noop import NoopExecutor
from pytest_postgresql.janitor import DatabaseJanitor
from pytest_postgresql.janitor import AsyncDatabaseJanitor, DatabaseJanitor
from pytest_postgresql.types import FixtureScopeT


def postgresql(
process_fixture_name: str,
dbname: str | None = None,
isolation_level: "psycopg.IsolationLevel | None" = None,
scope: FixtureScopeT = "function",
) -> Callable[[FixtureRequest], Iterator[Connection]]:
"""Return connection fixture factory for PostgreSQL.

:param process_fixture_name: name of the process fixture
:param dbname: database name
:param isolation_level: optional postgresql isolation level
defaults to server's default
:param scope: fixture scope; by default "function" which is recommended.
:returns: function which makes a connection to postgresql
"""

@pytest.fixture
@pytest.fixture(scope=scope)
def postgresql_factory(request: FixtureRequest) -> Iterator[Connection]:
"""Fixture factory for PostgreSQL.
"""Fixture connection factory for PostgreSQL.

:param request: fixture request object
:returns: postgresql client
Expand Down Expand Up @@ -85,3 +88,63 @@ def postgresql_factory(request: FixtureRequest) -> Iterator[Connection]:
db_connection.close()

return postgresql_factory


def postgresql_async(
process_fixture_name: str,
dbname: str | None = None,
isolation_level: "psycopg.IsolationLevel | None" = None,
scope: FixtureScopeT = "function",
) -> Callable[[FixtureRequest], AsyncIterator[AsyncConnection]]:
"""Return async connection fixture factory for PostgreSQL.

:param process_fixture_name: name of the process fixture
:param dbname: database name
:param isolation_level: optional postgresql isolation level
defaults to server's default
:param scope: fixture scope; by default "function" which is recommended.
:returns: function which makes a connection to postgresql
"""
import pytest_asyncio

@pytest_asyncio.fixture(scope=scope)
async def postgresql_factory(request: FixtureRequest) -> AsyncIterator[AsyncConnection]:
"""Async connection fixture factory for PostgreSQL.

:param request: fixture request object
:returns: postgresql client
"""
proc_fixture: PostgreSQLExecutor | NoopExecutor = request.getfixturevalue(process_fixture_name)
config = get_config(request)

pg_host = proc_fixture.host
pg_port = proc_fixture.port
pg_user = proc_fixture.user
pg_password = proc_fixture.password
pg_options = proc_fixture.options
pg_db = dbname or proc_fixture.dbname
janitor = AsyncDatabaseJanitor(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need AsyncDatabaseJanitor? Is there something in Janitor that would require tests to block on I/O?

This might come in handy in case test would have more async fixtures to prepare, not in a straightforward scenario with a single fixture?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your assessment is accurate; you can get around without it for sure unless you're using it directly in an async test or fixture.

user=pg_user,
host=pg_host,
port=pg_port,
dbname=pg_db,
template_dbname=proc_fixture.template_dbname,
version=proc_fixture.version,
password=pg_password,
isolation_level=isolation_level,
)
if config["drop_test_database"]:
await janitor.drop()
async with janitor:
db_connection: AsyncConnection = await AsyncConnection.connect(
dbname=pg_db,
user=pg_user,
password=pg_password,
host=pg_host,
port=pg_port,
options=pg_options,
)
yield db_connection
await db_connection.close()

return postgresql_factory
5 changes: 4 additions & 1 deletion pytest_postgresql/factories/noprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from pytest_postgresql.config import get_config
from pytest_postgresql.executor_noop import NoopExecutor
from pytest_postgresql.janitor import DatabaseJanitor
from pytest_postgresql.types import FixtureScopeT


def xdistify_dbname(dbname: str) -> str:
Expand All @@ -45,6 +46,7 @@ def postgresql_noproc(
dbname: str | None = None,
options: str = "",
load: list[Callable | str | Path] | None = None,
scope: FixtureScopeT = "session",
) -> Callable[[FixtureRequest], Iterator[NoopExecutor]]:
"""Postgresql noprocess factory.

Expand All @@ -55,10 +57,11 @@ def postgresql_noproc(
:param dbname: postgresql database name
:param options: Postgresql connection options
:param load: List of functions used to initialize database's template.
:param scope: fixture scope; by default "session" which is recommended.
:returns: function which makes a postgresql process
"""

@pytest.fixture(scope="session")
@pytest.fixture(scope=scope)
def postgresql_noproc_fixture(request: FixtureRequest) -> Iterator[NoopExecutor]:
"""Noop Process fixture for PostgreSQL.

Expand Down
5 changes: 4 additions & 1 deletion pytest_postgresql/factories/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from pytest_postgresql.exceptions import ExecutableMissingException
from pytest_postgresql.executor import PostgreSQLExecutor
from pytest_postgresql.janitor import DatabaseJanitor
from pytest_postgresql.types import FixtureScopeT

PortType = port_for.PortType # mypy requires explicit export

Expand Down Expand Up @@ -81,6 +82,7 @@ def postgresql_proc(
unixsocketdir: str | None = None,
postgres_options: str | None = None,
load: list[Callable | str | Path] | None = None,
scope: FixtureScopeT = "session",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you want to add the scope parameter to all the factories there are? What's the use case?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For codebases that already have thousands of tests switching everything to a different scope to test out this library to see if it speeds up tests might be rather prohibitory. We have 6000 tests and over a hundred fixtures e.g. and I used this to get things running and next I can work on side effects of changing scope. Could have a warning or something if it's somehow critical

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see now; yes it's probably not necessary for proc and noproc.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have 6000 tests
😍 love it!

So, the original intention behind those scopes would be that:

process sets up the database instance (with later additon of template database)
And client fixture creates a one-off database and drops it after the test (with later addition of doing that based on the template). This makes sure the state does not leak between tests, and they do not influence each other or make the order of tests dependant.

I've tested it manually, and the speed is rather similar if I had created a single database and used transactions. However, then all your tests are inside transactions, and you have to use that single connection that set up the transaction (that's why warehouse is using just DatabaseJanitor in their tests).

With the template database approach, you can create your own client connection to read the state as you prepare it, and not worry about hanging transactions with data fixtures.
Even run external processes in tests, if you wish to do so.

To address somethig like you described, I planned #890, although got distracted along the way if thinking about it. Maybe it's time to return to that idea....

) -> Callable[[FixtureRequest, TempPathFactory], Iterator[PostgreSQLExecutor]]:
"""Postgresql process factory.

Expand All @@ -101,10 +103,11 @@ def postgresql_proc(
:param unixsocketdir: directory to create postgresql's unixsockets
:param postgres_options: Postgres executable options for use by pg_ctl
:param load: List of functions used to initialize database's template.
:param scope: fixture scope; by default "session" which is recommended.
:returns: function which makes a postgresql process
"""

@pytest.fixture(scope="session")
@pytest.fixture(scope=scope)
def postgresql_proc_fixture(
request: FixtureRequest, tmp_path_factory: TempPathFactory
) -> Iterator[PostgreSQLExecutor]:
Expand Down
155 changes: 149 additions & 6 deletions pytest_postgresql/janitor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Database Janitor."""

from contextlib import contextmanager
import inspect
from contextlib import asynccontextmanager, contextmanager
from pathlib import Path
from types import TracebackType
from typing import Callable, Iterator, Type, TypeVar
Expand All @@ -9,8 +10,8 @@
from packaging.version import parse
from psycopg import Connection, Cursor

from pytest_postgresql.loader import build_loader
from pytest_postgresql.retry import retry
from pytest_postgresql.loader import build_loader, build_loader_async
from pytest_postgresql.retry import retry, retry_async

Version = type(parse("1"))

Expand Down Expand Up @@ -40,7 +41,7 @@ def __init__(
:param host: postgresql host
:param port: postgresql port
:param dbname: database name
:param dbname: template database name
:param template_dbname: template database name
:param version: postgresql version number
:param password: optional postgresql password
:param isolation_level: optional postgresql isolation level
Expand All @@ -52,8 +53,8 @@ def __init__(
self.password = password
self.host = host
self.port = port
# At least one of the dbname or template_dbname has to be filled.
assert any([dbname, template_dbname])
if not (dbname or template_dbname):
raise ValueError("At least one of the dbname or template_dbname has to be filled.")
self.dbname = dbname
self.template_dbname = template_dbname
self._connection_timeout = connection_timeout
Expand Down Expand Up @@ -163,3 +164,145 @@ def __exit__(
) -> None:
"""Exit from Database janitor context cleaning after itself."""
self.drop()


class AsyncDatabaseJanitor:
"""Manage database state for specific tasks."""

def __init__(
self,
*,
user: str,
host: str,
port: str | int,
version: str | float | Version, # type: ignore[valid-type]
dbname: str | None = None,
template_dbname: str | None = None,
password: str | None = None,
isolation_level: "psycopg.IsolationLevel | None" = None,
connection_timeout: int = 60,
) -> None:
"""Initialize janitor.

:param user: postgresql username
:param host: postgresql host
:param port: postgresql port
:param dbname: database name
:param template_dbname: template database name
:param version: postgresql version number
:param password: optional postgresql password
:param isolation_level: optional postgresql isolation level
defaults to server's default
:param connection_timeout: how long to retry connection before
raising a TimeoutError
"""
self.user = user
self.password = password
self.host = host
self.port = port
if not (dbname or template_dbname):
raise ValueError("At least one of the dbname or template_dbname has to be filled.")
self.dbname = dbname
self.template_dbname = template_dbname
self._connection_timeout = connection_timeout
self.isolation_level = isolation_level
if not isinstance(version, Version):
self.version = parse(str(version))
else:
self.version = version

async def init(self) -> None:
"""Create database in postgresql."""
async with self.cursor() as cur:
if self.is_template():
await cur.execute(f'CREATE DATABASE "{self.template_dbname}" WITH is_template = true;')
elif self.template_dbname is None:
await cur.execute(f'CREATE DATABASE "{self.dbname}";')
else:
# And make sure no-one is left connected to the template database.
# Otherwise, Creating database from template will fail
await self._terminate_connection(cur, self.template_dbname)
await cur.execute(f'CREATE DATABASE "{self.dbname}" TEMPLATE "{self.template_dbname}";')

def is_template(self) -> bool:
"""Determine whether the DatabaseJanitor maintains template or database."""
return self.dbname is None

async def drop(self) -> None:
"""Drop database in postgresql (async)."""
db_to_drop = self.template_dbname if self.is_template() else self.dbname
assert db_to_drop
async with self.cursor() as cur:
await self._dont_datallowconn(cur, db_to_drop)
await self._terminate_connection(cur, db_to_drop)
if self.is_template():
await cur.execute(f'ALTER DATABASE "{db_to_drop}" with is_template false;')
await cur.execute(f'DROP DATABASE IF EXISTS "{db_to_drop}";')

@staticmethod
async def _dont_datallowconn(cur, dbname: str) -> None:
await cur.execute(f'ALTER DATABASE "{dbname}" with allow_connections false;')

@staticmethod
async def _terminate_connection(cur, dbname: str) -> None:
await cur.execute(
"SELECT pg_terminate_backend(pg_stat_activity.pid)"
"FROM pg_stat_activity "
"WHERE pg_stat_activity.datname = %s;",
(dbname,),
)

async def load(self, load: Callable | str | Path) -> None:
"""Load data into a database (async).

Expects:

* a Path to sql file, that'll be loaded
* an import path to import callable
* a callable that expects: host, port, user, dbname and password arguments.

"""
db_to_load = self.template_dbname if self.is_template() else self.dbname
_loader = build_loader_async(load)
cor = _loader(
host=self.host,
port=self.port,
user=self.user,
dbname=db_to_load,
password=self.password,
)
if inspect.isawaitable(cor):
await cor

@asynccontextmanager
async def cursor(self, dbname: str = "postgres"):
"""Async context manager for postgresql cursor."""

async def connect() -> psycopg.AsyncConnection:
"""Return postgresql async cursor."""
return await psycopg.AsyncConnection.connect(
dbname=dbname,
user=self.user,
password=self.password,
host=self.host,
port=self.port,
)

conn = await retry_async(connect, timeout=self._connection_timeout, possible_exception=psycopg.OperationalError)
try:
await conn.set_isolation_level(self.isolation_level)
await conn.set_autocommit(True)
# We must not run a transaction since we create a database.
async with conn.cursor() as cur:
yield cur
finally:
await conn.close()

async def __aenter__(self):
"""Initialize Database Janitor."""
await self.init()
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Exit from Database janitor context cleaning after itself."""
await self.drop()
Loading
Loading