Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding back Rakshith's websocket changes #24410

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions sdk/eventhub/azure-eventhub/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,7 @@

### Features Added

### Breaking Changes

### Bugs Fixed

### Other Changes
- Added support for connection using websocket and http proxy.

## 5.8.0a3 (2022-03-08)

Expand Down
11 changes: 8 additions & 3 deletions sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,8 +333,6 @@ def _create_auth(self):
functools.partial(self._credential.get_token, JWT_TOKEN_SCOPE),
token_type=token_type,
timeout=self._config.auth_timeout,
http_proxy=self._config.http_proxy,
transport_type=self._config.transport_type,
custom_endpoint_hostname=self._config.custom_endpoint_hostname,
port=self._config.connection_port,
verify=self._config.connection_verify,
Expand Down Expand Up @@ -379,8 +377,15 @@ def _management_request(self, mgmt_msg, op_type):
last_exception = None
while retried_times <= self._config.max_retries:
mgmt_auth = self._create_auth()
hostname = self._address.hostname
if self._config.transport_type.name == 'AmqpOverWebsocket':
hostname += '/$servicebus/websocket/'
mgmt_client = AMQPClient(
self._address.hostname, auth=mgmt_auth, debug=self._config.network_tracing
hostname,
auth=mgmt_auth,
debug=self._config.network_tracing,
transport_type=self._config.transport_type,
http_proxy=self._config.http_proxy
)
try:
mgmt_client.open()
Expand Down
11 changes: 9 additions & 2 deletions sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,13 @@ def _create_handler(self, auth):
)
desired_capabilities = [RECEIVER_RUNTIME_METRIC_SYMBOL] if self._track_last_enqueued_event_properties else None

transport_type = self._client._config.transport_type # pylint:disable=protected-access
hostname = urlparse(source.address).hostname
if transport_type.name == 'AmqpOverWebsocket':
hostname += '/$servicebus/websocket/'

self._handler = ReceiveClient(
urlparse(source.address).hostname,
hostname,
source,
auth=auth,
idle_timeout=self._idle_timeout,
Expand All @@ -164,7 +169,9 @@ def _create_handler(self, auth):
properties=create_properties(self._client._config.user_agent), # pylint:disable=protected-access
desired_capabilities=desired_capabilities,
streaming_receive=True,
message_received_callback=self._message_received
message_received_callback=self._message_received,
transport_type=transport_type,
http_proxy=self._client._config.http_proxy # pylint:disable=protected-access
)

def _open_with_retry(self):
Expand Down
8 changes: 7 additions & 1 deletion sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,12 @@ def __init__(self, client, target, **kwargs):

def _create_handler(self, auth):
# type: (JWTTokenAuth) -> None
transport_type = self._client._config.transport_type # pylint:disable=protected-access
hostname = self._client._address.hostname # pylint: disable=protected-access
if transport_type.name == 'AmqpOverWebsocket':
hostname += '/$servicebus/websocket/'
self._handler = SendClient(
self._client._address.hostname, # pylint: disable=protected-access
hostname, # pylint: disable=protected-access
self._target,
auth=auth,
idle_timeout=self._idle_timeout,
Expand All @@ -136,6 +140,8 @@ def _create_handler(self, auth):
client_name=self._name,
link_properties=self._link_properties,
properties=create_properties(self._client._config.user_agent), # pylint: disable=protected-access
transport_type=transport_type,
http_proxy=self._client._config.http_proxy # pylint: disable=protected-access
)

def _open_with_retry(self):
Expand Down
23 changes: 18 additions & 5 deletions sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ssl import SSLError

from ._transport import Transport
from .sasl import SASLTransport
from .sasl import SASLTransport, SASLWithWebSocket
from .session import Session
from .performatives import OpenFrame, CloseFrame
from .constants import (
Expand All @@ -22,7 +22,8 @@
MAX_FRAME_SIZE_BYTES,
HEADER_FRAME,
ConnectionState,
EMPTY_FRAME
EMPTY_FRAME,
TransportType
)

