Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 37 additions & 31 deletions cassandra/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -1683,14 +1683,7 @@ def protocol_downgrade(self, host_endpoint, previous_version):
"http://datastax.github.io/python-driver/api/cassandra/cluster.html#cassandra.cluster.Cluster.protocol_version", self.protocol_version, new_version, host_endpoint)
self.protocol_version = new_version

def _add_resolved_hosts(self):
for endpoint in self.endpoints_resolved:
host, new = self.add_host(endpoint, signal=False)
if new:
host.set_up()
for listener in self.listeners:
listener.on_add(host)

def _populate_hosts(self):
self.profile_manager.populate(
weakref.proxy(self), self.metadata.all_hosts())
self.load_balancing_policy.populate(
Expand All @@ -1717,17 +1710,10 @@ def connect(self, keyspace=None, wait_for_all_pools=False):
self.contact_points, self.protocol_version)
self.connection_class.initialize_reactor()
_register_cluster_shutdown(self)

self._add_resolved_hosts()

try:
self.control_connection.connect()

# we set all contact points up for connecting, but we won't infer state after this
for endpoint in self.endpoints_resolved:
h = self.metadata.get_host(endpoint)
if h and self.profile_manager.distance(h) == HostDistance.IGNORED:
h.is_up = None
self._populate_hosts()

log.debug("Control connection created")
except Exception:
Expand Down Expand Up @@ -3534,18 +3520,20 @@ def _set_new_connection(self, conn):
if old:
log.debug("[control connection] Closing old connection %r, replacing with %r", old, conn)
old.close()
def _connect_host_in_lbp(self):

def _connect_host(self):
errors = {}

lbp = (
self._cluster.load_balancing_policy
if self._cluster._config_mode == _ConfigMode.LEGACY else
self._cluster._default_load_balancing_policy
)

# use endpoints from the default LBP if it is already initialized
for host in lbp.make_query_plan():
try:
return (self._try_connect(host), None)
return (self._try_connect(host.endpoint), None)
except ConnectionException as exc:
errors[str(host.endpoint)] = exc
log.warning("[control connection] Error connecting to %s:", host, exc_info=True)
Expand All @@ -3555,7 +3543,22 @@ def _connect_host_in_lbp(self):
log.warning("[control connection] Error connecting to %s:", host, exc_info=True)
if self._is_shutdown:
raise DriverException("[control connection] Reconnection in progress during shutdown")


# if lbp not initialized use contact points provided to the cluster
if len(errors) == 0:
for endpoint in self._cluster.endpoints_resolved:
try:
return (self._try_connect(endpoint), None)
except ConnectionException as exc:
errors[str(endpoint)] = exc
log.warning("[control connection] Error connecting to %s:", endpoint, exc_info=True)
self._cluster.signal_connection_failure(endpoint, exc, is_host_addition=False)
except Exception as exc:
errors[str(endpoint)] = exc
log.warning("[control connection] Error connecting to %s:", endpoint, exc_info=True)
if self._is_shutdown:
raise DriverException("[control connection] Reconnection in progress during shutdown")

return (None, errors)

