Skip to content

Commit b83c89f

Browse files
committed
add async connection pool
1 parent d2f7e62 commit b83c89f

File tree

11 files changed

+369
-281
lines changed

11 files changed

+369
-281
lines changed

aws_advanced_python_wrapper/errors.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,27 @@ class FailoverSuccessError(FailoverError):
4545

4646
class ReadWriteSplittingError(AwsWrapperError):
4747
__module__ = "aws_advanced_python_wrapper"
48+
49+
50+
class AsyncConnectionPoolError(AwsWrapperError):
51+
__module__ = "aws_advanced_python_wrapper"
52+
53+
54+
class PoolNotInitializedError(AsyncConnectionPoolError):
55+
__module__ = "aws_advanced_python_wrapper"
56+
57+
58+
class PoolClosingError(AsyncConnectionPoolError):
59+
__module__ = "aws_advanced_python_wrapper"
60+
61+
62+
class PoolExhaustedError(AsyncConnectionPoolError):
63+
__module__ = "aws_advanced_python_wrapper"
64+
65+
66+
class ConnectionReleasedError(AsyncConnectionPoolError):
67+
__module__ = "aws_advanced_python_wrapper"
68+
69+
70+
class PoolSizeLimitError(AsyncConnectionPoolError):
71+
__module__ = "aws_advanced_python_wrapper"

aws_advanced_python_wrapper/iam_plugin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl
114114

115115
is_cached_token = (token_info is not None and not token_info.is_expired())
116116
if not self._plugin_service.is_login_exception(error=e) or not is_cached_token:
117-
raise AwsWrapperError(Messages.get_formatted("IamAuthPlugin.ConnectException", e)) from e
117+
raise
118118

119119
# Login unsuccessful with cached token
120120
# Try to generate a new token and try to connect again

aws_advanced_python_wrapper/tortoise/async_support/async_wrapper.py

Lines changed: 34 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -37,38 +37,6 @@ async def close_aws_wrapper(connection: AwsWrapperConnection) -> None:
3737
"""Close an AWS wrapper connection asynchronously."""
3838
await asyncio.to_thread(connection.close)
3939

40-
41-
class AwsCursorAsyncWrapper:
42-
"""Wraps sync AwsCursor cursor with async support."""
43-
44-
def __init__(self, sync_cursor):
45-
self._cursor = sync_cursor
46-
47-
async def execute(self, query, params=None):
48-
"""Execute a query asynchronously."""
49-
return await asyncio.to_thread(self._cursor.execute, query, params)
50-
51-
async def executemany(self, query, params_list):
52-
"""Execute multiple queries asynchronously."""
53-
return await asyncio.to_thread(self._cursor.executemany, query, params_list)
54-
55-
async def fetchall(self):
56-
"""Fetch all results asynchronously."""
57-
return await asyncio.to_thread(self._cursor.fetchall)
58-
59-
async def fetchone(self):
60-
"""Fetch one result asynchronously."""
61-
return await asyncio.to_thread(self._cursor.fetchone)
62-
63-
async def close(self):
64-
"""Close cursor asynchronously."""
65-
return await asyncio.to_thread(self._cursor.close)
66-
67-
def __getattr__(self, name):
68-
"""Delegate non-async attributes to the wrapped cursor."""
69-
return getattr(self._cursor, name)
70-
71-
7240
class AwsConnectionAsyncWrapper(AwsWrapperConnection):
7341
"""Wraps sync AwsConnection with async cursor support."""
7442

@@ -96,6 +64,10 @@ async def set_autocommit(self, value: bool):
9664
"""Set autocommit mode."""
9765
return await asyncio.to_thread(setattr, self._wrapped_connection, 'autocommit', value)
9866

