1515from __future__ import annotations
1616
1717from time import perf_counter_ns , sleep
18- from typing import TYPE_CHECKING , Callable , Optional , TypeVar , Type
18+ from typing import TYPE_CHECKING , Callable , Optional , Type , TypeVar
1919
2020from aws_advanced_python_wrapper .host_availability import HostAvailability
2121from aws_advanced_python_wrapper .read_write_splitting_plugin import (
@@ -67,15 +67,10 @@ def __init__(self, plugin_service: PluginService, props: Properties):
6767 )
6868
6969 self ._plugin_service : PluginService = plugin_service
70- self ._properties : Properties = props
7170 self ._rds_utils : RdsUtils = RdsUtils ()
7271 self ._host_list_provider_service : Optional [HostListProviderService ] = None
73- self ._write_endpoint_host_info : HostInfo = self ._create_host_info (
74- self ._write_endpoint , HostRole .WRITER
75- )
76- self ._read_endpoint_host_info : HostInfo = self ._create_host_info (
77- self ._read_endpoint , HostRole .READER
78- )
72+ self ._write_endpoint_host_info : HostInfo = self ._create_host_info (self ._write_endpoint , HostRole .WRITER )
73+ self ._read_endpoint_host_info : HostInfo = self ._create_host_info (self ._read_endpoint , HostRole .READER )
7974
8075 @property
8176 def host_list_provider_service (self ) -> Optional [HostListProviderService ]:
@@ -87,39 +82,29 @@ def host_list_provider_service(self, new_value: HostListProviderService) -> None
8782
8883 def open_new_writer_connection (
8984 self ,
85+ plugin_service_connect_func : Callable [[HostInfo ], Connection ],
9086 ) -> tuple [Optional [Connection ], Optional [HostInfo ]]:
91- conn : Optional [Connection ] = None
9287 if self ._verify_new_connections :
93- conn = self ._get_verified_connection (
94- self ._properties , self ._write_endpoint_host_info , HostRole .WRITER
95- )
96- else :
97- conn = self ._plugin_service .connect (
98- self ._write_endpoint_host_info , self ._properties , None
99- )
88+ return self ._get_verified_connection (self ._write_endpoint_host_info , HostRole .WRITER , plugin_service_connect_func ), \
89+ self ._write_endpoint_host_info
10090
101- return conn , self ._write_endpoint_host_info
91+ return plugin_service_connect_func ( self . _write_endpoint_host_info ) , self ._write_endpoint_host_info
10292
10393 def open_new_reader_connection (
10494 self ,
95+ plugin_service_connect_func : Callable [[HostInfo ], Connection ],
10596 ) -> tuple [Optional [Connection ], Optional [HostInfo ]]:
106- conn : Optional [Connection ] = None
10797 if self ._verify_new_connections :
108- conn = self ._get_verified_connection (
109- self ._properties , self ._read_endpoint_host_info , HostRole .READER
110- )
111- else :
112- conn = self ._plugin_service .connect (
113- self ._read_endpoint_host_info , self ._properties , None
114- )
98+ return self ._get_verified_connection (self ._read_endpoint_host_info , HostRole .READER , plugin_service_connect_func ), \
99+ self ._read_endpoint_host_info
115100
116- return conn , self ._read_endpoint_host_info
101+ return plugin_service_connect_func ( self . _read_endpoint_host_info ) , self ._read_endpoint_host_info
117102
118103 def get_verified_initial_connection (
119104 self ,
120105 host_info : HostInfo ,
121- props : Properties ,
122106 is_initial_connection : bool ,
107+ plugin_service_connect_func : Callable [[HostInfo ], Connection ],
123108 connect_func : Callable ,
124109 ) -> Connection :
125110 if not is_initial_connection or not self ._verify_new_connections :
@@ -133,34 +118,30 @@ def get_verified_initial_connection(
133118 url_type == RdsUrlType .RDS_WRITER_CLUSTER
134119 or self ._verify_opened_connection_type == HostRole .WRITER
135120 ):
136- conn = self ._get_verified_connection (
137- props , host_info , HostRole .WRITER , connect_func
138- )
121+ conn = self ._get_verified_connection (host_info , HostRole .WRITER , plugin_service_connect_func , connect_func )
139122 elif (
140123 url_type == RdsUrlType .RDS_READER_CLUSTER
141124 or self ._verify_opened_connection_type == HostRole .READER
142125 ):
143- conn = self ._get_verified_connection (
144- props , host_info , HostRole .READER , connect_func
145- )
126+ conn = self ._get_verified_connection (host_info , HostRole .READER , plugin_service_connect_func , connect_func )
146127
147128 if conn is None :
148129 conn = connect_func ()
149130
150- self ._set_initial_connection_host_info (conn , host_info )
131+ self ._set_initial_connection_host_info (host_info )
151132 return conn
152133
153- def _set_initial_connection_host_info (self , conn : Connection , host_info : HostInfo ):
134+ def _set_initial_connection_host_info (self , host_info : HostInfo ):
154135 if self ._host_list_provider_service is None :
155136 return
156137
157138 self ._host_list_provider_service .initial_connection_host_info = host_info
158139
159140 def _get_verified_connection (
160141 self ,
161- props : Properties ,
162142 host_info : HostInfo ,
163143 role : HostRole ,
144+ plugin_service_connect_func : Callable [[HostInfo ], Connection ],
164145 connect_func : Optional [Callable ] = None ,
165146 ) -> Optional [Connection ]:
166147 end_time_nano = perf_counter_ns () + (self ._connect_retry_timeout_ms * 1000000 )
@@ -174,9 +155,7 @@ def _get_verified_connection(
174155 if connect_func is not None :
175156 candidate_conn = connect_func ()
176157 elif host_info is not None :
177- candidate_conn = self ._plugin_service .connect (
178- host_info , props , None
179- )
158+ candidate_conn = plugin_service_connect_func (host_info )
180159 else :
181160 return None
182161
@@ -187,14 +166,14 @@ def _get_verified_connection(
187166 actual_role = self ._plugin_service .get_host_role (candidate_conn )
188167
189168 if actual_role != role :
190- ReadWriteSplittingConnectionManager .close_connection (candidate_conn )
169+ ReadWriteSplittingConnectionManager .close_connection (candidate_conn , self . _plugin_service . driver_dialect )
191170 self ._delay ()
192171 continue
193172
194173 return candidate_conn
195174
196175 except Exception :
197- ReadWriteSplittingConnectionManager .close_connection (candidate_conn )
176+ ReadWriteSplittingConnectionManager .close_connection (candidate_conn , self . _plugin_service . driver_dialect )
198177 self ._delay ()
199178
200179 return None
@@ -249,28 +228,18 @@ def is_reader_host(self, current_host: HostInfo) -> bool:
249228 or current_host .url .casefold () == self ._read_endpoint .casefold ()
250229 )
251230
252- def _create_host_info (self , endpoint , role : HostRole ) -> HostInfo :
231+ def _create_host_info (self , endpoint : str , role : HostRole ) -> HostInfo :
253232 endpoint = endpoint .strip ()
254233 host = endpoint
255- port = self ._plugin_service .database_dialect .default_port
234+ port = self ._plugin_service .database_dialect .default_port if not self ._plugin_service .current_host_info .is_port_specified () \
235+ else self ._plugin_service .current_host_info .port
256236 colon_index = endpoint .rfind (":" )
257237
258238 if colon_index != - 1 :
239+ host = endpoint [:colon_index ]
259240 port_str = endpoint [colon_index + 1 :]
260241 if port_str .isdigit ():
261- host = endpoint [:colon_index ]
262242 port = int (port_str )
263- else :
264- if (
265- self ._host_list_provider_service is not None
266- and self ._host_list_provider_service .initial_connection_host_info
267- is not None
268- and self ._host_list_provider_service .initial_connection_host_info .port
269- != HostInfo .NO_PORT
270- ):
271- port = (
272- self ._host_list_provider_service .initial_connection_host_info .port
273- )
274243
275244 return HostInfo (
276245 host = host , port = port , role = role , availability = HostAvailability .AVAILABLE
@@ -322,12 +291,9 @@ def _parse_connection_type(phase_str: Optional[str]) -> HostRole:
322291
323292
324293class SimpleReadWriteSplittingPlugin (ReadWriteSplittingConnectionManager ):
325- def __init__ (self , plugin_service , props : Properties ):
294+ def __init__ (self , plugin_service : PluginService , props : Properties ):
326295 # The simple read/write splitting plugin handles connections based on configuration parameter endpoints.
327- connection_handler = EndpointBasedConnectionHandler (
328- plugin_service ,
329- props ,
330- )
296+ connection_handler = EndpointBasedConnectionHandler (plugin_service , props )
331297
332298 super ().__init__ (plugin_service , props , connection_handler )
333299
0 commit comments