Skip to content

Commit ed6384e

Browse files
committed
refactor: parametrize unit tests
1 parent c7253b5 commit ed6384e

File tree

4 files changed

+372
-525
lines changed

4 files changed

+372
-525
lines changed

aws_advanced_python_wrapper/read_write_splitting_plugin.py

Lines changed: 25 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class ReadWriteSplittingConnectionManager(Plugin):
5050
"Connection.set_read_only",
5151
}
5252
_POOL_PROVIDER_CLASS_NAME = "aws_advanced_python_wrapper.sql_alchemy_connection_provider.SqlAlchemyPooledConnectionProvider"
53+
_CLOSE_METHOD = "Connection.close"
5354

5455
def __init__(
5556
self,
@@ -94,8 +95,7 @@ def connect(
9495
connect_func: Callable,
9596
) -> Connection:
9697
return self._connection_handler.get_verified_initial_connection(
97-
host_info, props, is_initial_connection, connect_func
98-
)
98+
host_info, is_initial_connection, lambda x: self._plugin_service.connect(x, props, self), connect_func)
9999

100100
def notify_connection_changed(
101101
self, changes: Set[ConnectionEvent]
@@ -140,13 +140,13 @@ def execute(
140140
if isinstance(ex, FailoverError):
141141
logger.debug(
142142
"ReadWriteSplittingPlugin.FailoverExceptionWhileExecutingCommand",
143-
method_name,
143+
method_name
144144
)
145145
self._close_idle_connections()
146146
else:
147147
logger.debug(
148148
"ReadWriteSplittingPlugin.ExceptionWhileExecutingCommand",
149-
method_name,
149+
method_name
150150
)
151151
raise ex
152152

@@ -184,13 +184,13 @@ def _set_reader_connection(
184184
)
185185

186186
def _initialize_writer_connection(self):
187-
conn, writer_host = self._connection_handler.open_new_writer_connection()
187+
conn, writer_host = self._connection_handler.open_new_writer_connection(lambda x: self._plugin_service.connect(x, self._properties, self))
188188

189189
if conn is None:
190190
self.log_and_raise_exception(
191191
"ReadWriteSplittingPlugin.FailedToConnectToWriter"
192192
)
193-
return
193+
return None
194194

195195
provider = self._conn_provider_manager.get_connection_provider(
196196
writer_host, self._properties
@@ -335,9 +335,7 @@ def _switch_to_reader_connection(self):
335335
self._reader_host_info.url,
336336
)
337337

338-
ReadWriteSplittingConnectionManager.close_connection(
339-
self._reader_connection
340-
)
338+
ReadWriteSplittingConnectionManager.close_connection(self._reader_connection, driver_dialect)
341339
self._reader_connection = None
342340
self._reader_host_info = None
343341
self._initialize_reader_connection()
@@ -356,7 +354,7 @@ def _initialize_reader_connection(self):
356354
)
357355
return
358356

359-
conn, reader_host = self._connection_handler.open_new_reader_connection()
357+
conn, reader_host = self._connection_handler.open_new_reader_connection(lambda x: self._plugin_service.connect(x, self._properties, self))
360358

361359
if conn is None or reader_host is None:
362360
self.log_and_raise_exception("ReadWriteSplittingPlugin.NoReadersAvailable")
@@ -392,7 +390,7 @@ def _close_connection_if_idle(self, internal_conn: Optional[Connection]):
392390
if internal_conn != current_conn and self._is_connection_usable(
393391
internal_conn, driver_dialect
394392
):
395-
internal_conn.close()
393+
driver_dialect.execute(ReadWriteSplittingConnectionManager._CLOSE_METHOD, lambda: internal_conn.close())
396394
if internal_conn == self._writer_connection:
397395
self._writer_connection = None
398396
self._writer_host_info = None
@@ -420,10 +418,8 @@ def log_and_raise_exception(log_msg: str):
420418
raise ReadWriteSplittingError(Messages.get(log_msg))
421419

422420
@staticmethod
423-
def _is_connection_usable(
424-
conn: Optional[Connection], driver_dialect: Optional[DriverDialect]
425-
):
426-
if conn is None or driver_dialect is None:
421+
def _is_connection_usable(conn: Optional[Connection], driver_dialect: DriverDialect):
422+
if conn is None:
427423
return False
428424
try:
429425
return not driver_dialect.is_closed(conn)
@@ -432,10 +428,10 @@ def _is_connection_usable(
432428
return False
433429

434430
@staticmethod
435-
def close_connection(connection: Optional[Connection]):
436-
if connection is not None:
431+
def close_connection(conn: Optional[Connection], driver_dialect: DriverDialect):
432+
if conn is not None:
437433
try:
438-
connection.close()
434+
driver_dialect.execute(ReadWriteSplittingConnectionManager._CLOSE_METHOD, lambda: conn.close())
439435
except Exception:
440436
# Swallow exception
441437
return
@@ -456,21 +452,23 @@ def host_list_provider_service(self, new_value: int) -> None:
456452

457453
def open_new_writer_connection(
458454
self,
455+
plugin_service_connect_func: Callable[[HostInfo], Connection],
459456
) -> tuple[Optional[Connection], Optional[HostInfo]]:
460457
"""Open a writer connection."""
461458
...
462459

463460
def open_new_reader_connection(
464461
self,
462+
plugin_service_connect_func: Callable[[HostInfo], Connection],
465463
) -> tuple[Optional[Connection], Optional[HostInfo]]:
466464
"""Open a reader connection."""
467465
...
468466

469467
def get_verified_initial_connection(
470468
self,
471469
host_info: HostInfo,
472-
props: Properties,
473470
is_initial_connection: bool,
471+
plugin_service_connect_func: Callable[[HostInfo], Connection],
474472
connect_func: Callable,
475473
) -> Connection:
476474
"""Verify initial connection or return normal workflow."""
@@ -516,9 +514,8 @@ class TopologyBasedConnectionHandler(ConnectionHandler):
516514

517515
def __init__(self, plugin_service: PluginService, props: Properties):
518516
self._plugin_service: PluginService = plugin_service
519-
self._properties: Properties = props
520517
self._host_list_provider_service: Optional[HostListProviderService] = None
521-
strategy = WrapperProperties.READER_HOST_SELECTOR_STRATEGY.get(self._properties)
518+
strategy = WrapperProperties.READER_HOST_SELECTOR_STRATEGY.get(props)
522519
if strategy is not None:
523520
self._reader_selector_strategy = strategy
524521
else:
@@ -539,17 +536,19 @@ def host_list_provider_service(self, new_value: HostListProviderService) -> None
539536

540537
def open_new_writer_connection(
541538
self,
539+
plugin_service_connect_func: Callable[[HostInfo], Connection],
542540
) -> tuple[Optional[Connection], Optional[HostInfo]]:
543541
writer_host = self._get_writer()
544542
if writer_host is None:
545543
return None, None
546544

547-
conn = self._plugin_service.connect(writer_host, self._properties, None)
545+
conn = plugin_service_connect_func(writer_host)
548546

549547
return conn, writer_host
550548

551549
def open_new_reader_connection(
552550
self,
551+
plugin_service_connect_func: Callable[[HostInfo], Connection],
553552
) -> tuple[Optional[Connection], Optional[HostInfo]]:
554553
conn: Optional[Connection] = None
555554
reader_host: Optional[HostInfo] = None
@@ -561,7 +560,7 @@ def open_new_reader_connection(
561560
)
562561
if host is not None:
563562
try:
564-
conn = self._plugin_service.connect(host, self._properties, None)
563+
conn = plugin_service_connect_func(host)
565564
reader_host = host
566565
break
567566
except Exception:
@@ -574,8 +573,8 @@ def open_new_reader_connection(
574573
def get_verified_initial_connection(
575574
self,
576575
host_info: HostInfo,
577-
props: Properties,
578576
is_initial_connection: bool,
577+
plugin_service_connect_func: Callable[[HostInfo], Connection],
579578
connect_func: Callable,
580579
) -> Connection:
581580
if not self._plugin_service.accepts_strategy(
@@ -670,12 +669,9 @@ def _get_writer(self) -> Optional[HostInfo]:
670669

671670

672671
class ReadWriteSplittingPlugin(ReadWriteSplittingConnectionManager):
673-
def __init__(self, plugin_service, props: Properties):
672+
def __init__(self, plugin_service: PluginService, props: Properties):
674673
# The read/write splitting plugin handles connections based on topology.
675-
connection_handler = TopologyBasedConnectionHandler(
676-
plugin_service,
677-
props,
678-
)
674+
connection_handler = TopologyBasedConnectionHandler(plugin_service, props)
679675

680676
super().__init__(plugin_service, props, connection_handler)
681677

aws_advanced_python_wrapper/simple_read_write_splitting_plugin.py

Lines changed: 26 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from __future__ import annotations
1616

1717
from time import perf_counter_ns, sleep
18-
from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Type
18+
from typing import TYPE_CHECKING, Callable, Optional, Type, TypeVar
1919

2020
from aws_advanced_python_wrapper.host_availability import HostAvailability
2121
from aws_advanced_python_wrapper.read_write_splitting_plugin import (
@@ -67,15 +67,10 @@ def __init__(self, plugin_service: PluginService, props: Properties):
6767
)
6868

6969
self._plugin_service: PluginService = plugin_service
70-
self._properties: Properties = props
7170
self._rds_utils: RdsUtils = RdsUtils()
7271
self._host_list_provider_service: Optional[HostListProviderService] = None
73-
self._write_endpoint_host_info: HostInfo = self._create_host_info(
74-
self._write_endpoint, HostRole.WRITER
75-
)
76-
self._read_endpoint_host_info: HostInfo = self._create_host_info(
77-
self._read_endpoint, HostRole.READER
78-
)
72+
self._write_endpoint_host_info: HostInfo = self._create_host_info(self._write_endpoint, HostRole.WRITER)
73+
self._read_endpoint_host_info: HostInfo = self._create_host_info(self._read_endpoint, HostRole.READER)
7974

8075
@property
8176
def host_list_provider_service(self) -> Optional[HostListProviderService]:
@@ -87,39 +82,29 @@ def host_list_provider_service(self, new_value: HostListProviderService) -> None
8782

8883
def open_new_writer_connection(
8984
self,
85+
plugin_service_connect_func: Callable[[HostInfo], Connection],
9086
) -> tuple[Optional[Connection], Optional[HostInfo]]:
91-
conn: Optional[Connection] = None
9287
if self._verify_new_connections:
93-
conn = self._get_verified_connection(
94-
self._properties, self._write_endpoint_host_info, HostRole.WRITER
95-
)
96-
else:
97-
conn = self._plugin_service.connect(
98-
self._write_endpoint_host_info, self._properties, None
99-
)
88+
return self._get_verified_connection(self._write_endpoint_host_info, HostRole.WRITER, plugin_service_connect_func), \
89+
self._write_endpoint_host_info
10090

101-
return conn, self._write_endpoint_host_info
91+
return plugin_service_connect_func(self._write_endpoint_host_info), self._write_endpoint_host_info
10292

10393
def open_new_reader_connection(
10494
self,
95+
plugin_service_connect_func: Callable[[HostInfo], Connection],
10596
) -> tuple[Optional[Connection], Optional[HostInfo]]:
106-
conn: Optional[Connection] = None
10797
if self._verify_new_connections:
108-
conn = self._get_verified_connection(
109-
self._properties, self._read_endpoint_host_info, HostRole.READER
110-
)
111-
else:
112-
conn = self._plugin_service.connect(
113-
self._read_endpoint_host_info, self._properties, None
114-
)
98+
return self._get_verified_connection(self._read_endpoint_host_info, HostRole.READER, plugin_service_connect_func), \
99+
self._read_endpoint_host_info
115100

116-
return conn, self._read_endpoint_host_info
101+
return plugin_service_connect_func(self._read_endpoint_host_info), self._read_endpoint_host_info
117102

118103
def get_verified_initial_connection(
119104
self,
120105
host_info: HostInfo,
121-
props: Properties,
122106
is_initial_connection: bool,
107+
plugin_service_connect_func: Callable[[HostInfo], Connection],
123108
connect_func: Callable,
124109
) -> Connection:
125110
if not is_initial_connection or not self._verify_new_connections:
@@ -133,34 +118,30 @@ def get_verified_initial_connection(
133118
url_type == RdsUrlType.RDS_WRITER_CLUSTER
134119
or self._verify_opened_connection_type == HostRole.WRITER
135120
):
136-
conn = self._get_verified_connection(
137-
props, host_info, HostRole.WRITER, connect_func
138-
)
121+
conn = self._get_verified_connection(host_info, HostRole.WRITER, plugin_service_connect_func, connect_func)
139122
elif (
140123
url_type == RdsUrlType.RDS_READER_CLUSTER
141124
or self._verify_opened_connection_type == HostRole.READER
142125
):
143-
conn = self._get_verified_connection(
144-
props, host_info, HostRole.READER, connect_func
145-
)
126+
conn = self._get_verified_connection(host_info, HostRole.READER, plugin_service_connect_func, connect_func)
146127

147128
if conn is None:
148129
conn = connect_func()
149130

150-
self._set_initial_connection_host_info(conn, host_info)
131+
self._set_initial_connection_host_info(host_info)
151132
return conn
152133

153-
def _set_initial_connection_host_info(self, conn: Connection, host_info: HostInfo):
134+
def _set_initial_connection_host_info(self, host_info: HostInfo):
154135
if self._host_list_provider_service is None:
155136
return
156137

157138
self._host_list_provider_service.initial_connection_host_info = host_info
158139

159140
def _get_verified_connection(
160141
self,
161-
props: Properties,
162142
host_info: HostInfo,
163143
role: HostRole,
144+
plugin_service_connect_func: Callable[[HostInfo], Connection],
164145
connect_func: Optional[Callable] = None,
165146
) -> Optional[Connection]:
166147
end_time_nano = perf_counter_ns() + (self._connect_retry_timeout_ms * 1000000)
@@ -174,9 +155,7 @@ def _get_verified_connection(
174155
if connect_func is not None:
175156
candidate_conn = connect_func()
176157
elif host_info is not None:
177-
candidate_conn = self._plugin_service.connect(
178-
host_info, props, None
179-
)
158+
candidate_conn = plugin_service_connect_func(host_info)
180159
else:
181160
return None
182161

@@ -187,14 +166,14 @@ def _get_verified_connection(
187166
actual_role = self._plugin_service.get_host_role(candidate_conn)
188167

189168
if actual_role != role:
190-
ReadWriteSplittingConnectionManager.close_connection(candidate_conn)
169+
ReadWriteSplittingConnectionManager.close_connection(candidate_conn, self._plugin_service.driver_dialect)
191170
self._delay()
192171
continue
193172

194173
return candidate_conn
195174

196175
except Exception:
197-
ReadWriteSplittingConnectionManager.close_connection(candidate_conn)
176+
ReadWriteSplittingConnectionManager.close_connection(candidate_conn, self._plugin_service.driver_dialect)
198177
self._delay()
199178

200179
return None
@@ -249,28 +228,18 @@ def is_reader_host(self, current_host: HostInfo) -> bool:
249228
or current_host.url.casefold() == self._read_endpoint.casefold()
250229
)
251230

252-
def _create_host_info(self, endpoint, role: HostRole) -> HostInfo:
231+
def _create_host_info(self, endpoint: str, role: HostRole) -> HostInfo:
253232
endpoint = endpoint.strip()
254233
host = endpoint
255-
port = self._plugin_service.database_dialect.default_port
234+
port = self._plugin_service.database_dialect.default_port if not self._plugin_service.current_host_info.is_port_specified() \
235+
else self._plugin_service.current_host_info.port
256236
colon_index = endpoint.rfind(":")
257237

258238
if colon_index != -1:
239+
host = endpoint[:colon_index]
259240
port_str = endpoint[colon_index + 1:]
260241
if port_str.isdigit():
261-
host = endpoint[:colon_index]
262242
port = int(port_str)
263-
else:
264-
if (
265-
self._host_list_provider_service is not None
266-
and self._host_list_provider_service.initial_connection_host_info
267-
is not None
268-
and self._host_list_provider_service.initial_connection_host_info.port
269-
!= HostInfo.NO_PORT
270-
):
271-
port = (
272-
self._host_list_provider_service.initial_connection_host_info.port
273-
)
274243

275244
return HostInfo(
276245
host=host, port=port, role=role, availability=HostAvailability.AVAILABLE
@@ -322,12 +291,9 @@ def _parse_connection_type(phase_str: Optional[str]) -> HostRole:
322291

323292

324293
class SimpleReadWriteSplittingPlugin(ReadWriteSplittingConnectionManager):
325-
def __init__(self, plugin_service, props: Properties):
294+
def __init__(self, plugin_service: PluginService, props: Properties):
326295
# The simple read/write splitting plugin handles connections based on configuration parameter endpoints.
327-
connection_handler = EndpointBasedConnectionHandler(
328-
plugin_service,
329-
props,
330-
)
296+
connection_handler = EndpointBasedConnectionHandler(plugin_service, props)
331297

332298
super().__init__(plugin_service, props, connection_handler)
333299

0 commit comments

Comments
 (0)