67+
async def close(self):
68+
"""Close the connection asynchronously."""
69+
return await asyncio.to_thread(self._wrapped_connection.close)
70+
9971
def __getattr__(self, name):
10072
"""Delegate all other attributes/methods to the wrapped connection."""
10173
return getattr(self._wrapped_connection, name)
@@ -105,3 +77,33 @@ def __del__(self):
10577
if hasattr(self, '_wrapped_connection'):
10678
# Let the wrapped connection handle its own cleanup
10779
pass
80+
81+
class AwsCursorAsyncWrapper:
82+
"""Wraps sync AwsCursor cursor with async support."""
83+
84+
def __init__(self, sync_cursor):
85+
self._cursor = sync_cursor
86+
87+
async def execute(self, query, params=None):
88+
"""Execute a query asynchronously."""
89+
return await asyncio.to_thread(self._cursor.execute, query, params)
90+
91+
async def executemany(self, query, params_list):
92+
"""Execute multiple queries asynchronously."""
93+
return await asyncio.to_thread(self._cursor.executemany, query, params_list)
94+
95+
async def fetchall(self):
96+
"""Fetch all results asynchronously."""
97+
return await asyncio.to_thread(self._cursor.fetchall)
98+
99+
async def fetchone(self):
100+
"""Fetch one result asynchronously."""
101+
return await asyncio.to_thread(self._cursor.fetchone)
102+
103+
async def close(self):
104+
"""Close cursor asynchronously."""
105+
return await asyncio.to_thread(self._cursor.close)
106+
107+
def __getattr__(self, name):
108+
"""Delegate non-async attributes to the wrapped cursor."""
109+
return getattr(self._cursor, name)

aws_advanced_python_wrapper/tortoise/backends/base/client.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from typing import Any, Callable, Dict, Generic, cast
1818

19+
import asyncio
1920
import mysql.connector
2021
from tortoise.backends.base.client import (BaseDBAsyncClient, T_conn,
2122
TransactionalDBClient,
@@ -110,3 +111,85 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
110111
finally:
111112
await AwsWrapperAsyncConnector.close_aws_wrapper(self.client._connection)
112113
connections.reset(self.token)
114+
115+
class TortoiseAwsClientPooledConnectionWrapper(Generic[T_conn]):
116+
"""Manages acquiring from and releasing connections to a pool."""
117+
118+
__slots__ = ("client", "connection", "_pool_init_lock", "with_db")
119+
120+
def __init__(
121+
self,
122+
client: BaseDBAsyncClient,
123+
pool_init_lock: asyncio.Lock,
124+
with_db: bool = True
125+
) -> None:
126+
self.client = client
127+
self.connection: T_conn | None = None
128+
self._pool_init_lock = pool_init_lock
129+
self.with_db = with_db
130+
131+
async def ensure_connection(self) -> None:
132+
"""Ensure the connection pool is initialized."""
133+
if not self.client._pool:
134+
async with self._pool_init_lock:
135+
if not self.client._pool:
136+
await self.client.create_connection(with_db=self.with_db)
137+
138+
async def __aenter__(self) -> T_conn:
139+
"""Acquire connection from pool."""
140+
await self.ensure_connection()
141+
self.connection = await self.client._pool.acquire()
142+
return self.connection
143+
144+
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
145+
"""Close connection and release back to pool."""
146+
if self.connection:
147+
await self.connection.release()
148+
149+
class TortoiseAwsClientPooledTransactionContext(TransactionContext):
150+
"""Transaction context that uses a pool to acquire connections."""
151+
152+
__slots__ = ("client", "connection_name", "token", "_pool_init_lock", "connection")
153+
154+
def __init__(self, client: TransactionalDBClient, pool_init_lock: asyncio.Lock) -> None:
155+
self.client = client
156+
self.connection_name = client.connection_name
157+
self._pool_init_lock = pool_init_lock
158+
self.connection = None
159+
160+
async def ensure_connection(self) -> None:
161+
"""Ensure the connection pool is initialized."""
162+
if not self.client._parent._pool:
163+
# a safeguard against multiple concurrent tasks trying to initialize the pool
164+
async with self._pool_init_lock:
165+
if not self.client._parent._pool:
166+
await self.client._parent.create_connection(with_db=True)
167+
168+
async def __aenter__(self) -> TransactionalDBClient:
169+
"""Enter transaction context."""
170+
await self.ensure_connection()
171+
172+
# Set the context variable so the current task sees a TransactionWrapper connection
173+
self.token = connections.set(self.connection_name, self.client)
174+
175+
# Create connection and begin transaction
176+
self.connection = await self.client._parent._pool.acquire()
177+
self.client._connection = self.connection
178+
await self.client.begin()
179+
return self.client
180+
181+
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
182+
"""Exit transaction context with proper cleanup."""
183+
try:
184+
if not self.client._finalized:
185+
if exc_type:
186+
# Can't rollback a transaction that already failed
187+
if exc_type is not TransactionManagementError:
188+
await self.client.rollback()
189+
else:
190+
await self.client.commit()
191+
finally:
192+
if self.client._connection:
193+
await self.client._connection.release()
194+
# self.client._connection = None
195+
connections.reset(self.token)

aws_advanced_python_wrapper/tortoise/backends/mysql/client.py

Lines changed: 54 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,16 @@
3131

3232
from aws_advanced_python_wrapper.errors import AwsWrapperError, FailoverError
3333
from aws_advanced_python_wrapper.tortoise.async_support.async_wrapper import (
34-
AwsConnectionAsyncWrapper)
34+
AwsConnectionAsyncWrapper, AwsWrapperAsyncConnector)
3535
from aws_advanced_python_wrapper.tortoise.backends.base.client import (
3636
AwsBaseDBAsyncClient, AwsTransactionalDBClient,
37-
TortoiseAwsClientConnectionWrapper, TortoiseAwsClientTransactionContext)
37+
TortoiseAwsClientPooledConnectionWrapper, TortoiseAwsClientPooledTransactionContext)
3838
from aws_advanced_python_wrapper.tortoise.backends.mysql.executor import \
3939
AwsMySQLExecutor
4040
from aws_advanced_python_wrapper.tortoise.backends.mysql.schema_generator import \
4141
AwsMySQLSchemaGenerator
42+
from aws_advanced_python_wrapper.async_connection_pool import AsyncConnectionPool, PoolConfig
43+
from dataclasses import fields
4244
from aws_advanced_python_wrapper.utils.log import Logger
4345

