Skip to content

Commit a67c200

Browse files
committed
Added custom pool
1 parent b83c89f commit a67c200

File tree

13 files changed

+677
-583
lines changed

13 files changed

+677
-583
lines changed

aws_advanced_python_wrapper/async_connection_pool.py

Lines changed: 503 additions & 0 deletions
Large diffs are not rendered by default.

aws_advanced_python_wrapper/custom_endpoint_plugin.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,8 @@ def _run(self):
169169
len(endpoints),
170170
endpoint_hostnames)
171171

172-
sleep(self._refresh_rate_ns / 1_000_000_000)
172+
if self._stop_event.wait(self._refresh_rate_ns / 1_000_000_000):
173+
break
173174
continue
174175

175176
endpoint_info = CustomEndpointInfo.from_db_cluster_endpoint(endpoints[0])
@@ -178,7 +179,8 @@ def _run(self):
178179
if cached_info is not None and cached_info == endpoint_info:
179180
elapsed_time = perf_counter_ns() - start_ns
180181
sleep_duration = max(0, self._refresh_rate_ns - elapsed_time)
181-
sleep(sleep_duration / 1_000_000_000)
182+
if self._stop_event.wait(sleep_duration / 1_000_000_000):
183+
break
182184
continue
183185

184186
logger.debug(
@@ -196,7 +198,8 @@ def _run(self):
196198

197199
elapsed_time = perf_counter_ns() - start_ns
198200
sleep_duration = max(0, self._refresh_rate_ns - elapsed_time)
199-
sleep(sleep_duration / 1_000_000_000)
201+
if self._stop_event.wait(sleep_duration / 1_000_000_000):
202+
break
200203
continue
201204
except InterruptedError as e:
202205
raise e
@@ -219,7 +222,6 @@ def close(self):
219222
CustomEndpointMonitor._custom_endpoint_info_cache.remove(self._custom_endpoint_host_info.host)
220223
self._stop_event.set()
221224

222-
223225
class CustomEndpointPlugin(Plugin):
224226
"""
225227
A plugin that analyzes custom endpoints for custom endpoint information and custom endpoint changes, such as adding

aws_advanced_python_wrapper/errors.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,7 @@ class ConnectionReleasedError(AsyncConnectionPoolError):
6969

7070
class PoolSizeLimitError(AsyncConnectionPoolError):
7171
__module__ = "aws_advanced_python_wrapper"
72+
73+
74+
class PoolHealthCheckError(AsyncConnectionPoolError):
75+
__module__ = "aws_advanced_python_wrapper"

aws_advanced_python_wrapper/tortoise/__init__.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,16 @@
1414

1515
from tortoise.backends.base.config_generator import DB_LOOKUP
1616

17+
18+
def cast_to_bool(value):
19+
"""Generic function to cast various types to boolean."""
20+
if isinstance(value, bool):
21+
return value
22+
if isinstance(value, str):
23+
return value.lower() in ('true', '1', 'yes', 'on')
24+
return bool(value)
25+
26+
1727
# Register AWS MySQL backend
1828
DB_LOOKUP["aws-mysql"] = {
1929
"engine": "aws_advanced_python_wrapper.tortoise.backends.mysql",
@@ -29,9 +39,9 @@
2939
"minsize": int,
3040
"maxsize": int,
3141
"connect_timeout": int,
32-
"echo": bool,
33-
"use_unicode": bool,
34-
"ssl": bool,
35-
"use_pure": bool,
42+
"echo": cast_to_bool,
43+
"use_unicode": cast_to_bool,
44+
"ssl": cast_to_bool,
45+
"use_pure": cast_to_bool
3646
},
3747
}

aws_advanced_python_wrapper/tortoise/backends/base/client.py

Lines changed: 2 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -37,103 +37,26 @@ class AwsTransactionalDBClient(TransactionalDBClient):
3737
_parent: AwsBaseDBAsyncClient
3838
pass
3939

40-
41-
class TortoiseAwsClientConnectionWrapper(Generic[T_conn]):
42-
"""Manages acquiring from and releasing connections to a pool."""
43-
44-
__slots__ = ("client", "connection", "connect_func", "with_db")
45-
46-
def __init__(
47-
self,
48-
client: AwsBaseDBAsyncClient,
49-
connect_func: Callable,
50-
with_db: bool = True
51-
) -> None:
52-
self.connect_func = connect_func
53-
self.client = client
54-
self.connection: AwsConnectionAsyncWrapper | None = None
55-
self.with_db = with_db
56-
57-
async def ensure_connection(self) -> None:
58-
"""Ensure the connection pool is initialized."""
59-
await self.client.create_connection(with_db=self.with_db)
60-
61-
async def __aenter__(self) -> T_conn:
62-
"""Acquire connection from pool."""
63-
await self.ensure_connection()
64-
self.connection = await AwsWrapperAsyncConnector.connect_with_aws_wrapper(self.connect_func, **self.client._template)
65-
return cast("T_conn", self.connection)
66-
67-
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
68-
"""Close connection and release back to pool."""
69-
if self.connection:
70-
await AwsWrapperAsyncConnector.close_aws_wrapper(self.connection)
71-
72-
73-
class TortoiseAwsClientTransactionContext(TransactionContext):
74-
"""Transaction context that uses a pool to acquire connections."""
75-
76-
__slots__ = ("client", "connection_name", "token")
77-
78-
def __init__(self, client: AwsTransactionalDBClient) -> None:
79-
self.client: AwsTransactionalDBClient = client
80-
self.connection_name = client.connection_name
81-
82-
async def ensure_connection(self) -> None:
83-
"""Ensure the connection pool is initialized."""
84-
await self.client._parent.create_connection(with_db=True)
85-
86-
async def __aenter__(self) -> TransactionalDBClient:
87-
"""Enter transaction context."""
88-
await self.ensure_connection()
89-
90-
# Set the context variable so the current task sees a TransactionWrapper connection
91-
self.token = connections.set(self.connection_name, self.client)
92-
93-
# Create connection and begin transaction
94-
self.client._connection = await AwsWrapperAsyncConnector.connect_with_aws_wrapper(
95-
mysql.connector.Connect,
96-
**self.client._parent._template
97-
)
98-
await self.client.begin()
99-
return self.client
100-
101-
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
102-
"""Exit transaction context with proper cleanup."""
103-
try:
104-
if not self.client._finalized:
105-
if exc_type:
106-
# Can't rollback a transaction that already failed
107-
if exc_type is not TransactionManagementError:
108-
await self.client.rollback()
109-
else:
110-
await self.client.commit()
111-
finally:
112-
await AwsWrapperAsyncConnector.close_aws_wrapper(self.client._connection)
113-
connections.reset(self.token)
114-
11540
class TortoiseAwsClientPooledConnectionWrapper(Generic[T_conn]):
11641
"""Manages acquiring from and releasing connections to a pool."""
11742

118-
__slots__ = ("client", "connection", "_pool_init_lock", "with_db")
43+
__slots__ = ("client", "connection", "_pool_init_lock",)
11944

12045
def __init__(
12146
self,
12247
client: BaseDBAsyncClient,
12348
pool_init_lock: asyncio.Lock,
124-
with_db: bool = True
12549
) -> None:
12650
self.client = client
12751
self.connection: T_conn | None = None
12852
self._pool_init_lock = pool_init_lock
129-
self.with_db = with_db
13053

13154
async def ensure_connection(self) -> None:
13255
"""Ensure the connection pool is initialized."""
13356
if not self.client._pool:
13457
async with self._pool_init_lock:
13558
if not self.client._pool:
136-
await self.client.create_connection(with_db=self.with_db)
59+
await self.client.create_connection(with_db=True)
13760

13861
async def __aenter__(self) -> T_conn:
13962
"""Acquire connection from pool."""

aws_advanced_python_wrapper/tortoise/backends/mysql/client.py

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -135,10 +135,10 @@ def __init__(
135135
self._pool_config = PoolConfig(
136136
min_size = self.extra.pop("min_size", default_pool_config["min_size"]),
137137
max_size = self.extra.pop("max_size", default_pool_config["max_size"]),
138-
timeout = self.extra.pop("pool_timeout", default_pool_config["timeout"]),
139-
max_lifetime = self.extra.pop("pool_lifetime", default_pool_config["max_lifetime"]),
140-
max_idle_time = self.extra.pop("pool_max_idle_time", default_pool_config["max_idle_time"]),
141-
health_check_interval = self.extra.pop("pool_health_check_interval", default_pool_config["health_check_interval"]),
138+
acquire_conn_timeout = self.extra.pop("acquire_conn_timeout", default_pool_config["acquire_conn_timeout"]),
139+
max_conn_lifetime = self.extra.pop("max_conn_lifetime", default_pool_config["max_conn_lifetime"]),
140+
max_conn_idle_time = self.extra.pop("max_conn_idle_time", default_pool_config["max_conn_idle_time"]),
141+
health_check_interval = self.extra.pop("health_check_interval", default_pool_config["health_check_interval"]),
142142
pre_ping = self.extra.pop("pre_ping", default_pool_config["pre_ping"])
143143
)
144144

@@ -165,17 +165,8 @@ async def _init_pool(self) -> None:
165165
async def create_connection():
166166
return await AwsWrapperAsyncConnector.connect_with_aws_wrapper(mysql.connector.Connect, **self._template)
167167

168-
async def health_check(conn):
169-
is_closed = await asyncio.to_thread(lambda: conn._wrapped_connection.is_closed)
170-
if is_closed:
171-
raise Exception("Connection is closed")
172-
else:
173-
print("NOT CLOSED!")
174-
175168
self._pool: AsyncConnectionPool = AsyncConnectionPool(
176169
creator=create_connection,
177-
# closer=close_connection,
178-
health_check=health_check,
179170
config=self._pool_config
180171
)
181172
await self._pool.initialize()
@@ -195,8 +186,13 @@ async def create_connection(self, with_db: bool) -> None:
195186
self._template = self._template_with_db if with_db else self._template_no_db
196187

197188
await self._init_pool()
198-
print("Pool is initialized")
199-
print(self._pool.get_stats())
189+
190+
def _disable_pool_for_testing(self) -> None:
191+
"""Disable pool initialization for unit testing."""
192+
self._pool = None
193+
async def _no_op():
194+
pass
195+
self._init_pool = _no_op
200196

201197
async def close(self) -> None:
202198
"""Close connections - AWS wrapper handles cleanup internally."""
@@ -206,25 +202,21 @@ async def close(self) -> None:
206202

207203
def acquire_connection(self):
208204
"""Acquire a connection from the pool."""
209-
return self._acquire_connection(with_db=True)
210-
211-
def _acquire_connection(self, with_db: bool) -> TortoiseAwsClientPooledConnectionWrapper:
212-
"""Create connection wrapper for specified database mode."""
213205
return TortoiseAwsClientPooledConnectionWrapper(
214-
self, pool_init_lock=self._pool_init_lock, with_db=with_db
206+
self, pool_init_lock=self._pool_init_lock
215207
)
216208

217209
# Database Operations
218210
async def db_create(self) -> None:
219211
"""Create the database."""
220212
await self.create_connection(with_db=False)
221-
await self._execute_script(f"CREATE DATABASE {self.database};", False)
213+
await self.execute_script(f"CREATE DATABASE {self.database};")
222214
await self.close()
223215

224216
async def db_delete(self) -> None:
225217
"""Delete the database."""
226218
await self.create_connection(with_db=False)
227-
await self._execute_script(f"DROP DATABASE {self.database};", False)
219+
await self.execute_script(f"DROP DATABASE {self.database};")
228220
await self.close()
229221

230222
# Query Execution Methods
@@ -279,14 +271,10 @@ async def execute_query_dict(self, query: str, values: Optional[List[Any]] = Non
279271
"""Execute a query and return only the results as dictionaries."""
280272
return (await self.execute_query(query, values))[1]
281273

274+
@translate_exceptions
282275
async def execute_script(self, query: str) -> None:
283276
"""Execute a script query."""
284-
await self._execute_script(query, True)
285-
286-
@translate_exceptions
287-
async def _execute_script(self, query: str, with_db: bool) -> None:
288-
"""Execute a multi-statement query by parsing and running statements sequentially."""
289-
async with self._acquire_connection(with_db) as connection:
277+
async with self.acquire_connection() as connection:
290278
logger.debug(f"Executing script: {query}")
291279
async with connection.cursor() as cursor:
292280
# Parse multi-statement queries since MySQL Connector doesn't handle them well

tests/integration/container/conftest.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,16 +62,17 @@ def conn_utils():
6262

6363
def pytest_runtest_setup(item):
6464
test_name: Optional[str] = None
65+
full_test_name = item.nodeid # Full test path including class and method
66+
6567
if hasattr(item, "callspec"):
6668
current_driver = item.callspec.params.get("test_driver")
6769
TestEnvironment.get_current().set_current_driver(current_driver)
68-
test_name = item.callspec.id
70+
test_name = f"{item.name}[{item.callspec.id}]"
6971
else:
7072
TestEnvironment.get_current().set_current_driver(None)
71-
# Fallback to item.name if no callspec (for non-parameterized tests)
72-
test_name = getattr(item, 'name', None) or str(item)
73+
test_name = item.name
7374

74-
logger.info(f"Starting test preparation for: {test_name}")
75+
logger.info(f"Starting test preparation for: {test_name} (full: {full_test_name})")
7576

7677
segment: Optional[Segment] = None
7778
if TestEnvironmentFeatures.TELEMETRY_TRACES_ENABLED in TestEnvironment.get_current().get_features():
@@ -82,6 +83,7 @@ def pytest_runtest_setup(item):
8283
.get_info().get_request().get_target_python_version().name)
8384
if test_name is not None:
8485
segment.put_annotation("test_name", test_name)
86+
segment.put_annotation("full_test_name", full_test_name)
8587

8688
info = TestEnvironment.get_current().get_info()
8789
request = info.get_request()

tests/integration/container/tortoise/test_tortoise_custom_endpoint.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def create_custom_endpoint(self, rds_utils):
7272
finally:
7373
try:
7474
rds_client.delete_db_cluster_endpoint(DBClusterEndpointIdentifier=self.endpoint_id)
75+
self._wait_until_endpoint_deleted(rds_client)
7576
except ClientError as e:
7677
if e.response['Error']['Code'] != 'DBClusterEndpointNotFoundFault':
7778
pass # Ignore if endpoint doesn't exist
@@ -108,18 +109,46 @@ def _wait_until_endpoint_available(self, rds_client):
108109

109110
if not available:
110111
pytest.fail(f"Timed out waiting for custom endpoint to become available: {self.endpoint_id}")
112+
113+
def _wait_until_endpoint_deleted(self, rds_client):
114+
"""Wait for the custom endpoint to be deleted."""
115+
end_ns = perf_counter_ns() + 5 * 60 * 1_000_000_000 # 5 minutes
116+
117+
while perf_counter_ns() < end_ns:
118+
try:
119+
rds_client.describe_db_cluster_endpoints(DBClusterEndpointIdentifier=self.endpoint_id)
120+
sleep(5) # Still exists, keep waiting
121+
except ClientError as e:
122+
if e.response['Error']['Code'] == 'DBClusterEndpointNotFoundFault':
123+
return # Successfully deleted
124+
raise # Other error, re-raise
111125

112126
@pytest_asyncio.fixture
113-
async def setup_tortoise_custom_endpoint(self, conn_utils, create_custom_endpoint):
127+
async def setup_tortoise_custom_endpoint(self, conn_utils, create_custom_endpoint, request):
114128
"""Setup Tortoise with custom endpoint plugin."""
115-
async for result in setup_tortoise(conn_utils, plugins="custom_endpoint,aurora_connection_tracker", host=create_custom_endpoint):
129+
plugins, user = request.param
130+
user_value = getattr(conn_utils, user) if user != "default" else None
131+
132+
kwargs = {}
133+
if "fastest_response_strategy" in plugins:
134+
kwargs["reader_host_selector_strategy"] = "fastest_response"
135+
136+
async for result in setup_tortoise(conn_utils, plugins=plugins, host=create_custom_endpoint, user=user_value, **kwargs):
116137
yield result
117138

139+
@pytest.mark.parametrize("setup_tortoise_custom_endpoint", [
140+
("custom_endpoint,aurora_connection_tracker", "default"),
141+
("failover,iam,aurora_connection_tracker,custom_endpoint,fastest_response_strategy", "iam_user")
142+
], indirect=True)
118143
@pytest.mark.asyncio
119144
async def test_basic_read_operations(self, setup_tortoise_custom_endpoint):
120145
"""Test basic read operations with custom endpoint plugin."""
121146
await run_basic_read_operations("Custom Test", "custom")
122147

148+
@pytest.mark.parametrize("setup_tortoise_custom_endpoint", [
149+
("custom_endpoint,aurora_connection_tracker", "default"),
150+
("failover,iam,aurora_connection_tracker,custom_endpoint,fastest_response_strategy", "iam_user")
151+
], indirect=True)
123152
@pytest.mark.asyncio
124153
async def test_basic_write_operations(self, setup_tortoise_custom_endpoint):
125154
"""Test basic write operations with custom endpoint plugin."""

0 commit comments

Comments
 (0)