Skip to content

Commit

Permalink
#151: Add unit tests for resolving host names
Browse files Browse the repository at this point in the history
  • Loading branch information
kaklakariada committed Sep 2, 2024
1 parent 282a362 commit 5e0ae80
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 9 deletions.
27 changes: 18 additions & 9 deletions pyexasol/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from . import callback as cb

from typing import (NamedTuple, Optional)
from .exceptions import *
from .statement import ExaStatement
from .logger import ExaLogger
Expand All @@ -27,6 +28,13 @@
from .version import __version__


class ResolvedHost(NamedTuple):
"""This represents a resolved host name with its IP address and port number."""
hostname: str
ip_address: str
port: int
fingerprint: Optional[str]

class ExaConnection(object):
cls_statement = ExaStatement
cls_formatter = ExaFormatter
Expand Down Expand Up @@ -671,7 +679,7 @@ def _init_ws(self):
failed_attempts += 1

if failed_attempts == len(dsn_items):
raise ExaConnectionFailedError(self, 'Could not connect to Exasol: ' + str(e))
raise ExaConnectionFailedError(self, 'Could not connect to Exasol: ' + str(e)) from e
else:
self._ws.settimeout(self.options['socket_timeout'])

Expand Down Expand Up @@ -729,13 +737,13 @@ def _get_login_attributes(self):

return attributes

def _process_dsn(self, dsn):
def _process_dsn(self, dsn: str, shuffle_host_names: bool=True) -> list[ResolvedHost]:
"""
Parse DSN, expand ranges and resolve IP addresses for all hostnames
Return list of (hostname, ip_address, port) tuples in random order
Randomness is required to guarantee proper distribution of workload across all nodes
"""
if len(dsn.strip()) == 0:
if dsn is None or len(dsn.strip()) == 0:
raise ExaConnectionDsnError(self, 'Connection string is empty')

current_port = constant.DEFAULT_PORT
Expand Down Expand Up @@ -789,22 +797,23 @@ def _process_dsn(self, dsn):
else:
result.extend(self._resolve_hostname(m.group('hostname_prefix'), current_port, current_fingerprint))

random.shuffle(result)
if shuffle_host_names:
random.shuffle(result)

return result

def _resolve_hostname(self, hostname, port, fingerprint):
def _resolve_hostname(self, hostname: str, port: int, fingerprint: Optional[str]) -> list[ResolvedHost]:
"""
Resolve all IP addresses for hostname and add port
It also implicitly checks that all hostnames mentioned in DSN can be resolved
"""
try:
hostname, alias_list, ipaddr_list = socket.gethostbyname_ex(hostname)
except OSError:
hostname, _, ipaddr_list = socket.gethostbyname_ex(hostname)
except OSError as ex:
raise ExaConnectionDsnError(self, f'Could not resolve IP address of hostname [{hostname}] '
f'derived from connection string')
f'derived from connection string') from ex

return [(hostname, ipaddr, port, fingerprint) for ipaddr in ipaddr_list]
return [ResolvedHost(hostname, ipaddr, port, fingerprint) for ipaddr in ipaddr_list]

def _validate_fingerprint(self, provided_fingerprint):
server_fingerprint = hashlib.sha256(self._ws.sock.getpeercert(True)).hexdigest().upper()
Expand Down
57 changes: 57 additions & 0 deletions test/integration/connection_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import pytest

from unittest import mock

from pyexasol.exceptions import ExaConnectionDsnError
from pyexasol.connection import ResolvedHost

# pylint: disable=protected-access/W0212

def test_resolve_hostname(connection):
with mock.patch("socket.gethostbyname_ex") as get_hostname:
get_hostname.return_value = ("host", [], ["ip1", "ip2"])
actual = connection._resolve_hostname("host", 1234, "fingerprint")
expected = [("host","ip1", 1234, "fingerprint"),("host","ip2", 1234, "fingerprint")]
assert actual == expected


@pytest.mark.parametrize("empty_dsn", [None, "", " ", "\t"])
def test_process_empty_dsn_fails(connection, empty_dsn):
with pytest.raises(ExaConnectionDsnError, match="Connection string is empty"):
connection._process_dsn(empty_dsn)

def test_process_dsn_shuffles_hosts(connection):
dsn = "host1:1234,host2:4321"
def resolve_hostname():
with mock.patch("socket.gethostbyname_ex") as get_hostname:
get_hostname.side_effect = [("host1", [], ["ip11", "ip12"]), ("host2", [], ["ip21", "ip22"]),
("host1", [], ["ip11", "ip12"]), ("host2", [], ["ip21", "ip22"])]
return tuple(connection._process_dsn(dsn))
count = 100
results = {resolve_hostname() for _ in range(0, count)}
assert len(results) > 1

def test_process_dsn_without_shuffling(connection):
with mock.patch("socket.gethostbyname_ex") as get_hostname:
get_hostname.side_effect = [("host1", [], ["ip11", "ip12"]), ("host2", [], ["ip21", "ip22"])]
actual = connection._process_dsn("host1,host2:1234", shuffle_host_names=False)
expected = [
ResolvedHost("host1","ip11", 1234, None),
ResolvedHost("host1","ip12", 1234, None),
ResolvedHost("host2","ip21", 1234, None),
ResolvedHost("host2","ip22", 1234, None)]
assert actual == expected

def test_process_dsn_without_port(connection):
with mock.patch("socket.gethostbyname_ex") as get_hostname:
get_hostname.side_effect = [("host1", [], ["ip1"])]
actual = connection._process_dsn("host1")
expected = [ResolvedHost("host1", "ip1", 8563, None)]
assert actual == expected

def test_process_dsn_with_fingerprint(connection):
with mock.patch("socket.gethostbyname_ex") as get_hostname:
get_hostname.side_effect = [("host1", [], ["ip1"])]
actual = connection._process_dsn("host1/135a1d2dce102de866f58267521f4232153545a075dc85f8f7596f57e588a181:1234")
expected = [ResolvedHost("host1", "ip1", 1234, "135A1D2DCE102DE866F58267521F4232153545A075DC85F8F7596F57E588A181")]
assert actual == expected

0 comments on commit 5e0ae80

Please sign in to comment.