4446
logger = Logger(__name__)
@@ -125,6 +127,21 @@ def __init__(
125127
# Initialize state
126128
self._template: Dict[str, Any] = {}
127129
self._connection = None
130+
self._pool_init_lock: asyncio.Lock = asyncio.Lock()
131+
self._pool: Optional[AsyncConnectionPool] = None
132+
133+
# Pool configuration
134+
default_pool_config = {field.name: field.default for field in fields(PoolConfig)}
135+
self._pool_config = PoolConfig(
136+
min_size = self.extra.pop("min_size", default_pool_config["min_size"]),
137+
max_size = self.extra.pop("max_size", default_pool_config["max_size"]),
138+
timeout = self.extra.pop("pool_timeout", default_pool_config["timeout"]),
139+
max_lifetime = self.extra.pop("pool_lifetime", default_pool_config["max_lifetime"]),
140+
max_idle_time = self.extra.pop("pool_max_idle_time", default_pool_config["max_idle_time"]),
141+
health_check_interval = self.extra.pop("pool_health_check_interval", default_pool_config["health_check_interval"]),
142+
pre_ping = self.extra.pop("pre_ping", default_pool_config["pre_ping"])
143+
)
144+
128145

129146
def _init_connection_templates(self) -> None:
130147
"""Initialize connection templates for with/without database."""
@@ -140,6 +157,29 @@ def _init_connection_templates(self) -> None:
140157
self._template_with_db = {**base_template, "database": self.database}
141158
self._template_no_db = {**base_template, "database": None}
142159

160+
async def _init_pool(self) -> None:
161+
"""Initialize the connection pool."""
162+
if self._pool is not None:
163+
return
164+
165+
async def create_connection():
166+
return await AwsWrapperAsyncConnector.connect_with_aws_wrapper(mysql.connector.Connect, **self._template)
167+
168+
async def health_check(conn):
169+
is_closed = await asyncio.to_thread(lambda: conn._wrapped_connection.is_closed)
170+
if is_closed:
171+
raise Exception("Connection is closed")
172+
else:
173+
print("NOT CLOSED!")
174+
175+
self._pool: AsyncConnectionPool = AsyncConnectionPool(
176+
creator=create_connection,
177+
# closer=close_connection,
178+
health_check=health_check,
179+
config=self._pool_config
180+
)
181+
await self._pool.initialize()
182+
143183
# Connection Management
144184
async def create_connection(self, with_db: bool) -> None:
145185
"""Initialize connection pool and configure database settings."""
@@ -154,18 +194,24 @@ async def create_connection(self, with_db: bool) -> None:
154194
# Set template based on database requirement
155195
self._template = self._template_with_db if with_db else self._template_no_db
156196

197+
await self._init_pool()
198+
print("Pool is initialized")
199+
print(self._pool.get_stats())
200+
157201
async def close(self) -> None:
158202
"""Close connections - AWS wrapper handles cleanup internally."""
159-
pass
203+
if hasattr(self, '_pool') and self._pool:
204+
await self._pool.close()
205+
self._pool = None
160206

161207
def acquire_connection(self):
162208
"""Acquire a connection from the pool."""
163209
return self._acquire_connection(with_db=True)
164210

165-
def _acquire_connection(self, with_db: bool) -> TortoiseAwsClientConnectionWrapper:
211+
def _acquire_connection(self, with_db: bool) -> TortoiseAwsClientPooledConnectionWrapper:
166212
"""Create connection wrapper for specified database mode."""
167-
return TortoiseAwsClientConnectionWrapper(
168-
self, mysql.connector.Connect, with_db=with_db
213+
return TortoiseAwsClientPooledConnectionWrapper(
214+
self, pool_init_lock=self._pool_init_lock, with_db=with_db
169215
)
170216

171217
# Database Operations
@@ -253,7 +299,7 @@ async def _execute_script(self, query: str, with_db: bool) -> None:
253299
# Transaction Support
254300
def _in_transaction(self) -> TransactionContext:
255301
"""Create a new transaction context."""
256-
return TortoiseAwsClientTransactionContext(TransactionWrapper(self))
302+
return TortoiseAwsClientPooledTransactionContext(TransactionWrapper(self), self._pool_init_lock)
257303

258304

259305
class TransactionWrapper(AwsMySQLClient, AwsTransactionalDBClient):
@@ -262,6 +308,7 @@ class TransactionWrapper(AwsMySQLClient, AwsTransactionalDBClient):
262308
def __init__(self, connection: AwsMySQLClient) -> None:
263309
self.connection_name = connection.connection_name
264310
self._connection: AwsConnectionAsyncWrapper = connection._connection
311+
265312
self._lock = asyncio.Lock()
266313
self._savepoint: Optional[str] = None
267314
self._finalized: bool = False

tests/integration/container/tortoise/test_tortoise_common.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ async def setup_tortoise(conn_utils, plugins="aurora_connection_tracker", **kwar
4040
plugins=plugins,
4141
**kwargs,
4242
)
43-
4443
config = {
4544
"connections": {
4645
"default": db_url
@@ -53,17 +52,13 @@ async def setup_tortoise(conn_utils, plugins="aurora_connection_tracker", **kwar
5352
}
5453
}
5554

56-
from aws_advanced_python_wrapper.tortoise.sql_alchemy_tortoise_connection_provider import \
57-
setup_tortoise_connection_provider
58-
setup_tortoise_connection_provider()
5955
await Tortoise.init(config=config)
60-
await Tortoise.generate_schemas()
6156

57+
await Tortoise.generate_schemas()
6258
await clear_test_models()
6359

6460
yield
6561

66-
await clear_test_models()
6762
await reset_tortoise()
6863

6964

tests/integration/container/tortoise/test_tortoise_config.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818

1919
# Import to register the aws-mysql backend
2020
import aws_advanced_python_wrapper.tortoise # noqa: F401
21-
from aws_advanced_python_wrapper.tortoise.sql_alchemy_tortoise_connection_provider import \
22-
setup_tortoise_connection_provider
2321
from tests.integration.container.tortoise.models.test_models import User
2422
from tests.integration.container.tortoise.test_tortoise_common import \
2523
reset_tortoise
@@ -67,7 +65,6 @@ async def setup_tortoise_dict_config(self, conn_utils):
6765
}
6866
}
6967

70-
setup_tortoise_connection_provider()
7168
await Tortoise.init(config=config)
7269
await Tortoise.generate_schemas()
7370
await self._clear_all_test_models()
@@ -121,7 +118,6 @@ async def setup_tortoise_multi_db(self, conn_utils):
121118
}
122119
}
123120

124-
setup_tortoise_connection_provider()
125121
await Tortoise.init(config=config)
126122

127123
# Create second database
@@ -169,7 +165,6 @@ async def setup_tortoise_with_router(self, conn_utils):
169165
"routers": ["tests.integration.container.tortoise.router.test_router.TestRouter"]
170166
}
171167

172-
setup_tortoise_connection_provider()
173168
await Tortoise.init(config=config)
174169
await Tortoise.generate_schemas()
175170
await self._clear_all_test_models()

0 commit comments

Comments
 (0)