def _reconnect_internal(self):
Expand All @@ -3567,43 +3570,43 @@ def _reconnect_internal(self):
to the exception that was raised when an attempt was made to open
a connection to that host.
"""
(conn, _) = self._connect_host_in_lbp()
(conn, _) = self._connect_host()
if conn is not None:
return conn

# Try to re-resolve hostnames as a fallback when all hosts are unreachable
self._cluster._resolve_hostnames()

self._cluster._add_resolved_hosts()
self._cluster._populate_hosts()

(conn, errors) = self._connect_host_in_lbp()
(conn, errors) = self._connect_host()
if conn is not None:
return conn

raise NoHostAvailable("Unable to connect to any servers", errors)

def _try_connect(self, host):
def _try_connect(self, endpoint):
"""
Creates a new Connection, registers for pushed events, and refreshes
node/token and schema metadata.
"""
log.debug("[control connection] Opening new connection to %s", host)
log.debug("[control connection] Opening new connection to %s", endpoint)

while True:
try:
connection = self._cluster.connection_factory(host.endpoint, is_control_connection=True)
connection = self._cluster.connection_factory(endpoint, is_control_connection=True)
if self._is_shutdown:
connection.close()
raise DriverException("Reconnecting during shutdown")
break
except ProtocolVersionUnsupported as e:
self._cluster.protocol_downgrade(host.endpoint, e.startup_version)
self._cluster.protocol_downgrade(endpoint, e.startup_version)
except ProtocolException as e:
# protocol v5 is out of beta in C* >=4.0-beta5 and is now the default driver
# protocol version. If the protocol version was not explicitly specified,
# and that the server raises a beta protocol error, we should downgrade.
if not self._cluster._protocol_version_explicit and e.is_beta_protocol_error:
self._cluster.protocol_downgrade(host.endpoint, self._cluster.protocol_version)
self._cluster.protocol_downgrade(endpoint, self._cluster.protocol_version)
else:
raise

Expand Down Expand Up @@ -3879,6 +3882,9 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None,

self._cluster.metadata.update_host(host, old_endpoint=connection.endpoint)
connection.original_endpoint = connection.endpoint = host.endpoint
else:
log.info("Consider local host new found host")
peers_result.append(local_row)
# Check metadata.partitioner to see if we haven't built anything yet. If
# every node in the cluster was in the contact points, we won't discover
# any new nodes, so we need this additional check. (See PYTHON-90)
Expand Down Expand Up @@ -4177,8 +4183,8 @@ def _get_peers_query(self, peers_query_type, connection=None):
query_template = (self._SELECT_SCHEMA_PEERS_TEMPLATE
if peers_query_type == self.PeersQueryType.PEERS_SCHEMA
else self._SELECT_PEERS_NO_TOKENS_TEMPLATE)
host_release_version = self._cluster.metadata.get_host(connection.original_endpoint).release_version
host_dse_version = self._cluster.metadata.get_host(connection.original_endpoint).dse_version
host_release_version = None if self._cluster.metadata.get_host(connection.original_endpoint) == None else self._cluster.metadata.get_host(connection.original_endpoint).release_version
host_dse_version = None if self._cluster.metadata.get_host(connection.original_endpoint) == None else self._cluster.metadata.get_host(connection.original_endpoint).dse_version
uses_native_address_query = (
host_dse_version and Version(host_dse_version) >= self._MINIMUM_NATIVE_ADDRESS_DSE_VERSION)

Expand Down
3 changes: 3 additions & 0 deletions cassandra/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@ def export_schema_as_string(self):
def refresh(self, connection, timeout, target_type=None, change_type=None, fetch_size=None,
metadata_request_timeout=None, **kwargs):

if not self.get_host(connection.original_endpoint):
return

server_version = self.get_host(connection.original_endpoint).release_version
dse_version = self.get_host(connection.original_endpoint).dse_version
parser = get_schema_parser(connection, server_version, dse_version, timeout, metadata_request_timeout, fetch_size)
Expand Down
3 changes: 3 additions & 0 deletions cassandra/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,9 @@ def populate(self, cluster, hosts):

def distance(self, host):
dc = self._dc(host)
if not self.local_dc:
self.local_dc = dc
return HostDistance.LOCAL
Comment on lines +267 to +269
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should not be in this PR

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sylwiaszunejko, what is the reason for having it here ?

if dc == self.local_dc:
return HostDistance.LOCAL

Expand Down
2 changes: 1 addition & 1 deletion cassandra/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def __init__(self, endpoint, conviction_policy_factory, datacenter=None, rack=No
self.endpoint = endpoint if isinstance(endpoint, EndPoint) else DefaultEndPoint(endpoint)
self.conviction_policy = conviction_policy_factory(self)
if not host_id:
host_id = uuid.uuid4()
raise ValueError("host_id may not be None")
self.host_id = host_id
self.set_location_info(datacenter, rack)
self.lock = RLock()
Expand Down
36 changes: 24 additions & 12 deletions tests/integration/standard/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,8 +900,9 @@ def test_profile_lb_swap(self):
"""
Tests that profile load balancing policies are not shared

Creates two LBP, runs a few queries, and validates that each LBP is execised
seperately between EP's
Creates two LBP, runs a few queries, and validates that each LBP is exercised
separately between EP's. Each RoundRobinPolicy starts from its own random
position and maintains independent round-robin ordering.

@since 3.5
@jira_ticket PYTHON-569
Expand All @@ -916,17 +917,28 @@ def test_profile_lb_swap(self):
with TestCluster(execution_profiles=exec_profiles) as cluster:
session = cluster.connect(wait_for_all_pools=True)

# default is DCA RR for all hosts
expected_hosts = set(cluster.metadata.all_hosts())
rr1_queried_hosts = set()
rr2_queried_hosts = set()

rs = session.execute(query, execution_profile='rr1')
rr1_queried_hosts.add(rs.response_future._current_host)
rs = session.execute(query, execution_profile='rr2')
rr2_queried_hosts.add(rs.response_future._current_host)

assert rr2_queried_hosts == rr1_queried_hosts
num_hosts = len(expected_hosts)
assert num_hosts > 1, "Need at least 2 hosts for this test"

rr1_queried_hosts = []
rr2_queried_hosts = []

for _ in range(num_hosts * 2):
rs = session.execute(query, execution_profile='rr1')
rr1_queried_hosts.append(rs.response_future._current_host)
rs = session.execute(query, execution_profile='rr2')
rr2_queried_hosts.append(rs.response_future._current_host)

# Both policies should have queried all hosts
assert set(rr1_queried_hosts) == expected_hosts
assert set(rr2_queried_hosts) == expected_hosts

