diff --git a/test/test_client.py b/test/test_client.py index 7ad6a4b..e4c62fb 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -104,6 +104,15 @@ def test_parse_wss_scheme_with_query_string(self): self.assertEqual(c.resource, "/?token=value") self.assertEqual(c.bind_addr, ("127.0.0.1", 443)) + def test_overriding_host_from_headers(self): + c = WebSocketBaseClient(url="wss://127.0.0.1", headers=[("Host", "example123.com")]) + self.assertEqual(c.host, "127.0.0.1") + self.assertEqual(c.port, 443) + self.assertEqual(c.bind_addr, ("127.0.0.1", 443)) + for h in c.handshake_headers: + if h[0].lower() == "host": + self.assertEqual(h[1], "example123.com") + @patch('ws4py.client.socket') def test_connect_and_close(self, sock): diff --git a/ws4py/client/__init__.py b/ws4py/client/__init__.py index 411638f..32d28a8 100644 --- a/ws4py/client/__init__.py +++ b/ws4py/client/__init__.py @@ -253,7 +253,6 @@ def handshake_headers(self): handshake. """ headers = [ - ('Host', '%s:%s' % (self.host, self.port)), ('Connection', 'Upgrade'), ('Upgrade', 'websocket'), ('Sec-WebSocket-Key', self.key.decode('utf-8')), @@ -266,6 +265,12 @@ def handshake_headers(self): if self.extra_headers: headers.extend(self.extra_headers) + # keep old logic if no overriding Host in headers + if not any(x for x in headers if x[0].lower() == 'host') and \ + 'host' not in self.exclude_headers: + headers.append(('Host', '%s:%s' % (self.host, self.port))) + + if not any(x for x in headers if x[0].lower() == 'origin') and \ 'origin' not in self.exclude_headers: