Skip to content

Commit d2f7e62

Browse files
committed
refactor async wrappers
1 parent cb67cdb commit d2f7e62

File tree

4 files changed

+125
-90
lines changed

4 files changed

+125
-90
lines changed
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import asyncio
18+
from contextlib import asynccontextmanager
19+
from typing import Callable
20+
21+
from aws_advanced_python_wrapper import AwsWrapperConnection
22+
23+
24+
class AwsWrapperAsyncConnector:
25+
"""Class for creating and closing AWS wrapper connections."""
26+
27+
@staticmethod
28+
async def connect_with_aws_wrapper(connect_func: Callable, **kwargs) -> AwsConnectionAsyncWrapper:
29+
"""Create an AWS wrapper connection with async cursor support."""
30+
connection = await asyncio.to_thread(
31+
AwsWrapperConnection.connect, connect_func, **kwargs
32+
)
33+
return AwsConnectionAsyncWrapper(connection)
34+
35+
@staticmethod
36+
async def close_aws_wrapper(connection: AwsWrapperConnection) -> None:
37+
"""Close an AWS wrapper connection asynchronously."""
38+
await asyncio.to_thread(connection.close)
39+
40+
41+
class AwsCursorAsyncWrapper:
42+
"""Wraps sync AwsCursor cursor with async support."""
43+
44+
def __init__(self, sync_cursor):
45+
self._cursor = sync_cursor
46+
47+
async def execute(self, query, params=None):
48+
"""Execute a query asynchronously."""
49+
return await asyncio.to_thread(self._cursor.execute, query, params)
50+
51+
async def executemany(self, query, params_list):
52+
"""Execute multiple queries asynchronously."""
53+
return await asyncio.to_thread(self._cursor.executemany, query, params_list)
54+
55+
async def fetchall(self):
56+
"""Fetch all results asynchronously."""
57+
return await asyncio.to_thread(self._cursor.fetchall)
58+
59+
async def fetchone(self):
60+
"""Fetch one result asynchronously."""
61+
return await asyncio.to_thread(self._cursor.fetchone)
62+
63+
async def close(self):
64+
"""Close cursor asynchronously."""
65+
return await asyncio.to_thread(self._cursor.close)
66+
67+
def __getattr__(self, name):
68+
"""Delegate non-async attributes to the wrapped cursor."""
69+
return getattr(self._cursor, name)
70+
71+
72+
class AwsConnectionAsyncWrapper(AwsWrapperConnection):
73+
"""Wraps sync AwsConnection with async cursor support."""
74+
75+
def __init__(self, connection: AwsWrapperConnection):
76+
self._wrapped_connection = connection
77+
78+
@asynccontextmanager
79+
async def cursor(self):
80+
"""Create an async cursor context manager."""
81+
cursor_obj = await asyncio.to_thread(self._wrapped_connection.cursor)
82+
try:
83+
yield AwsCursorAsyncWrapper(cursor_obj)
84+
finally:
85+
await asyncio.to_thread(cursor_obj.close)
86+
87+
async def rollback(self):
88+
"""Rollback the current transaction."""
89+
return await asyncio.to_thread(self._wrapped_connection.rollback)
90+
91+
async def commit(self):
92+
"""Commit the current transaction."""
93+
return await asyncio.to_thread(self._wrapped_connection.commit)
94+
95+
async def set_autocommit(self, value: bool):
96+
"""Set autocommit mode."""
97+
return await asyncio.to_thread(setattr, self._wrapped_connection, 'autocommit', value)
98+
99+
def __getattr__(self, name):
100+
"""Delegate all other attributes/methods to the wrapped connection."""
101+
return getattr(self._wrapped_connection, name)
102+
103+
def __del__(self):
104+
"""Delegate cleanup to wrapped connection."""
105+
if hasattr(self, '_wrapped_connection'):
106+
# Let the wrapped connection handle its own cleanup
107+
pass

aws_advanced_python_wrapper/tortoise/backends/base/client.py

Lines changed: 2 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414

1515
from __future__ import annotations
1616

17-
import asyncio
18-
from contextlib import asynccontextmanager
1917
from typing import Any, Callable, Dict, Generic, cast
2018

2119
import mysql.connector
@@ -25,93 +23,8 @@
2523
from tortoise.connection import connections
2624
from tortoise.exceptions import TransactionManagementError
2725

