diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 66bf7c7049..d0daf406a4 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -1683,14 +1683,7 @@ def protocol_downgrade(self, host_endpoint, previous_version): "http://datastax.github.io/python-driver/api/cassandra/cluster.html#cassandra.cluster.Cluster.protocol_version", self.protocol_version, new_version, host_endpoint) self.protocol_version = new_version - def _add_resolved_hosts(self): - for endpoint in self.endpoints_resolved: - host, new = self.add_host(endpoint, signal=False) - if new: - host.set_up() - for listener in self.listeners: - listener.on_add(host) - + def _populate_hosts(self): self.profile_manager.populate( weakref.proxy(self), self.metadata.all_hosts()) self.load_balancing_policy.populate( @@ -1717,17 +1710,10 @@ def connect(self, keyspace=None, wait_for_all_pools=False): self.contact_points, self.protocol_version) self.connection_class.initialize_reactor() _register_cluster_shutdown(self) - - self._add_resolved_hosts() try: self.control_connection.connect() - - # we set all contact points up for connecting, but we won't infer state after this - for endpoint in self.endpoints_resolved: - h = self.metadata.get_host(endpoint) - if h and self.profile_manager.distance(h) == HostDistance.IGNORED: - h.is_up = None + self._populate_hosts() log.debug("Control connection created") except Exception: @@ -3534,18 +3520,20 @@ def _set_new_connection(self, conn): if old: log.debug("[control connection] Closing old connection %r, replacing with %r", old, conn) old.close() - - def _connect_host_in_lbp(self): + + def _connect_host(self): errors = {} + lbp = ( self._cluster.load_balancing_policy if self._cluster._config_mode == _ConfigMode.LEGACY else self._cluster._default_load_balancing_policy ) + # use endpoints from the default LBP if it is already initialized for host in lbp.make_query_plan(): try: - return (self._try_connect(host), None) + return (self._try_connect(host.endpoint), None) except ConnectionException as exc: errors[str(host.endpoint)] = exc log.warning("[control connection] Error connecting to %s:", host, exc_info=True) @@ -3555,7 +3543,22 @@ def _connect_host_in_lbp(self): log.warning("[control connection] Error connecting to %s:", host, exc_info=True) if self._is_shutdown: raise DriverException("[control connection] Reconnection in progress during shutdown") - + + # if lbp not initialized use contact points provided to the cluster + if len(errors) == 0: + for endpoint in self._cluster.endpoints_resolved: + try: + return (self._try_connect(endpoint), None) + except ConnectionException as exc: + errors[str(endpoint)] = exc + log.warning("[control connection] Error connecting to %s:", endpoint, exc_info=True) + self._cluster.signal_connection_failure(endpoint, exc, is_host_addition=False) + except Exception as exc: + errors[str(endpoint)] = exc + log.warning("[control connection] Error connecting to %s:", endpoint, exc_info=True) + if self._is_shutdown: + raise DriverException("[control connection] Reconnection in progress during shutdown") + return (None, errors) def _reconnect_internal(self): @@ -3567,43 +3570,43 @@ def _reconnect_internal(self): to the exception that was raised when an attempt was made to open a connection to that host. """ - (conn, _) = self._connect_host_in_lbp() + (conn, _) = self._connect_host() if conn is not None: return conn # Try to re-resolve hostnames as a fallback when all hosts are unreachable self._cluster._resolve_hostnames() - self._cluster._add_resolved_hosts() + self._cluster._populate_hosts() - (conn, errors) = self._connect_host_in_lbp() + (conn, errors) = self._connect_host() if conn is not None: return conn - + raise NoHostAvailable("Unable to connect to any servers", errors) - def _try_connect(self, host): + def _try_connect(self, endpoint): """ Creates a new Connection, registers for pushed events, and refreshes node/token and schema metadata. """ - log.debug("[control connection] Opening new connection to %s", host) + log.debug("[control connection] Opening new connection to %s", endpoint) while True: try: - connection = self._cluster.connection_factory(host.endpoint, is_control_connection=True) + connection = self._cluster.connection_factory(endpoint, is_control_connection=True) if self._is_shutdown: connection.close() raise DriverException("Reconnecting during shutdown") break except ProtocolVersionUnsupported as e: - self._cluster.protocol_downgrade(host.endpoint, e.startup_version) + self._cluster.protocol_downgrade(endpoint, e.startup_version) except ProtocolException as e: # protocol v5 is out of beta in C* >=4.0-beta5 and is now the default driver # protocol version. If the protocol version was not explicitly specified, # and that the server raises a beta protocol error, we should downgrade. if not self._cluster._protocol_version_explicit and e.is_beta_protocol_error: - self._cluster.protocol_downgrade(host.endpoint, self._cluster.protocol_version) + self._cluster.protocol_downgrade(endpoint, self._cluster.protocol_version) else: raise @@ -3879,6 +3882,9 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, self._cluster.metadata.update_host(host, old_endpoint=connection.endpoint) connection.original_endpoint = connection.endpoint = host.endpoint + else: + log.info("Consider local host new found host") + peers_result.append(local_row) # Check metadata.partitioner to see if we haven't built anything yet. If # every node in the cluster was in the contact points, we won't discover # any new nodes, so we need this additional check. (See PYTHON-90) @@ -4177,8 +4183,8 @@ def _get_peers_query(self, peers_query_type, connection=None): query_template = (self._SELECT_SCHEMA_PEERS_TEMPLATE if peers_query_type == self.PeersQueryType.PEERS_SCHEMA else self._SELECT_PEERS_NO_TOKENS_TEMPLATE) - host_release_version = self._cluster.metadata.get_host(connection.original_endpoint).release_version - host_dse_version = self._cluster.metadata.get_host(connection.original_endpoint).dse_version + host_release_version = None if self._cluster.metadata.get_host(connection.original_endpoint) == None else self._cluster.metadata.get_host(connection.original_endpoint).release_version + host_dse_version = None if self._cluster.metadata.get_host(connection.original_endpoint) == None else self._cluster.metadata.get_host(connection.original_endpoint).dse_version uses_native_address_query = ( host_dse_version and Version(host_dse_version) >= self._MINIMUM_NATIVE_ADDRESS_DSE_VERSION) diff --git a/cassandra/metadata.py b/cassandra/metadata.py index 6379de069a..06632cfd5f 100644 --- a/cassandra/metadata.py +++ b/cassandra/metadata.py @@ -139,6 +139,9 @@ def export_schema_as_string(self): def refresh(self, connection, timeout, target_type=None, change_type=None, fetch_size=None, metadata_request_timeout=None, **kwargs): + if not self.get_host(connection.original_endpoint): + return + server_version = self.get_host(connection.original_endpoint).release_version dse_version = self.get_host(connection.original_endpoint).dse_version parser = get_schema_parser(connection, server_version, dse_version, timeout, metadata_request_timeout, fetch_size) diff --git a/cassandra/policies.py b/cassandra/policies.py index bcfd797706..ddb06ba5ee 100644 --- a/cassandra/policies.py +++ b/cassandra/policies.py @@ -264,6 +264,9 @@ def populate(self, cluster, hosts): def distance(self, host): dc = self._dc(host) + if not self.local_dc: + self.local_dc = dc + return HostDistance.LOCAL if dc == self.local_dc: return HostDistance.LOCAL diff --git a/cassandra/pool.py b/cassandra/pool.py index b8a8ef7493..2da657256f 100644 --- a/cassandra/pool.py +++ b/cassandra/pool.py @@ -176,7 +176,7 @@ def __init__(self, endpoint, conviction_policy_factory, datacenter=None, rack=No self.endpoint = endpoint if isinstance(endpoint, EndPoint) else DefaultEndPoint(endpoint) self.conviction_policy = conviction_policy_factory(self) if not host_id: - host_id = uuid.uuid4() + raise ValueError("host_id may not be None") self.host_id = host_id self.set_location_info(datacenter, rack) self.lock = RLock() diff --git a/tests/integration/standard/test_cluster.py b/tests/integration/standard/test_cluster.py index d7f89ad598..1208edb9d2 100644 --- a/tests/integration/standard/test_cluster.py +++ b/tests/integration/standard/test_cluster.py @@ -900,8 +900,9 @@ def test_profile_lb_swap(self): """ Tests that profile load balancing policies are not shared - Creates two LBP, runs a few queries, and validates that each LBP is execised - seperately between EP's + Creates two LBP, runs a few queries, and validates that each LBP is exercised + separately between EP's. Each RoundRobinPolicy starts from its own random + position and maintains independent round-robin ordering. @since 3.5 @jira_ticket PYTHON-569 @@ -916,17 +917,28 @@ def test_profile_lb_swap(self): with TestCluster(execution_profiles=exec_profiles) as cluster: session = cluster.connect(wait_for_all_pools=True) - # default is DCA RR for all hosts expected_hosts = set(cluster.metadata.all_hosts()) - rr1_queried_hosts = set() - rr2_queried_hosts = set() - - rs = session.execute(query, execution_profile='rr1') - rr1_queried_hosts.add(rs.response_future._current_host) - rs = session.execute(query, execution_profile='rr2') - rr2_queried_hosts.add(rs.response_future._current_host) - - assert rr2_queried_hosts == rr1_queried_hosts + num_hosts = len(expected_hosts) + assert num_hosts > 1, "Need at least 2 hosts for this test" + + rr1_queried_hosts = [] + rr2_queried_hosts = [] + + for _ in range(num_hosts * 2): + rs = session.execute(query, execution_profile='rr1') + rr1_queried_hosts.append(rs.response_future._current_host) + rs = session.execute(query, execution_profile='rr2') + rr2_queried_hosts.append(rs.response_future._current_host) + + # Both policies should have queried all hosts + assert set(rr1_queried_hosts) == expected_hosts + assert set(rr2_queried_hosts) == expected_hosts + + # The order of hosts should demonstrate round-robin behavior + # After num_hosts queries, the pattern should repeat + for i in range(num_hosts): + assert rr1_queried_hosts[i] == rr1_queried_hosts[i + num_hosts] + assert rr2_queried_hosts[i] == rr2_queried_hosts[i + num_hosts] def test_ta_lbp(self): """ diff --git a/tests/integration/standard/test_control_connection.py b/tests/integration/standard/test_control_connection.py index 206945f0b3..990f133962 100644 --- a/tests/integration/standard/test_control_connection.py +++ b/tests/integration/standard/test_control_connection.py @@ -101,8 +101,12 @@ def test_get_control_connection_host(self): # reconnect and make sure that the new host is reflected correctly self.cluster.control_connection._reconnect() - new_host = self.cluster.get_control_connection_host() - assert host != new_host + new_host1 = self.cluster.get_control_connection_host() + + self.cluster.control_connection._reconnect() + new_host2 = self.cluster.get_control_connection_host() + + assert new_host1 != new_host2 # TODO: enable after https://github.com/scylladb/python-driver/issues/121 is fixed @unittest.skip('Fails on scylla due to the broadcast_rpc_port is None') diff --git a/tests/integration/standard/test_metrics.py b/tests/integration/standard/test_metrics.py index 8ccd278ee4..48c7b49b95 100644 --- a/tests/integration/standard/test_metrics.py +++ b/tests/integration/standard/test_metrics.py @@ -218,7 +218,7 @@ def test_metrics_per_cluster(self): try: # Test write query = SimpleStatement("INSERT INTO {0}.{0} (k, v) VALUES (2, 2)".format(self.ks_name), consistency_level=ConsistencyLevel.ALL) - with pytest.raises(WriteTimeout): + with pytest.raises((WriteTimeout, Unavailable)): self.session.execute(query, timeout=None) finally: get_node(1).resume() @@ -230,7 +230,7 @@ def test_metrics_per_cluster(self): stats_cluster2 = cluster2.metrics.get_stats() # Test direct access to stats - assert 1 == self.cluster.metrics.stats.write_timeouts + assert (1 == self.cluster.metrics.stats.write_timeouts or 1 == self.cluster.metrics.stats.unavailables) assert 0 == cluster2.metrics.stats.write_timeouts # Test direct access to a child stats diff --git a/tests/integration/standard/test_policies.py b/tests/integration/standard/test_policies.py index 0c84fd06be..c7516995f0 100644 --- a/tests/integration/standard/test_policies.py +++ b/tests/integration/standard/test_policies.py @@ -45,9 +45,6 @@ def test_predicate_changes(self): external_event = True contact_point = DefaultEndPoint("127.0.0.1") - single_host = {Host(contact_point, SimpleConvictionPolicy)} - all_hosts = {Host(DefaultEndPoint("127.0.0.{}".format(i)), SimpleConvictionPolicy) for i in (1, 2, 3)} - predicate = lambda host: host.endpoint == contact_point if external_event else True hfp = ExecutionProfile( load_balancing_policy=HostFilterPolicy(RoundRobinPolicy(), predicate=predicate) @@ -62,7 +59,8 @@ def test_predicate_changes(self): response = session.execute("SELECT * from system.local WHERE key='local'") queried_hosts.update(response.response_future.attempted_hosts) - assert queried_hosts == single_host + assert len(queried_hosts) == 1 + assert queried_hosts.pop().endpoint == contact_point external_event = False futures = session.update_created_pools() @@ -72,7 +70,7 @@ def test_predicate_changes(self): for _ in range(10): response = session.execute("SELECT * from system.local WHERE key='local'") queried_hosts.update(response.response_future.attempted_hosts) - assert queried_hosts == all_hosts + assert len(queried_hosts) == 3 class WhiteListRoundRobinPolicyTests(unittest.TestCase): diff --git a/tests/integration/standard/test_query.py b/tests/integration/standard/test_query.py index a3bdf8a735..36147ae581 100644 --- a/tests/integration/standard/test_query.py +++ b/tests/integration/standard/test_query.py @@ -460,7 +460,8 @@ def make_query_plan(self, working_keyspace=None, query=None): live_hosts = sorted(list(self._live_hosts)) host = [] try: - host = [live_hosts[self.host_index_to_use]] + if len(live_hosts) > self.host_index_to_use: + host = [live_hosts[self.host_index_to_use]] except IndexError as e: raise IndexError( 'You specified an index larger than the number of hosts. Total hosts: {}. Index specified: {}'.format( diff --git a/tests/unit/advanced/test_policies.py b/tests/unit/advanced/test_policies.py index 8e421a859d..75cfd3fbf9 100644 --- a/tests/unit/advanced/test_policies.py +++ b/tests/unit/advanced/test_policies.py @@ -13,6 +13,7 @@ # limitations under the License. import unittest from unittest.mock import Mock +import uuid from cassandra.pool import Host from cassandra.policies import RoundRobinPolicy @@ -72,7 +73,7 @@ def test_target_no_host(self): def test_target_host_down(self): node_count = 4 - hosts = [Host(i, Mock()) for i in range(node_count)] + hosts = [Host(i, Mock(), host_id=uuid.uuid4()) for i in range(node_count)] target_host = hosts[1] policy = DSELoadBalancingPolicy(RoundRobinPolicy()) @@ -87,7 +88,7 @@ def test_target_host_down(self): def test_target_host_nominal(self): node_count = 4 - hosts = [Host(i, Mock()) for i in range(node_count)] + hosts = [Host(i, Mock(), host_id=uuid.uuid4()) for i in range(node_count)] target_host = hosts[1] target_host.is_up = True diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index f3efed9f54..49208ac53e 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -17,6 +17,7 @@ import socket from unittest.mock import patch, Mock +import uuid from cassandra import ConsistencyLevel, DriverException, Timeout, Unavailable, RequestExecutionException, ReadTimeout, WriteTimeout, CoordinationFailure, ReadFailure, WriteFailure, FunctionFailure, AlreadyExists,\ InvalidRequest, Unauthorized, AuthenticationFailed, OperationTimedOut, UnsupportedOperation, RequestValidationException, ConfigurationException, ProtocolVersion @@ -200,7 +201,7 @@ def test_default_serial_consistency_level_ep(self, *_): PR #510 """ c = Cluster(protocol_version=4) - s = Session(c, [Host("127.0.0.1", SimpleConvictionPolicy)]) + s = Session(c, [Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4())]) c.connection_class.initialize_reactor() # default is None @@ -229,7 +230,7 @@ def test_default_serial_consistency_level_legacy(self, *_): PR #510 """ c = Cluster(protocol_version=4) - s = Session(c, [Host("127.0.0.1", SimpleConvictionPolicy)]) + s = Session(c, [Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4())]) c.connection_class.initialize_reactor() # default is None assert s.default_serial_consistency_level is None @@ -286,7 +287,7 @@ def test_default_exec_parameters(self): assert cluster.profile_manager.default.load_balancing_policy.__class__ == default_lbp_factory().__class__ assert cluster.default_retry_policy.__class__ == RetryPolicy assert cluster.profile_manager.default.retry_policy.__class__ == RetryPolicy - session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) + session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4())]) assert session.default_timeout == 10.0 assert cluster.profile_manager.default.request_timeout == 10.0 assert session.default_consistency_level == ConsistencyLevel.LOCAL_ONE @@ -300,7 +301,7 @@ def test_default_exec_parameters(self): def test_default_legacy(self): cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), default_retry_policy=DowngradingConsistencyRetryPolicy()) assert cluster._config_mode == _ConfigMode.LEGACY - session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) + session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4())]) session.default_timeout = 3.7 session.default_consistency_level = ConsistencyLevel.ALL session.default_serial_consistency_level = ConsistencyLevel.SERIAL @@ -314,7 +315,7 @@ def test_default_legacy(self): def test_default_profile(self): non_default_profile = ExecutionProfile(RoundRobinPolicy(), *[object() for _ in range(2)]) cluster = Cluster(execution_profiles={'non-default': non_default_profile}) - session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) + session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4())]) assert cluster._config_mode == _ConfigMode.PROFILES @@ -347,7 +348,7 @@ def test_serial_consistency_level_validation(self): def test_statement_params_override_legacy(self): cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), default_retry_policy=DowngradingConsistencyRetryPolicy()) assert cluster._config_mode == _ConfigMode.LEGACY - session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) + session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4())]) ss = SimpleStatement("query", retry_policy=DowngradingConsistencyRetryPolicy(), consistency_level=ConsistencyLevel.ALL, serial_consistency_level=ConsistencyLevel.SERIAL) @@ -368,7 +369,7 @@ def test_statement_params_override_legacy(self): def test_statement_params_override_profile(self): non_default_profile = ExecutionProfile(RoundRobinPolicy(), *[object() for _ in range(2)]) cluster = Cluster(execution_profiles={'non-default': non_default_profile}) - session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) + session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4())]) assert cluster._config_mode == _ConfigMode.PROFILES @@ -406,7 +407,7 @@ def test_no_profile_with_legacy(self): # session settings lock out profiles cluster = Cluster() - session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) + session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4())]) for attr, value in (('default_timeout', 1), ('default_consistency_level', ConsistencyLevel.ANY), ('default_serial_consistency_level', ConsistencyLevel.SERIAL), @@ -432,7 +433,7 @@ def test_no_legacy_with_profile(self): ('load_balancing_policy', default_lbp_factory())): with pytest.raises(ValueError): setattr(cluster, attr, value) - session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) + session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4())]) for attr, value in (('default_timeout', 1), ('default_consistency_level', ConsistencyLevel.ANY), ('default_serial_consistency_level', ConsistencyLevel.SERIAL), @@ -445,7 +446,7 @@ def test_profile_name_value(self): internalized_profile = ExecutionProfile(RoundRobinPolicy(), *[object() for _ in range(2)]) cluster = Cluster(execution_profiles={'by-name': internalized_profile}) - session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) + session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4())]) assert cluster._config_mode == _ConfigMode.PROFILES rf = session.execute_async("query", execution_profile='by-name') @@ -459,7 +460,7 @@ def test_profile_name_value(self): def test_exec_profile_clone(self): cluster = Cluster(execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(), 'one': ExecutionProfile()}) - session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) + session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4())]) profile_attrs = {'request_timeout': 1, 'consistency_level': ConsistencyLevel.ANY, diff --git a/tests/unit/test_concurrent.py b/tests/unit/test_concurrent.py index a3587a3e16..9c85b1ccac 100644 --- a/tests/unit/test_concurrent.py +++ b/tests/unit/test_concurrent.py @@ -22,6 +22,7 @@ from queue import PriorityQueue import sys import platform +import uuid from cassandra.cluster import Cluster, Session from cassandra.concurrent import execute_concurrent, execute_concurrent_with_args @@ -248,7 +249,7 @@ def test_recursion_limited(self): PYTHON-585 """ max_recursion = sys.getrecursionlimit() - s = Session(Cluster(), [Host("127.0.0.1", SimpleConvictionPolicy)]) + s = Session(Cluster(), [Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4())]) with pytest.raises(TypeError): execute_concurrent_with_args(s, "doesn't matter", [('param',)] * max_recursion, raise_on_first_error=True) diff --git a/tests/unit/test_host_connection_pool.py b/tests/unit/test_host_connection_pool.py index e7b930a990..580eb336b2 100644 --- a/tests/unit/test_host_connection_pool.py +++ b/tests/unit/test_host_connection_pool.py @@ -14,6 +14,7 @@ from concurrent.futures import ThreadPoolExecutor import logging import time +import uuid from cassandra.protocol_features import ProtocolFeatures from cassandra.shard_info import _ShardingInfo @@ -205,20 +206,20 @@ def test_host_instantiations(self): """ with pytest.raises(ValueError): - Host(None, None) + Host(None, None, host_id=uuid.uuid4()) with pytest.raises(ValueError): - Host('127.0.0.1', None) + Host('127.0.0.1', None, host_id=uuid.uuid4()) with pytest.raises(ValueError): - Host(None, SimpleConvictionPolicy) + Host(None, SimpleConvictionPolicy, host_id=uuid.uuid4()) def test_host_equality(self): """ Test host equality has correct logic """ - a = Host('127.0.0.1', SimpleConvictionPolicy) - b = Host('127.0.0.1', SimpleConvictionPolicy) - c = Host('127.0.0.2', SimpleConvictionPolicy) + a = Host('127.0.0.1', SimpleConvictionPolicy, host_id=uuid.uuid4()) + b = Host('127.0.0.1', SimpleConvictionPolicy, host_id=uuid.uuid4()) + c = Host('127.0.0.2', SimpleConvictionPolicy, host_id=uuid.uuid4()) assert a == b, 'Two Host instances should be equal when sharing.' assert a != c, 'Two Host instances should NOT be equal when using two different addresses.' @@ -253,7 +254,7 @@ def mock_connection_factory(self, *args, **kwargs): connection.is_shutdown = False connection.is_defunct = False connection.is_closed = False - connection.features = ProtocolFeatures(shard_id=self.connection_counter, + connection.features = ProtocolFeatures(shard_id=self.connection_counter, sharding_info=_ShardingInfo(shard_id=1, shards_count=14, partitioner="", sharding_algorithm="", sharding_ignore_msb=0, shard_aware_port="", shard_aware_port_ssl="")) diff --git a/tests/unit/test_metadata.py b/tests/unit/test_metadata.py index 3069f6bced..ec29979095 100644 --- a/tests/unit/test_metadata.py +++ b/tests/unit/test_metadata.py @@ -18,6 +18,7 @@ from unittest.mock import Mock import os import timeit +import uuid import cassandra from cassandra.cqltypes import strip_frozen @@ -121,7 +122,7 @@ def test_simple_replication_type_parsing(self): # make token replica map ring = [MD5Token(0), MD5Token(1), MD5Token(2)] - hosts = [Host('dc1.{}'.format(host), SimpleConvictionPolicy) for host in range(3)] + hosts = [Host('dc1.{}'.format(host), SimpleConvictionPolicy, host_id=uuid.uuid4()) for host in range(3)] token_to_host = dict(zip(ring, hosts)) assert simple_int.make_token_replica_map(token_to_host, ring) == simple_str.make_token_replica_map(token_to_host, ring) @@ -139,7 +140,7 @@ def test_transient_replication_parsing(self): # make token replica map ring = [MD5Token(0), MD5Token(1), MD5Token(2)] - hosts = [Host('dc1.{}'.format(host), SimpleConvictionPolicy) for host in range(3)] + hosts = [Host('dc1.{}'.format(host), SimpleConvictionPolicy, host_id=uuid.uuid4()) for host in range(3)] token_to_host = dict(zip(ring, hosts)) assert simple_transient.make_token_replica_map(token_to_host, ring) == simple_str.make_token_replica_map(token_to_host, ring) @@ -160,7 +161,7 @@ def test_nts_replication_parsing(self): # make token replica map ring = [MD5Token(0), MD5Token(1), MD5Token(2)] - hosts = [Host('dc1.{}'.format(host), SimpleConvictionPolicy) for host in range(3)] + hosts = [Host('dc1.{}'.format(host), SimpleConvictionPolicy, host_id=uuid.uuid4()) for host in range(3)] token_to_host = dict(zip(ring, hosts)) assert nts_int.make_token_replica_map(token_to_host, ring) == nts_str.make_token_replica_map(token_to_host, ring) @@ -180,30 +181,30 @@ def test_nts_transient_parsing(self): # make token replica map ring = [MD5Token(0), MD5Token(1), MD5Token(2)] - hosts = [Host('dc1.{}'.format(host), SimpleConvictionPolicy) for host in range(3)] + hosts = [Host('dc1.{}'.format(host), SimpleConvictionPolicy, host_id=uuid.uuid4()) for host in range(3)] token_to_host = dict(zip(ring, hosts)) assert nts_transient.make_token_replica_map(token_to_host, ring) == nts_str.make_token_replica_map(token_to_host, ring) def test_nts_make_token_replica_map(self): token_to_host_owner = {} - dc1_1 = Host('dc1.1', SimpleConvictionPolicy) - dc1_2 = Host('dc1.2', SimpleConvictionPolicy) - dc1_3 = Host('dc1.3', SimpleConvictionPolicy) + dc1_1 = Host('dc1.1', SimpleConvictionPolicy, host_id=uuid.uuid4()) + dc1_2 = Host('dc1.2', SimpleConvictionPolicy, host_id=uuid.uuid4()) + dc1_3 = Host('dc1.3', SimpleConvictionPolicy, host_id=uuid.uuid4()) for host in (dc1_1, dc1_2, dc1_3): host.set_location_info('dc1', 'rack1') token_to_host_owner[MD5Token(0)] = dc1_1 token_to_host_owner[MD5Token(100)] = dc1_2 token_to_host_owner[MD5Token(200)] = dc1_3 - dc2_1 = Host('dc2.1', SimpleConvictionPolicy) - dc2_2 = Host('dc2.2', SimpleConvictionPolicy) + dc2_1 = Host('dc2.1', SimpleConvictionPolicy, host_id=uuid.uuid4()) + dc2_2 = Host('dc2.2', SimpleConvictionPolicy, host_id=uuid.uuid4()) dc2_1.set_location_info('dc2', 'rack1') dc2_2.set_location_info('dc2', 'rack1') token_to_host_owner[MD5Token(1)] = dc2_1 token_to_host_owner[MD5Token(101)] = dc2_2 - dc3_1 = Host('dc3.1', SimpleConvictionPolicy) + dc3_1 = Host('dc3.1', SimpleConvictionPolicy, host_id=uuid.uuid4()) dc3_1.set_location_info('dc3', 'rack3') token_to_host_owner[MD5Token(2)] = dc3_1 @@ -238,7 +239,7 @@ def test_nts_token_performance(self): vnodes_per_host = 500 for i in range(dc1hostnum): - host = Host('dc1.{0}'.format(i), SimpleConvictionPolicy) + host = Host('dc1.{0}'.format(i), SimpleConvictionPolicy, host_id=uuid.uuid4()) host.set_location_info('dc1', "rack1") for vnode_num in range(vnodes_per_host): md5_token = MD5Token(current_token+vnode_num) @@ -262,10 +263,10 @@ def test_nts_make_token_replica_map_multi_rack(self): token_to_host_owner = {} # (A) not enough distinct racks, first skipped is used - dc1_1 = Host('dc1.1', SimpleConvictionPolicy) - dc1_2 = Host('dc1.2', SimpleConvictionPolicy) - dc1_3 = Host('dc1.3', SimpleConvictionPolicy) - dc1_4 = Host('dc1.4', SimpleConvictionPolicy) + dc1_1 = Host('dc1.1', SimpleConvictionPolicy, host_id=uuid.uuid4()) + dc1_2 = Host('dc1.2', SimpleConvictionPolicy, host_id=uuid.uuid4()) + dc1_3 = Host('dc1.3', SimpleConvictionPolicy, host_id=uuid.uuid4()) + dc1_4 = Host('dc1.4', SimpleConvictionPolicy, host_id=uuid.uuid4()) dc1_1.set_location_info('dc1', 'rack1') dc1_2.set_location_info('dc1', 'rack1') dc1_3.set_location_info('dc1', 'rack2') @@ -276,9 +277,9 @@ def test_nts_make_token_replica_map_multi_rack(self): token_to_host_owner[MD5Token(300)] = dc1_4 # (B) distinct racks, but not contiguous - dc2_1 = Host('dc2.1', SimpleConvictionPolicy) - dc2_2 = Host('dc2.2', SimpleConvictionPolicy) - dc2_3 = Host('dc2.3', SimpleConvictionPolicy) + dc2_1 = Host('dc2.1', SimpleConvictionPolicy, host_id=uuid.uuid4()) + dc2_2 = Host('dc2.2', SimpleConvictionPolicy, host_id=uuid.uuid4()) + dc2_3 = Host('dc2.3', SimpleConvictionPolicy, host_id=uuid.uuid4()) dc2_1.set_location_info('dc2', 'rack1') dc2_2.set_location_info('dc2', 'rack1') dc2_3.set_location_info('dc2', 'rack2') @@ -301,7 +302,7 @@ def test_nts_make_token_replica_map_multi_rack(self): assertCountEqual(token_replicas, (dc1_1, dc1_2, dc1_3, dc2_1, dc2_3)) def test_nts_make_token_replica_map_empty_dc(self): - host = Host('1', SimpleConvictionPolicy) + host = Host('1', SimpleConvictionPolicy, host_id=uuid.uuid4()) host.set_location_info('dc1', 'rack1') token_to_host_owner = {MD5Token(0): host} ring = [MD5Token(0)] @@ -315,9 +316,9 @@ def test_nts_export_for_schema(self): assert "{'class': 'NetworkTopologyStrategy', 'dc1': '1', 'dc2': '2'}" == strategy.export_for_schema() def test_simple_strategy_make_token_replica_map(self): - host1 = Host('1', SimpleConvictionPolicy) - host2 = Host('2', SimpleConvictionPolicy) - host3 = Host('3', SimpleConvictionPolicy) + host1 = Host('1', SimpleConvictionPolicy, host_id=uuid.uuid4()) + host2 = Host('2', SimpleConvictionPolicy, host_id=uuid.uuid4()) + host3 = Host('3', SimpleConvictionPolicy, host_id=uuid.uuid4()) token_to_host_owner = { MD5Token(0): host1, MD5Token(100): host2, @@ -406,7 +407,7 @@ def test_is_valid_name(self): class GetReplicasTest(unittest.TestCase): def _get_replicas(self, token_klass): tokens = [token_klass(i) for i in range(0, (2 ** 127 - 1), 2 ** 125)] - hosts = [Host("ip%d" % i, SimpleConvictionPolicy) for i in range(len(tokens))] + hosts = [Host("ip%d" % i, SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(len(tokens))] token_to_primary_replica = dict(zip(tokens, hosts)) keyspace = KeyspaceMetadata("ks", True, "SimpleStrategy", {"replication_factor": "1"}) metadata = Mock(spec=Metadata, keyspaces={'ks': keyspace}) @@ -784,8 +785,8 @@ def test_iterate_all_hosts_and_modify(self): PYTHON-572 """ metadata = Metadata() - metadata.add_or_return_host(Host('dc1.1', SimpleConvictionPolicy)) - metadata.add_or_return_host(Host('dc1.2', SimpleConvictionPolicy)) + metadata.add_or_return_host(Host('dc1.1', SimpleConvictionPolicy, host_id=uuid.uuid4())) + metadata.add_or_return_host(Host('dc1.2', SimpleConvictionPolicy, host_id=uuid.uuid4())) assert len(metadata.all_hosts()) == 2 diff --git a/tests/unit/test_policies.py b/tests/unit/test_policies.py index e15705c8f7..65feaf72e5 100644 --- a/tests/unit/test_policies.py +++ b/tests/unit/test_policies.py @@ -17,6 +17,7 @@ from itertools import islice, cycle from unittest.mock import Mock, patch, call from random import randint +import uuid import pytest from _thread import LockType import sys @@ -46,7 +47,7 @@ def test_non_implemented(self): """ policy = LoadBalancingPolicy() - host = Host(DefaultEndPoint("ip1"), SimpleConvictionPolicy) + host = Host(DefaultEndPoint("ip1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) host.set_location_info("dc1", "rack1") with pytest.raises(NotImplementedError): @@ -192,11 +193,11 @@ class TestRackOrDCAwareRoundRobinPolicy: def test_no_remote(self, policy_specialization, constructor_args): hosts = [] for i in range(2): - h = Host(DefaultEndPoint(i), SimpleConvictionPolicy) + h = Host(DefaultEndPoint(i), SimpleConvictionPolicy, host_id=uuid.uuid4()) h.set_location_info("dc1", "rack2") hosts.append(h) for i in range(2): - h = Host(DefaultEndPoint(i + 2), SimpleConvictionPolicy) + h = Host(DefaultEndPoint(i + 2), SimpleConvictionPolicy, host_id=uuid.uuid4()) h.set_location_info("dc1", "rack1") hosts.append(h) @@ -208,7 +209,7 @@ def test_no_remote(self, policy_specialization, constructor_args): assert sorted(qplan) == sorted(hosts) def test_with_remotes(self, policy_specialization, constructor_args): - hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy) for i in range(6)] + hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(6)] for h in hosts[:2]: h.set_location_info("dc1", "rack1") for h in hosts[2:4]: @@ -263,7 +264,7 @@ def test_get_distance(self, policy_specialization, constructor_args): policy = policy_specialization(*constructor_args, used_hosts_per_remote_dc=0) # same dc, same rack - host = Host(DefaultEndPoint("ip1"), SimpleConvictionPolicy) + host = Host(DefaultEndPoint("ip1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) host.set_location_info("dc1", "rack1") policy.populate(Mock(), [host]) @@ -273,14 +274,14 @@ def test_get_distance(self, policy_specialization, constructor_args): assert policy.distance(host) == HostDistance.LOCAL_RACK # same dc different rack - host = Host(DefaultEndPoint("ip1"), SimpleConvictionPolicy) + host = Host(DefaultEndPoint("ip1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) host.set_location_info("dc1", "rack2") policy.populate(Mock(), [host]) assert policy.distance(host) == HostDistance.LOCAL # used_hosts_per_remote_dc is set to 0, so ignore it - remote_host = Host(DefaultEndPoint("ip2"), SimpleConvictionPolicy) + remote_host = Host(DefaultEndPoint("ip2"), SimpleConvictionPolicy, host_id=uuid.uuid4()) remote_host.set_location_info("dc2", "rack1") assert policy.distance(remote_host) == HostDistance.IGNORED @@ -294,14 +295,14 @@ def test_get_distance(self, policy_specialization, constructor_args): # since used_hosts_per_remote_dc is set to 1, only the first # remote host in dc2 will be REMOTE, the rest are IGNORED - second_remote_host = Host(DefaultEndPoint("ip3"), SimpleConvictionPolicy) + second_remote_host = Host(DefaultEndPoint("ip3"), SimpleConvictionPolicy, host_id=uuid.uuid4()) second_remote_host.set_location_info("dc2", "rack1") policy.populate(Mock(), [host, remote_host, second_remote_host]) distances = set([policy.distance(remote_host), policy.distance(second_remote_host)]) assert distances == set([HostDistance.REMOTE, HostDistance.IGNORED]) def test_status_updates(self, policy_specialization, constructor_args): - hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy) for i in range(5)] + hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(5)] for h in hosts[:2]: h.set_location_info("dc1", "rack1") for h in hosts[2:4]: @@ -314,11 +315,11 @@ def test_status_updates(self, policy_specialization, constructor_args): policy.on_down(hosts[0]) policy.on_remove(hosts[2]) - new_local_host = Host(DefaultEndPoint(5), SimpleConvictionPolicy) + new_local_host = Host(DefaultEndPoint(5), SimpleConvictionPolicy, host_id=uuid.uuid4()) new_local_host.set_location_info("dc1", "rack1") policy.on_up(new_local_host) - new_remote_host = Host(DefaultEndPoint(6), SimpleConvictionPolicy) + new_remote_host = Host(DefaultEndPoint(6), SimpleConvictionPolicy, host_id=uuid.uuid4()) new_remote_host.set_location_info("dc9000", "rack1") policy.on_add(new_remote_host) @@ -343,7 +344,7 @@ def test_status_updates(self, policy_specialization, constructor_args): assert qplan == [] def test_modification_during_generation(self, policy_specialization, constructor_args): - hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy) for i in range(4)] + hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(4)] for h in hosts[:2]: h.set_location_info("dc1", "rack1") for h in hosts[2:]: @@ -357,7 +358,7 @@ def test_modification_during_generation(self, policy_specialization, constructor # approach that changes specific things during known phases of the # generator. - new_host = Host(DefaultEndPoint(4), SimpleConvictionPolicy) + new_host = Host(DefaultEndPoint(4), SimpleConvictionPolicy, host_id=uuid.uuid4()) new_host.set_location_info("dc1", "rack1") # new local before iteration @@ -468,7 +469,7 @@ def test_modification_during_generation(self, policy_specialization, constructor policy.on_up(hosts[2]) policy.on_up(hosts[3]) - another_host = Host(DefaultEndPoint(5), SimpleConvictionPolicy) + another_host = Host(DefaultEndPoint(5), SimpleConvictionPolicy, host_id=uuid.uuid4()) another_host.set_location_info("dc3", "rack1") new_host.set_location_info("dc3", "rack1") @@ -502,7 +503,7 @@ def test_no_live_nodes(self, policy_specialization, constructor_args): hosts = [] for i in range(4): - h = Host(DefaultEndPoint(i), SimpleConvictionPolicy) + h = Host(DefaultEndPoint(i), SimpleConvictionPolicy, host_id=uuid.uuid4()) h.set_location_info("dc1", "rack1") hosts.append(h) @@ -527,7 +528,7 @@ def test_no_nodes(self, policy_specialization, constructor_args): assert qplan == [] def test_wrong_dc(self, policy_specialization, constructor_args): - hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy) for i in range(3)] + hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(3)] for h in hosts[:3]: h.set_location_info("dc2", "rack2") @@ -539,9 +540,9 @@ def test_wrong_dc(self, policy_specialization, constructor_args): class DCAwareRoundRobinPolicyTest(unittest.TestCase): def test_default_dc(self): - host_local = Host(DefaultEndPoint(1), SimpleConvictionPolicy, 'local') - host_remote = Host(DefaultEndPoint(2), SimpleConvictionPolicy, 'remote') - host_none = Host(DefaultEndPoint(1), SimpleConvictionPolicy) + host_local = Host(DefaultEndPoint(1), SimpleConvictionPolicy, 'local', host_id=uuid.uuid4()) + host_remote = Host(DefaultEndPoint(2), SimpleConvictionPolicy, 'remote', host_id=uuid.uuid4()) + host_none = Host(DefaultEndPoint(1), SimpleConvictionPolicy, host_id=uuid.uuid4()) # contact point is '1' cluster = Mock(endpoints_resolved=[DefaultEndPoint(1)]) @@ -585,7 +586,7 @@ def test_wrap_round_robin(self): cluster.metadata = Mock(spec=Metadata) cluster.metadata._tablets = Mock(spec=Tablets) cluster.metadata._tablets.table_has_tablets.return_value = [] - hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy) for i in range(4)] + hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(4)] for host in hosts: host.set_up() @@ -618,7 +619,7 @@ def test_wrap_dc_aware(self): cluster.metadata = Mock(spec=Metadata) cluster.metadata._tablets = Mock(spec=Tablets) cluster.metadata._tablets.table_has_tablets.return_value = [] - hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy) for i in range(4)] + hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(4)] for host in hosts: host.set_up() for h in hosts[:2]: @@ -667,7 +668,7 @@ def test_wrap_rack_aware(self): cluster.metadata = Mock(spec=Metadata) cluster.metadata._tablets = Mock(spec=Tablets) cluster.metadata._tablets.table_has_tablets.return_value = [] - hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy) for i in range(8)] + hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(8)] for host in hosts: host.set_up() hosts[0].set_location_info("dc1", "rack1") @@ -731,7 +732,7 @@ def test_get_distance(self): """ policy = TokenAwarePolicy(DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=0)) - host = Host(DefaultEndPoint("ip1"), SimpleConvictionPolicy) + host = Host(DefaultEndPoint("ip1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) host.set_location_info("dc1", "rack1") policy.populate(self.FakeCluster(), [host]) @@ -739,7 +740,7 @@ def test_get_distance(self): assert policy.distance(host) == HostDistance.LOCAL # used_hosts_per_remote_dc is set to 0, so ignore it - remote_host = Host(DefaultEndPoint("ip2"), SimpleConvictionPolicy) + remote_host = Host(DefaultEndPoint("ip2"), SimpleConvictionPolicy, host_id=uuid.uuid4()) remote_host.set_location_info("dc2", "rack1") assert policy.distance(remote_host) == HostDistance.IGNORED @@ -753,7 +754,7 @@ def test_get_distance(self): # since used_hosts_per_remote_dc is set to 1, only the first # remote host in dc2 will be REMOTE, the rest are IGNORED - second_remote_host = Host(DefaultEndPoint("ip3"), SimpleConvictionPolicy) + second_remote_host = Host(DefaultEndPoint("ip3"), SimpleConvictionPolicy, host_id=uuid.uuid4()) second_remote_host.set_location_info("dc2", "rack1") policy.populate(self.FakeCluster(), [host, remote_host, second_remote_host]) distances = set([policy.distance(remote_host), policy.distance(second_remote_host)]) @@ -764,7 +765,7 @@ def test_status_updates(self): Same test as DCAwareRoundRobinPolicyTest.test_status_updates() """ - hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy) for i in range(4)] + hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(4)] for h in hosts[:2]: h.set_location_info("dc1", "rack1") for h in hosts[2:]: @@ -775,11 +776,11 @@ def test_status_updates(self): policy.on_down(hosts[0]) policy.on_remove(hosts[2]) - new_local_host = Host(DefaultEndPoint(4), SimpleConvictionPolicy) + new_local_host = Host(DefaultEndPoint(4), SimpleConvictionPolicy, host_id=uuid.uuid4()) new_local_host.set_location_info("dc1", "rack1") policy.on_up(new_local_host) - new_remote_host = Host(DefaultEndPoint(5), SimpleConvictionPolicy) + new_remote_host = Host(DefaultEndPoint(5), SimpleConvictionPolicy, host_id=uuid.uuid4()) new_remote_host.set_location_info("dc9000", "rack1") policy.on_add(new_remote_host) @@ -802,7 +803,7 @@ def test_status_updates(self): assert qplan == [] def test_statement_keyspace(self): - hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy) for i in range(4)] + hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(4)] for host in hosts: host.set_up() @@ -896,7 +897,7 @@ def test_no_shuffle_if_given_no_routing_key(self): self._assert_shuffle(cluster=self._prepare_cluster_with_tablets(), keyspace='keyspace', routing_key=None) def _prepare_cluster_with_vnodes(self): - hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy) for i in range(4)] + hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(4)] for host in hosts: host.set_up() cluster = Mock(spec=Cluster) @@ -908,7 +909,7 @@ def _prepare_cluster_with_vnodes(self): return cluster def _prepare_cluster_with_tablets(self): - hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy) for i in range(4)] + hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(4)] for host in hosts: host.set_up() cluster = Mock(spec=Cluster) @@ -1422,7 +1423,7 @@ class WhiteListRoundRobinPolicyTest(unittest.TestCase): def test_hosts_with_hostname(self): hosts = ['localhost'] policy = WhiteListRoundRobinPolicy(hosts) - host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy) + host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) policy.populate(None, [host]) qplan = list(policy.make_query_plan()) @@ -1433,7 +1434,7 @@ def test_hosts_with_hostname(self): def test_hosts_with_socket_hostname(self): hosts = [UnixSocketEndPoint('/tmp/scylla-workdir/cql.m')] policy = WhiteListRoundRobinPolicy(hosts) - host = Host(UnixSocketEndPoint('/tmp/scylla-workdir/cql.m'), SimpleConvictionPolicy) + host = Host(UnixSocketEndPoint('/tmp/scylla-workdir/cql.m'), SimpleConvictionPolicy, host_id=uuid.uuid4()) policy.populate(None, [host]) qplan = list(policy.make_query_plan()) @@ -1559,8 +1560,8 @@ def setUp(self): child_policy=Mock(name='child_policy', distance=Mock(name='distance')), predicate=lambda host: host.address == 'acceptme' ) - self.ignored_host = Host(DefaultEndPoint('ignoreme'), conviction_policy_factory=Mock()) - self.accepted_host = Host(DefaultEndPoint('acceptme'), conviction_policy_factory=Mock()) + self.ignored_host = Host(DefaultEndPoint('ignoreme'), conviction_policy_factory=Mock(), host_id=uuid.uuid4()) + self.accepted_host = Host(DefaultEndPoint('acceptme'), conviction_policy_factory=Mock(), host_id=uuid.uuid4()) def test_ignored_with_filter(self): assert self.hfp.distance(self.ignored_host) == HostDistance.IGNORED @@ -1629,7 +1630,7 @@ def test_query_plan_deferred_to_child(self): def test_wrap_token_aware(self): cluster = Mock(spec=Cluster) - hosts = [Host(DefaultEndPoint("127.0.0.{}".format(i)), SimpleConvictionPolicy) for i in range(1, 6)] + hosts = [Host(DefaultEndPoint("127.0.0.{}".format(i)), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(1, 6)] for host in hosts: host.set_up() @@ -1656,13 +1657,13 @@ def get_replicas(keyspace, packed_key): query_plan = hfp.make_query_plan("keyspace", mocked_query) # First the not filtered replica, and then the rest of the allowed hosts ordered query_plan = list(query_plan) - assert query_plan[0] == Host(DefaultEndPoint("127.0.0.2"), SimpleConvictionPolicy) - assert set(query_plan[1:]) == {Host(DefaultEndPoint("127.0.0.3"), SimpleConvictionPolicy), - Host(DefaultEndPoint("127.0.0.5"), SimpleConvictionPolicy)} + assert query_plan[0] == Host(DefaultEndPoint("127.0.0.2"), SimpleConvictionPolicy, host_id=uuid.uuid4()) + assert set(query_plan[1:]) == {Host(DefaultEndPoint("127.0.0.3"), SimpleConvictionPolicy, host_id=uuid.uuid4()), + Host(DefaultEndPoint("127.0.0.5"), SimpleConvictionPolicy, host_id=uuid.uuid4())} def test_create_whitelist(self): cluster = Mock(spec=Cluster) - hosts = [Host(DefaultEndPoint("127.0.0.{}".format(i)), SimpleConvictionPolicy) for i in range(1, 6)] + hosts = [Host(DefaultEndPoint("127.0.0.{}".format(i)), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(1, 6)] for host in hosts: host.set_up() @@ -1680,5 +1681,5 @@ def test_create_whitelist(self): mocked_query = Mock() query_plan = hfp.make_query_plan("keyspace", mocked_query) # Only the filtered replicas should be allowed - assert set(query_plan) == {Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy), - Host(DefaultEndPoint("127.0.0.4"), SimpleConvictionPolicy)} + assert set(query_plan) == {Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=uuid.uuid4()), + Host(DefaultEndPoint("127.0.0.4"), SimpleConvictionPolicy, host_id=uuid.uuid4())} diff --git a/tests/unit/test_types.py b/tests/unit/test_types.py index 3390f6dbd6..a5bd028b26 100644 --- a/tests/unit/test_types.py +++ b/tests/unit/test_types.py @@ -1002,11 +1002,11 @@ def test_host_order(self): @test_category data_types """ - hosts = [Host(addr, SimpleConvictionPolicy) for addr in + hosts = [Host(addr, SimpleConvictionPolicy, host_id=uuid.uuid4()) for addr in ("127.0.0.1", "127.0.0.2", "127.0.0.3", "127.0.0.4")] - hosts_equal = [Host(addr, SimpleConvictionPolicy) for addr in + hosts_equal = [Host(addr, SimpleConvictionPolicy, host_id=uuid.uuid4()) for addr in ("127.0.0.1", "127.0.0.1")] - hosts_equal_conviction = [Host("127.0.0.1", SimpleConvictionPolicy), Host("127.0.0.1", ConvictionPolicy)] + hosts_equal_conviction = [Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4()), Host("127.0.0.1", ConvictionPolicy, host_id=uuid.uuid4())] check_sequence_consistency(hosts) check_sequence_consistency(hosts_equal, equal=True) check_sequence_consistency(hosts_equal_conviction, equal=True)