from .error import (
Expand Down Expand Up @@ -77,12 +78,19 @@ class Connection(object):
Default value is `0.1`.
:keyword bool network_trace: Whether to log the network traffic. Default value is `False`. If enabled, frames
will be logged at the logging.INFO level.
:keyword str transport_type: Determines if the transport type is Amqp or AmqpOverWebSocket.
Defaults to TransportType.Amqp. It will be AmqpOverWebSocket if using http_proxy.
:keyword Dict http_proxy: HTTP proxy settings. This must be a dictionary with the following
keys: `'proxy_hostname'` (str value) and `'proxy_port'` (int value). When using these settings,
the transport_type would be AmqpOverWebSocket.
Additionally the following keys may also be present: `'username', 'password'`.
"""

def __init__(self, endpoint, **kwargs):
# type(str, Any) -> None
parsed_url = urlparse(endpoint)
self._hostname = parsed_url.hostname
endpoint = self._hostname
if parsed_url.port:
self._port = parsed_url.port
elif parsed_url.scheme == 'amqps':
Expand All @@ -92,16 +100,21 @@ def __init__(self, endpoint, **kwargs):
self.state = None # type: Optional[ConnectionState]

transport = kwargs.get('transport')
self._transport_type = kwargs.pop('transport_type', TransportType.Amqp)
if transport:
self._transport = transport
elif 'sasl_credential' in kwargs:
self._transport = SASLTransport(
host=parsed_url.netloc,
sasl_transport = SASLTransport
if self._transport_type.name == 'AmqpOverWebsocket' or kwargs.get("http_proxy"):
sasl_transport = SASLWithWebSocket
endpoint = parsed_url.hostname + parsed_url.path
self._transport = sasl_transport(
host=endpoint,
credential=kwargs['sasl_credential'],
**kwargs
)
else:
self._transport = Transport(parsed_url.netloc, **kwargs)
self._transport = Transport(parsed_url.netloc, transport_type=self._transport_type, **kwargs)

self._container_id = kwargs.pop('container_id', None) or str(uuid.uuid4()) # type: str
self._max_frame_size = kwargs.pop('max_frame_size', MAX_FRAME_SIZE_BYTES) # type: int
Expand Down
80 changes: 75 additions & 5 deletions sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from ._platform import KNOWN_TCP_OPTS, SOL_TCP, pack, unpack
from ._encode import encode_frame
from ._decode import decode_frame, decode_empty_frame
from .constants import TLS_HEADER_FRAME
from .constants import TLS_HEADER_FRAME, WEBSOCKET_PORT, TransportType, AMQP_WS_SUBPROTOCOL


try:
Expand Down Expand Up @@ -439,7 +439,7 @@ def write(self, s):

def receive_frame(self, *args, **kwargs):
try:
header, channel, payload = self.read(**kwargs)
header, channel, payload = self.read(**kwargs)
if not payload:
decoded = decode_empty_frame(header)
else:
Expand Down Expand Up @@ -646,12 +646,82 @@ def _read(self, n, initial=False, _errnos=(errno.EAGAIN, errno.EINTR)):
result, self._read_buffer = rbuf[:n], rbuf[n:]
return result


def Transport(host, connect_timeout=None, ssl=False, **kwargs):
def Transport(host, transport_type, connect_timeout=None, ssl=False, **kwargs):
"""Create transport.

Given a few parameters from the Connection constructor,
select and create a subclass of _AbstractTransport.
"""
transport = SSLTransport if ssl else TCPTransport
if transport_type == TransportType.AmqpOverWebsocket:
transport = WebSocketTransport
else:
transport = SSLTransport if ssl else TCPTransport
return transport(host, connect_timeout=connect_timeout, ssl=ssl, **kwargs)

class WebSocketTransport(_AbstractTransport):
def __init__(self, host, port=WEBSOCKET_PORT, connect_timeout=None, ssl=None, **kwargs):
self.sslopts = ssl if isinstance(ssl, dict) else {}
self._connect_timeout = connect_timeout
self._host = host
super().__init__(
host, port, connect_timeout, **kwargs
)
self.ws = None
self._http_proxy = kwargs.get('http_proxy', None)

def connect(self):
http_proxy_host, http_proxy_port, http_proxy_auth = None, None, None
if self._http_proxy:
http_proxy_host = self._http_proxy['proxy_hostname']
http_proxy_port = self._http_proxy['proxy_port']
username = self._http_proxy.get('username', None)
password = self._http_proxy.get('password', None)
if username or password:
http_proxy_auth = (username, password)
try:
from websocket import create_connection
self.ws = create_connection(
url="wss://{}".format(self._host),
subprotocols=[AMQP_WS_SUBPROTOCOL],
timeout=self._connect_timeout,
skip_utf8_validation=True,
sslopt=self.sslopts,
http_proxy_host=http_proxy_host,
http_proxy_port=http_proxy_port,
http_proxy_auth=http_proxy_auth
)
except ImportError:
raise ValueError("Please install websocket-client library to use websocket transport.")

def _read(self, n, initial=False, buffer=None, **kwargs): # pylint: disable=unused-arguments
"""Read exactly n bytes from the peer."""

length = 0
view = buffer or memoryview(bytearray(n))
nbytes = self._read_buffer.readinto(view)
length += nbytes
n -= nbytes
while n:
data = self.ws.recv()

if len(data) <= n:
view[length: length + len(data)] = data
n -= len(data)
else:
view[length: length + n] = data[0:n]
self._read_buffer = BytesIO(data[n:])
n = 0

return view

def _shutdown_transport(self):
"""Do any preliminary work in shutting down the connection."""
self.ws.close()

def _write(self, s):
"""Completely write a string to the peer.
ABNF, OPCODE_BINARY = 0x2
See http://tools.ietf.org/html/rfc5234
http://tools.ietf.org/html/rfc6455#section-5.2
"""
self.ws.send_binary(s)
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from ._receiver_async import ReceiverLink
from ._sender_async import SenderLink
from ._session_async import Session
from ._sasl_async import SASLTransport
from ._cbs_async import CBSAuthenticator
from ..client import AMQPClient as AMQPClientSync
from ..client import ReceiveClient as ReceiveClientSync
Expand Down Expand Up @@ -201,7 +200,9 @@ async def open_async(self):
channel_max=self._channel_max,
idle_timeout=self._idle_timeout,
properties=self._properties,
network_trace=self._network_trace
network_trace=self._network_trace,
transport_type=self._transport_type,
http_proxy=self._http_proxy
)
await self._connection.open()
if not self._session:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import asyncio

from ._transport_async import AsyncTransport
from ._sasl_async import SASLTransport
from ._sasl_async import SASLTransport, SASLWithWebSocket
from ._session_async import Session
from ..performatives import OpenFrame, CloseFrame
from .._connection import get_local_timeout
Expand All @@ -27,7 +27,8 @@
MAX_CHANNELS,
HEADER_FRAME,
ConnectionState,
EMPTY_FRAME
EMPTY_FRAME,
TransportType
)

from ..error import (
Expand Down Expand Up @@ -58,11 +59,19 @@ class Connection(object):
:param list(str) offered_capabilities: The extension capabilities the sender supports.
:param list(str) desired_capabilities: The extension capabilities the sender may use if the receiver supports
:param dict properties: Connection properties.
:keyword str transport_type: Determines if the transport type is Amqp or AmqpOverWebSocket.
Defaults to TransportType.Amqp. It will be AmqpOverWebSocket if using http_proxy.
:keyword Dict http_proxy: HTTP proxy settings. This must be a dictionary with the following
keys: `'proxy_hostname'` (str value) and `'proxy_port'` (int value). When using these settings,
the transport_type would be AmqpOverWebSocket.
Additionally the following keys may also be present: `'username', 'password'`.
"""

def __init__(self, endpoint, **kwargs):
parsed_url = urlparse(endpoint)
self.hostname = parsed_url.hostname
endpoint = self.hostname
self._transport_type = kwargs.pop('transport_type', TransportType.Amqp)
if parsed_url.port:
self.port = parsed_url.port
elif parsed_url.scheme == 'amqps':
Expand All @@ -75,8 +84,12 @@ def __init__(self, endpoint, **kwargs):
if transport:
self.transport = transport
elif 'sasl_credential' in kwargs:
self.transport = SASLTransport(
host=parsed_url.netloc,
sasl_transport = SASLTransport
if self._transport_type.name == "AmqpOverWebsocket" or kwargs.get("http_proxy"):
sasl_transport = SASLWithWebSocket
endpoint = parsed_url.hostname + parsed_url.path
self.transport = sasl_transport(
host=endpoint,
credential=kwargs['sasl_credential'],
**kwargs
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import struct
from enum import Enum

from ._transport_async import AsyncTransport
from ._transport_async import AsyncTransport, WebSocketTransportAsync
from ..types import AMQPTypes, TYPE, VALUE
from ..constants import FIELD, SASLCode, SASL_HEADER_FRAME
from ..constants import FIELD, SASLCode, SASL_HEADER_FRAME, WEBSOCKET_PORT, TransportType
from .._transport import AMQPS_PORT
from ..performatives import (
SASLOutcome,
Expand Down Expand Up @@ -73,14 +73,8 @@ def start(self):
return b''


class SASLTransport(AsyncTransport):

def __init__(self, host, credential, port=AMQPS_PORT, connect_timeout=None, ssl=None, **kwargs):
self.credential = credential
ssl = ssl or True
super(SASLTransport, self).__init__(host, port=port, connect_timeout=connect_timeout, ssl=ssl, **kwargs)

async def negotiate(self):
class SASLTransportMixinAsync():
async def _negotiate(self):
await self.write(SASL_HEADER_FRAME)
_, returned_header = await self.receive_frame()
if returned_header[1] != SASL_HEADER_FRAME:
Expand All @@ -104,3 +98,35 @@ async def negotiate(self):
return
else:
raise ValueError("SASL negotiation failed.\nOutcome: {}\nDetails: {}".format(*fields))


class SASLTransport(AsyncTransport, SASLTransportMixinAsync):

def __init__(self, host, credential, port=AMQPS_PORT, connect_timeout=None, ssl=None, **kwargs):
self.credential = credential
ssl = ssl or True
super(SASLTransport, self).__init__(host, port=port, connect_timeout=connect_timeout, ssl=ssl, **kwargs)

async def negotiate(self):
await self._negotiate()


class SASLWithWebSocket(WebSocketTransportAsync, SASLTransportMixinAsync):
def __init__(
self, host, credential, port=WEBSOCKET_PORT, connect_timeout=None, ssl=None, **kwargs
):
self.credential = credential
ssl = ssl or True
http_proxy = kwargs.pop('http_proxy', None)
self._transport = WebSocketTransportAsync(
host,
port=port,
connect_timeout=connect_timeout,
ssl=ssl,
http_proxy=http_proxy,
**kwargs
)
super().__init__(host, port, connect_timeout, ssl, **kwargs)

async def negotiate(self):
await self._negotiate()
Loading