28-
from aws_advanced_python_wrapper import AwsWrapperConnection
29-
30-
31-
class AwsWrapperAsyncConnector:
32-
"""Class for creating and closing AWS wrapper connections."""
33-
34-
@staticmethod
35-
async def connect_with_aws_wrapper(connect_func: Callable, **kwargs) -> AwsConnectionAsyncWrapper:
36-
"""Create an AWS wrapper connection with async cursor support."""
37-
connection = await asyncio.to_thread(
38-
AwsWrapperConnection.connect, connect_func, **kwargs
39-
)
40-
return AwsConnectionAsyncWrapper(connection)
41-
42-
@staticmethod
43-
async def close_aws_wrapper(connection: AwsWrapperConnection) -> None:
44-
"""Close an AWS wrapper connection asynchronously."""
45-
await asyncio.to_thread(connection.close)
46-
47-
48-
class AwsCursorAsyncWrapper:
49-
"""Wraps sync AwsCursor cursor with async support."""
50-
51-
def __init__(self, sync_cursor):
52-
self._cursor = sync_cursor
53-
54-
async def execute(self, query, params=None):
55-
"""Execute a query asynchronously."""
56-
return await asyncio.to_thread(self._cursor.execute, query, params)
57-
58-
async def executemany(self, query, params_list):
59-
"""Execute multiple queries asynchronously."""
60-
return await asyncio.to_thread(self._cursor.executemany, query, params_list)
61-
62-
async def fetchall(self):
63-
"""Fetch all results asynchronously."""
64-
return await asyncio.to_thread(self._cursor.fetchall)
65-
66-
async def fetchone(self):
67-
"""Fetch one result asynchronously."""
68-
return await asyncio.to_thread(self._cursor.fetchone)
69-
70-
async def close(self):
71-
"""Close cursor asynchronously."""
72-
return await asyncio.to_thread(self._cursor.close)
73-
74-
def __getattr__(self, name):
75-
"""Delegate non-async attributes to the wrapped cursor."""
76-
return getattr(self._cursor, name)
77-
78-
79-
class AwsConnectionAsyncWrapper(AwsWrapperConnection):
80-
"""Wraps sync AwsConnection with async cursor support."""
81-
82-
def __init__(self, connection: AwsWrapperConnection):
83-
self._wrapped_connection = connection
84-
85-
@asynccontextmanager
86-
async def cursor(self):
87-
"""Create an async cursor context manager."""
88-
cursor_obj = await asyncio.to_thread(self._wrapped_connection.cursor)
89-
try:
90-
yield AwsCursorAsyncWrapper(cursor_obj)
91-
finally:
92-
await asyncio.to_thread(cursor_obj.close)
93-
94-
async def rollback(self):
95-
"""Rollback the current transaction."""
96-
return await asyncio.to_thread(self._wrapped_connection.rollback)
97-
98-
async def commit(self):
99-
"""Commit the current transaction."""
100-
return await asyncio.to_thread(self._wrapped_connection.commit)
101-
102-
async def set_autocommit(self, value: bool):
103-
"""Set autocommit mode."""
104-
return await asyncio.to_thread(setattr, self._wrapped_connection, 'autocommit', value)
105-
106-
def __getattr__(self, name):
107-
"""Delegate all other attributes/methods to the wrapped connection."""
108-
return getattr(self._wrapped_connection, name)
109-
110-
def __del__(self):
111-
"""Delegate cleanup to wrapped connection."""
112-
if hasattr(self, '_wrapped_connection'):
113-
# Let the wrapped connection handle its own cleanup
114-
pass
26+
from aws_advanced_python_wrapper.tortoise.async_support.async_wrapper import (
27+
AwsConnectionAsyncWrapper, AwsWrapperAsyncConnector)
11528

11629

11730
class AwsBaseDBAsyncClient(BaseDBAsyncClient):

aws_advanced_python_wrapper/tortoise/backends/mysql/client.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@
3030
OperationalError, TransactionManagementError)
3131

3232
from aws_advanced_python_wrapper.errors import AwsWrapperError, FailoverError
33+
from aws_advanced_python_wrapper.tortoise.async_support.async_wrapper import (
34+
AwsConnectionAsyncWrapper)
3335
from aws_advanced_python_wrapper.tortoise.backends.base.client import (
34-
AwsBaseDBAsyncClient, AwsConnectionAsyncWrapper, AwsTransactionalDBClient,
36+
AwsBaseDBAsyncClient, AwsTransactionalDBClient,
3537
TortoiseAwsClientConnectionWrapper, TortoiseAwsClientTransactionContext)
3638
from aws_advanced_python_wrapper.tortoise.backends.mysql.executor import \
3739
AwsMySQLExecutor

0 commit comments

Comments
 (0)