1919
2020from aws_advanced_python_wrapper .host_availability import HostAvailability
2121from aws_advanced_python_wrapper .read_write_splitting_plugin import (
22- ConnectionHandler , ReadWriteSplittingConnectionManager )
22+ ReadWriteConnectionHandler , ReadWriteSplittingConnectionManager )
2323from aws_advanced_python_wrapper .utils .rds_url_type import RdsUrlType
2424from aws_advanced_python_wrapper .utils .rdsutils import RdsUtils
2525
3737from aws_advanced_python_wrapper .utils .properties import WrapperProperties
3838
3939
40- class EndpointBasedConnectionHandler (ConnectionHandler ):
40+ class EndpointBasedConnectionHandler (ReadWriteConnectionHandler ):
4141 """Endpoint based implementation of connection handling logic."""
4242
4343 def __init__ (self , plugin_service : PluginService , props : Properties ):
44- self . _read_endpoint : str = EndpointBasedConnectionHandler ._verify_parameter (
44+ read_endpoint : str = EndpointBasedConnectionHandler ._verify_parameter (
4545 WrapperProperties .SRW_READ_ENDPOINT , props , str , required = True
4646 )
47- self . _write_endpoint : str = EndpointBasedConnectionHandler ._verify_parameter (
47+ write_endpoint : str = EndpointBasedConnectionHandler ._verify_parameter (
4848 WrapperProperties .SRW_WRITE_ENDPOINT , props , str , required = True
4949 )
5050
@@ -60,17 +60,19 @@ def __init__(self, plugin_service: PluginService, props: Properties):
6060 WrapperProperties .SRW_CONNECT_RETRY_INTERVAL_MS , props , int , lambda x : x > 0
6161 )
6262
63- self ._verify_opened_connection_type : Optional [HostRole ] = (
64- EndpointBasedConnectionHandler ._parse_connection_type (
63+ self ._verify_initial_connection_type : Optional [HostRole ] = (
64+ EndpointBasedConnectionHandler ._parse_role (
6565 WrapperProperties .SRW_VERIFY_INITIAL_CONNECTION_TYPE .get (props )
6666 )
6767 )
6868
6969 self ._plugin_service : PluginService = plugin_service
7070 self ._rds_utils : RdsUtils = RdsUtils ()
7171 self ._host_list_provider_service : Optional [HostListProviderService ] = None
72- self ._write_endpoint_host_info : HostInfo = self ._create_host_info (self ._write_endpoint , HostRole .WRITER )
73- self ._read_endpoint_host_info : HostInfo = self ._create_host_info (self ._read_endpoint , HostRole .READER )
72+ self ._write_endpoint_host_info : HostInfo = self ._create_host_info (write_endpoint , HostRole .WRITER )
73+ self ._read_endpoint_host_info : HostInfo = self ._create_host_info (read_endpoint , HostRole .READER )
74+ self ._write_endpoint = write_endpoint .casefold ()
75+ self ._read_endpoint = read_endpoint .casefold ()
7476
7577 @property
7678 def host_list_provider_service (self ) -> Optional [HostListProviderService ]:
@@ -116,12 +118,12 @@ def get_verified_initial_connection(
116118
117119 if (
118120 url_type == RdsUrlType .RDS_WRITER_CLUSTER
119- or self ._verify_opened_connection_type == HostRole .WRITER
121+ or self ._verify_initial_connection_type == HostRole .WRITER
120122 ):
121123 conn = self ._get_verified_connection (host_info , HostRole .WRITER , plugin_service_connect_func , connect_func )
122124 elif (
123125 url_type == RdsUrlType .RDS_READER_CLUSTER
124- or self ._verify_opened_connection_type == HostRole .READER
126+ or self ._verify_initial_connection_type == HostRole .READER
125127 ):
126128 conn = self ._get_verified_connection (host_info , HostRole .READER , plugin_service_connect_func , connect_func )
127129
@@ -154,10 +156,8 @@ def _get_verified_connection(
154156 try :
155157 if connect_func is not None :
156158 candidate_conn = connect_func ()
157- elif host_info is not None :
158- candidate_conn = plugin_service_connect_func (host_info )
159159 else :
160- return None
160+ candidate_conn = plugin_service_connect_func ( host_info )
161161
162162 if candidate_conn is None :
163163 self ._delay ()
@@ -178,12 +178,12 @@ def _get_verified_connection(
178178
179179 return None
180180
181- def old_reader_can_be_used (self , reader_host_info : HostInfo ) -> bool :
182- # Assume that the old reader can always be used, no topology-based information to check.
181+ def can_host_be_used (self , host_info : HostInfo ) -> bool :
182+ # Assume that the host can always be used, no topology-based information to check.
183183 return True
184184
185- def need_connect_to_writer (self ) -> bool :
186- # SetReadOnly(true) will always connect to the read_endpoint, and not the writer .
185+ def has_no_readers (self ) -> bool :
186+ # SetReadOnly(true) will always connect to the read_endpoint, regardless of number of readers .
187187 return False
188188
189189 def refresh_and_store_host_list (
@@ -218,14 +218,14 @@ def should_update_reader_with_current_conn(
218218
219219 def is_writer_host (self , current_host : HostInfo ) -> bool :
220220 return (
221- current_host .host .casefold () == self ._write_endpoint . casefold ()
222- or current_host .url .casefold () == self ._write_endpoint . casefold ()
221+ current_host .host .casefold () == self ._write_endpoint
222+ or current_host .url .casefold () == self ._write_endpoint
223223 )
224224
225225 def is_reader_host (self , current_host : HostInfo ) -> bool :
226226 return (
227- current_host .host .casefold () == self ._read_endpoint . casefold ()
228- or current_host .url .casefold () == self ._read_endpoint . casefold ()
227+ current_host .host .casefold () == self ._read_endpoint
228+ or current_host .url .casefold () == self ._read_endpoint
229229 )
230230
231231 def _create_host_info (self , endpoint : str , role : HostRole ) -> HostInfo :
@@ -272,14 +272,14 @@ def _delay(self):
272272 sleep (self ._connect_retry_interval_ms / 1000 )
273273
274274 @staticmethod
275- def _parse_connection_type ( phase_str : Optional [str ]) -> HostRole :
276- if not phase_str :
275+ def _parse_role ( role_str : Optional [str ]) -> HostRole :
276+ if not role_str :
277277 return HostRole .UNKNOWN
278278
279- phase_upper = phase_str .lower ()
280- if phase_upper == "reader" :
279+ phase_lower = role_str .lower ()
280+ if phase_lower == "reader" :
281281 return HostRole .READER
282- elif phase_upper == "writer" :
282+ elif phase_lower == "writer" :
283283 return HostRole .WRITER
284284 else :
285285 raise ValueError (
0 commit comments