Skip to content

Commit 9e9ac20

Browse files
karenc-bqsophia-bq
authored andcommitted
refactor: config properties retrieval and validation (#1053)
1 parent 6f2376c commit 9e9ac20

File tree

3 files changed

+140
-66
lines changed

3 files changed

+140
-66
lines changed

aws_advanced_python_wrapper/simple_read_write_splitting_plugin.py

Lines changed: 39 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from __future__ import annotations
1616

1717
from time import perf_counter_ns, sleep
18-
from typing import TYPE_CHECKING, Callable, Optional
18+
from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Type
1919

2020
from aws_advanced_python_wrapper.host_availability import HostAvailability
2121
from aws_advanced_python_wrapper.read_write_splitting_plugin import (
@@ -28,7 +28,7 @@
2828
from aws_advanced_python_wrapper.host_list_provider import HostListProviderService
2929
from aws_advanced_python_wrapper.pep249 import Connection
3030
from aws_advanced_python_wrapper.plugin_service import PluginService
31-
from aws_advanced_python_wrapper.utils.properties import Properties
31+
from aws_advanced_python_wrapper.utils.properties import Properties, WrapperProperty
3232

3333
from aws_advanced_python_wrapper.errors import AwsWrapperError
3434
from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole
@@ -41,53 +41,24 @@ class EndpointBasedConnectionHandler(ConnectionHandler):
4141
"""Endpoint based implementation of connection handling logic."""
4242

4343
def __init__(self, plugin_service: PluginService, props: Properties):
44-
srw_read_endpoint = WrapperProperties.SRW_READ_ENDPOINT.get(props)
45-
if srw_read_endpoint is None:
46-
raise AwsWrapperError(
47-
Messages.get_formatted(
48-
"SimpleReadWriteSplittingPlugin.MissingRequiredConfigParameter",
49-
WrapperProperties.SRW_READ_ENDPOINT.name,
50-
)
51-
)
52-
self._read_endpoint: str = srw_read_endpoint
53-
54-
srw_write_endpoint = WrapperProperties.SRW_WRITE_ENDPOINT.get(props)
55-
if srw_write_endpoint is None:
56-
raise AwsWrapperError(
57-
Messages.get_formatted(
58-
"SimpleReadWriteSplittingPlugin.MissingRequiredConfigParameter",
59-
WrapperProperties.SRW_WRITE_ENDPOINT.name,
60-
)
61-
)
62-
self._write_endpoint: str = srw_write_endpoint
44+
self._read_endpoint: str = EndpointBasedConnectionHandler._verify_parameter(
45+
WrapperProperties.SRW_READ_ENDPOINT, props, str, required=True
46+
)
47+
self._write_endpoint: str = EndpointBasedConnectionHandler._verify_parameter(
48+
WrapperProperties.SRW_WRITE_ENDPOINT, props, str, required=True
49+
)
6350

64-
self._verify_new_connections: bool = (
65-
WrapperProperties.SRW_VERIFY_NEW_CONNECTIONS.get_bool(props)
51+
self._verify_new_connections: bool = EndpointBasedConnectionHandler._verify_parameter(
52+
WrapperProperties.SRW_VERIFY_NEW_CONNECTIONS, props, bool
6653
)
67-
if self._verify_new_connections is True:
68-
srw_connect_retry_timeout_ms: int = (
69-
WrapperProperties.SRW_CONNECT_RETRY_TIMEOUT_MS.get_int(props)
70-
)
71-
if srw_connect_retry_timeout_ms <= 0:
72-
raise ValueError(
73-
Messages.get_formatted(
74-
"SimpleReadWriteSplittingPlugin.IncorrectConfiguration",
75-
WrapperProperties.SRW_CONNECT_RETRY_TIMEOUT_MS.name,
76-
)
77-
)
78-
self._connect_retry_timeout_ms: int = srw_connect_retry_timeout_ms
7954

80-
srw_connect_retry_interval_ms: int = (
81-
WrapperProperties.SRW_CONNECT_RETRY_INTERVAL_MS.get_int(props)
55+
if self._verify_new_connections:
56+
self._connect_retry_timeout_ms: int = EndpointBasedConnectionHandler._verify_parameter(
57+
WrapperProperties.SRW_CONNECT_RETRY_TIMEOUT_MS, props, int, lambda x: x > 0
58+
)
59+
self._connect_retry_interval_ms: int = EndpointBasedConnectionHandler._verify_parameter(
60+
WrapperProperties.SRW_CONNECT_RETRY_INTERVAL_MS, props, int, lambda x: x > 0
8261
)
83-
if srw_connect_retry_interval_ms <= 0:
84-
raise ValueError(
85-
Messages.get_formatted(
86-
"SimpleReadWriteSplittingPlugin.IncorrectConfiguration",
87-
WrapperProperties.SRW_CONNECT_RETRY_INTERVAL_MS.name,
88-
)
89-
)
90-
self._connect_retry_interval_ms: int = srw_connect_retry_interval_ms
9162

9263
self._verify_opened_connection_type: Optional[HostRole] = (
9364
EndpointBasedConnectionHandler._parse_connection_type(
@@ -305,6 +276,29 @@ def _create_host_info(self, endpoint, role: HostRole) -> HostInfo:
305276
host=host, port=port, role=role, availability=HostAvailability.AVAILABLE
306277
)
307278

279+
T = TypeVar('T')
280+
281+
@staticmethod
282+
def _verify_parameter(prop: WrapperProperty, props: Properties, expected_type: Type[T], validator=None, required=False):
283+
value = prop.get_type(props, expected_type)
284+
if required:
285+
if value is None:
286+
raise AwsWrapperError(
287+
Messages.get_formatted(
288+
"SimpleReadWriteSplittingPlugin.MissingRequiredConfigParameter",
289+
prop.name,
290+
)
291+
)
292+
293+
if validator and not validator(value):
294+
raise ValueError(
295+
Messages.get_formatted(
296+
"SimpleReadWriteSplittingPlugin.IncorrectConfiguration",
297+
prop.name,
298+
)
299+
)
300+
return value
301+
308302
def _delay(self):
309303
sleep(self._connect_retry_interval_ms / 1000)
310304

aws_advanced_python_wrapper/utils/properties.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
1415
import copy
15-
from typing import Any, Dict, Optional
16+
from typing import Any, Dict, Optional, TypeVar, Type
1617
from urllib.parse import unquote
1718

1819
from aws_advanced_python_wrapper.errors import AwsWrapperError
@@ -25,6 +26,9 @@ def put_if_absent(self, key: str, value: Any):
2526
self[key] = value
2627

2728

29+
T = TypeVar('T')
30+
31+
2832
class WrapperProperty:
2933
def __init__(
3034
self, name: str, description: str, default_value: Optional[Any] = None
@@ -34,41 +38,43 @@ def __init__(
3438
self.description = description
3539

3640
def __str__(self):
37-
return f"WrapperProperty(name={self.name}, default_value={self.default_value}"
41+
return f"WrapperProperty(name={self.name}, default_value={self.default_value})"
3842

3943
def get(self, props: Properties) -> Optional[str]:
4044
if self.default_value:
4145
return props.get(self.name, self.default_value)
4246
return props.get(self.name)
4347

48+
def get_type(self, props: Properties, type_class: Type[T]) -> T:
49+
value = props.get(self.name, self.default_value) if self.default_value else props.get(self.name)
50+
if value is None:
51+
if type_class == int:
52+
return -1 # type: ignore
53+
elif type_class == float:
54+
return -1.0 # type: ignore
55+
elif type_class == bool:
56+
return False # type: ignore
57+
else:
58+
return None # type: ignore
59+
if type_class == bool:
60+
if isinstance(value, bool):
61+
return value # type: ignore
62+
return value.lower() == "true" if isinstance(value, str) else bool(value) # type: ignore
63+
return type_class(value) # type: ignore
64+
4465
def get_or_default(self, props: Properties) -> str:
4566
if not self.default_value:
4667
raise ValueError(f"No default value found for property {self}")
4768
return props.get(self.name, self.default_value)
4869

4970
def get_int(self, props: Properties) -> int:
50-
if self.default_value:
51-
return int(props.get(self.name, self.default_value))
52-
53-
val = props.get(self.name)
54-
return int(val) if val else -1
71+
return self.get_type(props, int)
5572

5673
def get_float(self, props: Properties) -> float:
57-
if self.default_value:
58-
return float(props.get(self.name, self.default_value))
59-
60-
val = props.get(self.name)
61-
return float(val) if val else -1
74+
return self.get_type(props, float)
6275

6376
def get_bool(self, props: Properties) -> bool:
64-
if not self.default_value:
65-
value = props.get(self.name)
66-
else:
67-
value = props.get(self.name, self.default_value)
68-
if isinstance(value, bool):
69-
return value
70-
else:
71-
return value is not None and value.lower() == "true"
77+
return self.get_type(props, bool)
7278

7379
def set(self, props: Properties, value: Any):
7480
props[self.name] = value

tests/unit/test_properties_utils.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616

1717
from aws_advanced_python_wrapper.errors import AwsWrapperError
1818
from aws_advanced_python_wrapper.utils.properties import (Properties,
19-
PropertiesUtils)
19+
PropertiesUtils,
20+
WrapperProperty)
2021

2122

2223
@pytest.mark.parametrize(
@@ -93,3 +94,76 @@ def test_create_monitoring_properties(expected, test_props):
9394
props_copy = test_props.copy()
9495
props_copy = PropertiesUtils.create_monitoring_properties(props_copy)
9596
assert expected == props_copy
97+
98+
99+
@pytest.mark.parametrize("expected, props, type_class", [
100+
# Int type tests
101+
(123, Properties({"test_prop": "123"}), int),
102+
(-1, Properties(), int),
103+
(456, Properties({"test_prop": 456}), int),
104+
105+
# Float type tests
106+
(12.5, Properties({"test_prop": "12.5"}), float),
107+
(-1.0, Properties(), float),
108+
(3.14, Properties({"test_prop": 3.14}), float),
109+
110+
# Bool type tests
111+
(True, Properties({"test_prop": "true"}), bool),
112+
(True, Properties({"test_prop": "TRUE"}), bool),
113+
(False, Properties({"test_prop": "false"}), bool),
114+
(True, Properties({"test_prop": True}), bool),
115+
(False, Properties({"test_prop": False}), bool),
116+
(False, Properties(), bool),
117+
118+
# String type tests
119+
("test_value", Properties({"test_prop": "test_value"}), str),
120+
(None, Properties(), str),
121+
("", Properties({"test_prop": ""}), str),
122+
])
123+
def test_get_type(expected, props, type_class):
124+
wrapper_prop = WrapperProperty("test_prop", "Test property")
125+
result = wrapper_prop.get_type(props, type_class)
126+
assert result == expected
127+
128+
129+
def test_get_type_with_default():
130+
wrapper_prop = WrapperProperty("test_prop", "Test property", "default_value")
131+
props = Properties()
132+
result = wrapper_prop.get_type(props, str)
133+
assert result == "default_value"
134+
135+
136+
@pytest.mark.parametrize("expected, props", [
137+
(123, Properties({"test_prop": "123"})),
138+
(-1, Properties()),
139+
(456, Properties({"test_prop": 456})),
140+
])
141+
def test_get_int(expected, props):
142+
wrapper_prop = WrapperProperty("test_prop", "Test property")
143+
result = wrapper_prop.get_int(props)
144+
assert result == expected
145+
146+
147+
@pytest.mark.parametrize("expected, props", [
148+
(12.5, Properties({"test_prop": "12.5"})),
149+
(-1.0, Properties()),
150+
(3.14, Properties({"test_prop": 3.14})),
151+
])
152+
def test_get_float(expected, props):
153+
wrapper_prop = WrapperProperty("test_prop", "Test property")
154+
result = wrapper_prop.get_float(props)
155+
assert result == expected
156+
157+
158+
@pytest.mark.parametrize("expected, props", [
159+
(True, Properties({"test_prop": "true"})),
160+
(True, Properties({"test_prop": "TRUE"})),
161+
(False, Properties({"test_prop": "false"})),
162+
(True, Properties({"test_prop": True})),
163+
(False, Properties({"test_prop": False})),
164+
(False, Properties()),
165+
])
166+
def test_get_bool(expected, props):
167+
wrapper_prop = WrapperProperty("test_prop", "Test property")
168+
result = wrapper_prop.get_bool(props)
169+
assert result == expected

0 commit comments

Comments
 (0)