Skip to content

Commit e863f98

Browse files
committed
apply review comments
1 parent ed6384e commit e863f98

File tree

3 files changed

+41
-40
lines changed

3 files changed

+41
-40
lines changed

aws_advanced_python_wrapper/read_write_splitting_plugin.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,11 @@ def __init__(
5656
self,
5757
plugin_service: PluginService,
5858
props: Properties,
59-
connection_handler: ConnectionHandler,
59+
connection_handler: ReadWriteConnectionHandler,
6060
):
6161
self._plugin_service: PluginService = plugin_service
6262
self._properties: Properties = props
63-
self._connection_handler: ConnectionHandler = connection_handler
63+
self._connection_handler: ReadWriteConnectionHandler = connection_handler
6464
self._writer_connection: Optional[Connection] = None
6565
self._reader_connection: Optional[Connection] = None
6666
self._writer_host_info: Optional[HostInfo] = None
@@ -310,7 +310,7 @@ def _switch_to_reader_connection(self):
310310

311311
if (
312312
self._reader_connection is not None
313-
and not self._connection_handler.old_reader_can_be_used(
313+
and not self._connection_handler.can_host_be_used(
314314
self._reader_host_info
315315
)
316316
):
@@ -344,7 +344,7 @@ def _switch_to_reader_connection(self):
344344
self._close_connection_if_idle(self._writer_connection)
345345

346346
def _initialize_reader_connection(self):
347-
if self._connection_handler.need_connect_to_writer():
347+
if self._connection_handler.has_no_readers():
348348
if not self._is_connection_usable(
349349
self._writer_connection, self._plugin_service.driver_dialect
350350
):
@@ -437,7 +437,7 @@ def close_connection(conn: Optional[Connection], driver_dialect: DriverDialect):
437437
return
438438

439439

440-
class ConnectionHandler(Protocol):
440+
class ReadWriteConnectionHandler(Protocol):
441441
"""Protocol for handling writer/reader connection logic."""
442442

443443
@property
@@ -491,15 +491,15 @@ def is_writer_host(self, current_host: HostInfo) -> bool:
491491
...
492492

493493
def is_reader_host(self, current_host: HostInfo) -> bool:
494-
"""Return true if the current host fits the criteria of a writer host."""
494+
"""Return true if the current host fits the criteria of a reader host."""
495495
...
496496

497-
def old_reader_can_be_used(self, reader_host_info: HostInfo) -> bool:
498-
"""Return true if the current host can be used to switch connection to."""
497+
def can_host_be_used(self, host_info: HostInfo) -> bool:
498+
"""Returns true if connections can be switched to the given host"""
499499
...
500500

501-
def need_connect_to_writer(self) -> bool:
502-
"""Return true if switching to reader should instead connect to writer."""
501+
def has_no_readers(self) -> bool:
502+
"""Return true if there are no readers in the host list"""
503503
...
504504

505505
def refresh_and_store_host_list(
@@ -509,7 +509,7 @@ def refresh_and_store_host_list(
509509
...
510510

511511

512-
class TopologyBasedConnectionHandler(ConnectionHandler):
512+
class TopologyBasedConnectionHandler(ReadWriteConnectionHandler):
513513
"""Topology based implementation of connection handling logic."""
514514

515515
def __init__(self, plugin_service: PluginService, props: Properties):
@@ -615,11 +615,11 @@ def get_verified_initial_connection(
615615

616616
return current_conn
617617

618-
def old_reader_can_be_used(self, reader_host_info: HostInfo) -> bool:
618+
def can_host_be_used(self, host_info: HostInfo) -> bool:
619619
hostnames = [host_info.host for host_info in self._hosts]
620-
return reader_host_info is not None and reader_host_info.host in hostnames
620+
return host_info.host in hostnames
621621

622-
def need_connect_to_writer(self) -> bool:
622+
def has_no_readers(self) -> bool:
623623
if len(self._hosts) == 1:
624624
return self._get_writer() is not None
625625
return False

aws_advanced_python_wrapper/simple_read_write_splitting_plugin.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from aws_advanced_python_wrapper.host_availability import HostAvailability
2121
from aws_advanced_python_wrapper.read_write_splitting_plugin import (
22-
ConnectionHandler, ReadWriteSplittingConnectionManager)
22+
ReadWriteConnectionHandler, ReadWriteSplittingConnectionManager)
2323
from aws_advanced_python_wrapper.utils.rds_url_type import RdsUrlType
2424
from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils
2525

@@ -37,14 +37,14 @@
3737
from aws_advanced_python_wrapper.utils.properties import WrapperProperties
3838

3939

40-
class EndpointBasedConnectionHandler(ConnectionHandler):
40+
class EndpointBasedConnectionHandler(ReadWriteConnectionHandler):
4141
"""Endpoint based implementation of connection handling logic."""
4242

4343
def __init__(self, plugin_service: PluginService, props: Properties):
44-
self._read_endpoint: str = EndpointBasedConnectionHandler._verify_parameter(
44+
read_endpoint: str = EndpointBasedConnectionHandler._verify_parameter(
4545
WrapperProperties.SRW_READ_ENDPOINT, props, str, required=True
4646
)
47-
self._write_endpoint: str = EndpointBasedConnectionHandler._verify_parameter(
47+
write_endpoint: str = EndpointBasedConnectionHandler._verify_parameter(
4848
WrapperProperties.SRW_WRITE_ENDPOINT, props, str, required=True
4949
)
5050

@@ -60,17 +60,19 @@ def __init__(self, plugin_service: PluginService, props: Properties):
6060
WrapperProperties.SRW_CONNECT_RETRY_INTERVAL_MS, props, int, lambda x: x > 0
6161
)
6262

63-
self._verify_opened_connection_type: Optional[HostRole] = (
64-
EndpointBasedConnectionHandler._parse_connection_type(
63+
self._verify_initial_connection_type: Optional[HostRole] = (
64+
EndpointBasedConnectionHandler._parse_role(
6565
WrapperProperties.SRW_VERIFY_INITIAL_CONNECTION_TYPE.get(props)
6666
)
6767
)
6868

6969
self._plugin_service: PluginService = plugin_service
7070
self._rds_utils: RdsUtils = RdsUtils()
7171
self._host_list_provider_service: Optional[HostListProviderService] = None
72-
self._write_endpoint_host_info: HostInfo = self._create_host_info(self._write_endpoint, HostRole.WRITER)
73-
self._read_endpoint_host_info: HostInfo = self._create_host_info(self._read_endpoint, HostRole.READER)
72+
self._write_endpoint_host_info: HostInfo = self._create_host_info(write_endpoint, HostRole.WRITER)
73+
self._read_endpoint_host_info: HostInfo = self._create_host_info(read_endpoint, HostRole.READER)
74+
self._write_endpoint = write_endpoint.casefold()
75+
self._read_endpoint = read_endpoint.casefold()
7476

7577
@property
7678
def host_list_provider_service(self) -> Optional[HostListProviderService]:
@@ -116,12 +118,12 @@ def get_verified_initial_connection(
116118

117119
if (
118120
url_type == RdsUrlType.RDS_WRITER_CLUSTER
119-
or self._verify_opened_connection_type == HostRole.WRITER
121+
or self._verify_initial_connection_type == HostRole.WRITER
120122
):
121123
conn = self._get_verified_connection(host_info, HostRole.WRITER, plugin_service_connect_func, connect_func)
122124
elif (
123125
url_type == RdsUrlType.RDS_READER_CLUSTER
124-
or self._verify_opened_connection_type == HostRole.READER
126+
or self._verify_initial_connection_type == HostRole.READER
125127
):
126128
conn = self._get_verified_connection(host_info, HostRole.READER, plugin_service_connect_func, connect_func)
127129

@@ -154,10 +156,8 @@ def _get_verified_connection(
154156
try:
155157
if connect_func is not None:
156158
candidate_conn = connect_func()
157-
elif host_info is not None:
158-
candidate_conn = plugin_service_connect_func(host_info)
159159
else:
160-
return None
160+
candidate_conn = plugin_service_connect_func(host_info)
161161

162162
if candidate_conn is None:
163163
self._delay()
@@ -178,12 +178,12 @@ def _get_verified_connection(
178178

179179
return None
180180

181-
def old_reader_can_be_used(self, reader_host_info: HostInfo) -> bool:
182-
# Assume that the old reader can always be used, no topology-based information to check.
181+
def can_host_be_used(self, host_info: HostInfo) -> bool:
182+
# Assume that the host can always be used, no topology-based information to check.
183183
return True
184184

185-
def need_connect_to_writer(self) -> bool:
186-
# SetReadOnly(true) will always connect to the read_endpoint, and not the writer.
185+
def has_no_readers(self) -> bool:
186+
# SetReadOnly(true) will always connect to the read_endpoint, regardless of number of readers.
187187
return False
188188

189189
def refresh_and_store_host_list(
@@ -218,14 +218,14 @@ def should_update_reader_with_current_conn(
218218

219219
def is_writer_host(self, current_host: HostInfo) -> bool:
220220
return (
221-
current_host.host.casefold() == self._write_endpoint.casefold()
222-
or current_host.url.casefold() == self._write_endpoint.casefold()
221+
current_host.host.casefold() == self._write_endpoint
222+
or current_host.url.casefold() == self._write_endpoint
223223
)
224224

225225
def is_reader_host(self, current_host: HostInfo) -> bool:
226226
return (
227-
current_host.host.casefold() == self._read_endpoint.casefold()
228-
or current_host.url.casefold() == self._read_endpoint.casefold()
227+
current_host.host.casefold() == self._read_endpoint
228+
or current_host.url.casefold() == self._read_endpoint
229229
)
230230

231231
def _create_host_info(self, endpoint: str, role: HostRole) -> HostInfo:
@@ -272,14 +272,14 @@ def _delay(self):
272272
sleep(self._connect_retry_interval_ms / 1000)
273273

274274
@staticmethod
275-
def _parse_connection_type(phase_str: Optional[str]) -> HostRole:
276-
if not phase_str:
275+
def _parse_role(role_str: Optional[str]) -> HostRole:
276+
if not role_str:
277277
return HostRole.UNKNOWN
278278

279-
phase_upper = phase_str.lower()
280-
if phase_upper == "reader":
279+
phase_lower = role_str.lower()
280+
if phase_lower == "reader":
281281
return HostRole.READER
282-
elif phase_upper == "writer":
282+
elif phase_lower == "writer":
283283
return HostRole.WRITER
284284
else:
285285
raise ValueError(

tests/integration/container/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def pytest_runtest_setup(item):
140140
CustomEndpointPlugin._monitors.clear()
141141
CustomEndpointMonitor._custom_endpoint_info_cache.clear()
142142
MonitoringThreadContainer.clean_up()
143+
ConnectionProviderManager.release_resources()
143144

144145
ConnectionProviderManager.reset_provider()
145146
DatabaseDialectManager.reset_custom_dialect()

0 commit comments

Comments
 (0)