Skip to content

Commit

Permalink
feat(allow_hosts): add support for hostname resolution (#189)
Browse files Browse the repository at this point in the history
Co-authored-by: jguer <me@jguer.space>
  • Loading branch information
cabarnes and Jguer committed Feb 2, 2023
1 parent 7f8c9df commit 268a5d0
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 1 deletion.
41 changes: 40 additions & 1 deletion pytest_socket.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import ipaddress
import socket

import pytest
Expand Down Expand Up @@ -170,16 +171,54 @@ def host_from_connect_args(args):
return host_from_address(address)


def is_ipaddress(address: str) -> bool:
"""
Determine if the address is a valid IPv4 or IPv6 address.
"""
try:
ipaddress.ip_address(address)
return True
except ValueError:
return False


def resolve_hostname(hostname):
try:
return socket.gethostbyname(hostname)
except socket.gaierror:
return None


def normalize_allowed_hosts(allowed_hosts):
"""Convert all items in `allowed_hosts` to an IP address."""
ip_hosts = []
for host in allowed_hosts:
host = host.strip()
if is_ipaddress(host):
ip_hosts.append(host)
else:
resolved = resolve_hostname(host)
if resolved:
ip_hosts.append(resolved)

return ip_hosts


def socket_allow_hosts(allowed=None, allow_unix_socket=False):
"""disable socket.socket.connect() to disable the Internet. useful in testing."""
if isinstance(allowed, str):
allowed = allowed.split(",")

if not isinstance(allowed, list):
return

allowed_hosts = normalize_allowed_hosts(allowed)

def guarded_connect(inst, *args):
host = host_from_connect_args(args)
if host in allowed or (_is_unix_socket(inst.family) and allow_unix_socket):
if host in allowed_hosts or (
_is_unix_socket(inst.family) and allow_unix_socket
):
return _true_connect(inst, *args)

raise SocketConnectBlockedError(allowed, host)
Expand Down
8 changes: 8 additions & 0 deletions tests/test_restrict_hosts.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,14 @@ def test_single_cli_arg_connect_enabled(assert_connect):
assert_connect(True, cli_arg=localhost)


def test_single_cli_arg_connect_enabled_hostname_resolved(assert_connect):
assert_connect(True, cli_arg="localhost")


def test_single_cli_arg_connect_enabled_hostname_unresolvable(assert_connect):
assert_connect(False, cli_arg="unresolvable")


def test_single_cli_arg_connect_unicode_enabled(assert_connect):
assert_connect(True, cli_arg=localhost, code_template=connect_unicode_code_template)

Expand Down

0 comments on commit 268a5d0

Please sign in to comment.