# The order of hosts should demonstrate round-robin behavior
# After num_hosts queries, the pattern should repeat
for i in range(num_hosts):
assert rr1_queried_hosts[i] == rr1_queried_hosts[i + num_hosts]
assert rr2_queried_hosts[i] == rr2_queried_hosts[i + num_hosts]

def test_ta_lbp(self):
"""
Expand Down
8 changes: 6 additions & 2 deletions tests/integration/standard/test_control_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,12 @@ def test_get_control_connection_host(self):

# reconnect and make sure that the new host is reflected correctly
self.cluster.control_connection._reconnect()
new_host = self.cluster.get_control_connection_host()
assert host != new_host
new_host1 = self.cluster.get_control_connection_host()

self.cluster.control_connection._reconnect()
new_host2 = self.cluster.get_control_connection_host()

assert new_host1 != new_host2

# TODO: enable after https://github.com/scylladb/python-driver/issues/121 is fixed
@unittest.skip('Fails on scylla due to the broadcast_rpc_port is None')
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/standard/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def test_metrics_per_cluster(self):
try:
# Test write
query = SimpleStatement("INSERT INTO {0}.{0} (k, v) VALUES (2, 2)".format(self.ks_name), consistency_level=ConsistencyLevel.ALL)
with pytest.raises(WriteTimeout):
with pytest.raises((WriteTimeout, Unavailable)):
self.session.execute(query, timeout=None)
finally:
get_node(1).resume()
Expand All @@ -230,7 +230,7 @@ def test_metrics_per_cluster(self):
stats_cluster2 = cluster2.metrics.get_stats()

# Test direct access to stats
assert 1 == self.cluster.metrics.stats.write_timeouts
assert (1 == self.cluster.metrics.stats.write_timeouts or 1 == self.cluster.metrics.stats.unavailables)
assert 0 == cluster2.metrics.stats.write_timeouts

# Test direct access to a child stats
Expand Down
8 changes: 3 additions & 5 deletions tests/integration/standard/test_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,6 @@ def test_predicate_changes(self):
external_event = True
contact_point = DefaultEndPoint("127.0.0.1")

single_host = {Host(contact_point, SimpleConvictionPolicy)}
all_hosts = {Host(DefaultEndPoint("127.0.0.{}".format(i)), SimpleConvictionPolicy) for i in (1, 2, 3)}

predicate = lambda host: host.endpoint == contact_point if external_event else True
hfp = ExecutionProfile(
load_balancing_policy=HostFilterPolicy(RoundRobinPolicy(), predicate=predicate)
Expand All @@ -62,7 +59,8 @@ def test_predicate_changes(self):
response = session.execute("SELECT * from system.local WHERE key='local'")
queried_hosts.update(response.response_future.attempted_hosts)

assert queried_hosts == single_host
assert len(queried_hosts) == 1
assert queried_hosts.pop().endpoint == contact_point

external_event = False
futures = session.update_created_pools()
Expand All @@ -72,7 +70,7 @@ def test_predicate_changes(self):
for _ in range(10):
response = session.execute("SELECT * from system.local WHERE key='local'")
queried_hosts.update(response.response_future.attempted_hosts)
assert queried_hosts == all_hosts
assert len(queried_hosts) == 3


class WhiteListRoundRobinPolicyTests(unittest.TestCase):
Expand Down
3 changes: 2 additions & 1 deletion tests/integration/standard/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,8 @@ def make_query_plan(self, working_keyspace=None, query=None):
live_hosts = sorted(list(self._live_hosts))
host = []
try:
host = [live_hosts[self.host_index_to_use]]
if len(live_hosts) > self.host_index_to_use:
host = [live_hosts[self.host_index_to_use]]
except IndexError as e:
raise IndexError(
'You specified an index larger than the number of hosts. Total hosts: {}. Index specified: {}'.format(
Expand Down
5 changes: 3 additions & 2 deletions tests/unit/advanced/test_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import unittest
from unittest.mock import Mock
import uuid

from cassandra.pool import Host
from cassandra.policies import RoundRobinPolicy
Expand Down Expand Up @@ -72,7 +73,7 @@ def test_target_no_host(self):

def test_target_host_down(self):
node_count = 4
hosts = [Host(i, Mock()) for i in range(node_count)]
hosts = [Host(i, Mock(), host_id=uuid.uuid4()) for i in range(node_count)]
target_host = hosts[1]

policy = DSELoadBalancingPolicy(RoundRobinPolicy())
Expand All @@ -87,7 +88,7 @@ def test_target_host_down(self):

def test_target_host_nominal(self):
node_count = 4
hosts = [Host(i, Mock()) for i in range(node_count)]
hosts = [Host(i, Mock(), host_id=uuid.uuid4()) for i in range(node_count)]
target_host = hosts[1]
target_host.is_up = True

Expand Down
Loading
Loading