From de96d1755911a866cf66570b7953078b8fcf43da Mon Sep 17 00:00:00 2001 From: Swathi Pillalamarri Date: Fri, 29 Apr 2022 14:08:18 -0700 Subject: [PATCH 01/21] remove uamqp from EventData/EventDataBatch --- .../azure-eventhub/azure/eventhub/_common.py | 52 ++++---- .../azure/eventhub/_producer_client.py | 12 +- .../_transport/transport_message_base.py | 32 +++++ .../_transport/uamqp_plugins/constants.py | 17 +++ .../_transport/uamqp_plugins/message.py | 35 ++++++ .../_transport/uamqp_plugins/utils.py | 61 +++++++++ .../azure-eventhub/azure/eventhub/_utils.py | 17 +-- .../eventhub/aio/_producer_client_async.py | 4 +- .../azure/eventhub/amqp/_amqp_message.py | 116 +++++++----------- .../azure/eventhub/amqp/_amqp_utils.py | 27 ++++ .../tests/unittest/test_event_data.py | 5 +- 11 files changed, 269 insertions(+), 109 deletions(-) create mode 100644 sdk/eventhub/azure-eventhub/azure/eventhub/_transport/transport_message_base.py create mode 100644 sdk/eventhub/azure-eventhub/azure/eventhub/_transport/uamqp_plugins/constants.py create mode 100644 sdk/eventhub/azure-eventhub/azure/eventhub/_transport/uamqp_plugins/message.py create mode 100644 sdk/eventhub/azure-eventhub/azure/eventhub/_transport/uamqp_plugins/utils.py create mode 100644 sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_utils.py diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py index 5806b67caf8d..1f17cb91c71b 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py @@ -17,13 +17,12 @@ List, TYPE_CHECKING, cast, + Callable ) from typing_extensions import TypedDict import six -from uamqp import BatchMessage, Message, constants - from ._utils import ( set_message_partition_key, trace_message, @@ -125,7 +124,7 @@ def __init__( self._raw_amqp_message = AmqpAnnotatedMessage( # type: ignore data_body=body, annotations={}, application_properties={} ) - self.message = (self._raw_amqp_message._message) # pylint:disable=protected-access + self.message = None # amqp message to be set right before sending self._raw_amqp_message.header = AmqpMessageHeader() self._raw_amqp_message.properties = AmqpMessageProperties() self.message_id = None @@ -224,15 +223,15 @@ def _from_message(cls, message, raw_amqp_message=None): event_data._raw_amqp_message = raw_amqp_message if raw_amqp_message else AmqpAnnotatedMessage(message=message) return event_data - def _encode_message(self): - # type: () -> bytes - # pylint: disable=protected-access - return self._raw_amqp_message._message.encode_message() + #def _encode_message(self): + # # type: () -> bytes + # # pylint: disable=protected-access + # return self._raw_amqp_message._message.encode_message() def _decode_non_data_body_as_str(self, encoding="UTF-8"): # type: (str) -> str # pylint: disable=protected-access - body = self.raw_amqp_message._message._body + body = self.raw_amqp_message.body if self.body_type == AmqpMessageBodyType.VALUE: if not body.data: return "" @@ -241,11 +240,6 @@ def _decode_non_data_body_as_str(self, encoding="UTF-8"): seq_list = [d for seq_section in body.data for d in seq_section] return str(decode_with_recurse(seq_list, encoding)) - def _to_outgoing_message(self): - # type: () -> EventData - self.message = (self._raw_amqp_message._to_outgoing_amqp_message()) # pylint:disable=protected-access - return self - @property def raw_amqp_message(self): # type: () -> AmqpAnnotatedMessage @@ -513,8 +507,22 @@ class EventDataBatch(object): Event Hub decided by the service. """ - def __init__(self, max_size_in_bytes=None, partition_id=None, partition_key=None): - # type: (Optional[int], Optional[str], Optional[Union[str, bytes]]) -> None + def __init__( + self, + max_size_in_bytes: Optional [int] = None, + partition_id: Optional[str] = None, + partition_key: Optional[Union[str, bytes]] = None, + **kwargs + ) -> None: + self._uamqp_transport = kwargs.pop("uamqp_transport", True) + if self._uamqp_transport: + from ._transport.uamqp_plugins import utils as transport_utils, constants as transport_constants + self._transport_utils = transport_utils + self._transport_constants = transport_constants + self._to_outgoing_amqp_message = self._transport_utils.to_outgoing_amqp_message + else: + # TODO: add pyamqp support, will have to switch default to pyamqp + pass if partition_key and not isinstance( partition_key, (six.text_type, six.binary_type) @@ -526,8 +534,8 @@ def __init__(self, max_size_in_bytes=None, partition_id=None, partition_key=None "partition_key to only be string type, they might fail to parse the non-string value." ) - self.max_size_in_bytes = max_size_in_bytes or constants.MAX_MESSAGE_LENGTH_BYTES - self.message = BatchMessage(data=[], multi_messages=False, properties=None) + self.max_size_in_bytes = max_size_in_bytes or self._transport_constants.MAX_MESSAGE_LENGTH_BYTES + self.message = self._transport_constants.BATCH_MESSAGE(data=[]) self._partition_id = partition_id self._partition_key = partition_key @@ -546,9 +554,11 @@ def __len__(self): return self._count @classmethod - def _from_batch(cls, batch_data, partition_key=None): - # type: (Iterable[EventData], Optional[AnyStr]) -> EventDataBatch - outgoing_batch_data = [transform_outbound_single_message(m, EventData) for m in batch_data] + def _from_batch(cls, batch_data, to_outgoing_amqp_message, partition_key=None): + # type: (Iterable[EventData], Callable, Optional[AnyStr]) -> EventDataBatch + outgoing_batch_data = [ + transform_outbound_single_message(m, EventData, to_outgoing_amqp_message) for m in batch_data + ] batch_data_instance = cls(partition_key=partition_key) batch_data_instance.message._body_gen = ( # pylint:disable=protected-access outgoing_batch_data @@ -589,7 +599,7 @@ def add(self, event_data): :raise: :class:`ValueError`, when exceeding the size limit. """ - outgoing_event_data = transform_outbound_single_message(event_data, EventData) + outgoing_event_data = transform_outbound_single_message(event_data, EventData, self._to_outgoing_amqp_message) if self._partition_key: if ( diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py index b8bf46cd3733..ec44acd9b806 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py @@ -92,6 +92,14 @@ def __init__( **kwargs # type: Any ): # type:(...) -> None + self._uamqp_transport = kwargs.pop("uamqp_transport", True) + if self._uamqp_transport: + from ._transport import uamqp_plugins as transport_plugins + self._transport_plugins = transport_plugins + else: + # TODO: add pyamqp support, will have to switch if/else default + pass + super(EventHubProducerClient, self).__init__( fully_qualified_namespace=fully_qualified_namespace, eventhub_name=eventhub_name, @@ -119,7 +127,7 @@ def _get_partitions(self): for p_id in cast(List[str], self._partition_ids): self._producers[p_id] = None - def _get_max_mesage_size(self): + def _get_max_message_size(self): # type: () -> None # pylint: disable=protected-access,line-too-long with self._lock: @@ -345,7 +353,7 @@ def create_batch(self, **kwargs): """ if not self._max_message_size_on_link: - self._get_max_mesage_size() + self._get_max_message_size() max_size_in_bytes = kwargs.get("max_size_in_bytes", None) partition_id = kwargs.get("partition_id", None) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/transport_message_base.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/transport_message_base.py new file mode 100644 index 000000000000..7d9d5e661cf1 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/transport_message_base.py @@ -0,0 +1,32 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# ------------------------------------------------------------------------- +from abc import abstractmethod + +class TransportMessageBase: + """ + Abstract class that acts as a wrapper for the transport Message class. + """ + @property + @abstractmethod + def body_type(self): + """The body type of the underlying AMQP message. + + :rtype: ~azure.eventhub.amqp.AmqpMessageBodyType + """ + + @property + @abstractmethod + def body(self): + """The body of the Message. The format may vary depending on the body type: + For :class:`azure.eventhub.amqp.AmqpMessageBodyType.DATA`, + the body could be bytes or Iterable[bytes]. + For :class:`azure.eventhub.amqp.AmqpMessageBodyType.SEQUENCE`, + the body could be List or Iterable[List]. + For :class:`azure.eventhub.amqp.AmqpMessageBodyType.VALUE`, + the body could be any type. + + :rtype: Any + """ diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/uamqp_plugins/constants.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/uamqp_plugins/constants.py new file mode 100644 index 000000000000..6505becaf384 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/uamqp_plugins/constants.py @@ -0,0 +1,17 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# ------------------------------------------------------------------------- + +from uamqp import BatchMessage, constants, MessageBodyType +from ...amqp._constants import AmqpMessageBodyType + +AMQP_MESSAGE_BODY_TYPE_MAP = { + MessageBodyType.Data.value: AmqpMessageBodyType.DATA, + MessageBodyType.Sequence.value: AmqpMessageBodyType.SEQUENCE, + MessageBodyType.Value.value: AmqpMessageBodyType.VALUE, +} + +BATCH_MESSAGE = BatchMessage +MAX_MESSAGE_LENGTH_BYTES = constants.MAX_MESSAGE_LENGTH_BYTES diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/uamqp_plugins/message.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/uamqp_plugins/message.py new file mode 100644 index 000000000000..56807de7d5c7 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/uamqp_plugins/message.py @@ -0,0 +1,35 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# ------------------------------------------------------------------------- +from uamqp import Message as UamqpMessage +from .constants import AMQP_MESSAGE_BODY_TYPE_MAP, AmqpMessageBodyType +from ..transport_message_base import TransportMessageBase + +class TransportMessage(TransportMessageBase, UamqpMessage): + + @property + def body_type(self): + # type: () -> AmqpMessageBodyType + """The body type of the underlying AMQP message. + + :rtype: ~azure.eventhub.amqp.AmqpMessageBodyType + """ + return AMQP_MESSAGE_BODY_TYPE_MAP.get( + self._body.type, AmqpMessageBodyType.VALUE # pylint: disable=protected-access + ) + + @property + def body(self): + """The body of the Message. The format may vary depending on the body type: + For :class:`azure.eventhub.amqp.AmqpMessageBodyType.DATA`, + the body could be bytes or Iterable[bytes]. + For :class:`azure.eventhub.amqp.AmqpMessageBodyType.SEQUENCE`, + the body could be List or Iterable[List]. + For :class:`azure.eventhub.amqp.AmqpMessageBodyType.VALUE`, + the body could be any type. + + :rtype: Any + """ + return self.get_data() diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/uamqp_plugins/utils.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/uamqp_plugins/utils.py new file mode 100644 index 000000000000..376941757ca2 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/uamqp_plugins/utils.py @@ -0,0 +1,61 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# ------------------------------------------------------------------------- +from .constants import AmqpMessageBodyType +import uamqp + +def to_outgoing_amqp_message(annotated_message): + message_header = None + if annotated_message.header: + message_header = uamqp.message.MessageHeader() + message_header.delivery_count = annotated_message.header.delivery_count + message_header.time_to_live = annotated_message.header.time_to_live + message_header.first_acquirer = annotated_message.header.first_acquirer + message_header.durable = annotated_message.header.durable + message_header.priority = annotated_message.header.priority + + message_properties = None + if annotated_message.properties: + message_properties = uamqp.message.MessageProperties( + message_id=annotated_message.properties.message_id, + user_id=annotated_message.properties.user_id, + to=annotated_message.properties.to, + subject=annotated_message.properties.subject, + reply_to=annotated_message.properties.reply_to, + correlation_id=annotated_message.properties.correlation_id, + content_type=annotated_message.properties.content_type, + content_encoding=annotated_message.properties.content_encoding, + creation_time=int(annotated_message.properties.creation_time) + if annotated_message.properties.creation_time else None, + absolute_expiry_time=int(annotated_message.properties.absolute_expiry_time) + if annotated_message.properties.absolute_expiry_time else None, + group_id=annotated_message.properties.group_id, + group_sequence=annotated_message.properties.group_sequence, + reply_to_group_id=annotated_message.properties.reply_to_group_id, + encoding=annotated_message._encoding # pylint: disable=protected-access + ) + + amqp_body_type = annotated_message.body_type # pylint: disable=protected-access + amqp_body = annotated_message.body + if amqp_body_type == AmqpMessageBodyType.DATA: + amqp_body_type = uamqp.MessageBodyType.Data + amqp_body = list(amqp_body) + elif amqp_body_type == AmqpMessageBodyType.SEQUENCE: + amqp_body_type = uamqp.MessageBodyType.Sequence + amqp_body = list(amqp_body) + else: + # amqp_body_type is type of AmqpMessageBodyType.VALUE + amqp_body_type = uamqp.MessageBodyType.Value + + return uamqp.message.Message( + body=amqp_body, + body_type=amqp_body_type, + header=message_header, + properties=message_properties, + application_properties=annotated_message.application_properties, + annotations=annotated_message.annotations, + delivery_annotations=annotated_message.delivery_annotations, + footer=annotated_message.footer + ) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py index cf913f4e3335..41011b691db5 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py @@ -10,7 +10,7 @@ import datetime import calendar import logging -from typing import TYPE_CHECKING, Type, Optional, Dict, Union, Any, Iterable, Tuple, Mapping +from typing import TYPE_CHECKING, Type, Optional, Dict, Union, Any, Iterable, Tuple, Mapping, Callable import six @@ -271,8 +271,8 @@ def parse_sas_credential(credential): return (sas, expiry) -def transform_outbound_single_message(message, message_type): - # type: (Union[AmqpAnnotatedMessage, EventData], Type[EventData]) -> EventData +def transform_outbound_single_message(message, message_type, to_outgoing_amqp_message): + # type: (Union[AmqpAnnotatedMessage, EventData], Type[EventData], Callable) -> EventData """ This method serves multiple goals: 1. update the internal message to reflect any updates to settable properties on EventData @@ -284,17 +284,18 @@ def transform_outbound_single_message(message, message_type): :rtype: EventData """ try: - # EventData # pylint: disable=protected-access - return message._to_outgoing_message() # type: ignore + # EventData.message stores uamqp/pyamqp.Message during sending + message.message = to_outgoing_amqp_message(message.raw_amqp_message) + return message # type: ignore except AttributeError: - # AmqpAnnotatedMessage # pylint: disable=protected-access + # AmqpAnnotatedMessage is converted to uamqp/pyamqp.Message during sending + amqp_message = to_outgoing_amqp_message(message.raw_amqp_message) return message_type._from_message( - message=message._to_outgoing_amqp_message(), raw_amqp_message=message # type: ignore + message=amqp_message, raw_amqp_message=message # type: ignore ) - def decode_with_recurse(data, encoding="UTF-8"): # type: (Any, str) -> Any # pylint:disable=isinstance-second-argument-not-valid-type diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py index e628cf1817de..0c6b8a3ef5e2 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py @@ -119,7 +119,7 @@ async def _get_partitions(self) -> None: for p_id in cast(List[str], self._partition_ids): self._producers[p_id] = None - async def _get_max_mesage_size(self) -> None: + async def _get_max_message_size(self) -> None: # pylint: disable=protected-access,line-too-long async with self._lock: if not self._max_message_size_on_link: @@ -380,7 +380,7 @@ async def create_batch( """ if not self._max_message_size_on_link: - await self._get_max_mesage_size() + await self._get_max_message_size() if max_size_in_bytes and max_size_in_bytes > self._max_message_size_on_link: raise ValueError( diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py b/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py index 7669b4aef901..a67d23fd0ed2 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py @@ -6,8 +6,7 @@ from typing import Optional, Any, cast, Mapping, Dict -import uamqp - +from ._amqp_utils import normalized_data_body, normalized_sequence_body from ._constants import AMQP_MESSAGE_BODY_TYPE_MAP, AmqpMessageBodyType from .._mixin import DictMixin @@ -47,12 +46,14 @@ class AmqpAnnotatedMessage(object): def __init__(self, **kwargs): # type: (Any) -> None - self._message = kwargs.pop("message", None) self._encoding = kwargs.pop("encoding", "UTF-8") # internal usage only for Event Hub received message - if self._message: - self._from_amqp_message(self._message) + message = kwargs.pop("message", None) + if message: + self._from_amqp_message(message) + self._body = message.body + self._body_type = message.body_type return # manually constructed AMQPAnnotatedMessage @@ -66,16 +67,15 @@ def __init__(self, **kwargs): self._body = None self._body_type = None if "data_body" in kwargs: - self._body = kwargs.get("data_body") - self._body_type = uamqp.MessageBodyType.Data + self._body = normalized_data_body(kwargs.get("data_body")) + self._body_type = AmqpMessageBodyType.DATA elif "sequence_body" in kwargs: - self._body = kwargs.get("sequence_body") - self._body_type = uamqp.MessageBodyType.Sequence + self._body = normalized_sequence_body(kwargs.get("sequence_body")) + self._body_type = AmqpMessageBodyType.SEQUENCE elif "value_body" in kwargs: self._body = kwargs.get("value_body") - self._body_type = uamqp.MessageBodyType.Value + self._body_type = AmqpMessageBodyType.VALUE - self._message = uamqp.message.Message(body=self._body, body_type=self._body_type) header_dict = cast(Mapping, kwargs.get("header")) self._header = AmqpMessageHeader(**header_dict) if "header" in kwargs else None self._footer = kwargs.get("footer") @@ -86,7 +86,30 @@ def __init__(self, **kwargs): self._delivery_annotations = kwargs.get("delivery_annotations") def __str__(self): - return str(self._message) + if self._body_type == AmqpMessageBodyType.DATA: + output_str = "" + for data_section in self.body: + try: + output_str += data_section.decode(self._encoding) + except AttributeError: + output_str += str(data_section) + return output_str + elif self._body_type == AmqpMessageBodyType.SEQUENCE: + output_str = "" + for sequence_section in self.body: + for d in sequence_section: + try: + output_str += d.decode(self._encoding) + except AttributeError: + output_str += str(d) + return output_str + else: + if not self.body: + return "" + try: + return self.body.decode(self._encoding) + except AttributeError: + return str(self.body) def __repr__(self): # type: () -> str @@ -122,7 +145,7 @@ def __repr__(self): return "AmqpAnnotatedMessage({})".format(message_repr)[:1024] def _from_amqp_message(self, message): - # populate the properties from an uamqp message + # populate the properties from a amqp transport message self._properties = AmqpMessageProperties( message_id=message.properties.message_id, user_id=message.properties.user_id, @@ -145,63 +168,10 @@ def _from_amqp_message(self, message): durable=message.header.durable, priority=message.header.priority ) if message.header else None - self._footer = message.footer - self._annotations = message.annotations - self._delivery_annotations = message.delivery_annotations - self._application_properties = message.application_properties - - def _to_outgoing_amqp_message(self): - message_header = None - if self.header: - message_header = uamqp.message.MessageHeader() - message_header.delivery_count = self.header.delivery_count - message_header.time_to_live = self.header.time_to_live - message_header.first_acquirer = self.header.first_acquirer - message_header.durable = self.header.durable - message_header.priority = self.header.priority - - message_properties = None - if self.properties: - message_properties = uamqp.message.MessageProperties( - message_id=self.properties.message_id, - user_id=self.properties.user_id, - to=self.properties.to, - subject=self.properties.subject, - reply_to=self.properties.reply_to, - correlation_id=self.properties.correlation_id, - content_type=self.properties.content_type, - content_encoding=self.properties.content_encoding, - creation_time=int(self.properties.creation_time) if self.properties.creation_time else None, - absolute_expiry_time=int(self.properties.absolute_expiry_time) - if self.properties.absolute_expiry_time else None, - group_id=self.properties.group_id, - group_sequence=self.properties.group_sequence, - reply_to_group_id=self.properties.reply_to_group_id, - encoding=self._encoding - ) - - amqp_body = self._message._body # pylint: disable=protected-access - if isinstance(amqp_body, uamqp.message.DataBody): - amqp_body_type = uamqp.MessageBodyType.Data - amqp_body = list(amqp_body.data) - elif isinstance(amqp_body, uamqp.message.SequenceBody): - amqp_body_type = uamqp.MessageBodyType.Sequence - amqp_body = list(amqp_body.data) - else: - # amqp_body is type of uamqp.message.ValueBody - amqp_body_type = uamqp.MessageBodyType.Value - amqp_body = amqp_body.data - - return uamqp.message.Message( - body=amqp_body, - body_type=amqp_body_type, - header=message_header, - properties=message_properties, - application_properties=self.application_properties, - annotations=self.annotations, - delivery_annotations=self.delivery_annotations, - footer=self.footer - ) + self._footer = message.footer if message.footer else {} + self._annotations = message.annotations if message.annotations else {} + self._delivery_annotations = message.delivery_annotations if message.delivery_annotations else {} + self._application_properties = message.application_properties if message.application_properties else {} @property def body(self): @@ -216,7 +186,7 @@ def body(self): :rtype: Any """ - return self._message.get_data() + return self._body @property def body_type(self): @@ -225,9 +195,7 @@ def body_type(self): :rtype: ~azure.eventhub.amqp.AmqpMessageBodyType """ - return AMQP_MESSAGE_BODY_TYPE_MAP.get( - self._message._body.type, AmqpMessageBodyType.VALUE # pylint: disable=protected-access - ) + return self._body_type @property def properties(self): diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_utils.py b/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_utils.py new file mode 100644 index 000000000000..4bb676392f89 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_utils.py @@ -0,0 +1,27 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# ------------------------------------------------------------------------- + +def encode_str(data, encoding='utf-8'): + try: + return data.encode(encoding) + except AttributeError: + return data + +def normalized_data_body(data, **kwargs): + # A helper method to normalize input into AMQP Data Body format + encoding = kwargs.get("encoding", "utf-8") + if isinstance(data, list): + return [encode_str(item, encoding) for item in data] + else: + return [encode_str(data, encoding)] + + +def normalized_sequence_body(sequence): + # A helper method to normalize input into AMQP Sequence Body format + if isinstance(sequence, list) and all([isinstance(b, list) for b in sequence]): + return sequence + elif isinstance(sequence, list): + return [sequence] diff --git a/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py b/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py index ef562c2628a5..55511e0d6f57 100644 --- a/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py +++ b/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py @@ -2,6 +2,7 @@ import pytest import uamqp from packaging import version +from azure.eventhub._transport.uamqp_plugins.message import TransportMessage from azure.eventhub.amqp import AmqpAnnotatedMessage from azure.eventhub import _common @@ -70,7 +71,7 @@ def test_sys_properties(): properties.group_id = "group_id" properties.group_sequence = 1 properties.reply_to_group_id = "reply_to_group_id" - message = uamqp.Message(properties=properties) + message = TransportMessage(properties=properties) message.annotations = {_common.PROP_OFFSET: "@latest"} ed = EventData._from_message(message) # type: EventData @@ -108,7 +109,7 @@ def test_event_data_batch(): batch.add(EventData("A")) def test_event_data_from_message(): - message = uamqp.Message('A') + message = TransportMessage('A') event = EventData._from_message(message) assert event.content_type is None assert event.correlation_id is None From 5277aed3ad55e1c23cb38eaee57977a613942c48 Mon Sep 17 00:00:00 2001 From: Swathi Pillalamarri Date: Tue, 3 May 2022 16:44:18 -0700 Subject: [PATCH 02/21] move all bits to Transport class --- .../azure-eventhub/azure/eventhub/_common.py | 18 +-- .../azure/eventhub/_producer_client.py | 14 +-- .../azure/eventhub/_transport/_base.py | 17 +++ .../eventhub/_transport/_uamqp_transport.py | 116 ++++++++++++++++++ .../_transport/uamqp_plugins/constants.py | 17 --- .../_transport/uamqp_plugins/message.py | 35 ------ .../_transport/uamqp_plugins/utils.py | 61 --------- .../tests/unittest/test_event_data.py | 4 +- 8 files changed, 146 insertions(+), 136 deletions(-) create mode 100644 sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py create mode 100644 sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py delete mode 100644 sdk/eventhub/azure-eventhub/azure/eventhub/_transport/uamqp_plugins/constants.py delete mode 100644 sdk/eventhub/azure-eventhub/azure/eventhub/_transport/uamqp_plugins/message.py delete mode 100644 sdk/eventhub/azure-eventhub/azure/eventhub/_transport/uamqp_plugins/utils.py diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py index 1f17cb91c71b..0905fb96354c 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py @@ -514,15 +514,7 @@ def __init__( partition_key: Optional[Union[str, bytes]] = None, **kwargs ) -> None: - self._uamqp_transport = kwargs.pop("uamqp_transport", True) - if self._uamqp_transport: - from ._transport.uamqp_plugins import utils as transport_utils, constants as transport_constants - self._transport_utils = transport_utils - self._transport_constants = transport_constants - self._to_outgoing_amqp_message = self._transport_utils.to_outgoing_amqp_message - else: - # TODO: add pyamqp support, will have to switch default to pyamqp - pass + self._amqp_transport = kwargs.pop("amqp_transport") if partition_key and not isinstance( partition_key, (six.text_type, six.binary_type) @@ -534,8 +526,8 @@ def __init__( "partition_key to only be string type, they might fail to parse the non-string value." ) - self.max_size_in_bytes = max_size_in_bytes or self._transport_constants.MAX_MESSAGE_LENGTH_BYTES - self.message = self._transport_constants.BATCH_MESSAGE(data=[]) + self.max_size_in_bytes = max_size_in_bytes or self._amqp_transport.MAX_MESSAGE_LENGTH_BYTES + self.message = self._amqp_transport.BATCH_MESSAGE(data=[]) self._partition_id = partition_id self._partition_key = partition_key @@ -599,7 +591,9 @@ def add(self, event_data): :raise: :class:`ValueError`, when exceeding the size limit. """ - outgoing_event_data = transform_outbound_single_message(event_data, EventData, self._to_outgoing_amqp_message) + outgoing_event_data = transform_outbound_single_message( + event_data, EventData, self._amqp_transport.to_outgoing_amqp_message + ) if self._partition_key: if ( diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py index ec44acd9b806..5702b7cd76c6 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py @@ -7,14 +7,13 @@ from typing import Any, Union, TYPE_CHECKING, Dict, List, Optional, cast -from uamqp import constants - from .exceptions import ConnectError, EventHubError from .amqp import AmqpAnnotatedMessage from ._client_base import ClientBase from ._producer import EventHubProducer from ._constants import ALL_PARTITIONS from ._common import EventDataBatch, EventData +from ._transport._uamqp_transport import UamqpTransport if TYPE_CHECKING: from ._client_base import CredentialTypes @@ -93,18 +92,14 @@ def __init__( ): # type:(...) -> None self._uamqp_transport = kwargs.pop("uamqp_transport", True) - if self._uamqp_transport: - from ._transport import uamqp_plugins as transport_plugins - self._transport_plugins = transport_plugins - else: - # TODO: add pyamqp support, will have to switch if/else default - pass + self._amqp_transport = UamqpTransport() super(EventHubProducerClient, self).__init__( fully_qualified_namespace=fully_qualified_namespace, eventhub_name=eventhub_name, credential=credential, network_tracing=kwargs.get("logging_enable"), + transport=self._amqp_transport, **kwargs ) self._producers = { @@ -139,7 +134,7 @@ def _get_max_message_size(self): self._producers[ # type: ignore ALL_PARTITIONS ]._handler.message_handler._link.peer_max_message_size - or constants.MAX_MESSAGE_LENGTH_BYTES + or self._amqp_transport.MAX_MESSAGE_LENGTH_BYTES ) def _start_producer(self, partition_id, send_timeout): @@ -370,6 +365,7 @@ def create_batch(self, **kwargs): max_size_in_bytes=(max_size_in_bytes or self._max_message_size_on_link), partition_id=partition_id, partition_key=partition_key, + amqp_transport=self._amqp_transport, ) return event_data_batch diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py new file mode 100644 index 000000000000..c970a220d6af --- /dev/null +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py @@ -0,0 +1,17 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- +from abc import ABC, abstractmethod + +class AmqpTransport(ABC): + + # define constants + BATCH_MESSAGE = None + MAX_MESSAGE_LENGTH_BYTES = None + + @abstractmethod + def to_outgoing_amqp_message(self, annotated_message): + """ + Converts an AmqpAnnotatedMessage into an Amqp Transport Message. + """ diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py new file mode 100644 index 000000000000..328ed2ef4f7f --- /dev/null +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py @@ -0,0 +1,116 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- +from uamqp import ( + BatchMessage, + constants, + MessageBodyType, + Message, +) +from uamqp.message import ( + MessageHeader, + MessageProperties, +) + +from ._base import AmqpTransport +from ..amqp._constants import AmqpMessageBodyType +from .transport_message_base import TransportMessageBase +from ..amqp._constants import AmqpMessageBodyType + + +class TransportMessage(TransportMessageBase, Message): + + @property + def body_type(self): + # type: () -> AmqpMessageBodyType + """The body type of the underlying AMQP message. + + :rtype: ~azure.eventhub.amqp.AmqpMessageBodyType + """ + return UamqpTransport.AMQP_MESSAGE_BODY_TYPE_MAP.get( + self._body.type, AmqpMessageBodyType.VALUE # pylint: disable=protected-access + ) + + @property + def body(self): + """The body of the Message. The format may vary depending on the body type: + For :class:`azure.eventhub.amqp.AmqpMessageBodyType.DATA`, + the body could be bytes or Iterable[bytes]. + For :class:`azure.eventhub.amqp.AmqpMessageBodyType.SEQUENCE`, + the body could be List or Iterable[List]. + For :class:`azure.eventhub.amqp.AmqpMessageBodyType.VALUE`, + the body could be any type. + + :rtype: Any + """ + return self.get_data() + +class UamqpTransport(AmqpTransport): + + # define constants + BATCH_MESSAGE = BatchMessage + MAX_MESSAGE_LENGTH_BYTES = constants.MAX_MESSAGE_LENGTH_BYTES + + AMQP_MESSAGE_BODY_TYPE_MAP = { + MessageBodyType.Data.value: AmqpMessageBodyType.DATA, + MessageBodyType.Sequence.value: AmqpMessageBodyType.SEQUENCE, + MessageBodyType.Value.value: AmqpMessageBodyType.VALUE, + } + + def to_outgoing_amqp_message(self, annotated_message): + """ + Converts an AmqpAnnotatedMessage into an Amqp Transport Message. + """ + message_header = None + if annotated_message.header: + message_header = MessageHeader() + message_header.delivery_count = annotated_message.header.delivery_count + message_header.time_to_live = annotated_message.header.time_to_live + message_header.first_acquirer = annotated_message.header.first_acquirer + message_header.durable = annotated_message.header.durable + message_header.priority = annotated_message.header.priority + + message_properties = None + if annotated_message.properties: + message_properties = MessageProperties( + message_id=annotated_message.properties.message_id, + user_id=annotated_message.properties.user_id, + to=annotated_message.properties.to, + subject=annotated_message.properties.subject, + reply_to=annotated_message.properties.reply_to, + correlation_id=annotated_message.properties.correlation_id, + content_type=annotated_message.properties.content_type, + content_encoding=annotated_message.properties.content_encoding, + creation_time=int(annotated_message.properties.creation_time) + if annotated_message.properties.creation_time else None, + absolute_expiry_time=int(annotated_message.properties.absolute_expiry_time) + if annotated_message.properties.absolute_expiry_time else None, + group_id=annotated_message.properties.group_id, + group_sequence=annotated_message.properties.group_sequence, + reply_to_group_id=annotated_message.properties.reply_to_group_id, + encoding=annotated_message._encoding # pylint: disable=protected-access + ) + + amqp_body_type = annotated_message.body_type # pylint: disable=protected-access + amqp_body = annotated_message.body + if amqp_body_type == AmqpMessageBodyType.DATA: + amqp_body_type = MessageBodyType.Data + amqp_body = list(amqp_body) + elif amqp_body_type == AmqpMessageBodyType.SEQUENCE: + amqp_body_type = MessageBodyType.Sequence + amqp_body = list(amqp_body) + else: + # amqp_body_type is type of AmqpMessageBodyType.VALUE + amqp_body_type = MessageBodyType.Value + + return Message( + body=amqp_body, + body_type=amqp_body_type, + header=message_header, + properties=message_properties, + application_properties=annotated_message.application_properties, + annotations=annotated_message.annotations, + delivery_annotations=annotated_message.delivery_annotations, + footer=annotated_message.footer + ) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/uamqp_plugins/constants.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/uamqp_plugins/constants.py deleted file mode 100644 index 6505becaf384..000000000000 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/uamqp_plugins/constants.py +++ /dev/null @@ -1,17 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# ------------------------------------------------------------------------- - -from uamqp import BatchMessage, constants, MessageBodyType -from ...amqp._constants import AmqpMessageBodyType - -AMQP_MESSAGE_BODY_TYPE_MAP = { - MessageBodyType.Data.value: AmqpMessageBodyType.DATA, - MessageBodyType.Sequence.value: AmqpMessageBodyType.SEQUENCE, - MessageBodyType.Value.value: AmqpMessageBodyType.VALUE, -} - -BATCH_MESSAGE = BatchMessage -MAX_MESSAGE_LENGTH_BYTES = constants.MAX_MESSAGE_LENGTH_BYTES diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/uamqp_plugins/message.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/uamqp_plugins/message.py deleted file mode 100644 index 56807de7d5c7..000000000000 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/uamqp_plugins/message.py +++ /dev/null @@ -1,35 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# ------------------------------------------------------------------------- -from uamqp import Message as UamqpMessage -from .constants import AMQP_MESSAGE_BODY_TYPE_MAP, AmqpMessageBodyType -from ..transport_message_base import TransportMessageBase - -class TransportMessage(TransportMessageBase, UamqpMessage): - - @property - def body_type(self): - # type: () -> AmqpMessageBodyType - """The body type of the underlying AMQP message. - - :rtype: ~azure.eventhub.amqp.AmqpMessageBodyType - """ - return AMQP_MESSAGE_BODY_TYPE_MAP.get( - self._body.type, AmqpMessageBodyType.VALUE # pylint: disable=protected-access - ) - - @property - def body(self): - """The body of the Message. The format may vary depending on the body type: - For :class:`azure.eventhub.amqp.AmqpMessageBodyType.DATA`, - the body could be bytes or Iterable[bytes]. - For :class:`azure.eventhub.amqp.AmqpMessageBodyType.SEQUENCE`, - the body could be List or Iterable[List]. - For :class:`azure.eventhub.amqp.AmqpMessageBodyType.VALUE`, - the body could be any type. - - :rtype: Any - """ - return self.get_data() diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/uamqp_plugins/utils.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/uamqp_plugins/utils.py deleted file mode 100644 index 376941757ca2..000000000000 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/uamqp_plugins/utils.py +++ /dev/null @@ -1,61 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# ------------------------------------------------------------------------- -from .constants import AmqpMessageBodyType -import uamqp - -def to_outgoing_amqp_message(annotated_message): - message_header = None - if annotated_message.header: - message_header = uamqp.message.MessageHeader() - message_header.delivery_count = annotated_message.header.delivery_count - message_header.time_to_live = annotated_message.header.time_to_live - message_header.first_acquirer = annotated_message.header.first_acquirer - message_header.durable = annotated_message.header.durable - message_header.priority = annotated_message.header.priority - - message_properties = None - if annotated_message.properties: - message_properties = uamqp.message.MessageProperties( - message_id=annotated_message.properties.message_id, - user_id=annotated_message.properties.user_id, - to=annotated_message.properties.to, - subject=annotated_message.properties.subject, - reply_to=annotated_message.properties.reply_to, - correlation_id=annotated_message.properties.correlation_id, - content_type=annotated_message.properties.content_type, - content_encoding=annotated_message.properties.content_encoding, - creation_time=int(annotated_message.properties.creation_time) - if annotated_message.properties.creation_time else None, - absolute_expiry_time=int(annotated_message.properties.absolute_expiry_time) - if annotated_message.properties.absolute_expiry_time else None, - group_id=annotated_message.properties.group_id, - group_sequence=annotated_message.properties.group_sequence, - reply_to_group_id=annotated_message.properties.reply_to_group_id, - encoding=annotated_message._encoding # pylint: disable=protected-access - ) - - amqp_body_type = annotated_message.body_type # pylint: disable=protected-access - amqp_body = annotated_message.body - if amqp_body_type == AmqpMessageBodyType.DATA: - amqp_body_type = uamqp.MessageBodyType.Data - amqp_body = list(amqp_body) - elif amqp_body_type == AmqpMessageBodyType.SEQUENCE: - amqp_body_type = uamqp.MessageBodyType.Sequence - amqp_body = list(amqp_body) - else: - # amqp_body_type is type of AmqpMessageBodyType.VALUE - amqp_body_type = uamqp.MessageBodyType.Value - - return uamqp.message.Message( - body=amqp_body, - body_type=amqp_body_type, - header=message_header, - properties=message_properties, - application_properties=annotated_message.application_properties, - annotations=annotated_message.annotations, - delivery_annotations=annotated_message.delivery_annotations, - footer=annotated_message.footer - ) diff --git a/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py b/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py index 55511e0d6f57..98db66d62cf1 100644 --- a/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py +++ b/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py @@ -2,7 +2,7 @@ import pytest import uamqp from packaging import version -from azure.eventhub._transport.uamqp_plugins.message import TransportMessage +from azure.eventhub._transport._uamqp_transport import TransportMessage, UamqpTransport from azure.eventhub.amqp import AmqpAnnotatedMessage from azure.eventhub import _common @@ -92,7 +92,7 @@ def test_sys_properties(): def test_event_data_batch(): - batch = EventDataBatch(max_size_in_bytes=110, partition_key="par") + batch = EventDataBatch(max_size_in_bytes=110, partition_key="par", amqp_transport=UamqpTransport()) batch.add(EventData("A")) assert str(batch) == "EventDataBatch(max_size_in_bytes=110, partition_id=None, partition_key='par', event_count=1)" assert repr(batch) == "EventDataBatch(max_size_in_bytes=110, partition_id=None, partition_key='par', event_count=1)" From 267445eb06338e5237c533c9faf160ff8f2d20c1 Mon Sep 17 00:00:00 2001 From: Swathi Pillalamarri Date: Wed, 4 May 2022 12:17:27 -0700 Subject: [PATCH 03/21] more changes --- .../azure/eventhub/_producer.py | 3 +- .../azure/eventhub/_producer_client.py | 8 +++-- .../azure/eventhub/_transport/_base.py | 27 ++++++++++++++++ .../eventhub/_transport/_uamqp_transport.py | 8 ++--- .../_transport/transport_message_base.py | 32 ------------------- 5 files changed, 39 insertions(+), 39 deletions(-) delete mode 100644 sdk/eventhub/azure-eventhub/azure/eventhub/_transport/transport_message_base.py diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py index 75498fc0bf37..8e2e4ec2ced2 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py @@ -117,6 +117,7 @@ def __init__(self, client, target, **kwargs): self._link_properties = { types.AMQPSymbol(TIMEOUT_SYMBOL): types.AMQPLong(int(self._timeout * 1000)) } + self._amqp_transport = kwargs.pop("amqp_transport") def _create_handler(self, auth): # type: (JWTTokenAuth) -> None @@ -190,7 +191,7 @@ def _wrap_eventdata( ): # type: (...) -> Union[EventData, EventDataBatch] if isinstance(event_data, EventData): - outgoing_event_data = transform_outbound_single_message(event_data, EventData) + outgoing_event_data = transform_outbound_single_message(event_data, EventData, self._amqp_transport.to_outgoing_amqp_message) if partition_key: set_message_partition_key(outgoing_event_data.message, partition_key) wrapper_event_data = outgoing_event_data diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py index 5702b7cd76c6..157e0e67c73e 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py @@ -92,14 +92,17 @@ def __init__( ): # type:(...) -> None self._uamqp_transport = kwargs.pop("uamqp_transport", True) - self._amqp_transport = UamqpTransport() + if self._uamqp_transport: + self._amqp_transport = UamqpTransport() + else: + raise NotImplementedError('pyamqp transport') super(EventHubProducerClient, self).__init__( fully_qualified_namespace=fully_qualified_namespace, eventhub_name=eventhub_name, credential=credential, network_tracing=kwargs.get("logging_enable"), - transport=self._amqp_transport, + amqp_transport=self._amqp_transport, **kwargs ) self._producers = { @@ -175,6 +178,7 @@ def _create_producer(self, partition_id=None, send_timeout=None): partition=partition_id, send_timeout=send_timeout, idle_timeout=self._idle_timeout, + amqp_transport=self._amqp_transport, ) return handler diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py index c970a220d6af..b74783a89361 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py @@ -4,6 +4,32 @@ # -------------------------------------------------------------------------------------------- from abc import ABC, abstractmethod +class TransportMessageBase: + """ + Abstract class that acts as a wrapper for the transport Message class. + """ + @property + @abstractmethod + def body_type(self): + """The body type of the underlying AMQP message. + + :rtype: ~azure.eventhub.amqp.AmqpMessageBodyType + """ + + @property + @abstractmethod + def body(self): + """The body of the Message. The format may vary depending on the body type: + For :class:`azure.eventhub.amqp.AmqpMessageBodyType.DATA`, + the body could be bytes or Iterable[bytes]. + For :class:`azure.eventhub.amqp.AmqpMessageBodyType.SEQUENCE`, + the body could be List or Iterable[List]. + For :class:`azure.eventhub.amqp.AmqpMessageBodyType.VALUE`, + the body could be any type. + + :rtype: Any + """ + class AmqpTransport(ABC): # define constants @@ -14,4 +40,5 @@ class AmqpTransport(ABC): def to_outgoing_amqp_message(self, annotated_message): """ Converts an AmqpAnnotatedMessage into an Amqp Transport Message. + :rtype: TransportMessageBase """ diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py index 328ed2ef4f7f..66397571c157 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py @@ -13,13 +13,12 @@ MessageProperties, ) -from ._base import AmqpTransport +from ._base import AmqpTransport, TransportMessageBase from ..amqp._constants import AmqpMessageBodyType -from .transport_message_base import TransportMessageBase from ..amqp._constants import AmqpMessageBodyType -class TransportMessage(TransportMessageBase, Message): +class UamqpTransportMessage(TransportMessageBase, Message): @property def body_type(self): @@ -57,6 +56,7 @@ class UamqpTransport(AmqpTransport): MessageBodyType.Sequence.value: AmqpMessageBodyType.SEQUENCE, MessageBodyType.Value.value: AmqpMessageBodyType.VALUE, } + TRANSPORT_MESSAGE = UamqpTransportMessage def to_outgoing_amqp_message(self, annotated_message): """ @@ -104,7 +104,7 @@ def to_outgoing_amqp_message(self, annotated_message): # amqp_body_type is type of AmqpMessageBodyType.VALUE amqp_body_type = MessageBodyType.Value - return Message( + return UamqpTransportMessage( body=amqp_body, body_type=amqp_body_type, header=message_header, diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/transport_message_base.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/transport_message_base.py deleted file mode 100644 index 7d9d5e661cf1..000000000000 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/transport_message_base.py +++ /dev/null @@ -1,32 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# ------------------------------------------------------------------------- -from abc import abstractmethod - -class TransportMessageBase: - """ - Abstract class that acts as a wrapper for the transport Message class. - """ - @property - @abstractmethod - def body_type(self): - """The body type of the underlying AMQP message. - - :rtype: ~azure.eventhub.amqp.AmqpMessageBodyType - """ - - @property - @abstractmethod - def body(self): - """The body of the Message. The format may vary depending on the body type: - For :class:`azure.eventhub.amqp.AmqpMessageBodyType.DATA`, - the body could be bytes or Iterable[bytes]. - For :class:`azure.eventhub.amqp.AmqpMessageBodyType.SEQUENCE`, - the body could be List or Iterable[List]. - For :class:`azure.eventhub.amqp.AmqpMessageBodyType.VALUE`, - the body could be any type. - - :rtype: Any - """ From 4415a1d0c58f40fd6307cb68cd2b1184944e6c71 Mon Sep 17 00:00:00 2001 From: Swathi Pillalamarri Date: Thu, 2 Jun 2022 16:11:53 -0700 Subject: [PATCH 04/21] remove uamqp from producer --- .../azure-eventhub/azure/eventhub/_common.py | 5 +- .../azure/eventhub/_producer.py | 82 ++++------ .../azure/eventhub/_producer_client.py | 2 +- .../azure/eventhub/_transport/_base.py | 71 +++++++- .../eventhub/_transport/_uamqp_transport.py | 152 +++++++++++++++++- .../azure-eventhub/azure/eventhub/_utils.py | 45 ++---- .../azure/eventhub/exceptions.py | 16 +- .../tests/unittest/test_event_data.py | 2 +- 8 files changed, 272 insertions(+), 103 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py index 0905fb96354c..bc17f0fbf86c 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py @@ -24,7 +24,6 @@ import six from ._utils import ( - set_message_partition_key, trace_message, utc_from_timestamp, transform_outbound_single_message, @@ -531,7 +530,7 @@ def __init__( self._partition_id = partition_id self._partition_key = partition_key - set_message_partition_key(self.message, self._partition_key) + self._amqp_transport.set_message_partition_key(self.message, self._partition_key) self._size = self.message.gather()[0].get_message_encoded_size() self._count = 0 @@ -604,7 +603,7 @@ def add(self, event_data): "The partition key of event_data does not match the partition key of this batch." ) if not outgoing_event_data.partition_key: - set_message_partition_key( + self._amqp_transport.set_message_partition_key( outgoing_event_data.message, self._partition_key ) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py index 8e2e4ec2ced2..d67f1855234f 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py @@ -18,23 +18,22 @@ TYPE_CHECKING, ) # pylint: disable=unused-import -from uamqp import types, constants, errors -from uamqp import SendClient - from azure.core.tracing import AbstractSpan -from .exceptions import _error_handler, OperationTimeoutError from ._common import EventData, EventDataBatch from ._client_base import ConsumerProducerMixin from ._utils import ( create_properties, - set_message_partition_key, trace_message, send_context_manager, transform_outbound_single_message, ) from ._constants import TIMEOUT_SYMBOL +if TYPE_CHECKING: + from uamqp import constants, SendClient + from ._transport._base import AmqpTransport + _LOGGER = logging.getLogger(__name__) if TYPE_CHECKING: @@ -42,10 +41,10 @@ from ._producer_client import EventHubProducerClient -def _set_partition_key(event_datas, partition_key): - # type: (Iterable[EventData], AnyStr) -> Iterable[EventData] +def _set_partition_key(event_datas, partition_key, amqp_transport): + # type: (Iterable[EventData], AnyStr, AmqpTransport) -> Iterable[EventData] for ed in iter(event_datas): - set_message_partition_key(ed.message, partition_key) + amqp_transport.set_message_partition_key(ed.message, partition_key) yield ed @@ -83,6 +82,8 @@ class EventHubProducer( def __init__(self, client, target, **kwargs): # type: (EventHubProducerClient, str, Any) -> None + + self._amqp_transport = kwargs.pop("amqp_transport") partition = kwargs.get("partition", None) send_timeout = kwargs.get("send_timeout", 60) keep_alive = kwargs.get("keep_alive", None) @@ -97,13 +98,11 @@ def __init__(self, client, target, **kwargs): self._target = target self._partition = partition self._timeout = send_timeout - self._idle_timeout = (idle_timeout * 1000) if idle_timeout else None + self._idle_timeout = (idle_timeout * self._amqp_transport.IDLE_TIMEOUT_FACTOR) if idle_timeout else None self._error = None self._keep_alive = keep_alive self._auto_reconnect = auto_reconnect - self._retry_policy = errors.ErrorPolicy( - max_retries=self._client._config.max_retries, on_error=_error_handler # pylint: disable=protected-access - ) + self._retry_policy = self._amqp_transport.create_retry_policy(retry_total=self._client._config.max_retries) self._reconnect_backoff = 1 self._name = "EHProducer-{}".format(uuid.uuid4()) self._unsent_events = [] # type: List[Any] @@ -114,57 +113,35 @@ def __init__(self, client, target, **kwargs): self._outcome = None # type: Optional[constants.MessageSendResult] self._condition = None # type: Optional[Exception] self._lock = threading.Lock() - self._link_properties = { - types.AMQPSymbol(TIMEOUT_SYMBOL): types.AMQPLong(int(self._timeout * 1000)) - } - self._amqp_transport = kwargs.pop("amqp_transport") + self._link_properties = self._amqp_transport.create_link_properties(TIMEOUT_SYMBOL, int(self._timeout * 1000)) def _create_handler(self, auth): # type: (JWTTokenAuth) -> None - self._handler = SendClient( - self._target, + self._handler = self._amqp_transport.create_send_client( + config=self._client._config, # pylint:disable=protected-access + target=self._target, auth=auth, - debug=self._client._config.network_tracing, # pylint:disable=protected-access - msg_timeout=self._timeout * 1000, + network_trace=self._client._config.network_tracing, # pylint:disable=protected-access idle_timeout=self._idle_timeout, - error_policy=self._retry_policy, + retry_policy=self._retry_policy, keep_alive_interval=self._keep_alive, client_name=self._name, link_properties=self._link_properties, - properties=create_properties(self._client._config.user_agent), # pylint: disable=protected-access + properties=create_properties( + self._client._config.user_agent, amqp_transport=self._amqp_transport # pylint: disable=protected-access + ), + msg_timeout=self._timeout * 1000, ) def _open_with_retry(self): # type: () -> None return self._do_retryable_operation(self._open, operation_need_param=False) - def _set_msg_timeout(self, timeout_time, last_exception): - # type: (Optional[float], Optional[Exception]) -> None - if not timeout_time: - return - remaining_time = timeout_time - time.time() - if remaining_time <= 0.0: - if last_exception: - error = last_exception - else: - error = OperationTimeoutError("Send operation timed out") - _LOGGER.info("%r send operation timed out. (%r)", self._name, error) - raise error - self._handler._msg_timeout = remaining_time * 1000 # type: ignore # pylint: disable=protected-access - def _send_event_data(self, timeout_time=None, last_exception=None): # type: (Optional[float], Optional[Exception]) -> None if self._unsent_events: self._open() - self._set_msg_timeout(timeout_time, last_exception) - self._handler.queue_message(*self._unsent_events) # type: ignore - self._handler.wait() # type: ignore - self._unsent_events = self._handler.pending_messages # type: ignore - if self._outcome != constants.MessageSendResult.Ok: - if self._outcome == constants.MessageSendResult.Timeout: - self._condition = OperationTimeoutError("Send operation timed out") - if self._condition: - raise self._condition + self._amqp_transport.send_messages(self, timeout_time, last_exception, _LOGGER) def _send_event_data_with_retry(self, timeout=None): # type: (Optional[float]) -> None @@ -191,9 +168,11 @@ def _wrap_eventdata( ): # type: (...) -> Union[EventData, EventDataBatch] if isinstance(event_data, EventData): - outgoing_event_data = transform_outbound_single_message(event_data, EventData, self._amqp_transport.to_outgoing_amqp_message) + outgoing_event_data = transform_outbound_single_message( + event_data, EventData, self._amqp_transport.to_outgoing_amqp_message + ) if partition_key: - set_message_partition_key(outgoing_event_data.message, partition_key) + self._amqp_transport.set_message_partition_key(outgoing_event_data.message, partition_key) wrapper_event_data = outgoing_event_data trace_message(wrapper_event_data, span) else: @@ -206,15 +185,16 @@ def _wrap_eventdata( raise ValueError( "The partition_key does not match the one of the EventDataBatch" ) - for event in event_data.message._body_gen: # pylint: disable=protected-access + for event in self._amqp_transport.get_batch_message_data(event_data.message): # pylint: disable=protected-access trace_message(event, span) wrapper_event_data = event_data # type:ignore else: if partition_key: - event_data = _set_partition_key(event_data, partition_key) + event_data = _set_partition_key(event_data, partition_key, self._amqp_transport) event_data = _set_trace_message(event_data, span) - wrapper_event_data = EventDataBatch._from_batch(event_data, partition_key) # type: ignore # pylint: disable=protected-access - wrapper_event_data.message.on_send_complete = self._on_outcome + wrapper_event_data = EventDataBatch._from_batch( # type: ignore # pylint: disable=protected-access + event_data, partition_key, self._amqp_transport.to_outgoing_amqp_message + ) return wrapper_event_data def send( diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py index 157e0e67c73e..3a63bddc00fc 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py @@ -136,7 +136,7 @@ def _get_max_message_size(self): self._max_message_size_on_link = ( self._producers[ # type: ignore ALL_PARTITIONS - ]._handler.message_handler._link.peer_max_message_size + ]._handler.message_handler._link.peer_max_message_size # TODO: fix to fit pyamqp or self._amqp_transport.MAX_MESSAGE_LENGTH_BYTES ) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py index b74783a89361..e5b6c13907e3 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py @@ -6,7 +6,7 @@ class TransportMessageBase: """ - Abstract class that acts as a wrapper for the transport Message class. + Abstract class that acts as a wrapper for the Message class. """ @property @abstractmethod @@ -31,10 +31,19 @@ def body(self): """ class AmqpTransport(ABC): - + """ + Abstract class that defines a set of common methods needed by producer and consumer. + """ # define constants BATCH_MESSAGE = None MAX_MESSAGE_LENGTH_BYTES = None + IDLE_TIMEOUT_FACTOR = None + + PRODUCT_SYMBOL = None + VERSION_SYMBOL = None + FRAMEWORK_SYMBOL = None + PLATFORM_SYMBOL = None + USER_AGENT_SYMBOL = None @abstractmethod def to_outgoing_amqp_message(self, annotated_message): @@ -42,3 +51,61 @@ def to_outgoing_amqp_message(self, annotated_message): Converts an AmqpAnnotatedMessage into an Amqp Transport Message. :rtype: TransportMessageBase """ + + @abstractmethod + def create_retry_policy(self, retry_total): + """ + Creates and returns the error retry policy. + :param int retry_total: Max number of retries. + """ + + @abstractmethod + def create_link_properties(self, timeout_symbol, timeout): + """ + Creates and returns the link properties. + :param bytes timeout_symbol: The timeout symbol. + :param int timeout: The timeout to set as value. + """ + + @abstractmethod + def create_send_client(self, *, config, **kwargs): + """ + Creates and returns the send client. + :param ~azure.eventhub._configuration.Configuration config: The configuration. + + :keyword str target: Required. The target. + :keyword JWTTokenAuth auth: Required. + :keyword int idle_timeout: Required. + :keyword network_trace: Required. + :keyword retry_policy: Required. + :keyword keep_alive_interval: Required. + :keyword str client_name: Required. + :keyword dict link_properties: Required. + :keyword properties: Required. + :keyword config: Optional. Client configuration. + """ + + @abstractmethod + def send_messages(self, producer, timeout_time, last_exception): + """ + Handles sending of event data messages. + :param ~azure.eventhub._producer.EventHubProducer producer: The producer with handler to send messages. + :param int timeout_time: Timeout time. + :param last_exception: Exception to raise if message timed out. Only used by uamqp transport. + """ + + @abstractmethod + def get_batch_message_data(self, batch_message): + """ + Gets the data body of the BatchMessage. + :param batch_message: BatchMessage to retrieve data body from. + """ + + @abstractmethod + def set_message_partition_key(self, message, partition_key, **kwargs): + """Set the partition key as an annotation on a uamqp message. + + :param message: The message to update. + :param str partition_key: The partition key value. + :rtype: None + """ diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py index 66397571c157..527a92f52cab 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py @@ -2,21 +2,57 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- +import time +from typing import TYPE_CHECKING, Optional + from uamqp import ( BatchMessage, constants, MessageBodyType, Message, + types, + SendClient, ) from uamqp.message import ( MessageHeader, MessageProperties, ) +from uamqp.errors import ErrorPolicy, ErrorAction from ._base import AmqpTransport, TransportMessageBase from ..amqp._constants import AmqpMessageBodyType from ..amqp._constants import AmqpMessageBodyType +from .._constants import ( + NO_RETRY_ERRORS, + PROP_PARTITION_KEY_AMQP_SYMBOL, +) +from ..exceptions import OperationTimeoutError + +if TYPE_CHECKING: + import logging + +def _error_handler(error): + """ + Called internally when an event has failed to send so we + can parse the error to determine whether we should attempt + to retry sending the event again. + Returns the action to take according to error type. + :param error: The error received in the send attempt. + :type error: Exception + :rtype: ~uamqp.errors.ErrorAction + """ + if error.condition == b"com.microsoft:server-busy": + return ErrorAction(retry=True, backoff=4) + if error.condition == b"com.microsoft:timeout": + return ErrorAction(retry=True, backoff=2) + if error.condition == b"com.microsoft:operation-cancelled": + return ErrorAction(retry=True) + if error.condition == b"com.microsoft:container-close": + return ErrorAction(retry=True, backoff=4) + if error.condition in NO_RETRY_ERRORS: + return ErrorAction(retry=False) + return ErrorAction(retry=True) class UamqpTransportMessage(TransportMessageBase, Message): @@ -46,17 +82,26 @@ def body(self): return self.get_data() class UamqpTransport(AmqpTransport): - + """ + Class which defines uamqp-based methods used by the producer and consumer. + """ # define constants BATCH_MESSAGE = BatchMessage MAX_MESSAGE_LENGTH_BYTES = constants.MAX_MESSAGE_LENGTH_BYTES - AMQP_MESSAGE_BODY_TYPE_MAP = { MessageBodyType.Data.value: AmqpMessageBodyType.DATA, MessageBodyType.Sequence.value: AmqpMessageBodyType.SEQUENCE, MessageBodyType.Value.value: AmqpMessageBodyType.VALUE, } TRANSPORT_MESSAGE = UamqpTransportMessage + IDLE_TIMEOUT_FACTOR = 1000 # pyamqp = 1 + + # define symbols + PRODUCT_SYMBOL = types.AMQPSymbol("product") + VERSION_SYMBOL = types.AMQPSymbol("version") + FRAMEWORK_SYMBOL = types.AMQPSymbol("framework") + PLATFORM_SYMBOL = types.AMQPSymbol("platform") + USER_AGENT_SYMBOL = types.AMQPSymbol("user-agent") def to_outgoing_amqp_message(self, annotated_message): """ @@ -114,3 +159,106 @@ def to_outgoing_amqp_message(self, annotated_message): delivery_annotations=annotated_message.delivery_annotations, footer=annotated_message.footer ) + + def create_retry_policy(self, retry_total): + """ + Creates the error retry policy. + :param retry_total: Max number of retries. + """ + return ErrorPolicy(max_retries=retry_total, on_error=_error_handler) + + def create_link_properties(self, timeout_symbol, timeout): + """ + Creates and returns the link properties. + :param bytes timeout_symbol: The timeout symbol. + :param int timeout: The timeout to set as value. + """ + return { + types.AMQPSymbol(timeout_symbol): types.AMQPLong(timeout) + } + + def create_send_client(self, *, config, **kwargs): # pylint:disable=unused-argument + """ + Creates and returns the uamqp SendClient. + :param ~azure.eventhub._configuration.Configuration config: The configuration. + + :keyword str target: Required. The target. + :keyword JWTTokenAuth auth: Required. + :keyword int idle_timeout: Required. + :keyword network_trace: Required. + :keyword retry_policy: Required. + :keyword keep_alive_interval: Required. + :keyword str client_name: Required. + :keyword dict link_properties: Required. + :keyword properties: Required. + """ + target = kwargs.pop("target") + retry_policy = kwargs.pop("retry_policy") + network_trace = kwargs.pop("network_trace") + + return SendClient( + target, + debug=network_trace, # pylint:disable=protected-access + error_policy=retry_policy, + **kwargs + ) + + def _set_msg_timeout(self, timeout_time, last_exception, logger): + # type: (Optional[float], Optional[Exception], logging.Logger) -> None + if not timeout_time: + return + remaining_time = timeout_time - time.time() + if remaining_time <= 0.0: + if last_exception: + error = last_exception + else: + error = OperationTimeoutError("Send operation timed out") + logger.info("%r send operation timed out. (%r)", self._name, error) + raise error + self._handler._msg_timeout = remaining_time * 1000 # type: ignore # pylint: disable=protected-access + + def send_messages(self, producer, timeout_time, last_exception, logger): + """ + Handles sending of event data messages. + :param ~azure.eventhub._producer.EventHubProducer producer: The producer with handler to send messages. + :param int timeout_time: Timeout time. + :param last_exception: Exception to raise if message timed out. Only used by uamqp transport. + """ + # pylint: disable=protected-access + producer._unsent_events[0].on_send_complete = producer._on_outcome + self._set_msg_timeout(timeout_time, last_exception, logger) + producer._handler.queue_message(*producer._unsent_events) # type: ignore + producer._handler.wait() # type: ignore + producer._unsent_events = producer._handler.pending_messages # type: ignore + if producer._outcome != constants.MessageSendResult.Ok: + if producer._outcome == constants.MessageSendResult.Timeout: + producer._condition = OperationTimeoutError("Send operation timed out") + if producer._condition: + raise producer._condition + + def get_batch_message_data(self, batch_message): + """ + Gets the data body of the BatchMessage. + :param batch_message: BatchMessage to retrieve data body from. + """ + return batch_message._body_gen # pylint:disable=protected-access + + def set_message_partition_key(self, message, partition_key, **kwargs): # pylint:disable=unused-argument + # type: (Message, Optional[Union[bytes, str]], Any) -> None + """Set the partition key as an annotation on a uamqp message. + + :param ~uamqp.Message message: The message to update. + :param str partition_key: The partition key value. + :rtype: None + """ + if partition_key: + annotations = message.annotations + if annotations is None: + annotations = {} + annotations[ + PROP_PARTITION_KEY_AMQP_SYMBOL + ] = partition_key + header = MessageHeader() + header.durable = True + message.annotations = annotations + message.header = header diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py index 41011b691db5..b193340f0383 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py @@ -14,16 +14,12 @@ import six -from uamqp import types -from uamqp.message import MessageHeader - from azure.core.settings import settings from azure.core.tracing import SpanKind, Link from .amqp import AmqpAnnotatedMessage from ._version import VERSION from ._constants import ( - PROP_PARTITION_KEY_AMQP_SYMBOL, MAX_USER_AGENT_LENGTH, USER_AGENT_PREFIX, PROP_LAST_ENQUEUED_SEQUENCE_NUMBER, @@ -33,9 +29,12 @@ PROP_TIMESTAMP, ) +# Python 3 Type Checking imports +from ._transport._base import AmqpTransport +from uamqp import types + if TYPE_CHECKING: # pylint: disable=ungrouped-imports - from uamqp import Message from azure.core.tracing import AbstractSpan from azure.core.credentials import AzureSasCredential from ._common import EventData @@ -77,8 +76,9 @@ def utc_from_timestamp(timestamp): return datetime.datetime.fromtimestamp(timestamp, tz=TZ_UTC) -def create_properties(user_agent=None): - # type: (Optional[str]) -> Dict[types.AMQPSymbol, str] +def create_properties( + user_agent: Optional[str] = None, *, amqp_transport: AmqpTransport +) -> Dict[types.AMQPSymbol, str]: """ Format the properties with which to instantiate the connection. This acts like a user agent over HTTP. @@ -86,14 +86,14 @@ def create_properties(user_agent=None): :rtype: dict """ properties = {} - properties[types.AMQPSymbol("product")] = USER_AGENT_PREFIX - properties[types.AMQPSymbol("version")] = VERSION + properties[amqp_transport.PRODUCT_SYMBOL] = USER_AGENT_PREFIX + properties[amqp_transport.VERSION_SYMBOL] = VERSION framework = "Python/{}.{}.{}".format( sys.version_info[0], sys.version_info[1], sys.version_info[2] ) - properties[types.AMQPSymbol("framework")] = framework + properties[amqp_transport.FRAMEWORK_SYMBOL] = framework platform_str = platform.platform() - properties[types.AMQPSymbol("platform")] = platform_str + properties[amqp_transport.PLATFORM_SYMBOL] = platform_str final_user_agent = "{}/{} {} ({})".format( USER_AGENT_PREFIX, VERSION, framework, platform_str @@ -108,31 +108,10 @@ def create_properties(user_agent=None): MAX_USER_AGENT_LENGTH, final_user_agent, len(final_user_agent) ) ) - properties[types.AMQPSymbol("user-agent")] = final_user_agent + properties[amqp_transport.USER_AGENT_SYMBOL] = final_user_agent return properties -def set_message_partition_key(message, partition_key): - # type: (Message, Optional[Union[bytes, str]]) -> None - """Set the partition key as an annotation on a uamqp message. - - :param ~uamqp.Message message: The message to update. - :param str partition_key: The partition key value. - :rtype: None - """ - if partition_key: - annotations = message.annotations - if annotations is None: - annotations = dict() - annotations[ - PROP_PARTITION_KEY_AMQP_SYMBOL - ] = partition_key # pylint:disable=protected-access - header = MessageHeader() - header.durable = True - message.annotations = annotations - message.header = header - - @contextmanager def send_context_manager(): span_impl_type = settings.tracing_implementation() # type: Type[AbstractSpan] diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/exceptions.py b/sdk/eventhub/azure-eventhub/azure/eventhub/exceptions.py index 6d90033502f8..3d5d7885b4bc 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/exceptions.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/exceptions.py @@ -7,11 +7,8 @@ from uamqp import errors, compat -from ._constants import NO_RETRY_ERRORS - _LOGGER = logging.getLogger(__name__) - def _error_handler(error): """ Called internally when an event has failed to send so we @@ -24,17 +21,16 @@ def _error_handler(error): :rtype: ~uamqp.errors.ErrorAction """ if error.condition == b"com.microsoft:server-busy": - return errors.ErrorAction(retry=True, backoff=4) + return ErrorAction(retry=True, backoff=4) if error.condition == b"com.microsoft:timeout": - return errors.ErrorAction(retry=True, backoff=2) + return ErrorAction(retry=True, backoff=2) if error.condition == b"com.microsoft:operation-cancelled": - return errors.ErrorAction(retry=True) + return ErrorAction(retry=True) if error.condition == b"com.microsoft:container-close": - return errors.ErrorAction(retry=True, backoff=4) + return ErrorAction(retry=True, backoff=4) if error.condition in NO_RETRY_ERRORS: - return errors.ErrorAction(retry=False) - return errors.ErrorAction(retry=True) - + return ErrorAction(retry=False) + return ErrorAction(retry=True) class EventHubError(Exception): """Represents an error occurred in the client. diff --git a/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py b/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py index 98db66d62cf1..62e504aca9c5 100644 --- a/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py +++ b/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py @@ -2,7 +2,7 @@ import pytest import uamqp from packaging import version -from azure.eventhub._transport._uamqp_transport import TransportMessage, UamqpTransport +from azure.eventhub._transport._uamqp_transport import UamqpTransportMessage as TransportMessage, UamqpTransport from azure.eventhub.amqp import AmqpAnnotatedMessage from azure.eventhub import _common From 403e85263f18a0cb3c4c8963aef87313a789c590 Mon Sep 17 00:00:00 2001 From: Swathi Pillalamarri Date: Tue, 14 Jun 2022 16:59:09 -0700 Subject: [PATCH 05/21] move uamqp from consumer to transport --- .../azure/eventhub/_consumer.py | 93 +++++++------------ .../azure/eventhub/_consumer_client.py | 8 ++ .../azure/eventhub/_producer.py | 4 +- .../azure/eventhub/_transport/__init__.py | 4 + .../azure/eventhub/_transport/_base.py | 46 ++++++++- .../eventhub/_transport/_uamqp_transport.py | 91 ++++++++++++++++-- 6 files changed, 178 insertions(+), 68 deletions(-) create mode 100644 sdk/eventhub/azure-eventhub/azure/eventhub/_transport/__init__.py diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py index 2702d84a2f91..9edaee6693d0 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py @@ -10,11 +10,6 @@ from collections import deque from typing import TYPE_CHECKING, Callable, Dict, Optional, Any -import uamqp -from uamqp import types, errors, utils -from uamqp import ReceiveClient, Source, Message - -from .exceptions import _error_handler from ._common import EventData from ._client_base import ConsumerProducerMixin from ._utils import create_properties, event_position_selector @@ -26,6 +21,7 @@ if TYPE_CHECKING: from typing import Deque + from uamqp import ReceiveClient, Message from uamqp.authentication import JWTTokenAuth from ._consumer_client import EventHubConsumerClient @@ -86,6 +82,7 @@ def __init__(self, client, source, **kwargs): self.stop = False # used by event processor self.handler_ready = False + self._amqp_transport = kwargs.pop("amqp_transport") self._on_event_received = kwargs[ "on_event_received" ] # type: Callable[[EventData], None] @@ -97,27 +94,22 @@ def __init__(self, client, source, **kwargs): self._owner_level = owner_level self._keep_alive = keep_alive self._auto_reconnect = auto_reconnect - self._retry_policy = errors.ErrorPolicy( - max_retries=self._client._config.max_retries, on_error=_error_handler # pylint:disable=protected-access - ) + self._retry_policy = self._amqp_transport.create_retry_policy(retry_total=self._client._config.max_retries) self._reconnect_backoff = 1 - self._link_properties = {} # type: Dict[types.AMQPType, types.AMQPType] + link_properties = {} # type: Dict[bytes, int] self._error = None self._timeout = 0 - self._idle_timeout = (idle_timeout * 1000) if idle_timeout else None + self._idle_timeout = (idle_timeout * self._amqp_transport.IDLE_TIMEOUT_FACTOR) if idle_timeout else None partition = self._source.split("/")[-1] self._partition = partition self._name = "EHConsumer-{}-partition{}".format(uuid.uuid4(), partition) if owner_level is not None: - self._link_properties[types.AMQPSymbol(EPOCH_SYMBOL)] = types.AMQPLong( - int(owner_level) - ) + link_properties[EPOCH_SYMBOL] = int(owner_level) link_property_timeout_ms = ( self._client._config.receive_timeout or self._timeout # pylint:disable=protected-access - ) * 1000 - self._link_properties[types.AMQPSymbol(TIMEOUT_SYMBOL)] = types.AMQPLong( - int(link_property_timeout_ms) - ) + ) * self._amqp_transport.IDLE_TIMEOUT_FACTOR + link_properties[TIMEOUT_SYMBOL] = int(link_property_timeout_ms) + self._link_properties = self._amqp_transport.create_link_properties(link_properties) self._handler = None # type: Optional[ReceiveClient] self._track_last_enqueued_event_properties = ( track_last_enqueued_event_properties @@ -128,39 +120,30 @@ def __init__(self, client, source, **kwargs): def _create_handler(self, auth): # type: (JWTTokenAuth) -> None - source = Source(self._source) - if self._offset is not None: - source.set_filter( - event_position_selector(self._offset, self._offset_inclusive) - ) - desired_capabilities = None - if self._track_last_enqueued_event_properties: - symbol_array = [types.AMQPSymbol(RECEIVER_RUNTIME_METRIC_SYMBOL)] - desired_capabilities = utils.data_factory(types.AMQPArray(symbol_array)) - - properties = create_properties( - self._client._config.user_agent # pylint:disable=protected-access + source = self._amqp_transport.create_source( + self._source, + self._offset, + event_position_selector(self._offset, self._offset_inclusive) ) - self._handler = ReceiveClient( - source, + desired_capabilities = [RECEIVER_RUNTIME_METRIC_SYMBOL] if self._track_last_enqueued_event_properties else None + + self._handler = self._amqp_transport.create_receive_client( + config=self._client._config, # pylint:disable=protected-access + source=source, auth=auth, - debug=self._client._config.network_tracing, # pylint:disable=protected-access - prefetch=self._prefetch, + network_trace=self._client._config.network_tracing, # pylint:disable=protected-access + link_credit=self._prefetch, link_properties=self._link_properties, - timeout=self._timeout, idle_timeout=self._idle_timeout, - error_policy=self._retry_policy, + retry_policy=self._retry_policy, keep_alive_interval=self._keep_alive, client_name=self._name, - receive_settle_mode=uamqp.constants.ReceiverSettleMode.ReceiveAndDelete, - auto_complete=False, - properties=properties, + properties=create_properties( + self._client._config.user_agent, amqp_transport=self._amqp_transport # pylint:disable=protected-access + ), desired_capabilities=desired_capabilities, - ) - - self._handler._streaming_receive = True # pylint:disable=protected-access - self._handler._message_received_callback = ( # pylint:disable=protected-access - self._message_received + streaming_receive=True, + message_received_callback=self._message_received, ) def _open_with_retry(self): @@ -170,11 +153,11 @@ def _open_with_retry(self): def _message_received(self, message): # type: (uamqp.Message) -> None # pylint:disable=protected-access - self._message_buffer.appendleft(message) + self._message_buffer.append(message) def _next_message_in_buffer(self): # pylint:disable=protected-access - message = self._message_buffer.pop() + message = self._message_buffer.popleft() event_data = EventData._from_message(message) self._last_received_event = event_data return event_data @@ -182,7 +165,6 @@ def _next_message_in_buffer(self): def _open(self): # type: () -> bool """Open the EventHubConsumer/EventHubProducer using the supplied connection. - """ # pylint: disable=protected-access if not self.running: @@ -190,17 +172,12 @@ def _open(self): self._handler.close() auth = self._client._create_auth() self._create_handler(auth) - self._handler.open( - connection=self._client._conn_manager.get_connection( - self._client._address.hostname, auth - ) # pylint: disable=protected-access - ) - self.handler_ready = False + self._handler.open() + while not self._handler.client_ready(): + time.sleep(0.05) + self.handler_ready = True self.running = True - if not self.handler_ready: - if self._handler.client_ready(): # type: ignore - self.handler_ready = True return self.handler_ready def receive(self, batch=False, max_batch_size=300, max_wait_time=None): @@ -214,12 +191,12 @@ def receive(self, batch=False, max_batch_size=300, max_wait_time=None): while retried_times <= max_retries: try: if self._open(): - self._handler.do_work() # type: ignore + self._handler.do_work() # type: ignore # TODO: for pyamqp, this will pass in batch. But, in the ReceiveClient._client_run, can pass (batch=self._link_credit) break except Exception as exception: # pylint: disable=broad-except if ( - isinstance(exception, uamqp.errors.LinkDetach) - and exception.condition == uamqp.constants.ErrorCodes.LinkStolen # pylint: disable=no-member + isinstance(exception, self._amqp_transport.AMQP_LINK_ERROR) + and exception.condition == self._amqp_transport.LINK_STOLEN_CONDITION # pylint: disable=no-member ): raise self._handle_exception(exception) if not self.running: # exit by close diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer_client.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer_client.py index 7d3364a7f70b..7e23ce7868a7 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer_client.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer_client.py @@ -11,6 +11,7 @@ from ._constants import ALL_PARTITIONS from ._eventprocessor.event_processor import EventProcessor from ._eventprocessor.common import LoadBalancingStrategy +from ._transport._uamqp_transport import UamqpTransport if TYPE_CHECKING: @@ -144,6 +145,12 @@ def __init__( **kwargs # type: Any ): # type: (...) -> None + self._uamqp_transport = kwargs.pop("uamqp_transport", True) + if self._uamqp_transport: + self._amqp_transport = UamqpTransport() + else: + raise NotImplementedError('pyamqp transport') + self._checkpoint_store = kwargs.pop("checkpoint_store", None) self._load_balancing_interval = kwargs.pop("load_balancing_interval", None) if self._load_balancing_interval is None: @@ -208,6 +215,7 @@ def _create_consumer( prefetch=prefetch, idle_timeout=self._idle_timeout, track_last_enqueued_event_properties=track_last_enqueued_event_properties, + amqp_transport=self._amqp_transport, ) return handler diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py index d67f1855234f..4ad6da8848b2 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py @@ -113,7 +113,7 @@ def __init__(self, client, target, **kwargs): self._outcome = None # type: Optional[constants.MessageSendResult] self._condition = None # type: Optional[Exception] self._lock = threading.Lock() - self._link_properties = self._amqp_transport.create_link_properties(TIMEOUT_SYMBOL, int(self._timeout * 1000)) + self._link_properties = self._amqp_transport.create_link_properties({TIMEOUT_SYMBOL: int(self._timeout * 1000)}) def _create_handler(self, auth): # type: (JWTTokenAuth) -> None @@ -130,7 +130,7 @@ def _create_handler(self, auth): properties=create_properties( self._client._config.user_agent, amqp_transport=self._amqp_transport # pylint: disable=protected-access ), - msg_timeout=self._timeout * 1000, + msg_timeout=self._timeout * 1000, # extra passed in to pyamqp, but not used. should be used? ) def _open_with_retry(self): diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/__init__.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/__init__.py new file mode 100644 index 000000000000..34913fb394d7 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/__init__.py @@ -0,0 +1,4 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py index e5b6c13907e3..933b9d1a69ef 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py @@ -45,6 +45,10 @@ class AmqpTransport(ABC): PLATFORM_SYMBOL = None USER_AGENT_SYMBOL = None + # errors + AMQP_LINK_ERROR = None + LINK_STOLEN_CONDITION = None + @abstractmethod def to_outgoing_amqp_message(self, annotated_message): """ @@ -65,6 +69,7 @@ def create_link_properties(self, timeout_symbol, timeout): Creates and returns the link properties. :param bytes timeout_symbol: The timeout symbol. :param int timeout: The timeout to set as value. + :rtype: dict """ @abstractmethod @@ -82,7 +87,6 @@ def create_send_client(self, *, config, **kwargs): :keyword str client_name: Required. :keyword dict link_properties: Required. :keyword properties: Required. - :keyword config: Optional. Client configuration. """ @abstractmethod @@ -109,3 +113,43 @@ def set_message_partition_key(self, message, partition_key, **kwargs): :param str partition_key: The partition key value. :rtype: None """ + + @abstractmethod + def create_source(self, source, offset, filter): + """ + Creates and returns the Source. + + :param str source: Required. + :param int offset: Required. + :param bytes filter: Required. + """ + + @abstractmethod + def create_receive_client(self, *, config, **kwargs): + """ + Creates and returns the receive client. + :param ~azure.eventhub._configuration.Configuration config: The configuration. + + :keyword Source source: Required. The source. + :keyword JWTTokenAuth auth: Required. + :keyword int idle_timeout: Required. + :keyword network_trace: Required. + :keyword retry_policy: Required. + :keyword str client_name: Required. + :keyword dict link_properties: Required. + :keyword properties: Required. + :keyword link_credit: Required. The prefetch. + :keyword keep_alive_interval: Required. Missing in pyamqp. + :keyword desired_capabilities: Required. + :keyword streaming_receive: Required. + :keyword message_received_callback: Required. + :keyword timeout: Required. + """ + + @abstractmethod + def open_receive_client(self, *, handler, client): + """ + Opens the receive client. + :param ReceiveClient handler: The receive client. + :param ~azure.eventhub.EventHubConsumerClient client: The consumer client. + """ diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py index 527a92f52cab..bf22166209b6 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- import time +from tkinter import W from typing import TYPE_CHECKING, Optional from uamqp import ( @@ -12,12 +13,15 @@ Message, types, SendClient, + ReceiveClient, + Source, + utils, ) from uamqp.message import ( MessageHeader, MessageProperties, ) -from uamqp.errors import ErrorPolicy, ErrorAction +from uamqp.errors import ErrorPolicy, ErrorAction, LinkDetach from ._base import AmqpTransport, TransportMessageBase from ..amqp._constants import AmqpMessageBodyType @@ -103,6 +107,10 @@ class UamqpTransport(AmqpTransport): PLATFORM_SYMBOL = types.AMQPSymbol("platform") USER_AGENT_SYMBOL = types.AMQPSymbol("user-agent") + # define errors and conditions + AMQP_LINK_ERROR = LinkDetach + LINK_STOLEN_CONDITION = constants.ErrorCodes.LinkStolen + def to_outgoing_amqp_message(self, annotated_message): """ Converts an AmqpAnnotatedMessage into an Amqp Transport Message. @@ -167,15 +175,12 @@ def create_retry_policy(self, retry_total): """ return ErrorPolicy(max_retries=retry_total, on_error=_error_handler) - def create_link_properties(self, timeout_symbol, timeout): + def create_link_properties(self, link_properties): """ Creates and returns the link properties. - :param bytes timeout_symbol: The timeout symbol. - :param int timeout: The timeout to set as value. + :param dict[bytes, int] link_properties: The dict of symbols and corresponding values. """ - return { - types.AMQPSymbol(timeout_symbol): types.AMQPLong(timeout) - } + return {types.AMQPSymbol(symbol): types.AMQPLong(value) for (symbol, value) in link_properties.items()} def create_send_client(self, *, config, **kwargs): # pylint:disable=unused-argument """ @@ -262,3 +267,75 @@ def set_message_partition_key(self, message, partition_key, **kwargs): # pylint header.durable = True message.annotations = annotations message.header = header + + def create_source(self, source, offset, filter): + """ + Creates and returns the Source. + + :param str source: Required. + :param int offset: Required. + :param bytes filter: Required. + """ + source = Source(source) + if offset is not None: + source.set_filter(filter) + return source + + def create_receive_client(self, *, config, **kwargs): + """ + Creates and returns the receive client. + :param ~azure.eventhub._configuration.Configuration config: The configuration. + + :keyword str source: Required. The source. + :keyword str offset: Required. + :keyword str offset_inclusive: Required. + :keyword JWTTokenAuth auth: Required. + :keyword int idle_timeout: Required. + :keyword network_trace: Required. + :keyword retry_policy: Required. + :keyword str client_name: Required. + :keyword dict link_properties: Required. + :keyword properties: Required. + :keyword link_credit: Required. The prefetch. + :keyword keep_alive_interval: Required. Missing in pyamqp. + :keyword desired_capabilities: Required. + :keyword streaming_receive: Required. + :keyword message_received_callback: Required. + :keyword timeout: Required. + """ + + source = kwargs.pop("source") + symbol_array = kwargs.pop("desired_capabilities") + desired_capabilities = utils.data_factory(types.AMQPArray(symbol_array)) if symbol_array else None + retry_policy = kwargs.pop("retry_policy") + network_trace = kwargs.pop("network_trace") + link_credit = kwargs.pop("link_credit") + streaming_receive = kwargs.pop("streaming_receive") + message_received_callback = kwargs.pop("message_received_callback") + + client = ReceiveClient( + source, + debug=network_trace, # pylint:disable=protected-access + error_policy=retry_policy, + desired_capabilities=desired_capabilities, + prefetch=link_credit, + receive_settle_mode=constants.ReceiverSettleMode.ReceiveAndDelete, + auto_complete=False, + **kwargs + ) + + client._streaming_receive = streaming_receive + client._message_received_callback = (message_received_callback) + return client + + def open_receive_client(self, *, handler, client, auth): + """ + Opens the receive client and returns ready status. + :param ReceiveClient handler: The receive client. + :param ~azure.eventhub.EventHubConsumerClient client: The consumer client. + :param auth: Auth. + :rtype: bool + """ + handler.open(connection=client._conn_manager.get_connection( + client._address.hostname, auth + )) From edac33c7b29647c212eee109196fedf18cf02183 Mon Sep 17 00:00:00 2001 From: Swathi Pillalamarri Date: Mon, 20 Jun 2022 10:59:28 -0700 Subject: [PATCH 06/21] receive with data/seq/val props added to amqp message --- .../azure-eventhub/azure/eventhub/_common.py | 10 +++-- .../azure/eventhub/_transport/_base.py | 27 ------------- .../eventhub/_transport/_uamqp_transport.py | 39 +------------------ .../azure/eventhub/amqp/_amqp_message.py | 11 +++++- .../tests/unittest/test_event_data.py | 6 +-- 5 files changed, 21 insertions(+), 72 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py index bc17f0fbf86c..911ce414626f 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py @@ -71,6 +71,9 @@ uuid.UUID, ]] +if TYPE_CHECKING: + from ._transport._base import AmqpTransport + _LOGGER = logging.getLogger(__name__) # event_data.encoded_size < 255, batch encode overhead is 5, >=256, overhead is 8 each @@ -205,8 +208,8 @@ def from_message_content(cls, content: bytes, content_type: str, **kwargs: Any) return event_data @classmethod - def _from_message(cls, message, raw_amqp_message=None): - # type: (Message, Optional[AmqpAnnotatedMessage]) -> EventData + def _from_message(cls, message, raw_amqp_message=None, amqp_transport=None): + # type: (Message, Optional[AmqpAnnotatedMessage], AmqpTransport) -> EventData # pylint:disable=protected-access """Internal use only. @@ -219,7 +222,8 @@ def _from_message(cls, message, raw_amqp_message=None): event_data = cls(body="") event_data.message = message # pylint: disable=protected-access - event_data._raw_amqp_message = raw_amqp_message if raw_amqp_message else AmqpAnnotatedMessage(message=message) + event_data._raw_amqp_message = raw_amqp_message if raw_amqp_message else \ + AmqpAnnotatedMessage(message=message, amqp_transport=amqp_transport) return event_data #def _encode_message(self): diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py index 933b9d1a69ef..4813814db1b9 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py @@ -4,32 +4,6 @@ # -------------------------------------------------------------------------------------------- from abc import ABC, abstractmethod -class TransportMessageBase: - """ - Abstract class that acts as a wrapper for the Message class. - """ - @property - @abstractmethod - def body_type(self): - """The body type of the underlying AMQP message. - - :rtype: ~azure.eventhub.amqp.AmqpMessageBodyType - """ - - @property - @abstractmethod - def body(self): - """The body of the Message. The format may vary depending on the body type: - For :class:`azure.eventhub.amqp.AmqpMessageBodyType.DATA`, - the body could be bytes or Iterable[bytes]. - For :class:`azure.eventhub.amqp.AmqpMessageBodyType.SEQUENCE`, - the body could be List or Iterable[List]. - For :class:`azure.eventhub.amqp.AmqpMessageBodyType.VALUE`, - the body could be any type. - - :rtype: Any - """ - class AmqpTransport(ABC): """ Abstract class that defines a set of common methods needed by producer and consumer. @@ -145,7 +119,6 @@ def create_receive_client(self, *, config, **kwargs): :keyword message_received_callback: Required. :keyword timeout: Required. """ - @abstractmethod def open_receive_client(self, *, handler, client): """ diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py index bf22166209b6..b49e8ab222bc 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py @@ -3,7 +3,6 @@ # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- import time -from tkinter import W from typing import TYPE_CHECKING, Optional from uamqp import ( @@ -25,9 +24,8 @@ from ._base import AmqpTransport, TransportMessageBase from ..amqp._constants import AmqpMessageBodyType -from ..amqp._constants import AmqpMessageBodyType from .._constants import ( - NO_RETRY_ERRORS, + NO_RETRY_ERRORS, PROP_PARTITION_KEY_AMQP_SYMBOL, ) from ..exceptions import OperationTimeoutError @@ -58,33 +56,6 @@ def _error_handler(error): return ErrorAction(retry=False) return ErrorAction(retry=True) -class UamqpTransportMessage(TransportMessageBase, Message): - - @property - def body_type(self): - # type: () -> AmqpMessageBodyType - """The body type of the underlying AMQP message. - - :rtype: ~azure.eventhub.amqp.AmqpMessageBodyType - """ - return UamqpTransport.AMQP_MESSAGE_BODY_TYPE_MAP.get( - self._body.type, AmqpMessageBodyType.VALUE # pylint: disable=protected-access - ) - - @property - def body(self): - """The body of the Message. The format may vary depending on the body type: - For :class:`azure.eventhub.amqp.AmqpMessageBodyType.DATA`, - the body could be bytes or Iterable[bytes]. - For :class:`azure.eventhub.amqp.AmqpMessageBodyType.SEQUENCE`, - the body could be List or Iterable[List]. - For :class:`azure.eventhub.amqp.AmqpMessageBodyType.VALUE`, - the body could be any type. - - :rtype: Any - """ - return self.get_data() - class UamqpTransport(AmqpTransport): """ Class which defines uamqp-based methods used by the producer and consumer. @@ -92,12 +63,6 @@ class UamqpTransport(AmqpTransport): # define constants BATCH_MESSAGE = BatchMessage MAX_MESSAGE_LENGTH_BYTES = constants.MAX_MESSAGE_LENGTH_BYTES - AMQP_MESSAGE_BODY_TYPE_MAP = { - MessageBodyType.Data.value: AmqpMessageBodyType.DATA, - MessageBodyType.Sequence.value: AmqpMessageBodyType.SEQUENCE, - MessageBodyType.Value.value: AmqpMessageBodyType.VALUE, - } - TRANSPORT_MESSAGE = UamqpTransportMessage IDLE_TIMEOUT_FACTOR = 1000 # pyamqp = 1 # define symbols @@ -157,7 +122,7 @@ def to_outgoing_amqp_message(self, annotated_message): # amqp_body_type is type of AmqpMessageBodyType.VALUE amqp_body_type = MessageBodyType.Value - return UamqpTransportMessage( + return Message( body=amqp_body, body_type=amqp_body_type, header=message_header, diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py b/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py index a67d23fd0ed2..4468c8b82cc9 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py @@ -52,8 +52,6 @@ def __init__(self, **kwargs): message = kwargs.pop("message", None) if message: self._from_amqp_message(message) - self._body = message.body - self._body_type = message.body_type return # manually constructed AMQPAnnotatedMessage @@ -172,6 +170,15 @@ def _from_amqp_message(self, message): self._annotations = message.annotations if message.annotations else {} self._delivery_annotations = message.delivery_annotations if message.delivery_annotations else {} self._application_properties = message.application_properties if message.application_properties else {} + if message.data: + self._body = message.data + self._body_type = AmqpMessageBodyType.DATA + elif message.sequence: + self._body = message.sequence + self._body_type = AmqpMessageBodyType.SEQUENCE + else: + self._body = message.value + self._body_type = AmqpMessageBodyType.VALUE @property def body(self): diff --git a/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py b/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py index 62e504aca9c5..b486737c7230 100644 --- a/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py +++ b/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py @@ -2,7 +2,7 @@ import pytest import uamqp from packaging import version -from azure.eventhub._transport._uamqp_transport import UamqpTransportMessage as TransportMessage, UamqpTransport +from azure.eventhub._transport._uamqp_transport import UamqpTransport from azure.eventhub.amqp import AmqpAnnotatedMessage from azure.eventhub import _common @@ -71,7 +71,7 @@ def test_sys_properties(): properties.group_id = "group_id" properties.group_sequence = 1 properties.reply_to_group_id = "reply_to_group_id" - message = TransportMessage(properties=properties) + message = uamqp.message.Message(properties=properties) message.annotations = {_common.PROP_OFFSET: "@latest"} ed = EventData._from_message(message) # type: EventData @@ -109,7 +109,7 @@ def test_event_data_batch(): batch.add(EventData("A")) def test_event_data_from_message(): - message = TransportMessage('A') + message = uamqp.message.Message('A') event = EventData._from_message(message) assert event.content_type is None assert event.correlation_id is None From 4f6deddc826d3e189ce8e02b4bf7f30d8c675625 Mon Sep 17 00:00:00 2001 From: Swathi Pillalamarri Date: Tue, 21 Jun 2022 16:55:34 -0700 Subject: [PATCH 07/21] move uamqp credentials to transport --- .../azure-eventhub/azure/eventhub/__init__.py | 3 +- .../azure/eventhub/_client_base.py | 83 ++++--------------- .../azure-eventhub/azure/eventhub/_common.py | 10 +-- .../azure/eventhub/_consumer_client.py | 6 -- .../azure/eventhub/_producer_client.py | 8 -- .../azure/eventhub/_transport/_base.py | 5 +- .../eventhub/_transport/_uamqp_transport.py | 75 ++++++++++++++++- 7 files changed, 96 insertions(+), 94 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/__init__.py b/sdk/eventhub/azure-eventhub/azure/eventhub/__init__.py index e79284f5ae6b..670101d117ac 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/__init__.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/__init__.py @@ -10,7 +10,8 @@ from ._producer_client import EventHubProducerClient from ._consumer_client import EventHubConsumerClient -from ._client_base import EventHubSharedKeyCredential +# TODO in pyamqp: from ._client_base import EventHubSharedKeyCredential +from ._transport._uamqp_transport import EventHubSharedKeyCredential from ._eventprocessor.checkpoint_store import CheckpointStore from ._eventprocessor.common import CloseReason, LoadBalancingStrategy from ._eventprocessor.partition_context import PartitionContext diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py index 4ca87c3856a2..3ace9bd0cb84 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py @@ -10,15 +10,10 @@ import functools import collections from typing import Any, Dict, Tuple, List, Optional, TYPE_CHECKING, cast, Union -from datetime import timedelta -try: - from urlparse import urlparse - from urllib import quote_plus # type: ignore -except ImportError: - from urllib.parse import urlparse, quote_plus +from urllib.parse import urlparse -from uamqp import AMQPClient, Message, authentication, constants, errors, compat, utils +from uamqp import AMQPClient, Message, authentication, constants, errors, compat import six from azure.core.credentials import ( AccessToken, @@ -29,6 +24,7 @@ from azure.core.pipeline.policies import RetryMode +from ._transport._uamqp_transport import UamqpTransport, EventhubAzureNamedKeyTokenCredential from .exceptions import _handle_exception, ClientClosedError, ConnectError from ._configuration import Configuration from ._utils import utc_from_timestamp, parse_sas_credential @@ -132,24 +128,6 @@ def _parse_conn_str(conn_str, **kwargs): ) -def _generate_sas_token(uri, policy, key, expiry=None): - # type: (str, str, str, Optional[timedelta]) -> AccessToken - """Create a shared access signature token as a string literal. - :returns: SAS token as string literal. - :rtype: str - """ - if not expiry: - expiry = timedelta(hours=1) # Default to 1 hour. - - abs_expiry = int(time.time()) + expiry.seconds - encoded_uri = quote_plus(uri).encode("utf-8") # pylint: disable=no-member - encoded_policy = quote_plus(policy).encode("utf-8") # pylint: disable=no-member - encoded_key = key.encode("utf-8") - - token = utils.create_sas_token(encoded_policy, encoded_key, encoded_uri, expiry) - return AccessToken(token=token, expires_on=abs_expiry) - - def _build_uri(address, entity): # type: (str, Optional[str]) -> str parsed = urlparse(address) @@ -169,46 +147,6 @@ def _get_backoff_time(retry_mode, backoff_factor, backoff_max, retried_times): return min(backoff_max, backoff_value) -class EventHubSharedKeyCredential(object): - """The shared access key credential used for authentication. - - :param str policy: The name of the shared access policy. - :param str key: The shared access key. - """ - - def __init__(self, policy, key): - # type: (str, str) -> None - self.policy = policy - self.key = key - self.token_type = b"servicebus.windows.net:sastoken" - - def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument - # type: (str, Any) -> AccessToken - if not scopes: - raise ValueError("No token scope provided.") - return _generate_sas_token(scopes[0], self.policy, self.key) - - -class EventhubAzureNamedKeyTokenCredential(object): - """The named key credential used for authentication. - - :param credential: The AzureNamedKeyCredential that should be used. - :type credential: ~azure.core.credentials.AzureNamedKeyCredential - """ - - def __init__(self, azure_named_key_credential): - # type: (AzureNamedKeyCredential) -> None - self._credential = azure_named_key_credential - self.token_type = b"servicebus.windows.net:sastoken" - - def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument - # type: (str, Any) -> AccessToken - if not scopes: - raise ValueError("No token scope provided.") - name, key = self._credential.named_key - return _generate_sas_token(scopes[0], name, key) - - class EventHubSASTokenCredential(object): """The shared access token credential used for authentication. @@ -264,6 +202,7 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument if TYPE_CHECKING: from azure.core.credentials import TokenCredential + from ._transport._uamqp_transport import EventHubSharedKeyCredential # TODO: update when pyamqp added CredentialTypes = Union[ AzureSasCredential, AzureNamedKeyCredential, @@ -275,6 +214,12 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument class ClientBase(object): # pylint:disable=too-many-instance-attributes def __init__(self, fully_qualified_namespace, eventhub_name, credential, **kwargs): # type: (str, str, CredentialTypes, Any) -> None + self._uamqp_transport = kwargs.pop("uamqp_transport", True) + if self._uamqp_transport: + self._amqp_transport = UamqpTransport() + else: + raise NotImplementedError('pyamqp transport') + self.eventhub_name = eventhub_name if not eventhub_name: raise ValueError("The eventhub name can not be None or empty.") @@ -284,7 +229,10 @@ def __init__(self, fully_qualified_namespace, eventhub_name, credential, **kwarg if isinstance(credential, AzureSasCredential): self._credential = EventhubAzureSasTokenCredential(credential) elif isinstance(credential, AzureNamedKeyCredential): - self._credential = EventhubAzureNamedKeyTokenCredential(credential) # type: ignore + if self._uamqp_transport: + self._credential = UamqpTransport.create_named_key_token_credential(credential) # type: ignore + else: + raise NotImplementedError('pyamqp named key token credential') else: self._credential = credential # type: ignore self._keep_alive = kwargs.get("keep_alive", 30) @@ -309,7 +257,8 @@ def _from_connection_string(conn_str, **kwargs): if token and token_expiry: kwargs["credential"] = EventHubSASTokenCredential(token, token_expiry) elif policy and key: - kwargs["credential"] = EventHubSharedKeyCredential(policy, key) + # TODO: pyamqp by default here, else uamqp + kwargs["credential"] = UamqpTransport.create_shared_key_credential(policy, key) return kwargs def _create_auth(self): diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py index 911ce414626f..bc17f0fbf86c 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py @@ -71,9 +71,6 @@ uuid.UUID, ]] -if TYPE_CHECKING: - from ._transport._base import AmqpTransport - _LOGGER = logging.getLogger(__name__) # event_data.encoded_size < 255, batch encode overhead is 5, >=256, overhead is 8 each @@ -208,8 +205,8 @@ def from_message_content(cls, content: bytes, content_type: str, **kwargs: Any) return event_data @classmethod - def _from_message(cls, message, raw_amqp_message=None, amqp_transport=None): - # type: (Message, Optional[AmqpAnnotatedMessage], AmqpTransport) -> EventData + def _from_message(cls, message, raw_amqp_message=None): + # type: (Message, Optional[AmqpAnnotatedMessage]) -> EventData # pylint:disable=protected-access """Internal use only. @@ -222,8 +219,7 @@ def _from_message(cls, message, raw_amqp_message=None, amqp_transport=None): event_data = cls(body="") event_data.message = message # pylint: disable=protected-access - event_data._raw_amqp_message = raw_amqp_message if raw_amqp_message else \ - AmqpAnnotatedMessage(message=message, amqp_transport=amqp_transport) + event_data._raw_amqp_message = raw_amqp_message if raw_amqp_message else AmqpAnnotatedMessage(message=message) return event_data #def _encode_message(self): diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer_client.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer_client.py index 7e23ce7868a7..7897e680525e 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer_client.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer_client.py @@ -11,7 +11,6 @@ from ._constants import ALL_PARTITIONS from ._eventprocessor.event_processor import EventProcessor from ._eventprocessor.common import LoadBalancingStrategy -from ._transport._uamqp_transport import UamqpTransport if TYPE_CHECKING: @@ -145,11 +144,6 @@ def __init__( **kwargs # type: Any ): # type: (...) -> None - self._uamqp_transport = kwargs.pop("uamqp_transport", True) - if self._uamqp_transport: - self._amqp_transport = UamqpTransport() - else: - raise NotImplementedError('pyamqp transport') self._checkpoint_store = kwargs.pop("checkpoint_store", None) self._load_balancing_interval = kwargs.pop("load_balancing_interval", None) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py index 3a63bddc00fc..7c53936a0505 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py @@ -13,7 +13,6 @@ from ._producer import EventHubProducer from ._constants import ALL_PARTITIONS from ._common import EventDataBatch, EventData -from ._transport._uamqp_transport import UamqpTransport if TYPE_CHECKING: from ._client_base import CredentialTypes @@ -91,18 +90,11 @@ def __init__( **kwargs # type: Any ): # type:(...) -> None - self._uamqp_transport = kwargs.pop("uamqp_transport", True) - if self._uamqp_transport: - self._amqp_transport = UamqpTransport() - else: - raise NotImplementedError('pyamqp transport') - super(EventHubProducerClient, self).__init__( fully_qualified_namespace=fully_qualified_namespace, eventhub_name=eventhub_name, credential=credential, network_tracing=kwargs.get("logging_enable"), - amqp_transport=self._amqp_transport, **kwargs ) self._producers = { diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py index 4813814db1b9..274319f421ef 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py @@ -38,11 +38,10 @@ def create_retry_policy(self, retry_total): """ @abstractmethod - def create_link_properties(self, timeout_symbol, timeout): + def create_link_properties(self, link_properties): """ Creates and returns the link properties. - :param bytes timeout_symbol: The timeout symbol. - :param int timeout: The timeout to set as value. + :param dict[bytes, int] link_properties: The dict of symbols and corresponding values. :rtype: dict """ diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py index b49e8ab222bc..96682dc4572a 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py @@ -3,7 +3,10 @@ # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- import time +from datetime import timedelta from typing import TYPE_CHECKING, Optional +from urllib.parse import urlparse, quote_plus +from azure.core.credentials import AccessToken from uamqp import ( BatchMessage, @@ -22,7 +25,7 @@ ) from uamqp.errors import ErrorPolicy, ErrorAction, LinkDetach -from ._base import AmqpTransport, TransportMessageBase +from ._base import AmqpTransport from ..amqp._constants import AmqpMessageBodyType from .._constants import ( NO_RETRY_ERRORS, @@ -56,6 +59,65 @@ def _error_handler(error): return ErrorAction(retry=False) return ErrorAction(retry=True) + +def _generate_sas_token(uri, policy, key, expiry=None): + # type: (str, str, str, Optional[timedelta]) -> AccessToken + """Create a shared access signature token as a string literal. + :returns: SAS token as string literal. + :rtype: str + """ + if not expiry: + expiry = timedelta(hours=1) # Default to 1 hour. + + abs_expiry = int(time.time()) + expiry.seconds + encoded_uri = quote_plus(uri).encode("utf-8") # pylint: disable=no-member + encoded_policy = quote_plus(policy).encode("utf-8") # pylint: disable=no-member + encoded_key = key.encode("utf-8") + + token = utils.create_sas_token(encoded_policy, encoded_key, encoded_uri, expiry) + return AccessToken(token=token, expires_on=abs_expiry) + + +class EventHubSharedKeyCredential(object): + """The shared access key credential used for authentication. + + :param str policy: The name of the shared access policy. + :param str key: The shared access key. + """ + + def __init__(self, policy, key): + # type: (str, str) -> None + self.policy = policy + self.key = key + self.token_type = b"servicebus.windows.net:sastoken" + + def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument + # type: (str, Any) -> AccessToken + if not scopes: + raise ValueError("No token scope provided.") + return _generate_sas_token(scopes[0], self.policy, self.key) + + +class EventhubAzureNamedKeyTokenCredential(object): + """The named key credential used for authentication. + + :param credential: The AzureNamedKeyCredential that should be used. + :type credential: ~azure.core.credentials.AzureNamedKeyCredential + """ + + def __init__(self, azure_named_key_credential): + # type: (AzureNamedKeyCredential) -> None + self._credential = azure_named_key_credential + self.token_type = b"servicebus.windows.net:sastoken" + + def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument + # type: (str, Any) -> AccessToken + if not scopes: + raise ValueError("No token scope provided.") + name, key = self._credential.named_key + return _generate_sas_token(scopes[0], name, key) + + class UamqpTransport(AmqpTransport): """ Class which defines uamqp-based methods used by the producer and consumer. @@ -78,7 +140,7 @@ class UamqpTransport(AmqpTransport): def to_outgoing_amqp_message(self, annotated_message): """ - Converts an AmqpAnnotatedMessage into an Amqp Transport Message. + Converts an AmqpAnnotatedMessage into an Amqp Message. """ message_header = None if annotated_message.header: @@ -133,6 +195,14 @@ def to_outgoing_amqp_message(self, annotated_message): footer=annotated_message.footer ) + @classmethod + def create_named_key_token_credential(cls, credential): + return EventhubAzureNamedKeyTokenCredential(credential) + + @classmethod + def create_shared_key_credential(cls, policy, key): + return EventHubSharedKeyCredential(policy, key) + def create_retry_policy(self, retry_total): """ Creates the error retry policy. @@ -144,6 +214,7 @@ def create_link_properties(self, link_properties): """ Creates and returns the link properties. :param dict[bytes, int] link_properties: The dict of symbols and corresponding values. + :rtype: dict """ return {types.AMQPSymbol(symbol): types.AMQPLong(value) for (symbol, value) in link_properties.items()} From fb53ae2f685173d85439e450cb8f3177bd8ff946 Mon Sep 17 00:00:00 2001 From: Swathi Pillalamarri Date: Thu, 23 Jun 2022 16:09:52 -0700 Subject: [PATCH 08/21] move uamqp bits to transport in client base --- .../azure/eventhub/_client_base.py | 95 ++++++-------- .../azure-eventhub/azure/eventhub/_common.py | 2 +- .../azure/eventhub/_consumer.py | 6 +- .../azure/eventhub/_producer_client.py | 2 +- .../azure/eventhub/_transport/_base.py | 70 ++++++++++- .../eventhub/_transport/_uamqp_transport.py | 118 ++++++++++++++++-- .../azure-eventhub/azure/eventhub/_utils.py | 2 +- .../azure/eventhub/exceptions.py | 23 ---- 8 files changed, 216 insertions(+), 102 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py index 3ace9bd0cb84..21442b4d344d 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py @@ -13,7 +13,7 @@ from urllib.parse import urlparse -from uamqp import AMQPClient, Message, authentication, constants, errors, compat +from uamqp import Message, authentication, constants, errors, compat import six from azure.core.credentials import ( AccessToken, @@ -237,9 +237,6 @@ def __init__(self, fully_qualified_namespace, eventhub_name, credential, **kwarg self._credential = credential # type: ignore self._keep_alive = kwargs.get("keep_alive", 30) self._auto_reconnect = kwargs.get("auto_reconnect", True) - self._mgmt_target = "amqps://{}/{}".format( - self._address.hostname, self.eventhub_name - ) self._auth_uri = "sb://{}{}".format(self._address.hostname, self._address.path) self._config = Configuration(**kwargs) self._debug = self._config.network_tracing @@ -273,32 +270,19 @@ def _create_auth(self): except AttributeError: token_type = b"jwt" if token_type == b"servicebus.windows.net:sastoken": - auth = authentication.JWTTokenAuth( - self._auth_uri, + return self._amqp_transport.create_token_auth( self._auth_uri, functools.partial(self._credential.get_token, self._auth_uri), token_type=token_type, - timeout=self._config.auth_timeout, - http_proxy=self._config.http_proxy, - transport_type=self._config.transport_type, - custom_endpoint_hostname=self._config.custom_endpoint_hostname, - port=self._config.connection_port, - verify=self._config.connection_verify, + config=self._config, + update_token=True # TODO: discarded by pyamqp transport ) - auth.update_token() - return auth - return authentication.JWTTokenAuth( - self._auth_uri, + return self._amqp_transport.create_token_auth( self._auth_uri, functools.partial(self._credential.get_token, JWT_TOKEN_SCOPE), token_type=token_type, - timeout=self._config.auth_timeout, - http_proxy=self._config.http_proxy, - transport_type=self._config.transport_type, - custom_endpoint_hostname=self._config.custom_endpoint_hostname, - port=self._config.connection_port, - verify=self._config.connection_verify, - refresh_window=300, + config=self._config, + update_token=False ) def _close_connection(self): @@ -340,19 +324,19 @@ def _management_request(self, mgmt_msg, op_type): last_exception = None while retried_times <= self._config.max_retries: mgmt_auth = self._create_auth() - mgmt_client = AMQPClient( - self._mgmt_target, auth=mgmt_auth, debug=self._config.network_tracing + mgmt_client = self._amqp_transport.create_mgmt_client( + self._address, mgmt_auth=mgmt_auth, config=self._config ) try: - conn = self._conn_manager.get_connection( # pylint:disable=assignment-from-none - self._address.hostname, mgmt_auth - ) - mgmt_client.open(connection=conn) - mgmt_msg.application_properties["security_token"] = mgmt_auth.token - response = mgmt_client.mgmt_request( + mgmt_client.open() + while not mgmt_client.client_ready(): + time.sleep(0.05) + mgmt_msg.application_properties["security_token"] = self._amqp_transport.get_updated_token(mgmt_auth) + response = self._amqp_transport.mgmt_client_request( + mgmt_client, mgmt_msg, - constants.READ_OPERATION, - op_type=op_type, + operation=constants.READ_OPERATION, + operation_type=op_type, status_code_field=MGMT_STATUS_CODE, description_fields=MGMT_STATUS_DESC, ) @@ -365,21 +349,18 @@ def _management_request(self, mgmt_msg, op_type): if status_code < 400: return response if status_code in [401]: - raise errors.AuthenticationException( - "Management authentication failed. Status code: {}, Description: {!r}".format( - status_code, description - ) - ) - if status_code in [404]: - raise ConnectError( - "Management connection failed. Status code: {}, Description: {!r}".format( - status_code, description - ) + raise self._amqp_transport.get_error( + self._amqp_transport.AUTH_EXCEPTION, + f"Management authentication failed. Status code: {status_code}, Description: {description!r}" ) - raise errors.AMQPConnectionError( - "Management request error. Status code: {}, Description: {!r}".format( - status_code, description + if status_code in [404]: # TODO: make sure the error surfaced is the same across pyamqp and uamqp + return self._amqp_transport.get_error( + self._amqp_transport.CONNECTION_ERROR, + f"Management connection failed. Status code: {status_code}, Description: {description!r}" ) + return self._amqp_transport.get_error( + self._amqp_transport.AMQP_CONNECTION_ERROR, + f"Management request error. Status code: {status_code}, Description: {description!r}" ) except Exception as exception: # pylint: disable=broad-except last_exception = _handle_exception(exception, self) @@ -390,7 +371,7 @@ def _management_request(self, mgmt_msg, op_type): if retried_times > self._config.max_retries: _LOGGER.info( "%r returns an exception %r", self._container_id, last_exception - ) + ) raise last_exception finally: mgmt_client.close() @@ -406,7 +387,7 @@ def _get_eventhub_properties(self): mgmt_msg = Message(application_properties={"name": self.eventhub_name}) response = self._management_request(mgmt_msg, op_type=MGMT_OPERATION) output = {} - eh_info = response.get_data() # type: Dict[bytes, Any] + eh_info = response.value # type: Dict[bytes, Any] if eh_info: output["eventhub_name"] = eh_info[b"name"].decode("utf-8") output["created_at"] = utc_from_timestamp( @@ -415,7 +396,7 @@ def _get_eventhub_properties(self): output["partition_ids"] = [ p.decode("utf-8") for p in eh_info[b"partition_ids"] ] - return output + return output # TODO: pyamqp - might need to be indented? def _get_partition_ids(self): # type:() -> List[str] @@ -430,7 +411,7 @@ def _get_partition_properties(self, partition_id): } ) response = self._management_request(mgmt_msg, op_type=MGMT_PARTITION_OPERATION) - partition_info = response.get_data() # type: Dict[bytes, Any] + partition_info = response.value # type: Dict[bytes, Any] output = {} if partition_info: output["eventhub_name"] = partition_info[b"name"].decode("utf-8") @@ -481,17 +462,13 @@ def _open(self): self._handler.close() auth = self._client._create_auth() self._create_handler(auth) - self._handler.open( - connection=self._client._conn_manager.get_connection( - self._client._address.hostname, auth - ) # pylint: disable=protected-access - ) + self._handler.open() while not self._handler.client_ready(): time.sleep(0.05) self._max_message_size_on_link = ( - self._handler.message_handler._link.peer_max_message_size + self._amqp_transport.get_link_max_message_size(self._handler) or constants.MAX_MESSAGE_LENGTH_BYTES - ) # pylint: disable=protected-access + ) self.running = True def _close_handler(self): @@ -504,8 +481,8 @@ def _close_connection(self): self._client._conn_manager.reset_connection_if_broken() # pylint: disable=protected-access def _handle_exception(self, exception): - if not self.running and isinstance(exception, compat.TimeoutException): - exception = errors.AuthenticationException("Authorization timeout.") + if not self.running and isinstance(exception, self._amqp_transport.TIMEOUT_EXCEPTION): + exception = self._amqp_transport.get_error("Authorization timeout.") return _handle_exception(exception, self) def _do_retryable_operation(self, operation, timeout=None, **kwargs): diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py index bc17f0fbf86c..98e3c8a7f23c 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py @@ -525,7 +525,7 @@ def __init__( "partition_key to only be string type, they might fail to parse the non-string value." ) - self.max_size_in_bytes = max_size_in_bytes or self._amqp_transport.MAX_MESSAGE_LENGTH_BYTES + self.max_size_in_bytes = max_size_in_bytes or self._amqp_transport.MAX_FRAME_SIZE_BYTES self.message = self._amqp_transport.BATCH_MESSAGE(data=[]) self._partition_id = partition_id self._partition_key = partition_key diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py index 9edaee6693d0..bbfc561d390c 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py @@ -172,9 +172,9 @@ def _open(self): self._handler.close() auth = self._client._create_auth() self._create_handler(auth) - self._handler.open() - while not self._handler.client_ready(): - time.sleep(0.05) + self._handler.open() # TODO: uamqp handler is not using the passed in connection anyway + while not self._handler.client_ready(): + time.sleep(0.05) self.handler_ready = True self.running = True diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py index 7c53936a0505..21a7a2d471cf 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py @@ -129,7 +129,7 @@ def _get_max_message_size(self): self._producers[ # type: ignore ALL_PARTITIONS ]._handler.message_handler._link.peer_max_message_size # TODO: fix to fit pyamqp - or self._amqp_transport.MAX_MESSAGE_LENGTH_BYTES + or self._amqp_transport.MAX_FRAME_SIZE_BYTES ) def _start_producer(self, partition_id, send_timeout): diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py index 274319f421ef..a71250a5d225 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py @@ -10,7 +10,7 @@ class AmqpTransport(ABC): """ # define constants BATCH_MESSAGE = None - MAX_MESSAGE_LENGTH_BYTES = None + MAX_FRAME_SIZE_BYTES = None IDLE_TIMEOUT_FACTOR = None PRODUCT_SYMBOL = None @@ -22,6 +22,9 @@ class AmqpTransport(ABC): # errors AMQP_LINK_ERROR = None LINK_STOLEN_CONDITION = None + MGMT_AUTH_EXCEPTION = None + CONNECTION_ERROR = None + AMQP_CONNECTION_ERROR = None @abstractmethod def to_outgoing_amqp_message(self, annotated_message): @@ -63,12 +66,13 @@ def create_send_client(self, *, config, **kwargs): """ @abstractmethod - def send_messages(self, producer, timeout_time, last_exception): + def send_messages(self, producer, timeout_time, last_exception, logger): """ Handles sending of event data messages. :param ~azure.eventhub._producer.EventHubProducer producer: The producer with handler to send messages. :param int timeout_time: Timeout time. :param last_exception: Exception to raise if message timed out. Only used by uamqp transport. + :param logger: Logger. """ @abstractmethod @@ -96,7 +100,7 @@ def create_source(self, source, offset, filter): :param int offset: Required. :param bytes filter: Required. """ - + @abstractmethod def create_receive_client(self, *, config, **kwargs): """ @@ -119,9 +123,67 @@ def create_receive_client(self, *, config, **kwargs): :keyword timeout: Required. """ @abstractmethod - def open_receive_client(self, *, handler, client): + def open_receive_client(self, *, handler, client, auth): """ Opens the receive client. :param ReceiveClient handler: The receive client. :param ~azure.eventhub.EventHubConsumerClient client: The consumer client. """ + + @abstractmethod + def create_token_auth(self, auth_uri, get_token, token_type, config, **kwargs): + """ + Creates the JWTTokenAuth. + :param str auth_uri: The auth uri to pass to JWTTokenAuth. + :param get_token: The callback function used for getting and refreshing + tokens. It should return a valid jwt token each time it is called. + :param bytes token_type: Token type. + :param ~azure.eventhub._configuration.Configuration config: EH config. + + :keyword bool update_token: Whether to update token. If not updating token, + then pass 300 to refresh_window. Only used by uamqp. + """ + + @abstractmethod + def create_mgmt_client(self, address, mgmt_auth, config): + """ + Creates and returns the mgmt AMQP client. + :param _Address address: Required. The Address. + :param JWTTokenAuth mgmt_auth: Auth for client. + :param ~azure.eventhub._configuration.Configuration config: The configuration. + """ + + @abstractmethod + def get_updated_token(self, mgmt_auth): + """ + Return updated auth token. + :param mgmt_auth: Auth. + """ + + @abstractmethod + def mgmt_client_request(self, mgmt_client, mgmt_msg, **kwargs): + """ + Send mgmt request. + :param AMQP Client mgmt_client: Client to send request with. + :param str mgmt_msg: Message. + :keyword bytes operation: Operation. + :keyword operation_type: Op type. + :keyword status_code_field: mgmt status code. + :keyword description_fields: mgmt status desc. + """ + + @abstractmethod + def get_error(self, error, message, *, condition=None): + """ + Gets error and passes in error message, and, if applicable, condition. + :param error: The error to raise. + :param str message: Error message. + :param condition: Optional error condition. Will not be used by uamqp. + """ + + @abstractmethod + def get_link_max_message_size(self, handler): + """ + Returns max peer message size. + :param AMQPClient handler: Client to get remote max message size on link from. + """ diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py index 96682dc4572a..b461620c94dd 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py @@ -2,9 +2,10 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- +from lib2to3.pgen2 import token import time from datetime import timedelta -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Union, Any from urllib.parse import urlparse, quote_plus from azure.core.credentials import AccessToken @@ -18,12 +19,15 @@ ReceiveClient, Source, utils, + authentication, + AMQPClient, + compat ) from uamqp.message import ( MessageHeader, MessageProperties, ) -from uamqp.errors import ErrorPolicy, ErrorAction, LinkDetach +from uamqp.errors import ErrorPolicy, ErrorAction, LinkDetach, AuthenticationException, AMQPConnectionError from ._base import AmqpTransport from ..amqp._constants import AmqpMessageBodyType @@ -31,10 +35,11 @@ NO_RETRY_ERRORS, PROP_PARTITION_KEY_AMQP_SYMBOL, ) -from ..exceptions import OperationTimeoutError +from ..exceptions import ConnectError, OperationTimeoutError if TYPE_CHECKING: import logging + from azure.core.credentials import AzureNamedKeyCredential def _error_handler(error): """ @@ -124,7 +129,7 @@ class UamqpTransport(AmqpTransport): """ # define constants BATCH_MESSAGE = BatchMessage - MAX_MESSAGE_LENGTH_BYTES = constants.MAX_MESSAGE_LENGTH_BYTES + MAX_FRAME_SIZE_BYTES = constants.MAX_MESSAGE_LENGTH_BYTES IDLE_TIMEOUT_FACTOR = 1000 # pyamqp = 1 # define symbols @@ -137,6 +142,10 @@ class UamqpTransport(AmqpTransport): # define errors and conditions AMQP_LINK_ERROR = LinkDetach LINK_STOLEN_CONDITION = constants.ErrorCodes.LinkStolen + AUTH_EXCEPTION = AuthenticationException + CONNECTION_ERROR = ConnectError + AMQP_CONNECTION_ERROR = AMQPConnectionError + TIMEOUT_EXCEPTION = compat.TimeoutException def to_outgoing_amqp_message(self, annotated_message): """ @@ -244,8 +253,7 @@ def create_send_client(self, *, config, **kwargs): # pylint:disable=unused-argum **kwargs ) - def _set_msg_timeout(self, timeout_time, last_exception, logger): - # type: (Optional[float], Optional[Exception], logging.Logger) -> None + def _set_msg_timeout(self, producer, timeout_time, last_exception, logger): if not timeout_time: return remaining_time = timeout_time - time.time() @@ -254,9 +262,9 @@ def _set_msg_timeout(self, timeout_time, last_exception, logger): error = last_exception else: error = OperationTimeoutError("Send operation timed out") - logger.info("%r send operation timed out. (%r)", self._name, error) + logger.info("%r send operation timed out. (%r)", producer._name, error) raise error - self._handler._msg_timeout = remaining_time * 1000 # type: ignore # pylint: disable=protected-access + producer._handler._msg_timeout = remaining_time * 1000 # type: ignore # pylint: disable=protected-access def send_messages(self, producer, timeout_time, last_exception, logger): """ @@ -264,10 +272,11 @@ def send_messages(self, producer, timeout_time, last_exception, logger): :param ~azure.eventhub._producer.EventHubProducer producer: The producer with handler to send messages. :param int timeout_time: Timeout time. :param last_exception: Exception to raise if message timed out. Only used by uamqp transport. + :param logger: Logger. """ # pylint: disable=protected-access producer._unsent_events[0].on_send_complete = producer._on_outcome - self._set_msg_timeout(timeout_time, last_exception, logger) + self._set_msg_timeout(producer, timeout_time, last_exception, logger) producer._handler.queue_message(*producer._unsent_events) # type: ignore producer._handler.wait() # type: ignore producer._unsent_events = producer._handler.pending_messages # type: ignore @@ -359,7 +368,7 @@ def create_receive_client(self, *, config, **kwargs): auto_complete=False, **kwargs ) - + # pylint:disable=protected-access client._streaming_receive = streaming_receive client._message_received_callback = (message_received_callback) return client @@ -372,6 +381,95 @@ def open_receive_client(self, *, handler, client, auth): :param auth: Auth. :rtype: bool """ + # pylint:disable=protected-access handler.open(connection=client._conn_manager.get_connection( client._address.hostname, auth )) + + def create_token_auth(self, auth_uri, get_token, token_type, config, **kwargs): + """ + Creates the JWTTokenAuth. + :param str auth_uri: The auth uri to pass to JWTTokenAuth. + :param get_token: The callback function used for getting and refreshing + tokens. It should return a valid jwt token each time it is called. + :param bytes token_type: Token type. + :param ~azure.eventhub._configuration.Configuration config: EH config. + + :keyword bool update_token: Whether to update token. If not updating token, then pass 300 to refresh_window. + """ + update_token = kwargs.pop("update_token", False) + refresh_window = 300 + if update_token: + refresh_window = 0 + + token_auth = authentication.JWTTokenAuth( + auth_uri, + auth_uri, + get_token, + token_type=token_type, + timeout=config.auth_timeout, + http_proxy=config.http_proxy, + transport_type=config.transport_type, + custom_endpoint_hostname=config.custom_endpoint_hostname, + port=config.connection_port, + verify=config.connection_verify, + refresh_window=refresh_window + ) + if update_token: + token_auth.update_token() # TODO: why don't we need to update in pyamqp? + return token_auth + + def create_mgmt_client(self, address, mgmt_auth, config): + """ + Creates and returns the mgmt AMQP client. + :param _Address address: Required. The Address. + :param JWTTokenAuth mgmt_auth: Auth for client. + :param ~azure.eventhub._configuration.Configuration config: The configuration. + """ + + mgmt_target = f"amqps://{address.hostname}{address.path}" + return AMQPClient( + mgmt_target, + auth=mgmt_auth, + debug=config.network_tracing + ) + + def get_updated_token(self, mgmt_auth): + """ + Return updated auth token. + :param mgmt_auth: Auth. + """ + return mgmt_auth.token + + def mgmt_client_request(self, mgmt_client, mgmt_msg, **kwargs): + """ + Send mgmt request. + :param AMQP Client mgmt_client: Client to send request with. + :param str mgmt_msg: Message. + :keyword bytes operation: Operation. + :keyword operation_type: Op type. + :keyword status_code_field: mgmt status code. + :keyword description_fields: mgmt status desc. + """ + operation_type = kwargs.pop("operation_type") + return mgmt_client.mgmt_request( + mgmt_msg, + op_type=operation_type, + **kwargs + ) + + def get_error(self, error, message, *, condition=None): # pylint: disable=unused-argument + """ + Gets error and passes in error message, and, if applicable, condition. + :param error: The error to raise. + :param str message: Error message. + :param condition: Optional error condition. Will not be used by uamqp. + """ + return error(message) + + def get_link_max_message_size(self, handler): + """ + Returns max peer message size. + :param AMQPClient handler: Client to get remote max message size on link from. + """ + return handler.message_handler._link.peer_max_message_size # pylint: disable=protected-access diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py index b193340f0383..65bf8fef78d6 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py @@ -270,7 +270,7 @@ def transform_outbound_single_message(message, message_type, to_outgoing_amqp_me except AttributeError: # pylint: disable=protected-access # AmqpAnnotatedMessage is converted to uamqp/pyamqp.Message during sending - amqp_message = to_outgoing_amqp_message(message.raw_amqp_message) + amqp_message = to_outgoing_amqp_message(message) return message_type._from_message( message=amqp_message, raw_amqp_message=message # type: ignore ) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/exceptions.py b/sdk/eventhub/azure-eventhub/azure/eventhub/exceptions.py index 3d5d7885b4bc..c4c5a83fb723 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/exceptions.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/exceptions.py @@ -9,29 +9,6 @@ _LOGGER = logging.getLogger(__name__) -def _error_handler(error): - """ - Called internally when an event has failed to send so we - can parse the error to determine whether we should attempt - to retry sending the event again. - Returns the action to take according to error type. - - :param error: The error received in the send attempt. - :type error: Exception - :rtype: ~uamqp.errors.ErrorAction - """ - if error.condition == b"com.microsoft:server-busy": - return ErrorAction(retry=True, backoff=4) - if error.condition == b"com.microsoft:timeout": - return ErrorAction(retry=True, backoff=2) - if error.condition == b"com.microsoft:operation-cancelled": - return ErrorAction(retry=True) - if error.condition == b"com.microsoft:container-close": - return ErrorAction(retry=True, backoff=4) - if error.condition in NO_RETRY_ERRORS: - return ErrorAction(retry=False) - return ErrorAction(retry=True) - class EventHubError(Exception): """Represents an error occurred in the client. From 381e7cf577bfa8170a0dce4aa12d4e8c60e8673d Mon Sep 17 00:00:00 2001 From: Swathi Pillalamarri Date: Thu, 30 Jun 2022 14:03:44 -0700 Subject: [PATCH 09/21] remove uamqp imports from shared --- .../azure-eventhub/azure/eventhub/__init__.py | 4 +- .../azure/eventhub/_client_base.py | 22 +-- .../azure-eventhub/azure/eventhub/_common.py | 7 +- .../azure/eventhub/_configuration.py | 2 +- .../azure/eventhub/_connection_manager.py | 4 +- .../azure/eventhub/_constants.py | 17 ++- .../azure/eventhub/_consumer.py | 2 +- .../azure/eventhub/_producer.py | 6 +- .../eventhub/_transport/_uamqp_transport.py | 128 +++++++++++++++--- .../azure-eventhub/azure/eventhub/_utils.py | 8 +- .../azure/eventhub/amqp/_amqp_message.py | 2 +- .../azure/eventhub/amqp/_constants.py | 8 -- .../azure/eventhub/exceptions.py | 81 ----------- 13 files changed, 150 insertions(+), 141 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/__init__.py b/sdk/eventhub/azure-eventhub/azure/eventhub/__init__.py index 670101d117ac..a36aa0b27e8a 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/__init__.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/__init__.py @@ -2,12 +2,12 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- -from uamqp import constants from ._common import EventData, EventDataBatch from ._version import VERSION __version__ = VERSION +from ._constants import TransportType from ._producer_client import EventHubProducerClient from ._consumer_client import EventHubConsumerClient # TODO in pyamqp: from ._client_base import EventHubSharedKeyCredential @@ -20,8 +20,6 @@ EventHubConnectionStringProperties ) -TransportType = constants.TransportType - __all__ = [ "EventData", "EventDataBatch", diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py index 21442b4d344d..605a2c0cfd07 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py @@ -13,7 +13,6 @@ from urllib.parse import urlparse -from uamqp import Message, authentication, constants, errors, compat import six from azure.core.credentials import ( AccessToken, @@ -24,8 +23,8 @@ from azure.core.pipeline.policies import RetryMode -from ._transport._uamqp_transport import UamqpTransport, EventhubAzureNamedKeyTokenCredential -from .exceptions import _handle_exception, ClientClosedError, ConnectError +from ._transport._uamqp_transport import UamqpTransport +from .exceptions import ClientClosedError from ._configuration import Configuration from ._utils import utc_from_timestamp, parse_sas_credential from ._connection_manager import get_connection_manager @@ -36,7 +35,10 @@ MGMT_PARTITION_OPERATION, MGMT_STATUS_CODE, MGMT_STATUS_DESC, + READ_OPERATION, ) +if TYPE_CHECKING: + from uamqp import Message, authentication _LOGGER = logging.getLogger(__name__) @@ -318,7 +320,7 @@ def _backoff( raise last_exception def _management_request(self, mgmt_msg, op_type): - # type: (Message, bytes) -> Any + # type: ("Message", bytes) -> Any # pylint:disable=assignment-from-none retried_times = 0 last_exception = None @@ -335,7 +337,7 @@ def _management_request(self, mgmt_msg, op_type): response = self._amqp_transport.mgmt_client_request( mgmt_client, mgmt_msg, - operation=constants.READ_OPERATION, + operation=READ_OPERATION, operation_type=op_type, status_code_field=MGMT_STATUS_CODE, description_fields=MGMT_STATUS_DESC, @@ -363,7 +365,7 @@ def _management_request(self, mgmt_msg, op_type): f"Management request error. Status code: {status_code}, Description: {description!r}" ) except Exception as exception: # pylint: disable=broad-except - last_exception = _handle_exception(exception, self) + last_exception = self._amqp_transport._handle_exception(exception, self) self._backoff( retried_times=retried_times, last_exception=last_exception ) @@ -384,7 +386,7 @@ def _add_span_request_attributes(self, span): def _get_eventhub_properties(self): # type:() -> Dict[str, Any] - mgmt_msg = Message(application_properties={"name": self.eventhub_name}) + mgmt_msg = self._amqp_transport.MESSAGE(application_properties={"name": self.eventhub_name}) response = self._management_request(mgmt_msg, op_type=MGMT_OPERATION) output = {} eh_info = response.value # type: Dict[bytes, Any] @@ -404,7 +406,7 @@ def _get_partition_ids(self): def _get_partition_properties(self, partition_id): # type:(str) -> Dict[str, Any] - mgmt_msg = Message( + mgmt_msg = self._amqp_transport.MESSAGE( application_properties={ "name": self.eventhub_name, "partition": partition_id, @@ -467,7 +469,7 @@ def _open(self): time.sleep(0.05) self._max_message_size_on_link = ( self._amqp_transport.get_link_max_message_size(self._handler) - or constants.MAX_MESSAGE_LENGTH_BYTES + or self._amqp_transport.MAX_FRAME_SIZE_BYTES ) self.running = True @@ -483,7 +485,7 @@ def _close_connection(self): def _handle_exception(self, exception): if not self.running and isinstance(exception, self._amqp_transport.TIMEOUT_EXCEPTION): exception = self._amqp_transport.get_error("Authorization timeout.") - return _handle_exception(exception, self) + return self._amqp_transport._handle_exception(exception, self) # pylint: disable=protected-access def _do_retryable_operation(self, operation, timeout=None, **kwargs): # pylint:disable=protected-access diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py index 98e3c8a7f23c..d4e5248dd1ea 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py @@ -33,7 +33,6 @@ PROP_SEQ_NUMBER, PROP_OFFSET, PROP_PARTITION_KEY, - PROP_PARTITION_KEY_AMQP_SYMBOL, PROP_TIMESTAMP, PROP_ABSOLUTE_EXPIRY_TIME, PROP_CONTENT_ENCODING, @@ -285,10 +284,8 @@ def partition_key(self): :rtype: bytes """ - try: - return self._raw_amqp_message.annotations[PROP_PARTITION_KEY_AMQP_SYMBOL] - except KeyError: - return self._raw_amqp_message.annotations.get(PROP_PARTITION_KEY, None) + # TODO: I think just trying this is reasonable? Haven't seen a case where symbol is used to get. + return self._raw_amqp_message.annotations.get(PROP_PARTITION_KEY, None) @property def properties(self): diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_configuration.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_configuration.py index bf42d52bb761..af9cfdc781f5 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_configuration.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_configuration.py @@ -9,7 +9,7 @@ except ImportError: from urllib.parse import urlparse -from uamqp.constants import TransportType, DEFAULT_AMQPS_PORT, DEFAULT_AMQP_WSS_PORT +from ._constants import TransportType, DEFAULT_AMQPS_PORT, DEFAULT_AMQP_WSS_PORT from azure.core.pipeline.policies import RetryMode diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_connection_manager.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_connection_manager.py index 623a25ece678..eab34b955292 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_connection_manager.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_connection_manager.py @@ -5,9 +5,9 @@ from typing import TYPE_CHECKING -from uamqp import Connection if TYPE_CHECKING: + from uamqp import Connection from uamqp.authentication import JWTTokenAuth try: @@ -17,7 +17,7 @@ class ConnectionManager(Protocol): def get_connection(self, host, auth): - # type: (str, 'JWTTokenAuth') -> Connection + # type: (str, 'JWTTokenAuth') -> "Connection" pass def close_connection(self): diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_constants.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_constants.py index 97a614c341b5..fc9bcdb7ae57 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_constants.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_constants.py @@ -4,13 +4,12 @@ # -------------------------------------------------------------------------------------------- from __future__ import unicode_literals -from uamqp import types +from enum import Enum PROP_SEQ_NUMBER = b"x-opt-sequence-number" PROP_OFFSET = b"x-opt-offset" PROP_PARTITION_KEY = b"x-opt-partition-key" -PROP_PARTITION_KEY_AMQP_SYMBOL = types.AMQPSymbol(PROP_PARTITION_KEY) PROP_TIMESTAMP = b"x-opt-enqueued-time" PROP_LAST_ENQUEUED_SEQUENCE_NUMBER = b"last_enqueued_sequence_number" PROP_LAST_ENQUEUED_OFFSET = b"last_enqueued_offset" @@ -52,3 +51,17 @@ b"com.microsoft:precondition-failed", b"com.microsoft:argument-error", ) + +class TransportType(Enum): + """Transport type + The underlying transport protocol type: + Amqp: AMQP over the default TCP transport protocol, it uses port 5671. + AmqpOverWebsocket: Amqp over the Web Sockets transport protocol, it uses + port 443. + """ + Amqp = 1 + AmqpOverWebsocket = 2 + +DEFAULT_AMQPS_PORT = 5671 +DEFAULT_AMQP_WSS_PORT = 443 +READ_OPERATION = b"READ" diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py index bbfc561d390c..83000880eb56 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py @@ -151,7 +151,7 @@ def _open_with_retry(self): self._do_retryable_operation(self._open, operation_need_param=False) def _message_received(self, message): - # type: (uamqp.Message) -> None + # type: (Message) -> None # pylint:disable=protected-access self._message_buffer.append(message) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py index 4ad6da8848b2..00c89777270f 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py @@ -32,14 +32,12 @@ if TYPE_CHECKING: from uamqp import constants, SendClient + from uamqp.authentication import JWTTokenAuth # pylint: disable=ungrouped-imports from ._transport._base import AmqpTransport + from ._producer_client import EventHubProducerClient _LOGGER = logging.getLogger(__name__) -if TYPE_CHECKING: - from uamqp.authentication import JWTTokenAuth # pylint: disable=ungrouped-imports - from ._producer_client import EventHubProducerClient - def _set_partition_key(event_datas, partition_key, amqp_transport): # type: (Iterable[EventData], AnyStr, AmqpTransport) -> Iterable[EventData] diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py index b461620c94dd..86bbcc95e846 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py @@ -2,11 +2,12 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- -from lib2to3.pgen2 import token + import time +import logging from datetime import timedelta from typing import TYPE_CHECKING, Optional, Union, Any -from urllib.parse import urlparse, quote_plus +from urllib.parse import quote_plus from azure.core.credentials import AccessToken from uamqp import ( @@ -21,26 +22,38 @@ utils, authentication, AMQPClient, - compat + compat, + errors, ) from uamqp.message import ( MessageHeader, MessageProperties, ) -from uamqp.errors import ErrorPolicy, ErrorAction, LinkDetach, AuthenticationException, AMQPConnectionError from ._base import AmqpTransport from ..amqp._constants import AmqpMessageBodyType from .._constants import ( NO_RETRY_ERRORS, - PROP_PARTITION_KEY_AMQP_SYMBOL, + PROP_PARTITION_KEY, +) + +from ..exceptions import ( + ConnectError, + EventDataError, + EventDataSendError, + OperationTimeoutError, + EventHubError, + AuthenticationError, + ConnectionLostError, + EventDataError, + EventDataSendError, ) -from ..exceptions import ConnectError, OperationTimeoutError if TYPE_CHECKING: - import logging from azure.core.credentials import AzureNamedKeyCredential +_LOGGER = logging.getLogger(__name__) + def _error_handler(error): """ Called internally when an event has failed to send so we @@ -53,16 +66,16 @@ def _error_handler(error): :rtype: ~uamqp.errors.ErrorAction """ if error.condition == b"com.microsoft:server-busy": - return ErrorAction(retry=True, backoff=4) + return errors.ErrorAction(retry=True, backoff=4) if error.condition == b"com.microsoft:timeout": - return ErrorAction(retry=True, backoff=2) + return errors.ErrorAction(retry=True, backoff=2) if error.condition == b"com.microsoft:operation-cancelled": - return ErrorAction(retry=True) + return errors.ErrorAction(retry=True) if error.condition == b"com.microsoft:container-close": - return ErrorAction(retry=True, backoff=4) + return errors.ErrorAction(retry=True, backoff=4) if error.condition in NO_RETRY_ERRORS: - return ErrorAction(retry=False) - return ErrorAction(retry=True) + return errors.ErrorAction(retry=False) + return errors.ErrorAction(retry=True) def _generate_sas_token(uri, policy, key, expiry=None): @@ -131,6 +144,7 @@ class UamqpTransport(AmqpTransport): BATCH_MESSAGE = BatchMessage MAX_FRAME_SIZE_BYTES = constants.MAX_MESSAGE_LENGTH_BYTES IDLE_TIMEOUT_FACTOR = 1000 # pyamqp = 1 + MESSAGE = Message # define symbols PRODUCT_SYMBOL = types.AMQPSymbol("product") @@ -138,13 +152,14 @@ class UamqpTransport(AmqpTransport): FRAMEWORK_SYMBOL = types.AMQPSymbol("framework") PLATFORM_SYMBOL = types.AMQPSymbol("platform") USER_AGENT_SYMBOL = types.AMQPSymbol("user-agent") + PROP_PARTITION_KEY_AMQP_SYMBOL = types.AMQPSymbol(PROP_PARTITION_KEY) # define errors and conditions - AMQP_LINK_ERROR = LinkDetach + AMQP_LINK_ERROR = errors.LinkDetach LINK_STOLEN_CONDITION = constants.ErrorCodes.LinkStolen - AUTH_EXCEPTION = AuthenticationException + AUTH_EXCEPTION = errors.AuthenticationException CONNECTION_ERROR = ConnectError - AMQP_CONNECTION_ERROR = AMQPConnectionError + AMQP_CONNECTION_ERROR = errors.AMQPConnectionError TIMEOUT_EXCEPTION = compat.TimeoutException def to_outgoing_amqp_message(self, annotated_message): @@ -217,7 +232,7 @@ def create_retry_policy(self, retry_total): Creates the error retry policy. :param retry_total: Max number of retries. """ - return ErrorPolicy(max_retries=retry_total, on_error=_error_handler) + return errors.ErrorPolicy(max_retries=retry_total, on_error=_error_handler) def create_link_properties(self, link_properties): """ @@ -262,7 +277,7 @@ def _set_msg_timeout(self, producer, timeout_time, last_exception, logger): error = last_exception else: error = OperationTimeoutError("Send operation timed out") - logger.info("%r send operation timed out. (%r)", producer._name, error) + logger.info("%r send operation timed out. (%r)", producer._name, error) # pylint: disable=protected-access raise error producer._handler._msg_timeout = remaining_time * 1000 # type: ignore # pylint: disable=protected-access @@ -306,7 +321,7 @@ def set_message_partition_key(self, message, partition_key, **kwargs): # pylint if annotations is None: annotations = {} annotations[ - PROP_PARTITION_KEY_AMQP_SYMBOL + UamqpTransport.PROP_PARTITION_KEY_AMQP_SYMBOL ] = partition_key header = MessageHeader() header.durable = True @@ -473,3 +488,78 @@ def get_link_max_message_size(self, handler): :param AMQPClient handler: Client to get remote max message size on link from. """ return handler.message_handler._link.peer_max_message_size # pylint: disable=protected-access + + def _create_eventhub_exception(self, exception): + if isinstance(exception, errors.AuthenticationException): + error = AuthenticationError(str(exception), exception) + elif isinstance(exception, errors.VendorLinkDetach): + error = ConnectError(str(exception), exception) + elif isinstance(exception, errors.LinkDetach): + error = ConnectionLostError(str(exception), exception) + elif isinstance(exception, errors.ConnectionClose): + error = ConnectionLostError(str(exception), exception) + elif isinstance(exception, errors.MessageHandlerError): + error = ConnectionLostError(str(exception), exception) + elif isinstance(exception, errors.AMQPConnectionError): + error_type = ( + AuthenticationError + if str(exception).startswith("Unable to open authentication session") + else ConnectError + ) + error = error_type(str(exception), exception) + elif isinstance(exception, compat.TimeoutException): + error = ConnectionLostError(str(exception), exception) + else: + error = EventHubError(str(exception), exception) + return error + + + def _handle_exception( + self, exception, closable + ): # pylint:disable=too-many-branches, too-many-statements + try: # closable is a producer/consumer object + name = closable._name # pylint: disable=protected-access + except AttributeError: # closable is an client object + name = closable._container_id # pylint: disable=protected-access + if isinstance(exception, KeyboardInterrupt): # pylint:disable=no-else-raise + _LOGGER.info("%r stops due to keyboard interrupt", name) + closable._close_connection() # pylint:disable=protected-access + raise exception + elif isinstance(exception, EventHubError): + closable._close_handler() # pylint:disable=protected-access + raise exception + elif isinstance( + exception, + ( + errors.MessageAccepted, + errors.MessageAlreadySettled, + errors.MessageModified, + errors.MessageRejected, + errors.MessageReleased, + errors.MessageContentTooLarge, + ), + ): + _LOGGER.info("%r Event data error (%r)", name, exception) + error = EventDataError(str(exception), exception) + raise error + elif isinstance(exception, errors.MessageException): + _LOGGER.info("%r Event data send error (%r)", name, exception) + error = EventDataSendError(str(exception), exception) + raise error + else: + if isinstance(exception, errors.AuthenticationException): + if hasattr(closable, "_close_connection"): + closable._close_connection() # pylint:disable=protected-access + elif isinstance(exception, errors.LinkDetach): + if hasattr(closable, "_close_handler"): + closable._close_handler() # pylint:disable=protected-access + elif isinstance(exception, errors.ConnectionClose): + if hasattr(closable, "_close_connection"): + closable._close_connection() # pylint:disable=protected-access + elif isinstance(exception, errors.MessageHandlerError): + if hasattr(closable, "_close_handler"): + closable._close_handler() # pylint:disable=protected-access + else: # errors.AMQPConnectionError, compat.TimeoutException + if hasattr(closable, "_close_connection"): + closable._close_connection() # pylint:disable=protected-access + return self._create_eventhub_exception(exception) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py index 65bf8fef78d6..5fceeceedb9d 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py @@ -31,10 +31,10 @@ # Python 3 Type Checking imports from ._transport._base import AmqpTransport -from uamqp import types if TYPE_CHECKING: # pylint: disable=ungrouped-imports + from uamqp import types from azure.core.tracing import AbstractSpan from azure.core.credentials import AzureSasCredential from ._common import EventData @@ -78,7 +78,7 @@ def utc_from_timestamp(timestamp): def create_properties( user_agent: Optional[str] = None, *, amqp_transport: AmqpTransport -) -> Dict[types.AMQPSymbol, str]: +) -> Dict["types.AMQPSymbol", str]: """ Format the properties with which to instantiate the connection. This acts like a user agent over HTTP. @@ -264,12 +264,12 @@ def transform_outbound_single_message(message, message_type, to_outgoing_amqp_me """ try: # pylint: disable=protected-access - # EventData.message stores uamqp/pyamqp.Message during sending + # TODO:EventData.message stores uamqp/pyamqp.Message during sending message.message = to_outgoing_amqp_message(message.raw_amqp_message) return message # type: ignore except AttributeError: # pylint: disable=protected-access - # AmqpAnnotatedMessage is converted to uamqp/pyamqp.Message during sending + # TODO:AmqpAnnotatedMessage is converted to uamqp/pyamqp.Message during sending amqp_message = to_outgoing_amqp_message(message) return message_type._from_message( message=amqp_message, raw_amqp_message=message # type: ignore diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py b/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py index 4468c8b82cc9..e36071686c10 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py @@ -7,7 +7,7 @@ from typing import Optional, Any, cast, Mapping, Dict from ._amqp_utils import normalized_data_body, normalized_sequence_body -from ._constants import AMQP_MESSAGE_BODY_TYPE_MAP, AmqpMessageBodyType +from ._constants import AmqpMessageBodyType from .._mixin import DictMixin diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_constants.py b/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_constants.py index 1e2e7d3b6577..05cd1b5ce08e 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_constants.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_constants.py @@ -5,7 +5,6 @@ # ------------------------------------------------------------------------- from enum import Enum -from uamqp import MessageBodyType from azure.core import CaseInsensitiveEnumMeta @@ -13,10 +12,3 @@ class AmqpMessageBodyType(str, Enum, metaclass=CaseInsensitiveEnumMeta): DATA = "data" SEQUENCE = "sequence" VALUE = "value" - - -AMQP_MESSAGE_BODY_TYPE_MAP = { - MessageBodyType.Data.value: AmqpMessageBodyType.DATA, - MessageBodyType.Sequence.value: AmqpMessageBodyType.SEQUENCE, - MessageBodyType.Value.value: AmqpMessageBodyType.VALUE, -} diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/exceptions.py b/sdk/eventhub/azure-eventhub/azure/eventhub/exceptions.py index c4c5a83fb723..f686251e6e95 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/exceptions.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/exceptions.py @@ -2,13 +2,8 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- -import logging import six -from uamqp import errors, compat - -_LOGGER = logging.getLogger(__name__) - class EventHubError(Exception): """Represents an error occurred in the client. @@ -98,79 +93,3 @@ class OperationTimeoutError(EventHubError): class OwnershipLostError(Exception): """Raised when `update_checkpoint` detects the ownership to a partition has been lost.""" - - -def _create_eventhub_exception(exception): - if isinstance(exception, errors.AuthenticationException): - error = AuthenticationError(str(exception), exception) - elif isinstance(exception, errors.VendorLinkDetach): - error = ConnectError(str(exception), exception) - elif isinstance(exception, errors.LinkDetach): - error = ConnectionLostError(str(exception), exception) - elif isinstance(exception, errors.ConnectionClose): - error = ConnectionLostError(str(exception), exception) - elif isinstance(exception, errors.MessageHandlerError): - error = ConnectionLostError(str(exception), exception) - elif isinstance(exception, errors.AMQPConnectionError): - error_type = ( - AuthenticationError - if str(exception).startswith("Unable to open authentication session") - else ConnectError - ) - error = error_type(str(exception), exception) - elif isinstance(exception, compat.TimeoutException): - error = ConnectionLostError(str(exception), exception) - else: - error = EventHubError(str(exception), exception) - return error - - -def _handle_exception( - exception, closable -): # pylint:disable=too-many-branches, too-many-statements - try: # closable is a producer/consumer object - name = closable._name # pylint: disable=protected-access - except AttributeError: # closable is an client object - name = closable._container_id # pylint: disable=protected-access - if isinstance(exception, KeyboardInterrupt): # pylint:disable=no-else-raise - _LOGGER.info("%r stops due to keyboard interrupt", name) - closable._close_connection() # pylint:disable=protected-access - raise exception - elif isinstance(exception, EventHubError): - closable._close_handler() # pylint:disable=protected-access - raise exception - elif isinstance( - exception, - ( - errors.MessageAccepted, - errors.MessageAlreadySettled, - errors.MessageModified, - errors.MessageRejected, - errors.MessageReleased, - errors.MessageContentTooLarge, - ), - ): - _LOGGER.info("%r Event data error (%r)", name, exception) - error = EventDataError(str(exception), exception) - raise error - elif isinstance(exception, errors.MessageException): - _LOGGER.info("%r Event data send error (%r)", name, exception) - error = EventDataSendError(str(exception), exception) - raise error - else: - if isinstance(exception, errors.AuthenticationException): - if hasattr(closable, "_close_connection"): - closable._close_connection() # pylint:disable=protected-access - elif isinstance(exception, errors.LinkDetach): - if hasattr(closable, "_close_handler"): - closable._close_handler() # pylint:disable=protected-access - elif isinstance(exception, errors.ConnectionClose): - if hasattr(closable, "_close_connection"): - closable._close_connection() # pylint:disable=protected-access - elif isinstance(exception, errors.MessageHandlerError): - if hasattr(closable, "_close_handler"): - closable._close_handler() # pylint:disable=protected-access - else: # errors.AMQPConnectionError, compat.TimeoutException - if hasattr(closable, "_close_connection"): - closable._close_connection() # pylint:disable=protected-access - return _create_eventhub_exception(exception) From 5586dbd574b9e2d6a0fbc53694821d837fef0eaa Mon Sep 17 00:00:00 2001 From: Swathi Pillalamarri Date: Wed, 6 Jul 2022 13:34:38 -0700 Subject: [PATCH 10/21] fixed receive bugs --- sdk/eventhub/azure-eventhub/azure/eventhub/_common.py | 9 +++++---- .../azure-eventhub/azure/eventhub/_producer.py | 2 +- .../azure/eventhub/_transport/_uamqp_transport.py | 7 ++++++- .../tests/livetest/synctests/test_negative.py | 11 +++++++---- .../tests/livetest/synctests/test_receive.py | 3 +-- .../tests/livetest/synctests/test_reconnect.py | 10 ++++++++-- .../tests/livetest/synctests/test_send.py | 9 +++++---- 7 files changed, 33 insertions(+), 18 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py index d4e5248dd1ea..1b0b64edcd0d 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py @@ -57,6 +57,7 @@ if TYPE_CHECKING: import datetime + from ._transport._base import AmqpTransport MessageContent = TypedDict("MessageContent", {"content": bytes, "content_type": str}) PrimitiveTypes = Optional[Union[ @@ -542,12 +543,12 @@ def __len__(self): return self._count @classmethod - def _from_batch(cls, batch_data, to_outgoing_amqp_message, partition_key=None): - # type: (Iterable[EventData], Callable, Optional[AnyStr]) -> EventDataBatch + def _from_batch(cls, batch_data, amqp_transport, partition_key=None): + # type: (Iterable[EventData], AmqpTransport, Optional[AnyStr]) -> EventDataBatch outgoing_batch_data = [ - transform_outbound_single_message(m, EventData, to_outgoing_amqp_message) for m in batch_data + transform_outbound_single_message(m, EventData, amqp_transport.to_outgoing_amqp_message) for m in batch_data ] - batch_data_instance = cls(partition_key=partition_key) + batch_data_instance = cls(partition_key=partition_key, amqp_transport=amqp_transport) batch_data_instance.message._body_gen = ( # pylint:disable=protected-access outgoing_batch_data ) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py index 00c89777270f..23878d711cf6 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py @@ -191,7 +191,7 @@ def _wrap_eventdata( event_data = _set_partition_key(event_data, partition_key, self._amqp_transport) event_data = _set_trace_message(event_data, span) wrapper_event_data = EventDataBatch._from_batch( # type: ignore # pylint: disable=protected-access - event_data, partition_key, self._amqp_transport.to_outgoing_amqp_message + event_data, self._amqp_transport, partition_key=partition_key ) return wrapper_event_data diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py index 86bbcc95e846..9a725ce3dc28 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py @@ -366,7 +366,10 @@ def create_receive_client(self, *, config, **kwargs): source = kwargs.pop("source") symbol_array = kwargs.pop("desired_capabilities") - desired_capabilities = utils.data_factory(types.AMQPArray(symbol_array)) if symbol_array else None + desired_capabilities = None + if symbol_array: + symbol_array = [types.AMQPSymbol(symbol) for symbol in symbol_array] + desired_capabilities = utils.data_factory(types.AMQPArray(symbol_array)) retry_policy = kwargs.pop("retry_policy") network_trace = kwargs.pop("network_trace") link_credit = kwargs.pop("link_credit") @@ -467,8 +470,10 @@ def mgmt_client_request(self, mgmt_client, mgmt_msg, **kwargs): :keyword description_fields: mgmt status desc. """ operation_type = kwargs.pop("operation_type") + operation = kwargs.pop("operation") return mgmt_client.mgmt_request( mgmt_msg, + operation, op_type=operation_type, **kwargs ) diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_negative.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_negative.py index d3b775381796..f018f70f994e 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_negative.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_negative.py @@ -19,10 +19,12 @@ ) from azure.eventhub import EventHubConsumerClient from azure.eventhub import EventHubProducerClient +from azure.eventhub._transport._uamqp_transport import UamqpTransport +@pytest.mark.parametrize("amqp_transport", [UamqpTransport()], ) @pytest.mark.liveTest -def test_send_batch_with_invalid_hostname(invalid_hostname): +def test_send_batch_with_invalid_hostname(invalid_hostname, amqp_transport): if sys.platform.startswith('darwin'): pytest.skip("Skipping on OSX - it keeps reporting 'Unable to set external certificates' " "and blocking other tests") @@ -30,7 +32,7 @@ def test_send_batch_with_invalid_hostname(invalid_hostname): with client: with pytest.raises(ConnectError): batch = EventDataBatch() - batch.add(EventData("test data")) + batch.add(EventData("test data", amqp_transport=amqp_transport)) client.send_batch(batch) @@ -49,12 +51,13 @@ def on_event(partition_context, event): thread.join() +@pytest.mark.parametrize("amqp_transport", [UamqpTransport()], ) @pytest.mark.liveTest -def test_send_batch_with_invalid_key(invalid_key): +def test_send_batch_with_invalid_key(invalid_key, amqp_transport): client = EventHubProducerClient.from_connection_string(invalid_key) try: with pytest.raises(ConnectError): - batch = EventDataBatch() + batch = EventDataBatch(amqp_transport=amqp_transport) batch.add(EventData("test data")) client.send_batch(batch) finally: diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_receive.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_receive.py index 21d6e249581e..1a7f42630d52 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_receive.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_receive.py @@ -90,12 +90,11 @@ def on_event(partition_context, event): "track_last_enqueued_event_properties": True}) thread.daemon = True thread.start() - time.sleep(10) + time.sleep(15) assert on_event.event.body_as_str() == expected_result thread.join() - @pytest.mark.liveTest def test_receive_owner_level(connstr_senders): def on_event(partition_context, event): diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_reconnect.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_reconnect.py index 0abfa7a12d2f..78006ed1a188 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_reconnect.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_reconnect.py @@ -15,9 +15,12 @@ EventData, EventHubSharedKeyCredential, EventHubProducerClient, - EventHubConsumerClient + EventHubConsumerClient, + amqp ) from azure.eventhub.exceptions import OperationTimeoutError +from azure.eventhub._utils import transform_outbound_single_message +from azure.eventhub._transport._uamqp_transport import UamqpTransport @pytest.mark.liveTest def test_send_with_long_interval_sync(live_eventhub, sleep): @@ -61,8 +64,10 @@ def test_send_with_long_interval_sync(live_eventhub, sleep): assert list(received[0].body)[0] == b"A single event" +@pytest.mark.parametrize("amqp_transport", + [UamqpTransport()]) @pytest.mark.liveTest -def test_send_connection_idle_timeout_and_reconnect_sync(connstr_receivers): +def test_send_connection_idle_timeout_and_reconnect_sync(connstr_receivers, amqp_transport): connection_str, receivers = connstr_receivers client = EventHubProducerClient.from_connection_string(conn_str=connection_str, idle_timeout=10) with client: @@ -71,6 +76,7 @@ def test_send_connection_idle_timeout_and_reconnect_sync(connstr_receivers): with sender: sender._open_with_retry() time.sleep(11) + ed = transform_outbound_single_message(ed, EventData, amqp_transport.to_outgoing_amqp_message) sender._unsent_events = [ed.message] ed.message.on_send_complete = sender._on_outcome with pytest.raises((uamqp.errors.ConnectionClose, diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_send.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_send.py index 0872b55ae85f..d4cfb44ea74c 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_send.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_send.py @@ -20,6 +20,7 @@ AmqpAnnotatedMessage, AmqpMessageProperties, ) +from azure.eventhub._transport._uamqp_transport import UamqpTransport @pytest.mark.liveTest def test_send_with_partition_key(connstr_receivers): @@ -363,11 +364,11 @@ def test_send_list_wrong_data(connection_str, to_send, exception_type): client.send_batch(to_send) -@pytest.mark.parametrize("partition_id, partition_key", [("0", None), (None, "pk")]) -def test_send_batch_pid_pk(invalid_hostname, partition_id, partition_key): +@pytest.mark.parametrize("partition_id, partition_key, amqp_transport", [("0", None, UamqpTransport()), (None, "pk", UamqpTransport())], ) +def test_send_batch_pid_pk(invalid_hostname, partition_id, partition_key, amqp_transport): # Use invalid_hostname because this is not a live test. client = EventHubProducerClient.from_connection_string(invalid_hostname) - batch = EventDataBatch(partition_id=partition_id, partition_key=partition_key) + batch = EventDataBatch(partition_id=partition_id, partition_key=partition_key, amqp_transport=amqp_transport) with client: with pytest.raises(TypeError): - client.send_batch(batch, partition_id=partition_id, partition_key=partition_key) + client.send_batch(batch, partition_id=partition_id, partition_key=partition_key, amqp_transport=amqp_transport) From 678138ee7275625e712523907c2423874c01026b Mon Sep 17 00:00:00 2001 From: Swathi Pillalamarri Date: Wed, 6 Jul 2022 14:44:22 -0700 Subject: [PATCH 11/21] fix negative test --- .../tests/livetest/synctests/test_negative.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_negative.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_negative.py index f018f70f994e..a33943fd5fdf 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_negative.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_negative.py @@ -22,7 +22,7 @@ from azure.eventhub._transport._uamqp_transport import UamqpTransport -@pytest.mark.parametrize("amqp_transport", [UamqpTransport()], ) +@pytest.mark.parametrize("amqp_transport", [UamqpTransport(), ], ) @pytest.mark.liveTest def test_send_batch_with_invalid_hostname(invalid_hostname, amqp_transport): if sys.platform.startswith('darwin'): @@ -31,8 +31,8 @@ def test_send_batch_with_invalid_hostname(invalid_hostname, amqp_transport): client = EventHubProducerClient.from_connection_string(invalid_hostname) with client: with pytest.raises(ConnectError): - batch = EventDataBatch() - batch.add(EventData("test data", amqp_transport=amqp_transport)) + batch = EventDataBatch(amqp_transport=amqp_transport) + batch.add(EventData("test data")) client.send_batch(batch) From 141d2c09a365b7883f77d66356b72166616b3d55 Mon Sep 17 00:00:00 2001 From: Swathi Pillalamarri Date: Thu, 7 Jul 2022 16:07:50 -0700 Subject: [PATCH 12/21] update body of amqp message --- sdk/eventhub/azure-eventhub/azure/eventhub/_common.py | 6 +++--- .../azure-eventhub/azure/eventhub/amqp/_amqp_message.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py index 1b0b64edcd0d..4053d90ec7e0 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py @@ -232,11 +232,11 @@ def _decode_non_data_body_as_str(self, encoding="UTF-8"): # pylint: disable=protected-access body = self.raw_amqp_message.body if self.body_type == AmqpMessageBodyType.VALUE: - if not body.data: + if not body: return "" - return str(decode_with_recurse(body.data, encoding)) + return str(decode_with_recurse(body, encoding)) - seq_list = [d for seq_section in body.data for d in seq_section] + seq_list = [d for seq_section in body for d in seq_section] return str(decode_with_recurse(seq_list, encoding)) @property diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py b/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py index e36071686c10..334e2ec9c6cb 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py @@ -171,10 +171,10 @@ def _from_amqp_message(self, message): self._delivery_annotations = message.delivery_annotations if message.delivery_annotations else {} self._application_properties = message.application_properties if message.application_properties else {} if message.data: - self._body = message.data + self._body = list(message.data) self._body_type = AmqpMessageBodyType.DATA elif message.sequence: - self._body = message.sequence + self._body = list(message.sequence) self._body_type = AmqpMessageBodyType.SEQUENCE else: self._body = message.value From 9db8be02679cfc5c78e81176998b8a64696490e9 Mon Sep 17 00:00:00 2001 From: Swathi Pillalamarri Date: Tue, 12 Jul 2022 19:15:40 -0700 Subject: [PATCH 13/21] PYAMQP CLIENT UPDATED --- sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py index 551e610a9df4..d4825747facf 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py @@ -658,7 +658,7 @@ def _client_run(self, **kwargs): :rtype: bool """ try: - self._connection.listen(wait=self._socket_timeout, **kwargs) + self._connection.listen(wait=self._socket_timeout, batch=self._link_credit, **kwargs) except ValueError: _logger.info("Timeout reached, closing receiver.") self._shutdown = True From 455b06af57de554c6ceb4c874c67cafe6cd4201c Mon Sep 17 00:00:00 2001 From: Swathi Pillalamarri Date: Wed, 20 Jul 2022 16:16:35 -0700 Subject: [PATCH 14/21] fix receive bugs/event data bugs --- .../azure/eventhub/_client_base.py | 2 +- .../azure-eventhub/azure/eventhub/_common.py | 10 +++++++-- .../azure/eventhub/_producer.py | 2 +- .../eventhub/_transport/_pyamqp_transport.py | 14 +++++------- .../eventhub/_transport/_uamqp_transport.py | 7 +++--- .../azure/eventhub/amqp/_amqp_message.py | 22 +++++++++++++++++-- 6 files changed, 40 insertions(+), 17 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py index 8493291e39aa..e5694b2e32b2 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py @@ -285,7 +285,7 @@ def __init__( self, fully_qualified_namespace: str, eventhub_name: str, - credential: CredentialTypes, + credential: "CredentialTypes", **kwargs: Any, ) -> None: self._uamqp_transport = kwargs.pop("uamqp_transport", False) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py index 60622186f0d9..65eaf5511dfa 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py @@ -516,15 +516,21 @@ def __init__( self._partition_id = partition_id self._partition_key = partition_key - self._amqp_transport.set_message_partition_key( + self.message = self._amqp_transport.set_message_partition_key( self.message, self._partition_key ) self._size = self._amqp_transport.get_batch_message_encoded_size(self.message) self._count = 0 + self._internal_events: List[ + Union[EventData, AmqpAnnotatedMessage] + ] = [] # TODO: only used by uamqp def __repr__(self): # type: () -> str - batch_repr = f"max_size_in_bytes={self.max_size_in_bytes}, partition_id={self._partition_id}, partition_key={self._partition_key!r}, event_count={self._count}" + batch_repr = ( + f"max_size_in_bytes={self.max_size_in_bytes}, partition_id={self._partition_id}, " + f"partition_key={self._partition_key!r}, event_count={self._count}" + ) return f"EventDataBatch({batch_repr})" def __len__(self): diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py index 65bac2975734..fe914c7eca7b 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py @@ -118,7 +118,7 @@ def __init__( self._keep_alive = keep_alive self._auto_reconnect = auto_reconnect self._retry_policy = self._amqp_transport.create_retry_policy( - retry_total=self._client._config.max_retries + config=self._client._config ) self._reconnect_backoff = 1 self._name = f"EHProducer-{uuid.uuid4()}" diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_pyamqp_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_pyamqp_transport.py index 84fc8ac865f0..7e78178a5eec 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_pyamqp_transport.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_pyamqp_transport.py @@ -7,7 +7,6 @@ import logging from datetime import timedelta from typing import TYPE_CHECKING, Optional, Union, Any -from urllib.parse import quote_plus from azure.core.credentials import AccessToken from .._pyamqp import ( @@ -17,7 +16,6 @@ constants, AMQPClient, ReceiveClient, - types ) from .._pyamqp.message import Message, BatchMessage, Header, Properties from .._pyamqp.authentication import JWTTokenAuth @@ -37,7 +35,6 @@ from ..exceptions import ( ConnectError, - EventDataError, EventDataSendError, OperationTimeoutError, EventHubError, @@ -312,7 +309,7 @@ def set_message_partition_key( """Set the partition key as an annotation on a uamqp message. :param Message message: The message to update. :param str partition_key: The partition key value. - :rtype: None + :rtype: Message """ encoding = kwargs.pop("encoding", 'utf-8') if partition_key: @@ -327,7 +324,8 @@ def set_message_partition_key( PROP_PARTITION_KEY ] = partition_key # pylint:disable=protected-access header = Header(durable=True) - message._replace(message_annotations=annotations, header=header) + return message._replace(message_annotations=annotations, header=header) + return message def add_batch(self, batch_message, outgoing_event_data, event_data): # pylint: disable=unused-argument """ @@ -337,7 +335,7 @@ def add_batch(self, batch_message, outgoing_event_data, event_data): # pylint :param event_data: EventData to add to internal batch events. uamqp use only. :rtype: None """ - utils.add_batch(batch_message.message, outgoing_event_data) + utils.add_batch(batch_message.message, outgoing_event_data.message) def create_source(self, source, offset, selector): """ @@ -451,7 +449,7 @@ def create_mgmt_client(self, address, mgmt_auth, config): return AMQPClient( config.hostname, auth=mgmt_auth, - debug=config.network_tracing, + network_trace=config.network_tracing, transport_type=config.transport_type, http_proxy=config.http_proxy, custom_endpoint_address=config.custom_endpoint_address, @@ -478,7 +476,7 @@ def mgmt_client_request(self, mgmt_client, mgmt_msg, **kwargs): operation_type = kwargs.pop("operation_type") operation = kwargs.pop("operation") return mgmt_client.mgmt_request( - mgmt_msg, operation=operation.decode(), operation_type_type=operation_type.decode(), **kwargs + mgmt_msg, operation=operation.decode(), operation_type=operation_type.decode(), **kwargs ) def get_error(self, error, message, *, condition=None): diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py index 3d10e679cb3d..a379dd778bed 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py @@ -337,12 +337,12 @@ def send_messages(self, producer, timeout_time, last_exception, logger): # return batch_message._body_gen # pylint:disable=protected-access def set_message_partition_key(self, message, partition_key, **kwargs): # pylint:disable=unused-argument - # type: (Message, Optional[Union[bytes, str]], Any) -> None + # type: (Message, Optional[Union[bytes, str]], Any) -> Message """Set the partition key as an annotation on a uamqp message. :param ~uamqp.Message message: The message to update. :param str partition_key: The partition key value. - :rtype: None + :rtype: Message """ if partition_key: annotations = message.annotations @@ -355,6 +355,7 @@ def set_message_partition_key(self, message, partition_key, **kwargs): # pylint header.durable = True message.annotations = annotations message.header = header + return message def add_batch(self, batch_message, outgoing_event_data, event_data): """ @@ -379,7 +380,7 @@ def create_source(self, source, offset, selector): """ source = Source(source) if offset is not None: - source.set_filter(filter) + source.set_filter(selector) return source def create_receive_client(self, *, config, **kwargs): diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py b/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py index 7919b086ee7d..e2652ad971bb 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py @@ -5,6 +5,7 @@ # ------------------------------------------------------------------------- from typing import Optional, Any, cast, Mapping, Dict +from types import GeneratorType from ._amqp_utils import normalized_data_body, normalized_sequence_body from ._constants import AmqpMessageBodyType @@ -236,14 +237,31 @@ def _from_amqp_message(self, message): self._application_properties = message.application_properties if message.application_properties else {} if message.data: # TODO: body used to return generator object in uamqp. But returns a list rn. Ask Anna. - self._body = list(message.data) + self._body = message.data self._body_type = AmqpMessageBodyType.DATA elif message.sequence: - self._body = list(message.sequence) + self._body = message.sequence self._body_type = AmqpMessageBodyType.SEQUENCE else: self._body = message.value self._body_type = AmqpMessageBodyType.VALUE + #if message.data: + # # TODO: body used to return generator object in uamqp. But returns a list rn. Ask Anna. + # # below is def a hack. need to fix + # if isinstance(message.data, GeneratorType): + # self._body = list(message.data) + # else: + # self._body = message.data + # self._body_type = AmqpMessageBodyType.DATA + #elif message.sequence: + # if isinstance(message.data, GeneratorType): + # self._body = list(message.sequence) + # else: + # self._body = message.data + # self._body_type = AmqpMessageBodyType.SEQUENCE + #else: + # self._body = message.value + # self._body_type = AmqpMessageBodyType.VALUE @property From c10f67c625ba8e4f41ee1497f435b40253982931 Mon Sep 17 00:00:00 2001 From: Swathi Pillalamarri Date: Wed, 20 Jul 2022 17:20:26 -0700 Subject: [PATCH 15/21] update amqp message body --- .../azure/eventhub/amqp/_amqp_message.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py b/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py index e2652ad971bb..75a88ba41bed 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py @@ -236,7 +236,6 @@ def _from_amqp_message(self, message): self._delivery_annotations = message.delivery_annotations if message.delivery_annotations else {} self._application_properties = message.application_properties if message.application_properties else {} if message.data: - # TODO: body used to return generator object in uamqp. But returns a list rn. Ask Anna. self._body = message.data self._body_type = AmqpMessageBodyType.DATA elif message.sequence: @@ -245,23 +244,6 @@ def _from_amqp_message(self, message): else: self._body = message.value self._body_type = AmqpMessageBodyType.VALUE - #if message.data: - # # TODO: body used to return generator object in uamqp. But returns a list rn. Ask Anna. - # # below is def a hack. need to fix - # if isinstance(message.data, GeneratorType): - # self._body = list(message.data) - # else: - # self._body = message.data - # self._body_type = AmqpMessageBodyType.DATA - #elif message.sequence: - # if isinstance(message.data, GeneratorType): - # self._body = list(message.sequence) - # else: - # self._body = message.data - # self._body_type = AmqpMessageBodyType.SEQUENCE - #else: - # self._body = message.value - # self._body_type = AmqpMessageBodyType.VALUE @property From 87095b2768790da9fcc620c8baebfc8158c6c84e Mon Sep 17 00:00:00 2001 From: Swathi Pillalamarri Date: Mon, 25 Jul 2022 16:16:59 -0700 Subject: [PATCH 16/21] fixing bugs in tests --- .../azure-eventhub/azure/eventhub/__init__.py | 3 +- .../azure/eventhub/_client_base.py | 9 +- .../azure/eventhub/_connection_manager.py | 1 + .../azure/eventhub/_consumer.py | 4 +- .../azure/eventhub/_pyamqp/client.py | 2 +- .../eventhub/_transport/_pyamqp_transport.py | 70 ------------ .../eventhub/_transport/_uamqp_transport.py | 72 ------------ .../azure/eventhub/amqp/_amqp_message.py | 5 +- .../tests/livetest/synctests/test_auth.py | 40 +++++-- .../synctests/test_consumer_client.py | 57 +++++++--- .../tests/livetest/synctests/test_negative.py | 47 +++++--- .../livetest/synctests/test_properties.py | 39 +++++-- .../tests/livetest/synctests/test_receive.py | 52 ++++++--- .../livetest/synctests/test_reconnect.py | 79 ++++++++----- .../tests/livetest/synctests/test_send.py | 105 ++++++++++++------ .../tests/unittest/test_event_data.py | 23 ++-- 16 files changed, 308 insertions(+), 300 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/__init__.py b/sdk/eventhub/azure-eventhub/azure/eventhub/__init__.py index 1eb40ba60d64..5272d4fecd5f 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/__init__.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/__init__.py @@ -10,8 +10,7 @@ from ._constants import TransportType from ._producer_client import EventHubProducerClient from ._consumer_client import EventHubConsumerClient -# TODO in pyamqp: from ._client_base import EventHubSharedKeyCredential -from ._transport._pyamqp_transport import EventHubSharedKeyCredential +from ._client_base import EventHubSharedKeyCredential from ._eventprocessor.checkpoint_store import CheckpointStore from ._eventprocessor.common import CloseReason, LoadBalancingStrategy from ._eventprocessor.partition_context import PartitionContext diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py index e5694b2e32b2..4134e5be7159 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py @@ -292,7 +292,10 @@ def __init__( if not self._uamqp_transport: self._amqp_transport = PyamqpTransport() else: - self._amqp_transport = UamqpTransport() + try: + self._amqp_transport = UamqpTransport() + except ImportError: + raise ImportError("uamqp package is not installed") self.eventhub_name = eventhub_name if not eventhub_name: @@ -579,9 +582,9 @@ def _handle_exception(self, exception): "Authorization timeout.", condition=errors.ErrorCondition.InternalError, ) - return self._amqp_transport._handle_exception( + return self._amqp_transport._handle_exception( # pylint: disable=protected-access exception, self - ) # pylint: disable=protected-access + ) def _do_retryable_operation(self, operation, timeout=None, **kwargs): # pylint:disable=protected-access diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_connection_manager.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_connection_manager.py index e76c17506d3e..cdcfbdddd3f4 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_connection_manager.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_connection_manager.py @@ -88,6 +88,7 @@ def close_connection(self): self._conn.close() self._conn = None + # TODO: fix and add uamqp stuff def reset_connection_if_broken(self): # type: () -> None with self._lock: diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py index bb58adecbeb9..62f9c48348f4 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py @@ -201,8 +201,8 @@ def receive(self, batch=False, max_batch_size=300, max_wait_time=None): try: if self._open(): # TODO: for pyamqp, this will pass in batch. But, in the ReceiveClient._client_run, - # can pass (batch=self._link_credit) - self._handler.do_work() # type: ignore + # can remove (batch=self._link_credit)? + self._handler.do_work(batch=self._prefetch) # type: ignore break except Exception as exception: # pylint: disable=broad-except if ( diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py index d4825747facf..551e610a9df4 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py @@ -658,7 +658,7 @@ def _client_run(self, **kwargs): :rtype: bool """ try: - self._connection.listen(wait=self._socket_timeout, batch=self._link_credit, **kwargs) + self._connection.listen(wait=self._socket_timeout, **kwargs) except ValueError: _logger.info("Timeout reached, closing receiver.") self._shutdown = True diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_pyamqp_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_pyamqp_transport.py index 7e78178a5eec..c0a10f97abfd 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_pyamqp_transport.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_pyamqp_transport.py @@ -3,11 +3,8 @@ # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- -import time import logging -from datetime import timedelta from typing import TYPE_CHECKING, Optional, Union, Any -from azure.core.credentials import AccessToken from .._pyamqp import ( error as errors, @@ -40,71 +37,12 @@ EventHubError, AuthenticationError, ConnectionLostError, - EventDataError, EventDataSendError, ) -if TYPE_CHECKING: - from azure.core.credentials import AzureNamedKeyCredential - _LOGGER = logging.getLogger(__name__) -def _generate_sas_token(uri, policy, key, expiry=None): - # type: (str, str, str, Optional[timedelta]) -> AccessToken - """Create a shared access signature token as a string literal. - :returns: SAS token as string literal. - :rtype: str - """ - if not expiry: - expiry = timedelta(hours=1) # Default to 1 hour. - - abs_expiry = int(time.time()) + expiry.seconds - - token = utils.generate_sas_token(uri, policy, key, abs_expiry).encode() - return AccessToken(token=token, expires_on=abs_expiry) - - -class EventHubSharedKeyCredential(object): - """The shared access key credential used for authentication. - - :param str policy: The name of the shared access policy. - :param str key: The shared access key. - """ - - def __init__(self, policy, key): - # type: (str, str) -> None - self.policy = policy - self.key = key - self.token_type = b"servicebus.windows.net:sastoken" - - def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument - # type: (str, Any) -> AccessToken - if not scopes: - raise ValueError("No token scope provided.") - return _generate_sas_token(scopes[0], self.policy, self.key) - - -class EventhubAzureNamedKeyTokenCredential(object): - """The named key credential used for authentication. - - :param credential: The AzureNamedKeyCredential that should be used. - :type credential: ~azure.core.credentials.AzureNamedKeyCredential - """ - - def __init__(self, azure_named_key_credential): - # type: (AzureNamedKeyCredential) -> None - self._credential = azure_named_key_credential - self.token_type = b"servicebus.windows.net:sastoken" - - def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument - # type: (str, Any) -> AccessToken - if not scopes: - raise ValueError("No token scope provided.") - name, key = self._credential.named_key - return _generate_sas_token(scopes[0], name, key) - - class PyamqpTransport(AmqpTransport): """ Class which defines uamqp-based methods used by the producer and consumer. @@ -190,14 +128,6 @@ def to_outgoing_amqp_message(self, annotated_message): return Message(**message_dict) - @classmethod - def create_named_key_token_credential(cls, credential): - return EventhubAzureNamedKeyTokenCredential(credential) - - @classmethod - def create_shared_key_credential(cls, policy, key): - return EventHubSharedKeyCredential(policy, key) - def get_batch_message_encoded_size(self, message): """ Gets the batch message encoded size given an underlying Message. diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py index a379dd778bed..a9f3b2172f2f 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py @@ -5,10 +5,7 @@ import time import logging -from datetime import timedelta from typing import TYPE_CHECKING, Optional, Union, Any -from urllib.parse import quote_plus -from azure.core.credentials import AccessToken from uamqp import ( BatchMessage, @@ -49,9 +46,6 @@ EventDataSendError, ) -if TYPE_CHECKING: - from azure.core.credentials import AzureNamedKeyCredential - _LOGGER = logging.getLogger(__name__) def _error_handler(error): @@ -78,64 +72,6 @@ def _error_handler(error): return errors.ErrorAction(retry=True) -def _generate_sas_token(uri, policy, key, expiry=None): - # type: (str, str, str, Optional[timedelta]) -> AccessToken - """Create a shared access signature token as a string literal. - :returns: SAS token as string literal. - :rtype: str - """ - if not expiry: - expiry = timedelta(hours=1) # Default to 1 hour. - - abs_expiry = int(time.time()) + expiry.seconds - encoded_uri = quote_plus(uri).encode("utf-8") # pylint: disable=no-member - encoded_policy = quote_plus(policy).encode("utf-8") # pylint: disable=no-member - encoded_key = key.encode("utf-8") - - token = utils.create_sas_token(encoded_policy, encoded_key, encoded_uri, expiry) - return AccessToken(token=token, expires_on=abs_expiry) - - -class EventHubSharedKeyCredential(object): - """The shared access key credential used for authentication. - - :param str policy: The name of the shared access policy. - :param str key: The shared access key. - """ - - def __init__(self, policy, key): - # type: (str, str) -> None - self.policy = policy - self.key = key - self.token_type = b"servicebus.windows.net:sastoken" - - def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument - # type: (str, Any) -> AccessToken - if not scopes: - raise ValueError("No token scope provided.") - return _generate_sas_token(scopes[0], self.policy, self.key) - - -class EventhubAzureNamedKeyTokenCredential(object): - """The named key credential used for authentication. - - :param credential: The AzureNamedKeyCredential that should be used. - :type credential: ~azure.core.credentials.AzureNamedKeyCredential - """ - - def __init__(self, azure_named_key_credential): - # type: (AzureNamedKeyCredential) -> None - self._credential = azure_named_key_credential - self.token_type = b"servicebus.windows.net:sastoken" - - def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument - # type: (str, Any) -> AccessToken - if not scopes: - raise ValueError("No token scope provided.") - name, key = self._credential.named_key - return _generate_sas_token(scopes[0], name, key) - - class UamqpTransport(AmqpTransport): """ Class which defines uamqp-based methods used by the producer and consumer. @@ -221,14 +157,6 @@ def to_outgoing_amqp_message(self, annotated_message): footer=annotated_message.footer ) - @classmethod - def create_named_key_token_credential(cls, credential): - return EventhubAzureNamedKeyTokenCredential(credential) - - @classmethod - def create_shared_key_credential(cls, policy, key): - return EventHubSharedKeyCredential(policy, key) - def get_batch_message_encoded_size(self, message): """ Gets the batch message encoded size given an underlying Message. diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py b/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py index 75a88ba41bed..acd515a06be2 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py @@ -236,16 +236,15 @@ def _from_amqp_message(self, message): self._delivery_annotations = message.delivery_annotations if message.delivery_annotations else {} self._application_properties = message.application_properties if message.application_properties else {} if message.data: - self._body = message.data + self._body = list(message.data) self._body_type = AmqpMessageBodyType.DATA elif message.sequence: - self._body = message.sequence + self._body = list(message.sequence) self._body_type = AmqpMessageBodyType.SEQUENCE else: self._body = message.value self._body_type = AmqpMessageBodyType.VALUE - @property def body(self): # type: () -> Any diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_auth.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_auth.py index c00ea84067ea..f00d034a02ed 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_auth.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_auth.py @@ -13,18 +13,23 @@ from azure.core.credentials import AzureSasCredential, AzureNamedKeyCredential +@pytest.mark.parametrize("uamqp_transport", + [True, False]) @pytest.mark.liveTest -def test_client_secret_credential(live_eventhub): +def test_client_secret_credential(live_eventhub, uamqp_transport): credential = EnvironmentCredential() producer_client = EventHubProducerClient(fully_qualified_namespace=live_eventhub['hostname'], eventhub_name=live_eventhub['event_hub'], credential=credential, - user_agent='customized information') + user_agent='customized information', + uamqp_transport=uamqp_transport) consumer_client = EventHubConsumerClient(fully_qualified_namespace=live_eventhub['hostname'], eventhub_name=live_eventhub['event_hub'], consumer_group='$default', credential=credential, - user_agent='customized information') + user_agent='customized information', + uamqp_transport=uamqp_transport + ) with producer_client: batch = producer_client.create_batch(partition_id='0') batch.add(EventData(body='A single message')) @@ -50,11 +55,15 @@ def on_event(partition_context, event): assert list(on_event.event.body)[0] == 'A single message'.encode('utf-8') +@pytest.mark.parametrize("uamqp_transport", + [True, False]) @pytest.mark.liveTest -def test_client_sas_credential(live_eventhub): +def test_client_sas_credential(live_eventhub, uamqp_transport): # This should "just work" to validate known-good. hostname = live_eventhub['hostname'] - producer_client = EventHubProducerClient.from_connection_string(live_eventhub['connection_str'], eventhub_name = live_eventhub['event_hub']) + producer_client = EventHubProducerClient.from_connection_string( + live_eventhub['connection_str'], eventhub_name = live_eventhub['event_hub'], uamqp_transport=uamqp_transport + ) with producer_client: batch = producer_client.create_batch(partition_id='0') @@ -67,7 +76,8 @@ def test_client_sas_credential(live_eventhub): token = credential.get_token(auth_uri).token producer_client = EventHubProducerClient(fully_qualified_namespace=hostname, eventhub_name=live_eventhub['event_hub'], - credential=EventHubSASTokenCredential(token, time.time() + 3000)) + credential=EventHubSASTokenCredential(token, time.time() + 3000), + uamqp_transport=uamqp_transport) with producer_client: batch = producer_client.create_batch(partition_id='0') @@ -77,7 +87,8 @@ def test_client_sas_credential(live_eventhub): # Finally let's do it with SAS token + conn str token_conn_str = "Endpoint=sb://{}/;SharedAccessSignature={};".format(hostname, token.decode()) conn_str_producer_client = EventHubProducerClient.from_connection_string(token_conn_str, - eventhub_name=live_eventhub['event_hub']) + eventhub_name=live_eventhub['event_hub'], + uamqp_transport=uamqp_transport) with conn_str_producer_client: batch = conn_str_producer_client.create_batch(partition_id='0') @@ -85,11 +96,15 @@ def test_client_sas_credential(live_eventhub): conn_str_producer_client.send_batch(batch) +@pytest.mark.parametrize("uamqp_transport", + [True, False]) @pytest.mark.liveTest -def test_client_azure_sas_credential(live_eventhub): +def test_client_azure_sas_credential(live_eventhub, uamqp_transport): # This should "just work" to validate known-good. hostname = live_eventhub['hostname'] - producer_client = EventHubProducerClient.from_connection_string(live_eventhub['connection_str'], eventhub_name = live_eventhub['event_hub']) + producer_client = EventHubProducerClient.from_connection_string( + live_eventhub['connection_str'], eventhub_name = live_eventhub['event_hub'], uamqp_transport=uamqp_transport + ) with producer_client: batch = producer_client.create_batch(partition_id='0') @@ -110,14 +125,17 @@ def test_client_azure_sas_credential(live_eventhub): producer_client.send_batch(batch) +@pytest.mark.parametrize("uamqp_transport", + [True, False]) @pytest.mark.liveTest -def test_client_azure_named_key_credential(live_eventhub): +def test_client_azure_named_key_credential(live_eventhub, uamqp_transport): credential = AzureNamedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key']) consumer_client = EventHubConsumerClient(fully_qualified_namespace=live_eventhub['hostname'], eventhub_name=live_eventhub['event_hub'], consumer_group='$default', credential=credential, - user_agent='customized information') + user_agent='customized information', + uamqp_transport=uamqp_transport) assert consumer_client.get_eventhub_properties() is not None diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_consumer_client.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_consumer_client.py index 9e09cd156ef8..062edea4a4e9 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_consumer_client.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_consumer_client.py @@ -8,12 +8,19 @@ from azure.eventhub._constants import ALL_PARTITIONS +@pytest.mark.parametrize("uamqp_transport", + [True, False]) @pytest.mark.liveTest -def test_receive_no_partition(connstr_senders): +def test_receive_no_partition(connstr_senders, uamqp_transport): connection_str, senders = connstr_senders senders[0].send(EventData("Test EventData")) senders[1].send(EventData("Test EventData")) - client = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default', receive_timeout=1) + client = EventHubConsumerClient.from_connection_string( + connection_str, + consumer_group='$default', + receive_timeout=1, + uamqp_transport=uamqp_transport + ) def on_event(partition_context, event): on_event.received += 1 @@ -45,11 +52,15 @@ def on_event(partition_context, event): assert len([checkpoint for checkpoint in checkpoints if checkpoint["sequence_number"] == on_event.sequence_number]) > 0 +@pytest.mark.parametrize("uamqp_transport", + [True, False]) @pytest.mark.liveTest -def test_receive_partition(connstr_senders): +def test_receive_partition(connstr_senders, uamqp_transport): connection_str, senders = connstr_senders senders[0].send(EventData("Test EventData")) - client = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default') + client = EventHubConsumerClient.from_connection_string( + connection_str, consumer_group='$default', uamqp_transport=uamqp_transport + ) def on_event(partition_context, event): on_event.received += 1 @@ -73,17 +84,21 @@ def on_event(partition_context, event): assert on_event.eventhub_name == senders[0]._client.eventhub_name +@pytest.mark.parametrize("uamqp_transport", + [True, False]) @pytest.mark.liveTest -def test_receive_load_balancing(connstr_senders): +def test_receive_load_balancing(connstr_senders, uamqp_transport): if sys.platform.startswith('darwin'): pytest.skip("Skipping on OSX - test code using multiple threads. Sometimes OSX aborts python process") connection_str, senders = connstr_senders cs = InMemoryCheckpointStore() client1 = EventHubConsumerClient.from_connection_string( - connection_str, consumer_group='$default', checkpoint_store=cs, load_balancing_interval=1) + connection_str, consumer_group='$default', checkpoint_store=cs, load_balancing_interval=1, uamqp_transport=uamqp_transport + ) client2 = EventHubConsumerClient.from_connection_string( - connection_str, consumer_group='$default', checkpoint_store=cs, load_balancing_interval=1) + connection_str, consumer_group='$default', checkpoint_store=cs, load_balancing_interval=1, uamqp_transport=uamqp_transport + ) def on_event(partition_context, event): pass @@ -105,13 +120,17 @@ def on_event(partition_context, event): assert len(client2._event_processors[("$default", ALL_PARTITIONS)]._consumers) == 1 -def test_receive_batch_no_max_wait_time(connstr_senders): +@pytest.mark.parametrize("uamqp_transport", + [True, False]) +def test_receive_batch_no_max_wait_time(connstr_senders, uamqp_transport): '''Test whether callback is called when max_wait_time is None and max_batch_size has reached ''' connection_str, senders = connstr_senders senders[0].send(EventData("Test EventData")) senders[1].send(EventData("Test EventData")) - client = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default') + client = EventHubConsumerClient.from_connection_string( + connection_str, consumer_group='$default', uamqp_transport=uamqp_transport + ) def on_event_batch(partition_context, event_batch): on_event_batch.received += len(event_batch) @@ -146,14 +165,16 @@ def on_event_batch(partition_context, event_batch): worker.join() -@pytest.mark.parametrize("max_wait_time, sleep_time, expected_result", - [(3, 10, []), - (3, 2, None), +@pytest.mark.parametrize("max_wait_time, sleep_time, expected_result, uamqp_transport", + [(3, 10, [], True), + (3, 2, None, True), + (3, 10, [], False), + (3, 2, None, False), ]) -def test_receive_batch_empty_with_max_wait_time(connection_str, max_wait_time, sleep_time, expected_result): +def test_receive_batch_empty_with_max_wait_time(connection_str, max_wait_time, sleep_time, expected_result, uamqp_transport): '''Test whether event handler is called when max_wait_time > 0 and no event is received ''' - client = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default') + client = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default', uamqp_transport=uamqp_transport) def on_event_batch(partition_context, event_batch): on_event_batch.event_batch = event_batch @@ -168,13 +189,17 @@ def on_event_batch(partition_context, event_batch): worker.join() -def test_receive_batch_early_callback(connstr_senders): +@pytest.mark.parametrize("uamqp_transport", + [True, False]) +def test_receive_batch_early_callback(connstr_senders, uamqp_transport): ''' Test whether the callback is called once max_batch_size reaches and before max_wait_time reaches. ''' connection_str, senders = connstr_senders for _ in range(10): senders[0].send(EventData("Test EventData")) - client = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default') + client = EventHubConsumerClient.from_connection_string( + connection_str, consumer_group='$default', uamqp_transport=uamqp_transport + ) def on_event_batch(partition_context, event_batch): on_event_batch.received += len(event_batch) diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_negative.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_negative.py index a33943fd5fdf..303df40070e8 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_negative.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_negative.py @@ -20,15 +20,17 @@ from azure.eventhub import EventHubConsumerClient from azure.eventhub import EventHubProducerClient from azure.eventhub._transport._uamqp_transport import UamqpTransport +from azure.eventhub._transport._pyamqp_transport import PyamqpTransport -@pytest.mark.parametrize("amqp_transport", [UamqpTransport(), ], ) +@pytest.mark.parametrize("uamqp_transport", [True, False]) @pytest.mark.liveTest -def test_send_batch_with_invalid_hostname(invalid_hostname, amqp_transport): +def test_send_batch_with_invalid_hostname(invalid_hostname, uamqp_transport): + amqp_transport = UamqpTransport() if uamqp_transport else PyamqpTransport() if sys.platform.startswith('darwin'): pytest.skip("Skipping on OSX - it keeps reporting 'Unable to set external certificates' " "and blocking other tests") - client = EventHubProducerClient.from_connection_string(invalid_hostname) + client = EventHubProducerClient.from_connection_string(invalid_hostname, uamqp_transport=uamqp_transport) with client: with pytest.raises(ConnectError): batch = EventDataBatch(amqp_transport=amqp_transport) @@ -36,12 +38,15 @@ def test_send_batch_with_invalid_hostname(invalid_hostname, amqp_transport): client.send_batch(batch) +@pytest.mark.parametrize("uamqp_transport", [True, False]) @pytest.mark.liveTest -def test_receive_with_invalid_hostname_sync(invalid_hostname): +def test_receive_with_invalid_hostname_sync(invalid_hostname, uamqp_transport): def on_event(partition_context, event): pass - client = EventHubConsumerClient.from_connection_string(invalid_hostname, consumer_group='$default') + client = EventHubConsumerClient.from_connection_string( + invalid_hostname, consumer_group='$default', uamqp_transport=uamqp_transport + ) with client: thread = threading.Thread(target=client.receive, args=(on_event, )) @@ -51,10 +56,11 @@ def on_event(partition_context, event): thread.join() -@pytest.mark.parametrize("amqp_transport", [UamqpTransport()], ) +@pytest.mark.parametrize("uamqp_transport", [True, False]) @pytest.mark.liveTest -def test_send_batch_with_invalid_key(invalid_key, amqp_transport): - client = EventHubProducerClient.from_connection_string(invalid_key) +def test_send_batch_with_invalid_key(invalid_key, uamqp_transport): + client = EventHubProducerClient.from_connection_string(invalid_key, uamqp_transport=uamqp_transport) + amqp_transport = UamqpTransport() if uamqp_transport else PyamqpTransport() try: with pytest.raises(ConnectError): batch = EventDataBatch(amqp_transport=amqp_transport) @@ -64,11 +70,12 @@ def test_send_batch_with_invalid_key(invalid_key, amqp_transport): client.close() +@pytest.mark.parametrize("uamqp_transport", [True, False]) @pytest.mark.liveTest -def test_send_batch_to_invalid_partitions(connection_str): +def test_send_batch_to_invalid_partitions(connection_str, uamqp_transport): partitions = ["XYZ", "-1", "1000", "-"] for p in partitions: - client = EventHubProducerClient.from_connection_string(connection_str) + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) try: with pytest.raises(ConnectError): batch = client.create_batch(partition_id=p) @@ -78,11 +85,12 @@ def test_send_batch_to_invalid_partitions(connection_str): client.close() +@pytest.mark.parametrize("uamqp_transport", [True, False]) @pytest.mark.liveTest -def test_send_batch_too_large_message(connection_str): +def test_send_batch_too_large_message(connection_str, uamqp_transport): if sys.platform.startswith('darwin'): pytest.skip("Skipping on OSX - open issue regarding message size") - client = EventHubProducerClient.from_connection_string(connection_str) + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) try: data = EventData(b"A" * 1100000) batch = client.create_batch() @@ -92,9 +100,10 @@ def test_send_batch_too_large_message(connection_str): client.close() +@pytest.mark.parametrize("uamqp_transport", [True, False]) @pytest.mark.liveTest -def test_send_batch_null_body(connection_str): - client = EventHubProducerClient.from_connection_string(connection_str) +def test_send_batch_null_body(connection_str, uamqp_transport): + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) try: with pytest.raises(ValueError): data = EventData(None) @@ -105,20 +114,22 @@ def test_send_batch_null_body(connection_str): client.close() +@pytest.mark.parametrize("uamqp_transport", [True, False]) @pytest.mark.liveTest -def test_create_batch_with_invalid_hostname_sync(invalid_hostname): +def test_create_batch_with_invalid_hostname_sync(invalid_hostname, uamqp_transport): if sys.platform.startswith('darwin'): pytest.skip("Skipping on OSX - it keeps reporting 'Unable to set external certificates' " "and blocking other tests") - client = EventHubProducerClient.from_connection_string(invalid_hostname) + client = EventHubProducerClient.from_connection_string(invalid_hostname, uamqp_transport=uamqp_transport) with client: with pytest.raises(ConnectError): client.create_batch(max_size_in_bytes=300) +@pytest.mark.parametrize("uamqp_transport", [True, False]) @pytest.mark.liveTest -def test_create_batch_with_too_large_size_sync(connection_str): - client = EventHubProducerClient.from_connection_string(connection_str) +def test_create_batch_with_too_large_size_sync(connection_str, uamqp_transport): + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) with client: with pytest.raises(ValueError): client.create_batch(max_size_in_bytes=5 * 1024 * 1024) diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_properties.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_properties.py index eb197eec44b0..b9349f60b250 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_properties.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_properties.py @@ -11,58 +11,73 @@ from azure.eventhub.exceptions import AuthenticationError, ConnectError, EventHubError +@pytest.mark.parametrize("uamqp_transport", [True, False]) @pytest.mark.liveTest -def test_get_properties(live_eventhub): +def test_get_properties(live_eventhub, uamqp_transport): client = EventHubConsumerClient(live_eventhub['hostname'], live_eventhub['event_hub'], '$default', - EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key'])) + EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key']), + uamqp_transport=uamqp_transport + ) with client: properties = client.get_eventhub_properties() assert properties['eventhub_name'] == live_eventhub['event_hub'] and properties['partition_ids'] == ['0', '1'] +@pytest.mark.parametrize("uamqp_transport", [True, False]) @pytest.mark.liveTest -def test_get_properties_with_auth_error_sync(live_eventhub): +def test_get_properties_with_auth_error_sync(live_eventhub, uamqp_transport): client = EventHubConsumerClient(live_eventhub['hostname'], live_eventhub['event_hub'], '$default', - EventHubSharedKeyCredential(live_eventhub['key_name'], "AaBbCcDdEeFf=")) + EventHubSharedKeyCredential(live_eventhub['key_name'], "AaBbCcDdEeFf="), + uamqp_transport=uamqp_transport + ) with client: with pytest.raises(AuthenticationError) as e: client.get_eventhub_properties() client = EventHubConsumerClient(live_eventhub['hostname'], live_eventhub['event_hub'], '$default', - EventHubSharedKeyCredential("invalid", live_eventhub['access_key']) + EventHubSharedKeyCredential("invalid", live_eventhub['access_key']), uamqp_transport=uamqp_transport ) with client: with pytest.raises(AuthenticationError) as e: client.get_eventhub_properties() +@pytest.mark.parametrize("uamqp_transport", [True, False]) @pytest.mark.liveTest -def test_get_properties_with_connect_error(live_eventhub): +def test_get_properties_with_connect_error(live_eventhub, uamqp_transport): client = EventHubConsumerClient(live_eventhub['hostname'], "invalid", '$default', - EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key']) + EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key']), + uamqp_transport=uamqp_transport ) with client: with pytest.raises(ConnectError) as e: client.get_eventhub_properties() client = EventHubConsumerClient("invalid.servicebus.windows.net", live_eventhub['event_hub'], '$default', - EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key']) + EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key']), + uamqp_transport=uamqp_transport ) with client: with pytest.raises(EventHubError) as e: # This can be either ConnectError or ConnectionLostError client.get_eventhub_properties() +@pytest.mark.parametrize("uamqp_transport", [True, False]) @pytest.mark.liveTest -def test_get_partition_ids(live_eventhub): +def test_get_partition_ids(live_eventhub, uamqp_transport): client = EventHubConsumerClient(live_eventhub['hostname'], live_eventhub['event_hub'], '$default', - EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key'])) + EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key']), + uamqp_transport=uamqp_transport + ) with client: partition_ids = client.get_partition_ids() assert partition_ids == ['0', '1'] +@pytest.mark.parametrize("uamqp_transport", [True, False]) @pytest.mark.liveTest -def test_get_partition_properties(live_eventhub): +def test_get_partition_properties(live_eventhub, uamqp_transport): client = EventHubConsumerClient(live_eventhub['hostname'], live_eventhub['event_hub'], '$default', - EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key'])) + EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key']), + uamqp_transport=uamqp_transport + ) with client: properties = client.get_partition_properties('0') assert properties['eventhub_name'] == live_eventhub['event_hub'] \ diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_receive.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_receive.py index 7b3e12c22d65..1a58255f90d8 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_receive.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_receive.py @@ -14,8 +14,10 @@ from azure.eventhub.exceptions import EventHubError +@pytest.mark.parametrize("uamqp_transport", + [True, False]) @pytest.mark.liveTest -def test_receive_end_of_stream(connstr_senders): +def test_receive_end_of_stream(connstr_senders, uamqp_transport): def on_event(partition_context, event): if partition_context.partition_id == "0": assert event.body_as_str() == "Receiving only a single event" @@ -29,7 +31,9 @@ def on_event(partition_context, event): assert ", partition_key: 0" in event_str on_event.called = False connection_str, senders = connstr_senders - client = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default') + client = EventHubConsumerClient.from_connection_string( + connection_str, consumer_group='$default', uamqp_transport=uamqp_transport + ) with client: thread = threading.Thread(target=client.receive, args=(on_event,), kwargs={"partition_id": "0", "starting_position": "@latest"}) @@ -43,14 +47,20 @@ def on_event(partition_context, event): thread.join() -@pytest.mark.parametrize("position, inclusive, expected_result", - [("offset", False, "Exclusive"), - ("offset", True, "Inclusive"), - ("sequence", False, "Exclusive"), - ("sequence", True, "Inclusive"), - ("enqueued_time", False, "Exclusive")]) +@pytest.mark.parametrize("position, inclusive, expected_result, uamqp_transport", + [("offset", False, "Exclusive", True), + ("offset", True, "Inclusive", True), + ("sequence", False, "Exclusive", True), + ("sequence", True, "Inclusive", True), + ("enqueued_time", False, "Exclusive", True), + ("offset", False, "Exclusive", False), + ("offset", True, "Inclusive", False), + ("sequence", False, "Exclusive", False), + ("sequence", True, "Inclusive", False), + ("enqueued_time", False, "Exclusive", False) + ]) @pytest.mark.liveTest -def test_receive_with_event_position_sync(connstr_senders, position, inclusive, expected_result): +def test_receive_with_event_position_sync(connstr_senders, position, inclusive, expected_result, uamqp_transport): def on_event(partition_context, event): assert partition_context.last_enqueued_event_properties.get('sequence_number') == event.sequence_number assert partition_context.last_enqueued_event_properties.get('offset') == event.offset @@ -69,7 +79,9 @@ def on_event(partition_context, event): connection_str, senders = connstr_senders senders[0].send(EventData(b"Inclusive")) senders[1].send(EventData(b"Inclusive")) - client = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default') + client = EventHubConsumerClient.from_connection_string( + connection_str, consumer_group='$default', uamqp_transport=uamqp_transport + ) with client: thread = threading.Thread(target=client.receive, args=(on_event,), kwargs={"starting_position": "-1", @@ -82,7 +94,9 @@ def on_event(partition_context, event): thread.join() senders[0].send(EventData(expected_result)) senders[1].send(EventData(expected_result)) - client2 = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default') + client2 = EventHubConsumerClient.from_connection_string( + connection_str, consumer_group='$default', uamqp_transport=uamqp_transport + ) with client2: thread = threading.Thread(target=client2.receive, args=(on_event,), kwargs={"starting_position": on_event.event_position, @@ -95,8 +109,10 @@ def on_event(partition_context, event): thread.join() +@pytest.mark.parametrize("uamqp_transport", + [True, False]) @pytest.mark.liveTest -def test_receive_owner_level(connstr_senders): +def test_receive_owner_level(connstr_senders, uamqp_transport): def on_event(partition_context, event): pass def on_error(partition_context, error): @@ -104,8 +120,8 @@ def on_error(partition_context, error): on_error.error = None connection_str, senders = connstr_senders - client1 = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default') - client2 = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default') + client1 = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default', uamqp_transport=uamqp_transport) + client2 = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default', uamqp_transport=uamqp_transport) with client1, client2: thread1 = threading.Thread(target=client1.receive, args=(on_event,), kwargs={"partition_id": "0", "starting_position": "-1", @@ -127,9 +143,10 @@ def on_error(partition_context, error): assert isinstance(on_error.error, EventHubError) +@pytest.mark.parametrize("uamqp_transport", + [True, False]) @pytest.mark.liveTest -def test_receive_over_websocket_sync(connstr_senders): - pytest.skip("websocket not supported") +def test_receive_over_websocket_sync(connstr_senders, uamqp_transport): app_prop = {"raw_prop": "raw_value"} content_type = "text/plain" message_id_base = "mess_id_sample_" @@ -143,7 +160,8 @@ def on_event(partition_context, event): connection_str, senders = connstr_senders client = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default', - transport_type=TransportType.AmqpOverWebsocket) + transport_type=TransportType.AmqpOverWebsocket, + uamqp_transport=uamqp_transport) event_list = [] for i in range(5): diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_reconnect.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_reconnect.py index d9a3e9fa24d1..68671b51efea 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_reconnect.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_reconnect.py @@ -12,6 +12,7 @@ from azure.eventhub._pyamqp import error, constants import uamqp +from uamqp import errors, compat from azure.eventhub import ( EventData, @@ -23,13 +24,18 @@ from azure.eventhub.exceptions import OperationTimeoutError from azure.eventhub._utils import transform_outbound_single_message from azure.eventhub._transport._uamqp_transport import UamqpTransport +from azure.eventhub._transport._pyamqp_transport import PyamqpTransport +@pytest.mark.parametrize("uamqp_transport", + [True, False]) @pytest.mark.liveTest -def test_send_with_long_interval_sync(live_eventhub, sleep): +def test_send_with_long_interval_sync(live_eventhub, sleep, uamqp_transport): test_partition = "0" sender = EventHubProducerClient(live_eventhub['hostname'], live_eventhub['event_hub'], - EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key'])) + EventHubSharedKeyCredential(live_eventhub['key_name'], + live_eventhub['access_key']), uamqp_transport=uamqp_transport + ) with sender: batch = sender.create_batch(partition_id=test_partition) batch.add(EventData(b"A single event")) @@ -68,59 +74,73 @@ def test_send_with_long_interval_sync(live_eventhub, sleep): # TODO: fix and add pyamqp transport -@pytest.mark.parametrize("amqp_transport", - [UamqpTransport()]) +@pytest.mark.parametrize("uamqp_transport", + [True]) @pytest.mark.liveTest -def test_send_connection_idle_timeout_and_reconnect_sync(connstr_receivers, amqp_transport): +def test_send_connection_idle_timeout_and_reconnect_sync(connstr_receivers, uamqp_transport): connection_str, receivers = connstr_receivers + amqp_transport = UamqpTransport() if uamqp_transport else PyamqpTransport() # no retry, should just raise error - client = EventHubProducerClient.from_connection_string(conn_str=connection_str, idle_timeout=10, retry_total=0) + client = EventHubProducerClient.from_connection_string( + conn_str=connection_str, idle_timeout=10, retry_total=0, uamqp_transport=uamqp_transport + ) with client: ed = EventData('data') sender = client._create_producer(partition_id='0') with sender: sender._open_with_retry() time.sleep(11) - sender._unsent_events = [ed.message] - with pytest.raises(error.AMQPConnectionError): - sender._send_event_data() - - # with retry, should work - client = EventHubProducerClient.from_connection_string(conn_str=connection_str, idle_timeout=10) - with client: - ed = EventData('data') - sender = client._create_producer(partition_id='0') - with sender: - sender._open_with_retry() - time.sleep(11) - ed = transform_outbound_single_message(ed, EventData, amqp_transport.to_outgoing_amqp_message) - sender._unsent_events = [ed.message] - ed.message.on_send_complete = sender._on_outcome + wrapped_ed = sender._wrap_eventdata(ed, None, None) + sender._unsent_events = [wrapped_ed.message] + if uamqp_transport: + sender._unsent_events[0].on_send_complete = sender._on_outcome with pytest.raises((uamqp.errors.ConnectionClose, - uamqp.errors.MessageHandlerError, OperationTimeoutError)): - # Mac may raise OperationTimeoutError or MessageHandlerError + uamqp.errors.MessageHandlerError, OperationTimeoutError)): + sender._send_event_data() + else: + with pytest.raises(error.AMQPConnectionError): sender._send_event_data() + if uamqp_transport: sender._send_event_data_with_retry() + # pyamqp - with retry, should work + if not uamqp_transport: + client = EventHubProducerClient.from_connection_string( + conn_str=connection_str, idle_timeout=10, uamqp_transport=uamqp_transport + ) + with client: + ed = EventData('data') + sender = client._create_producer(partition_id='0') + with sender: + sender._open_with_retry() + time.sleep(11) + ed = transform_outbound_single_message(ed, EventData, amqp_transport.to_outgoing_amqp_message) + sender._unsent_events = [ed.message] + sender._send_event_data() + retry = 0 while retry < 3: try: - messages = receivers[0].receive_message_batch(max_batch_size=10, timeout=10) + timeout = 10000 if uamqp_transport else 10 + messages = receivers[0].receive_message_batch(max_batch_size=10, timeout=timeout) if messages: received_ed1 = EventData._from_message(messages[0]) assert received_ed1.body_as_str() == 'data' break - except TimeoutError: + except (compat.TimeoutException, TimeoutError): retry += 1 +@pytest.mark.parametrize("uamqp_transport", + [True, False]) @pytest.mark.liveTest -def test_receive_connection_idle_timeout_and_reconnect_sync(connstr_senders): +def test_receive_connection_idle_timeout_and_reconnect_sync(connstr_senders, uamqp_transport): connection_str, senders = connstr_senders client = EventHubConsumerClient.from_connection_string( conn_str=connection_str, consumer_group='$default', - idle_timeout=10 + idle_timeout=10, + uamqp_transport=uamqp_transport ) def on_event_received(event): @@ -135,7 +155,10 @@ def on_event_received(event): senders[0].send(ed) consumer._handler.do_work() - assert consumer._handler._connection.state == constants.ConnectionState.END + if uamqp_transport: + assert consumer._handler._connection._state == uamqp.c_uamqp.ConnectionState.DISCARDING + else: + assert consumer._handler._connection.state == constants.ConnectionState.END duration = 10 now_time = time.time() diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_send.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_send.py index 3c7d7b6b99ad..e43e55689d60 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_send.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_send.py @@ -21,11 +21,14 @@ AmqpMessageProperties, ) from azure.eventhub._transport._uamqp_transport import UamqpTransport +from azure.eventhub._transport._pyamqp_transport import PyamqpTransport +@pytest.mark.parametrize("uamqp_transport", + [True, False]) @pytest.mark.liveTest -def test_send_with_partition_key(connstr_receivers): +def test_send_with_partition_key(connstr_receivers, uamqp_transport): connection_str, receivers = connstr_receivers - client = EventHubProducerClient.from_connection_string(connection_str) + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) with client: data_val = 0 for partition in [b"a", b"b", b"c", b"d", b"e", b"f"]: @@ -50,12 +53,14 @@ def test_send_with_partition_key(connstr_receivers): found_partition_keys[event_data.partition_key] = index +@pytest.mark.parametrize("uamqp_transport", + [True, False]) @pytest.mark.liveTest -def test_send_and_receive_large_body_size(connstr_receivers): +def test_send_and_receive_large_body_size(connstr_receivers, uamqp_transport): if sys.platform.startswith('darwin'): pytest.skip("Skipping on OSX - open issue regarding message size") connection_str, receivers = connstr_receivers - client = EventHubProducerClient.from_connection_string(connection_str) + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) with client: payload = 250 * 1024 batch = client.create_batch() @@ -70,10 +75,12 @@ def test_send_and_receive_large_body_size(connstr_receivers): assert len(list(received[0].body)[0]) == payload +@pytest.mark.parametrize("uamqp_transport", + [True, False]) @pytest.mark.liveTest -def test_send_amqp_annotated_message(connstr_receivers): +def test_send_amqp_annotated_message(connstr_receivers, uamqp_transport): connection_str, receivers = connstr_receivers - client = EventHubProducerClient.from_connection_string(connection_str) + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) with client: sequence_body = [b'message', 123.456, True] footer = {'footer_key': 'footer_value'} @@ -109,7 +116,7 @@ def test_send_amqp_annotated_message(connstr_receivers): ) body_ed = """{"json_key": "json_val"}""" - prop_ed = {"raw_prop": "raw_value"} + prop_ed = {b"raw_prop": b"raw_value"} cont_type_ed = "text/plain" corr_id_ed = "corr_id" mess_id_ed = "mess_id" @@ -117,6 +124,7 @@ def test_send_amqp_annotated_message(connstr_receivers): event_data.content_type = cont_type_ed event_data.correlation_id = corr_id_ed event_data.message_id = mess_id_ed + event_data.properties = prop_ed batch = client.create_batch() batch.add(data_message) @@ -147,6 +155,7 @@ def check_values(event): assert event.correlation_id == corr_id_ed assert event.message_id == mess_id_ed assert event.content_type == cont_type_ed + assert event.properties == prop_ed assert event.body_type == AmqpMessageBodyType.DATA received_count["normal_msg"] += 1 elif raw_amqp_message.body_type == AmqpMessageBodyType.SEQUENCE: @@ -169,7 +178,8 @@ def on_event(partition_context, event): on_event.received = [] client = EventHubConsumerClient.from_connection_string(connection_str, - consumer_group='$default') + consumer_group='$default', + uamqp_transport=uamqp_transport) with client: thread = threading.Thread(target=client.receive, args=(on_event,), kwargs={"starting_position": "-1"}) @@ -185,12 +195,12 @@ def on_event(partition_context, event): assert received_count["normal_msg"] == 2 -@pytest.mark.parametrize("payload", - [b"", b"A single event"]) +@pytest.mark.parametrize("payload, uamqp_transport", + [(b"", True), (b"", False), (b"A single event", True), (b"A single event", False)]) @pytest.mark.liveTest -def test_send_and_receive_small_body(connstr_receivers, payload): +def test_send_and_receive_small_body(connstr_receivers, payload, uamqp_transport): connection_str, receivers = connstr_receivers - client = EventHubProducerClient.from_connection_string(connection_str) + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) with client: batch = client.create_batch() batch.add(EventData(payload)) @@ -203,10 +213,12 @@ def test_send_and_receive_small_body(connstr_receivers, payload): assert list(received[0].body)[0] == payload +@pytest.mark.parametrize("uamqp_transport", + [True, False]) @pytest.mark.liveTest -def test_send_partition(connstr_receivers): +def test_send_partition(connstr_receivers, uamqp_transport): connection_str, receivers = connstr_receivers - client = EventHubProducerClient.from_connection_string(connection_str) + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) with client: batch = client.create_batch() @@ -237,10 +249,12 @@ def test_send_partition(connstr_receivers): assert len(partition_0) + len(partition_1) == 2 +@pytest.mark.parametrize("uamqp_transport", + [True, False]) @pytest.mark.liveTest -def test_send_non_ascii(connstr_receivers): +def test_send_non_ascii(connstr_receivers, uamqp_transport): connection_str, receivers = connstr_receivers - client = EventHubProducerClient.from_connection_string(connection_str) + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) with client: batch = client.create_batch(partition_id="0") batch.add(EventData(u"é,è,à,ù,â,ê,î,ô,û")) @@ -257,13 +271,15 @@ def test_send_non_ascii(connstr_receivers): assert partition_0[1].body_as_json() == {"foo": u"漢字"} +@pytest.mark.parametrize("uamqp_transport", + [True, False]) @pytest.mark.liveTest -def test_send_multiple_partitions_with_app_prop(connstr_receivers): +def test_send_multiple_partitions_with_app_prop(connstr_receivers, uamqp_transport): connection_str, receivers = connstr_receivers app_prop_key = "raw_prop" app_prop_value = "raw_value" app_prop = {app_prop_key: app_prop_value} - client = EventHubProducerClient.from_connection_string(connection_str) + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) with client: ed0 = EventData(b"Message 0") ed0.properties = app_prop @@ -285,11 +301,14 @@ def test_send_multiple_partitions_with_app_prop(connstr_receivers): assert partition_1[0].properties[b"raw_prop"] == b"raw_value" +@pytest.mark.parametrize("uamqp_transport", + [True, False]) @pytest.mark.liveTest -def test_send_over_websocket_sync(connstr_receivers): - pytest.skip("websocket not supported") +def test_send_over_websocket_sync(connstr_receivers, uamqp_transport): connection_str, receivers = connstr_receivers - client = EventHubProducerClient.from_connection_string(connection_str, transport_type=TransportType.AmqpOverWebsocket) + client = EventHubProducerClient.from_connection_string( + connection_str, transport_type=TransportType.AmqpOverWebsocket, uamqp_transport=uamqp_transport + ) with client: batch = client.create_batch(partition_id="0") @@ -302,13 +321,17 @@ def test_send_over_websocket_sync(connstr_receivers): assert len(received) == 1 +@pytest.mark.parametrize("uamqp_transport", + [True, False]) @pytest.mark.liveTest -def test_send_with_create_event_batch_with_app_prop_sync(connstr_receivers): +def test_send_with_create_event_batch_with_app_prop_sync(connstr_receivers, uamqp_transport): connection_str, receivers = connstr_receivers app_prop_key = "raw_prop" app_prop_value = "raw_value" app_prop = {app_prop_key: app_prop_value} - client = EventHubProducerClient.from_connection_string(connection_str, transport_type=TransportType.AmqpOverWebsocket) + client = EventHubProducerClient.from_connection_string( + connection_str, transport_type=TransportType.AmqpOverWebsocket, uamqp_transport=uamqp_transport + ) with client: event_data_batch = client.create_batch(max_size_in_bytes=100000) while True: @@ -326,10 +349,12 @@ def test_send_with_create_event_batch_with_app_prop_sync(connstr_receivers): assert EventData._from_message(received[0]).properties[b"raw_prop"] == b"raw_value" +@pytest.mark.parametrize("uamqp_transport", + [True, False]) @pytest.mark.liveTest -def test_send_list(connstr_receivers): +def test_send_list(connstr_receivers, uamqp_transport): connection_str, receivers = connstr_receivers - client = EventHubProducerClient.from_connection_string(connection_str) + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) payload = "A1" with client: client.send_batch([EventData(payload)]) @@ -341,10 +366,12 @@ def test_send_list(connstr_receivers): assert received[0].body_as_str() == payload +@pytest.mark.parametrize("uamqp_transport", + [True, False]) @pytest.mark.liveTest -def test_send_list_partition(connstr_receivers): +def test_send_list_partition(connstr_receivers, uamqp_transport): connection_str, receivers = connstr_receivers - client = EventHubProducerClient.from_connection_string(connection_str) + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) payload = "A1" with client: client.send_batch([EventData(payload)], partition_id="0") @@ -353,22 +380,30 @@ def test_send_list_partition(connstr_receivers): assert received.body_as_str() == payload -@pytest.mark.parametrize("to_send, exception_type", - [([EventData("A"*1024)]*1100, ValueError), - ("any str", AttributeError) +@pytest.mark.parametrize("to_send, exception_type, uamqp_transport", + [([EventData("A"*1024)]*1100, ValueError, True), + ("any str", AttributeError, True), + ([EventData("A"*1024)]*1100, ValueError, False), + ("any str", AttributeError, False) ]) @pytest.mark.liveTest -def test_send_list_wrong_data(connection_str, to_send, exception_type): - client = EventHubProducerClient.from_connection_string(connection_str) +def test_send_list_wrong_data(connection_str, to_send, exception_type, uamqp_transport): + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) with client: with pytest.raises(exception_type): client.send_batch(to_send) -@pytest.mark.parametrize("partition_id, partition_key, amqp_transport", [("0", None, UamqpTransport()), (None, "pk", UamqpTransport())], ) -def test_send_batch_pid_pk(invalid_hostname, partition_id, partition_key, amqp_transport): +@pytest.mark.parametrize("partition_id, partition_key, uamqp_transport", [ + ("0", None, True), + (None, "pk", True), + ("0", None, False), + (None, "pk", False)] +) +def test_send_batch_pid_pk(invalid_hostname, partition_id, partition_key, uamqp_transport): # Use invalid_hostname because this is not a live test. - client = EventHubProducerClient.from_connection_string(invalid_hostname) + amqp_transport = UamqpTransport() if uamqp_transport else PyamqpTransport() + client = EventHubProducerClient.from_connection_string(invalid_hostname, uamqp_transport=uamqp_transport) batch = EventDataBatch(partition_id=partition_id, partition_key=partition_key, amqp_transport=amqp_transport) with client: with pytest.raises(TypeError): diff --git a/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py b/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py index 455ccc7ef16b..747750caa861 100644 --- a/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py +++ b/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py @@ -1,7 +1,8 @@ import platform import pytest -from packaging import version -from azure.eventhub._transport._uamqp_transport import UamqpTransport +import uamqp +from azure.eventhub._transport._uamqp_transport import UamqpTransport +from azure.eventhub._transport._pyamqp_transport import PyamqpTransport from azure.eventhub.amqp import AmqpAnnotatedMessage from azure.eventhub import _common from azure.eventhub._pyamqp.message import Message, Properties @@ -109,24 +110,26 @@ def test_sys_properties(): assert ed.system_properties[_common.PROP_REPLY_TO_GROUP_ID] == properties.reply_to_group_id -def test_event_data_batch(): - batch = EventDataBatch(max_size_in_bytes=110, partition_key="par", amqp_transport=UamqpTransport()) +# TODO: see why pyamqp went from 99 to 87 +@pytest.mark.parametrize("amqp_transport, expected_result", + [(UamqpTransport(), 101), (PyamqpTransport(), 87)]) +def test_event_data_batch(amqp_transport, expected_result): + batch = EventDataBatch(max_size_in_bytes=110, partition_key="par", amqp_transport=amqp_transport) batch.add(EventData("A")) assert str(batch) == "EventDataBatch(max_size_in_bytes=110, partition_id=None, partition_key='par', event_count=1)" assert repr(batch) == "EventDataBatch(max_size_in_bytes=110, partition_id=None, partition_key='par', event_count=1)" # TODO: uamqp uses 93 bytes for encode, while python amqp uses 99 bytes # we should understand why extra bytes are needed to encode the content and how it could be improved - assert batch.size_in_bytes == 99 and len(batch) == 1 + assert batch.size_in_bytes == expected_result and len(batch) == 1 with pytest.raises(ValueError): batch.add(EventData("A")) -# TODO: fix and add uamqp -def test_event_data_from_message(): - #message = uamqp.message.Message('A') - message = Message(data=b'A') +@pytest.mark.parametrize("message, expected_result", + [(uamqp.Message('A'), [b'A']), (Message(data=b'A'), [65])]) +def test_event_data_from_message(message, expected_result): event = EventData._from_message(message) assert event.content_type is None assert event.correlation_id is None @@ -138,7 +141,7 @@ def test_event_data_from_message(): assert event.content_type == 'content_type' assert event.correlation_id == 'correlation_id' assert event.message_id == 'message_id' - assert event.body == b'A' + assert list(event.body) == expected_result def test_amqp_message_str_repr(): From 0e6272c2fcd781c1bebae21859df7c5a7f64377b Mon Sep 17 00:00:00 2001 From: swathipil Date: Mon, 25 Jul 2022 18:49:05 -0700 Subject: [PATCH 17/21] fix bug with pyamqp transport sending --- .../eventhub/_transport/_pyamqp_transport.py | 28 ++++++++++--------- .../livetest/synctests/test_reconnect.py | 18 +++++++----- 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_pyamqp_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_pyamqp_transport.py index c0a10f97abfd..334a8ed27869 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_pyamqp_transport.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_pyamqp_transport.py @@ -4,6 +4,7 @@ # -------------------------------------------------------------------------------------------- import logging +import time from typing import TYPE_CHECKING, Optional, Union, Any from .._pyamqp import ( @@ -218,19 +219,20 @@ def send_messages(self, producer, timeout_time, last_exception, logger): :param logger: Logger. """ # pylint: disable=protected-access - # TODO: see if this works too - # timeout = timeout_time - time.time() if timeout_time else 0 - # producer._handler.send_message(producer._unsent_events[0], timeout=timeout) - # producer._unsent_events = None - try: - producer._open() - producer._handler.send_message( - producer._unsent_events[0], timeout=timeout_time - ) - except self.TIMEOUT_EXCEPTION as exc: - raise OperationTimeoutError(message=str(exc), details=exc) - except Exception as exc: - raise producer._handle_exception(exc) + producer._open() + timeout = timeout_time - time.time() if timeout_time else 0 + producer._handler.send_message(producer._unsent_events[0], timeout=timeout) + producer._unsent_events = None + # TODO: figure out if we want to use below, and see if it affects error story + #try: + # producer._open() + # producer._handler.send_message( + # producer._unsent_events[0], timeout=timeout_time + # ) + #except self.TIMEOUT_EXCEPTION as exc: + # raise OperationTimeoutError(message=str(exc), details=exc) + #except Exception as exc: + # raise producer._handle_exception(exc) def set_message_partition_key( self, message, partition_key, **kwargs diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_reconnect.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_reconnect.py index 68671b51efea..b86b5cc0d0a2 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_reconnect.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_reconnect.py @@ -12,14 +12,13 @@ from azure.eventhub._pyamqp import error, constants import uamqp -from uamqp import errors, compat +from uamqp import compat from azure.eventhub import ( EventData, EventHubSharedKeyCredential, EventHubProducerClient, EventHubConsumerClient, - amqp ) from azure.eventhub.exceptions import OperationTimeoutError from azure.eventhub._utils import transform_outbound_single_message @@ -75,14 +74,19 @@ def test_send_with_long_interval_sync(live_eventhub, sleep, uamqp_transport): # TODO: fix and add pyamqp transport @pytest.mark.parametrize("uamqp_transport", - [True]) + [True, False]) @pytest.mark.liveTest def test_send_connection_idle_timeout_and_reconnect_sync(connstr_receivers, uamqp_transport): connection_str, receivers = connstr_receivers - amqp_transport = UamqpTransport() if uamqp_transport else PyamqpTransport() + if uamqp_transport: + amqp_transport = UamqpTransport() + retry_total = 3 + else: + amqp_transport = PyamqpTransport() + retry_total = 0 # no retry, should just raise error client = EventHubProducerClient.from_connection_string( - conn_str=connection_str, idle_timeout=10, retry_total=0, uamqp_transport=uamqp_transport + conn_str=connection_str, idle_timeout=10, retry_total=retry_total, uamqp_transport=uamqp_transport ) with client: ed = EventData('data') @@ -90,8 +94,8 @@ def test_send_connection_idle_timeout_and_reconnect_sync(connstr_receivers, uamq with sender: sender._open_with_retry() time.sleep(11) - wrapped_ed = sender._wrap_eventdata(ed, None, None) - sender._unsent_events = [wrapped_ed.message] + ed = transform_outbound_single_message(ed, EventData, amqp_transport.to_outgoing_amqp_message) + sender._unsent_events = [ed.message] if uamqp_transport: sender._unsent_events[0].on_send_complete = sender._on_outcome with pytest.raises((uamqp.errors.ConnectionClose, From dbd304400b6a8aa99e89614b1eaa02728374f2bc Mon Sep 17 00:00:00 2001 From: swathipil Date: Mon, 25 Jul 2022 19:16:35 -0700 Subject: [PATCH 18/21] adding things back for async unittests --- .../azure-eventhub/azure/eventhub/_utils.py | 26 +++++++++++++++++++ .../azure/eventhub/exceptions.py | 26 +++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py index 185eb2b81c35..49a6110db3cc 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py @@ -27,8 +27,14 @@ PROP_RUNTIME_INFO_RETRIEVAL_TIME_UTC, PROP_LAST_ENQUEUED_OFFSET, PROP_TIMESTAMP, + PROP_PARTITION_KEY ) +from uamqp import types +from uamqp.message import MessageHeader + +PROP_PARTITION_KEY_AMQP_SYMBOL = types.AMQPSymbol(PROP_PARTITION_KEY) + if TYPE_CHECKING: # pylint: disable=ungrouped-imports from ._transport._base import AmqpTransport @@ -117,6 +123,26 @@ def send_context_manager(): else: yield None +# TODO: delete after async unit tests have been refactored +def set_message_partition_key(message, partition_key): + # type: (Message, Optional[Union[bytes, str]]) -> None + """Set the partition key as an annotation on a uamqp message. + :param ~uamqp.Message message: The message to update. + :param str partition_key: The partition key value. + :rtype: None + """ + if partition_key: + annotations = message.annotations + if annotations is None: + annotations = dict() + annotations[ + PROP_PARTITION_KEY_AMQP_SYMBOL + ] = partition_key # pylint:disable=protected-access + header = MessageHeader() + header.durable = True + message.annotations = annotations + message.header = header + def trace_message(event, parent_span=None): # type: (EventData, Optional[AbstractSpan]) -> None diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/exceptions.py b/sdk/eventhub/azure-eventhub/azure/eventhub/exceptions.py index f686251e6e95..1dbcedeab559 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/exceptions.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/exceptions.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- import six +from uamqp import errors, compat class EventHubError(Exception): """Represents an error occurred in the client. @@ -93,3 +94,28 @@ class OperationTimeoutError(EventHubError): class OwnershipLostError(Exception): """Raised when `update_checkpoint` detects the ownership to a partition has been lost.""" + +# TODO: delete when async unittests have been refactored +def _create_eventhub_exception(exception): + if isinstance(exception, errors.AuthenticationException): + error = AuthenticationError(str(exception), exception) + elif isinstance(exception, errors.VendorLinkDetach): + error = ConnectError(str(exception), exception) + elif isinstance(exception, errors.LinkDetach): + error = ConnectionLostError(str(exception), exception) + elif isinstance(exception, errors.ConnectionClose): + error = ConnectionLostError(str(exception), exception) + elif isinstance(exception, errors.MessageHandlerError): + error = ConnectionLostError(str(exception), exception) + elif isinstance(exception, errors.AMQPConnectionError): + error_type = ( + AuthenticationError + if str(exception).startswith("Unable to open authentication session") + else ConnectError + ) + error = error_type(str(exception), exception) + elif isinstance(exception, compat.TimeoutException): + error = ConnectionLostError(str(exception), exception) + else: + error = EventHubError(str(exception), exception) + return error From 1f5cb736fc58cee1d53d675990553f2bac20717e Mon Sep 17 00:00:00 2001 From: swathipil Date: Tue, 26 Jul 2022 11:49:47 -0700 Subject: [PATCH 19/21] run tests w/ + w/o uamqp installed --- .../azure/eventhub/_client_base.py | 7 +- .../eventhub/_transport/_uamqp_transport.py | 983 +++++++++--------- .../azure-eventhub/azure/eventhub/_utils.py | 11 +- .../azure/eventhub/exceptions.py | 6 +- sdk/eventhub/azure-eventhub/tests/__init__.py | 4 + .../azure-eventhub/tests/_test_case.py | 11 + .../azure-eventhub/tests/livetest/__init__.py | 4 + .../tests/livetest/synctests/__init__.py | 4 + .../tests/livetest/synctests/test_auth.py | 10 +- .../synctests/test_consumer_client.py | 25 +- .../tests/livetest/synctests/test_negative.py | 24 +- .../livetest/synctests/test_properties.py | 13 +- .../tests/livetest/synctests/test_receive.py | 30 +- .../livetest/synctests/test_reconnect.py | 19 +- .../tests/livetest/synctests/test_send.py | 51 +- .../azure-eventhub/tests/unittest/__init__.py | 4 + .../tests/unittest/test_event_data.py | 114 +- 17 files changed, 703 insertions(+), 617 deletions(-) create mode 100644 sdk/eventhub/azure-eventhub/tests/__init__.py create mode 100644 sdk/eventhub/azure-eventhub/tests/_test_case.py create mode 100644 sdk/eventhub/azure-eventhub/tests/livetest/__init__.py create mode 100644 sdk/eventhub/azure-eventhub/tests/livetest/synctests/__init__.py create mode 100644 sdk/eventhub/azure-eventhub/tests/unittest/__init__.py diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py index 4134e5be7159..c939d6f2ff6a 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py @@ -23,7 +23,10 @@ from azure.core.pipeline.policies import RetryMode -from ._transport._uamqp_transport import UamqpTransport +try: + from ._transport._uamqp_transport import UamqpTransport +except ImportError: + UamqpTransport = None from ._transport._pyamqp_transport import PyamqpTransport from .exceptions import ClientClosedError from ._configuration import Configuration @@ -294,7 +297,7 @@ def __init__( else: try: self._amqp_transport = UamqpTransport() - except ImportError: + except TypeError: raise ImportError("uamqp package is not installed") self.eventhub_name = eventhub_name diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py index a9f3b2172f2f..f8a36f536806 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py @@ -7,25 +7,29 @@ import logging from typing import TYPE_CHECKING, Optional, Union, Any -from uamqp import ( - BatchMessage, - constants, - MessageBodyType, - Message, - types, - SendClient, - ReceiveClient, - Source, - utils, - authentication, - AMQPClient, - compat, - errors, -) -from uamqp.message import ( - MessageHeader, - MessageProperties, -) +try: + from uamqp import ( + BatchMessage, + constants, + MessageBodyType, + Message, + types, + SendClient, + ReceiveClient, + Source, + utils, + authentication, + AMQPClient, + compat, + errors, + ) + from uamqp.message import ( + MessageHeader, + MessageProperties, + ) + uamqp_installed = True +except ImportError: + uamqp_installed = False from ._base import AmqpTransport from ..amqp._constants import AmqpMessageBodyType @@ -48,487 +52,488 @@ _LOGGER = logging.getLogger(__name__) -def _error_handler(error): - """ - Called internally when an event has failed to send so we - can parse the error to determine whether we should attempt - to retry sending the event again. - Returns the action to take according to error type. - - :param error: The error received in the send attempt. - :type error: Exception - :rtype: ~uamqp.errors.ErrorAction - """ - if error.condition == b"com.microsoft:server-busy": - return errors.ErrorAction(retry=True, backoff=4) - if error.condition == b"com.microsoft:timeout": - return errors.ErrorAction(retry=True, backoff=2) - if error.condition == b"com.microsoft:operation-cancelled": +if uamqp_installed: + def _error_handler(error): + """ + Called internally when an event has failed to send so we + can parse the error to determine whether we should attempt + to retry sending the event again. + Returns the action to take according to error type. + + :param error: The error received in the send attempt. + :type error: Exception + :rtype: ~uamqp.errors.ErrorAction + """ + if error.condition == b"com.microsoft:server-busy": + return errors.ErrorAction(retry=True, backoff=4) + if error.condition == b"com.microsoft:timeout": + return errors.ErrorAction(retry=True, backoff=2) + if error.condition == b"com.microsoft:operation-cancelled": + return errors.ErrorAction(retry=True) + if error.condition == b"com.microsoft:container-close": + return errors.ErrorAction(retry=True, backoff=4) + if error.condition in NO_RETRY_ERRORS: + return errors.ErrorAction(retry=False) return errors.ErrorAction(retry=True) - if error.condition == b"com.microsoft:container-close": - return errors.ErrorAction(retry=True, backoff=4) - if error.condition in NO_RETRY_ERRORS: - return errors.ErrorAction(retry=False) - return errors.ErrorAction(retry=True) - - -class UamqpTransport(AmqpTransport): - """ - Class which defines uamqp-based methods used by the producer and consumer. - """ - # define constants - BATCH_MESSAGE = BatchMessage - MAX_FRAME_SIZE_BYTES = constants.MAX_MESSAGE_LENGTH_BYTES - IDLE_TIMEOUT_FACTOR = 1000 - MESSAGE = Message - - # define symbols - PRODUCT_SYMBOL = types.AMQPSymbol("product") - VERSION_SYMBOL = types.AMQPSymbol("version") - FRAMEWORK_SYMBOL = types.AMQPSymbol("framework") - PLATFORM_SYMBOL = types.AMQPSymbol("platform") - USER_AGENT_SYMBOL = types.AMQPSymbol("user-agent") - PROP_PARTITION_KEY_AMQP_SYMBOL = types.AMQPSymbol(PROP_PARTITION_KEY) - - # define errors and conditions - AMQP_LINK_ERROR = errors.LinkDetach - LINK_STOLEN_CONDITION = constants.ErrorCodes.LinkStolen - AUTH_EXCEPTION = errors.AuthenticationException - CONNECTION_ERROR = ConnectError - AMQP_CONNECTION_ERROR = errors.AMQPConnectionError - TIMEOUT_EXCEPTION = compat.TimeoutException - - def to_outgoing_amqp_message(self, annotated_message): - """ - Converts an AmqpAnnotatedMessage into an Amqp Message. - :param AmqpAnnotatedMessage annotated_message: AmqpAnnotatedMessage to convert. - :rtype: uamqp.Message - """ - message_header = None - if annotated_message.header: - message_header = MessageHeader() - message_header.delivery_count = annotated_message.header.delivery_count - message_header.time_to_live = annotated_message.header.time_to_live - message_header.first_acquirer = annotated_message.header.first_acquirer - message_header.durable = annotated_message.header.durable - message_header.priority = annotated_message.header.priority - - message_properties = None - if annotated_message.properties: - message_properties = MessageProperties( - message_id=annotated_message.properties.message_id, - user_id=annotated_message.properties.user_id, - to=annotated_message.properties.to, - subject=annotated_message.properties.subject, - reply_to=annotated_message.properties.reply_to, - correlation_id=annotated_message.properties.correlation_id, - content_type=annotated_message.properties.content_type, - content_encoding=annotated_message.properties.content_encoding, - creation_time=int(annotated_message.properties.creation_time) - if annotated_message.properties.creation_time else None, - absolute_expiry_time=int(annotated_message.properties.absolute_expiry_time) - if annotated_message.properties.absolute_expiry_time else None, - group_id=annotated_message.properties.group_id, - group_sequence=annotated_message.properties.group_sequence, - reply_to_group_id=annotated_message.properties.reply_to_group_id, - encoding=annotated_message._encoding # pylint: disable=protected-access - ) - - amqp_body_type = annotated_message.body_type # pylint: disable=protected-access - amqp_body = annotated_message.body - if amqp_body_type == AmqpMessageBodyType.DATA: - amqp_body_type = MessageBodyType.Data - amqp_body = list(amqp_body) - elif amqp_body_type == AmqpMessageBodyType.SEQUENCE: - amqp_body_type = MessageBodyType.Sequence - amqp_body = list(amqp_body) - else: - # amqp_body_type is type of AmqpMessageBodyType.VALUE - amqp_body_type = MessageBodyType.Value - - return Message( - body=amqp_body, - body_type=amqp_body_type, - header=message_header, - properties=message_properties, - application_properties=annotated_message.application_properties, - annotations=annotated_message.annotations, - delivery_annotations=annotated_message.delivery_annotations, - footer=annotated_message.footer - ) - - def get_batch_message_encoded_size(self, message): - """ - Gets the batch message encoded size given an underlying Message. - :param uamqp.BatchMessage message: Message to get encoded size of. - :rtype: int - """ - return message.gather()[0].get_message_encoded_size() - - def get_message_encoded_size(self, message): - """ - Gets the message encoded size given an underlying Message. - :param uamqp.Message message: Message to get encoded size of. - :rtype: int - """ - return message.get_message_encoded_size() - - def get_remote_max_message_size(self, handler): - """ - Returns max peer message size. - :param AMQPClient handler: Client to get remote max message size on link from. - :rtype: int - """ - return handler.message_handler._link.peer_max_message_size # pylint:disable=protected-access - def create_retry_policy(self, config): - """ - Creates the error retry policy. - :param ~azure.eventhub._configuration.Configuration config: Configuration. - """ - return errors.ErrorPolicy(max_retries=config.max_retries, on_error=_error_handler) - def create_link_properties(self, link_properties): - """ - Creates and returns the link properties. - :param dict[bytes, int] link_properties: The dict of symbols and corresponding values. - :rtype: dict - """ - return {types.AMQPSymbol(symbol): types.AMQPLong(value) for (symbol, value) in link_properties.items()} - - def create_send_client(self, *, config, **kwargs): # pylint:disable=unused-argument - """ - Creates and returns the uamqp SendClient. - :param ~azure.eventhub._configuration.Configuration config: The configuration. - - :keyword str target: Required. The target. - :keyword JWTTokenAuth auth: Required. - :keyword int idle_timeout: Required. - :keyword network_trace: Required. - :keyword retry_policy: Required. - :keyword keep_alive_interval: Required. - :keyword str client_name: Required. - :keyword dict link_properties: Required. - :keyword properties: Required. - """ - target = kwargs.pop("target") - retry_policy = kwargs.pop("retry_policy") - network_trace = kwargs.pop("network_trace") - - return SendClient( - target, - debug=network_trace, # pylint:disable=protected-access - error_policy=retry_policy, - **kwargs - ) - - def _set_msg_timeout(self, producer, timeout_time, last_exception, logger): - if not timeout_time: - return - remaining_time = timeout_time - time.time() - if remaining_time <= 0.0: - if last_exception: - error = last_exception + class UamqpTransport(AmqpTransport): + """ + Class which defines uamqp-based methods used by the producer and consumer. + """ + # define constants + BATCH_MESSAGE = BatchMessage + MAX_FRAME_SIZE_BYTES = constants.MAX_MESSAGE_LENGTH_BYTES + IDLE_TIMEOUT_FACTOR = 1000 + MESSAGE = Message + + # define symbols + PRODUCT_SYMBOL = types.AMQPSymbol("product") + VERSION_SYMBOL = types.AMQPSymbol("version") + FRAMEWORK_SYMBOL = types.AMQPSymbol("framework") + PLATFORM_SYMBOL = types.AMQPSymbol("platform") + USER_AGENT_SYMBOL = types.AMQPSymbol("user-agent") + PROP_PARTITION_KEY_AMQP_SYMBOL = types.AMQPSymbol(PROP_PARTITION_KEY) + + # define errors and conditions + AMQP_LINK_ERROR = errors.LinkDetach + LINK_STOLEN_CONDITION = constants.ErrorCodes.LinkStolen + AUTH_EXCEPTION = errors.AuthenticationException + CONNECTION_ERROR = ConnectError + AMQP_CONNECTION_ERROR = errors.AMQPConnectionError + TIMEOUT_EXCEPTION = compat.TimeoutException + + def to_outgoing_amqp_message(self, annotated_message): + """ + Converts an AmqpAnnotatedMessage into an Amqp Message. + :param AmqpAnnotatedMessage annotated_message: AmqpAnnotatedMessage to convert. + :rtype: uamqp.Message + """ + message_header = None + if annotated_message.header: + message_header = MessageHeader() + message_header.delivery_count = annotated_message.header.delivery_count + message_header.time_to_live = annotated_message.header.time_to_live + message_header.first_acquirer = annotated_message.header.first_acquirer + message_header.durable = annotated_message.header.durable + message_header.priority = annotated_message.header.priority + + message_properties = None + if annotated_message.properties: + message_properties = MessageProperties( + message_id=annotated_message.properties.message_id, + user_id=annotated_message.properties.user_id, + to=annotated_message.properties.to, + subject=annotated_message.properties.subject, + reply_to=annotated_message.properties.reply_to, + correlation_id=annotated_message.properties.correlation_id, + content_type=annotated_message.properties.content_type, + content_encoding=annotated_message.properties.content_encoding, + creation_time=int(annotated_message.properties.creation_time) + if annotated_message.properties.creation_time else None, + absolute_expiry_time=int(annotated_message.properties.absolute_expiry_time) + if annotated_message.properties.absolute_expiry_time else None, + group_id=annotated_message.properties.group_id, + group_sequence=annotated_message.properties.group_sequence, + reply_to_group_id=annotated_message.properties.reply_to_group_id, + encoding=annotated_message._encoding # pylint: disable=protected-access + ) + + amqp_body_type = annotated_message.body_type # pylint: disable=protected-access + amqp_body = annotated_message.body + if amqp_body_type == AmqpMessageBodyType.DATA: + amqp_body_type = MessageBodyType.Data + amqp_body = list(amqp_body) + elif amqp_body_type == AmqpMessageBodyType.SEQUENCE: + amqp_body_type = MessageBodyType.Sequence + amqp_body = list(amqp_body) else: - error = OperationTimeoutError("Send operation timed out") - logger.info("%r send operation timed out. (%r)", producer._name, error) # pylint: disable=protected-access - raise error - producer._handler._msg_timeout = remaining_time * 1000 # type: ignore # pylint: disable=protected-access - - def send_messages(self, producer, timeout_time, last_exception, logger): - """ - Handles sending of event data messages. - :param ~azure.eventhub._producer.EventHubProducer producer: The producer with handler to send messages. - :param int timeout_time: Timeout time. - :param last_exception: Exception to raise if message timed out. Only used by uamqp transport. - :param logger: Logger. - """ - # pylint: disable=protected-access - producer._open() - producer._unsent_events[0].on_send_complete = producer._on_outcome - self._set_msg_timeout(producer, timeout_time, last_exception, logger) - producer._handler.queue_message(*producer._unsent_events) # type: ignore - producer._handler.wait() # type: ignore - producer._unsent_events = producer._handler.pending_messages # type: ignore - if producer._outcome != constants.MessageSendResult.Ok: - if producer._outcome == constants.MessageSendResult.Timeout: - producer._condition = OperationTimeoutError("Send operation timed out") - if producer._condition: - raise producer._condition - - # TODO: can delete this method, if data prop is added to uamqp.BatchMessage - #def get_batch_message_data(self, batch_message): - # """ - # Gets the data body of the BatchMessage. - # :param batch_message: BatchMessage to retrieve data body from. - # """ - # return batch_message._body_gen # pylint:disable=protected-access - - def set_message_partition_key(self, message, partition_key, **kwargs): # pylint:disable=unused-argument - # type: (Message, Optional[Union[bytes, str]], Any) -> Message - """Set the partition key as an annotation on a uamqp message. - - :param ~uamqp.Message message: The message to update. - :param str partition_key: The partition key value. - :rtype: Message - """ - if partition_key: - annotations = message.annotations - if annotations is None: - annotations = {} - annotations[ - UamqpTransport.PROP_PARTITION_KEY_AMQP_SYMBOL # TODO: see if setting non-amqp symbol is valid - ] = partition_key - header = MessageHeader() - header.durable = True - message.annotations = annotations - message.header = header - return message - - def add_batch(self, batch_message, outgoing_event_data, event_data): - """ - Add EventData to the data body of the BatchMessage. - :param batch_message: BatchMessage to add data to. - :param outgoing_event_data: Transformed EventData for sending. - :param event_data: EventData to add to internal batch events. uamqp use only. - :rtype: None - """ - batch_message._internal_events.append(event_data) - batch_message.message._body_gen.append( - outgoing_event_data - ) - - def create_source(self, source, offset, selector): - """ - Creates and returns the Source. - - :param str source: Required. - :param int offset: Required. - :param bytes selector: Required. - """ - source = Source(source) - if offset is not None: - source.set_filter(selector) - return source + # amqp_body_type is type of AmqpMessageBodyType.VALUE + amqp_body_type = MessageBodyType.Value + + return Message( + body=amqp_body, + body_type=amqp_body_type, + header=message_header, + properties=message_properties, + application_properties=annotated_message.application_properties, + annotations=annotated_message.annotations, + delivery_annotations=annotated_message.delivery_annotations, + footer=annotated_message.footer + ) - def create_receive_client(self, *, config, **kwargs): - """ - Creates and returns the receive client. - :param ~azure.eventhub._configuration.Configuration config: The configuration. - - :keyword str source: Required. The source. - :keyword str offset: Required. - :keyword str offset_inclusive: Required. - :keyword JWTTokenAuth auth: Required. - :keyword int idle_timeout: Required. - :keyword network_trace: Required. - :keyword retry_policy: Required. - :keyword str client_name: Required. - :keyword dict link_properties: Required. - :keyword properties: Required. - :keyword link_credit: Required. The prefetch. - :keyword keep_alive_interval: Required. Missing in pyamqp. - :keyword desired_capabilities: Required. - :keyword streaming_receive: Required. - :keyword message_received_callback: Required. - :keyword timeout: Required. - """ + def get_batch_message_encoded_size(self, message): + """ + Gets the batch message encoded size given an underlying Message. + :param uamqp.BatchMessage message: Message to get encoded size of. + :rtype: int + """ + return message.gather()[0].get_message_encoded_size() + + def get_message_encoded_size(self, message): + """ + Gets the message encoded size given an underlying Message. + :param uamqp.Message message: Message to get encoded size of. + :rtype: int + """ + return message.get_message_encoded_size() + + def get_remote_max_message_size(self, handler): + """ + Returns max peer message size. + :param AMQPClient handler: Client to get remote max message size on link from. + :rtype: int + """ + return handler.message_handler._link.peer_max_message_size # pylint:disable=protected-access + + def create_retry_policy(self, config): + """ + Creates the error retry policy. + :param ~azure.eventhub._configuration.Configuration config: Configuration. + """ + return errors.ErrorPolicy(max_retries=config.max_retries, on_error=_error_handler) + + def create_link_properties(self, link_properties): + """ + Creates and returns the link properties. + :param dict[bytes, int] link_properties: The dict of symbols and corresponding values. + :rtype: dict + """ + return {types.AMQPSymbol(symbol): types.AMQPLong(value) for (symbol, value) in link_properties.items()} + + def create_send_client(self, *, config, **kwargs): # pylint:disable=unused-argument + """ + Creates and returns the uamqp SendClient. + :param ~azure.eventhub._configuration.Configuration config: The configuration. + + :keyword str target: Required. The target. + :keyword JWTTokenAuth auth: Required. + :keyword int idle_timeout: Required. + :keyword network_trace: Required. + :keyword retry_policy: Required. + :keyword keep_alive_interval: Required. + :keyword str client_name: Required. + :keyword dict link_properties: Required. + :keyword properties: Required. + """ + target = kwargs.pop("target") + retry_policy = kwargs.pop("retry_policy") + network_trace = kwargs.pop("network_trace") + + return SendClient( + target, + debug=network_trace, # pylint:disable=protected-access + error_policy=retry_policy, + **kwargs + ) - source = kwargs.pop("source") - symbol_array = kwargs.pop("desired_capabilities") - desired_capabilities = None - if symbol_array: - symbol_array = [types.AMQPSymbol(symbol) for symbol in symbol_array] - desired_capabilities = utils.data_factory(types.AMQPArray(symbol_array)) - retry_policy = kwargs.pop("retry_policy") - network_trace = kwargs.pop("network_trace") - link_credit = kwargs.pop("link_credit") - streaming_receive = kwargs.pop("streaming_receive") - message_received_callback = kwargs.pop("message_received_callback") - - client = ReceiveClient( - source, - debug=network_trace, # pylint:disable=protected-access - error_policy=retry_policy, - desired_capabilities=desired_capabilities, - prefetch=link_credit, - receive_settle_mode=constants.ReceiverSettleMode.ReceiveAndDelete, - auto_complete=False, - **kwargs - ) - # pylint:disable=protected-access - client._streaming_receive = streaming_receive - client._message_received_callback = (message_received_callback) - return client - - def open_receive_client(self, *, handler, client, auth): - """ - Opens the receive client and returns ready status. - :param ReceiveClient handler: The receive client. - :param ~azure.eventhub.EventHubConsumerClient client: The consumer client. - :param auth: Auth. - :rtype: bool - """ - # pylint:disable=protected-access - handler.open(connection=client._conn_manager.get_connection( - client._address.hostname, auth - )) + def _set_msg_timeout(self, producer, timeout_time, last_exception, logger): + if not timeout_time: + return + remaining_time = timeout_time - time.time() + if remaining_time <= 0.0: + if last_exception: + error = last_exception + else: + error = OperationTimeoutError("Send operation timed out") + logger.info("%r send operation timed out. (%r)", producer._name, error) # pylint: disable=protected-access + raise error + producer._handler._msg_timeout = remaining_time * 1000 # type: ignore # pylint: disable=protected-access + + def send_messages(self, producer, timeout_time, last_exception, logger): + """ + Handles sending of event data messages. + :param ~azure.eventhub._producer.EventHubProducer producer: The producer with handler to send messages. + :param int timeout_time: Timeout time. + :param last_exception: Exception to raise if message timed out. Only used by uamqp transport. + :param logger: Logger. + """ + # pylint: disable=protected-access + producer._open() + producer._unsent_events[0].on_send_complete = producer._on_outcome + self._set_msg_timeout(producer, timeout_time, last_exception, logger) + producer._handler.queue_message(*producer._unsent_events) # type: ignore + producer._handler.wait() # type: ignore + producer._unsent_events = producer._handler.pending_messages # type: ignore + if producer._outcome != constants.MessageSendResult.Ok: + if producer._outcome == constants.MessageSendResult.Timeout: + producer._condition = OperationTimeoutError("Send operation timed out") + if producer._condition: + raise producer._condition + + # TODO: can delete this method, if data prop is added to uamqp.BatchMessage + #def get_batch_message_data(self, batch_message): + # """ + # Gets the data body of the BatchMessage. + # :param batch_message: BatchMessage to retrieve data body from. + # """ + # return batch_message._body_gen # pylint:disable=protected-access + + def set_message_partition_key(self, message, partition_key, **kwargs): # pylint:disable=unused-argument + # type: (Message, Optional[Union[bytes, str]], Any) -> Message + """Set the partition key as an annotation on a uamqp message. + + :param ~uamqp.Message message: The message to update. + :param str partition_key: The partition key value. + :rtype: Message + """ + if partition_key: + annotations = message.annotations + if annotations is None: + annotations = {} + annotations[ + UamqpTransport.PROP_PARTITION_KEY_AMQP_SYMBOL # TODO: see if setting non-amqp symbol is valid + ] = partition_key + header = MessageHeader() + header.durable = True + message.annotations = annotations + message.header = header + return message + + def add_batch(self, batch_message, outgoing_event_data, event_data): + """ + Add EventData to the data body of the BatchMessage. + :param batch_message: BatchMessage to add data to. + :param outgoing_event_data: Transformed EventData for sending. + :param event_data: EventData to add to internal batch events. uamqp use only. + :rtype: None + """ + batch_message._internal_events.append(event_data) + batch_message.message._body_gen.append( + outgoing_event_data + ) - def create_token_auth(self, auth_uri, get_token, token_type, config, **kwargs): - """ - Creates the JWTTokenAuth. - :param str auth_uri: The auth uri to pass to JWTTokenAuth. - :param get_token: The callback function used for getting and refreshing - tokens. It should return a valid jwt token each time it is called. - :param bytes token_type: Token type. - :param ~azure.eventhub._configuration.Configuration config: EH config. - - :keyword bool update_token: Required. Whether to update token. If not updating token, - then pass 300 to refresh_window. - """ - update_token = kwargs.pop("update_token") - refresh_window = 300 - if update_token: - refresh_window = 0 - - token_auth = authentication.JWTTokenAuth( - auth_uri, - auth_uri, - get_token, - token_type=token_type, - timeout=config.auth_timeout, - http_proxy=config.http_proxy, - transport_type=config.transport_type, - custom_endpoint_hostname=config.custom_endpoint_hostname, - port=config.connection_port, - verify=config.connection_verify, - refresh_window=refresh_window - ) - if update_token: - token_auth.update_token() # TODO: why don't we need to update in pyamqp? - return token_auth - - def create_mgmt_client(self, address, mgmt_auth, config): - """ - Creates and returns the mgmt AMQP client. - :param _Address address: Required. The Address. - :param JWTTokenAuth mgmt_auth: Auth for client. - :param ~azure.eventhub._configuration.Configuration config: The configuration. - """ + def create_source(self, source, offset, selector): + """ + Creates and returns the Source. + + :param str source: Required. + :param int offset: Required. + :param bytes selector: Required. + """ + source = Source(source) + if offset is not None: + source.set_filter(selector) + return source + + def create_receive_client(self, *, config, **kwargs): + """ + Creates and returns the receive client. + :param ~azure.eventhub._configuration.Configuration config: The configuration. + + :keyword str source: Required. The source. + :keyword str offset: Required. + :keyword str offset_inclusive: Required. + :keyword JWTTokenAuth auth: Required. + :keyword int idle_timeout: Required. + :keyword network_trace: Required. + :keyword retry_policy: Required. + :keyword str client_name: Required. + :keyword dict link_properties: Required. + :keyword properties: Required. + :keyword link_credit: Required. The prefetch. + :keyword keep_alive_interval: Required. Missing in pyamqp. + :keyword desired_capabilities: Required. + :keyword streaming_receive: Required. + :keyword message_received_callback: Required. + :keyword timeout: Required. + """ + + source = kwargs.pop("source") + symbol_array = kwargs.pop("desired_capabilities") + desired_capabilities = None + if symbol_array: + symbol_array = [types.AMQPSymbol(symbol) for symbol in symbol_array] + desired_capabilities = utils.data_factory(types.AMQPArray(symbol_array)) + retry_policy = kwargs.pop("retry_policy") + network_trace = kwargs.pop("network_trace") + link_credit = kwargs.pop("link_credit") + streaming_receive = kwargs.pop("streaming_receive") + message_received_callback = kwargs.pop("message_received_callback") + + client = ReceiveClient( + source, + debug=network_trace, # pylint:disable=protected-access + error_policy=retry_policy, + desired_capabilities=desired_capabilities, + prefetch=link_credit, + receive_settle_mode=constants.ReceiverSettleMode.ReceiveAndDelete, + auto_complete=False, + **kwargs + ) + # pylint:disable=protected-access + client._streaming_receive = streaming_receive + client._message_received_callback = (message_received_callback) + return client + + def open_receive_client(self, *, handler, client, auth): + """ + Opens the receive client and returns ready status. + :param ReceiveClient handler: The receive client. + :param ~azure.eventhub.EventHubConsumerClient client: The consumer client. + :param auth: Auth. + :rtype: bool + """ + # pylint:disable=protected-access + handler.open(connection=client._conn_manager.get_connection( + client._address.hostname, auth + )) + + def create_token_auth(self, auth_uri, get_token, token_type, config, **kwargs): + """ + Creates the JWTTokenAuth. + :param str auth_uri: The auth uri to pass to JWTTokenAuth. + :param get_token: The callback function used for getting and refreshing + tokens. It should return a valid jwt token each time it is called. + :param bytes token_type: Token type. + :param ~azure.eventhub._configuration.Configuration config: EH config. + + :keyword bool update_token: Required. Whether to update token. If not updating token, + then pass 300 to refresh_window. + """ + update_token = kwargs.pop("update_token") + refresh_window = 300 + if update_token: + refresh_window = 0 + + token_auth = authentication.JWTTokenAuth( + auth_uri, + auth_uri, + get_token, + token_type=token_type, + timeout=config.auth_timeout, + http_proxy=config.http_proxy, + transport_type=config.transport_type, + custom_endpoint_hostname=config.custom_endpoint_hostname, + port=config.connection_port, + verify=config.connection_verify, + refresh_window=refresh_window + ) + if update_token: + token_auth.update_token() # TODO: why don't we need to update in pyamqp? + return token_auth + + def create_mgmt_client(self, address, mgmt_auth, config): + """ + Creates and returns the mgmt AMQP client. + :param _Address address: Required. The Address. + :param JWTTokenAuth mgmt_auth: Auth for client. + :param ~azure.eventhub._configuration.Configuration config: The configuration. + """ + + mgmt_target = f"amqps://{address.hostname}{address.path}" + return AMQPClient( + mgmt_target, + auth=mgmt_auth, + debug=config.network_tracing + ) - mgmt_target = f"amqps://{address.hostname}{address.path}" - return AMQPClient( - mgmt_target, - auth=mgmt_auth, - debug=config.network_tracing - ) + def get_updated_token(self, mgmt_auth): + """ + Return updated auth token. + :param mgmt_auth: Auth. + """ + return mgmt_auth.token + + def mgmt_client_request(self, mgmt_client, mgmt_msg, **kwargs): + """ + Send mgmt request. + :param AMQP Client mgmt_client: Client to send request with. + :param str mgmt_msg: Message. + :keyword bytes operation: Operation. + :keyword operation_type: Op type. + :keyword status_code_field: mgmt status code. + :keyword description_fields: mgmt status desc. + """ + operation_type = kwargs.pop("operation_type") + operation = kwargs.pop("operation") + return mgmt_client.mgmt_request( + mgmt_msg, + operation, + op_type=operation_type, + **kwargs + ) - def get_updated_token(self, mgmt_auth): - """ - Return updated auth token. - :param mgmt_auth: Auth. - """ - return mgmt_auth.token + def get_error(self, error, message, *, condition=None): # pylint: disable=unused-argument + """ + Gets error and passes in error message, and, if applicable, condition. + :param error: The error to raise. + :param str message: Error message. + :param condition: Optional error condition. Will not be used by uamqp. + """ + return error(message) - def mgmt_client_request(self, mgmt_client, mgmt_msg, **kwargs): - """ - Send mgmt request. - :param AMQP Client mgmt_client: Client to send request with. - :param str mgmt_msg: Message. - :keyword bytes operation: Operation. - :keyword operation_type: Op type. - :keyword status_code_field: mgmt status code. - :keyword description_fields: mgmt status desc. - """ - operation_type = kwargs.pop("operation_type") - operation = kwargs.pop("operation") - return mgmt_client.mgmt_request( - mgmt_msg, - operation, - op_type=operation_type, - **kwargs - ) - - def get_error(self, error, message, *, condition=None): # pylint: disable=unused-argument - """ - Gets error and passes in error message, and, if applicable, condition. - :param error: The error to raise. - :param str message: Error message. - :param condition: Optional error condition. Will not be used by uamqp. - """ - return error(message) - - def _create_eventhub_exception(self, exception): - if isinstance(exception, errors.AuthenticationException): - error = AuthenticationError(str(exception), exception) - elif isinstance(exception, errors.VendorLinkDetach): - error = ConnectError(str(exception), exception) - elif isinstance(exception, errors.LinkDetach): - error = ConnectionLostError(str(exception), exception) - elif isinstance(exception, errors.ConnectionClose): - error = ConnectionLostError(str(exception), exception) - elif isinstance(exception, errors.MessageHandlerError): - error = ConnectionLostError(str(exception), exception) - elif isinstance(exception, errors.AMQPConnectionError): - error_type = ( - AuthenticationError - if str(exception).startswith("Unable to open authentication session") - else ConnectError - ) - error = error_type(str(exception), exception) - elif isinstance(exception, compat.TimeoutException): - error = ConnectionLostError(str(exception), exception) - else: - error = EventHubError(str(exception), exception) - return error - - - def _handle_exception( - self, exception, closable - ): # pylint:disable=too-many-branches, too-many-statements - try: # closable is a producer/consumer object - name = closable._name # pylint: disable=protected-access - except AttributeError: # closable is an client object - name = closable._container_id # pylint: disable=protected-access - if isinstance(exception, KeyboardInterrupt): # pylint:disable=no-else-raise - _LOGGER.info("%r stops due to keyboard interrupt", name) - closable._close_connection() # pylint:disable=protected-access - raise exception - elif isinstance(exception, EventHubError): - closable._close_handler() # pylint:disable=protected-access - raise exception - elif isinstance( - exception, - ( - errors.MessageAccepted, - errors.MessageAlreadySettled, - errors.MessageModified, - errors.MessageRejected, - errors.MessageReleased, - errors.MessageContentTooLarge, - ), - ): - _LOGGER.info("%r Event data error (%r)", name, exception) - error = EventDataError(str(exception), exception) - raise error - elif isinstance(exception, errors.MessageException): - _LOGGER.info("%r Event data send error (%r)", name, exception) - error = EventDataSendError(str(exception), exception) - raise error - else: + def _create_eventhub_exception(self, exception): if isinstance(exception, errors.AuthenticationException): - if hasattr(closable, "_close_connection"): - closable._close_connection() # pylint:disable=protected-access + error = AuthenticationError(str(exception), exception) + elif isinstance(exception, errors.VendorLinkDetach): + error = ConnectError(str(exception), exception) elif isinstance(exception, errors.LinkDetach): - if hasattr(closable, "_close_handler"): - closable._close_handler() # pylint:disable=protected-access + error = ConnectionLostError(str(exception), exception) elif isinstance(exception, errors.ConnectionClose): - if hasattr(closable, "_close_connection"): - closable._close_connection() # pylint:disable=protected-access + error = ConnectionLostError(str(exception), exception) elif isinstance(exception, errors.MessageHandlerError): - if hasattr(closable, "_close_handler"): - closable._close_handler() # pylint:disable=protected-access - else: # errors.AMQPConnectionError, compat.TimeoutException - if hasattr(closable, "_close_connection"): - closable._close_connection() # pylint:disable=protected-access - return self._create_eventhub_exception(exception) + error = ConnectionLostError(str(exception), exception) + elif isinstance(exception, errors.AMQPConnectionError): + error_type = ( + AuthenticationError + if str(exception).startswith("Unable to open authentication session") + else ConnectError + ) + error = error_type(str(exception), exception) + elif isinstance(exception, compat.TimeoutException): + error = ConnectionLostError(str(exception), exception) + else: + error = EventHubError(str(exception), exception) + return error + + + def _handle_exception( + self, exception, closable + ): # pylint:disable=too-many-branches, too-many-statements + try: # closable is a producer/consumer object + name = closable._name # pylint: disable=protected-access + except AttributeError: # closable is an client object + name = closable._container_id # pylint: disable=protected-access + if isinstance(exception, KeyboardInterrupt): # pylint:disable=no-else-raise + _LOGGER.info("%r stops due to keyboard interrupt", name) + closable._close_connection() # pylint:disable=protected-access + raise exception + elif isinstance(exception, EventHubError): + closable._close_handler() # pylint:disable=protected-access + raise exception + elif isinstance( + exception, + ( + errors.MessageAccepted, + errors.MessageAlreadySettled, + errors.MessageModified, + errors.MessageRejected, + errors.MessageReleased, + errors.MessageContentTooLarge, + ), + ): + _LOGGER.info("%r Event data error (%r)", name, exception) + error = EventDataError(str(exception), exception) + raise error + elif isinstance(exception, errors.MessageException): + _LOGGER.info("%r Event data send error (%r)", name, exception) + error = EventDataSendError(str(exception), exception) + raise error + else: + if isinstance(exception, errors.AuthenticationException): + if hasattr(closable, "_close_connection"): + closable._close_connection() # pylint:disable=protected-access + elif isinstance(exception, errors.LinkDetach): + if hasattr(closable, "_close_handler"): + closable._close_handler() # pylint:disable=protected-access + elif isinstance(exception, errors.ConnectionClose): + if hasattr(closable, "_close_connection"): + closable._close_connection() # pylint:disable=protected-access + elif isinstance(exception, errors.MessageHandlerError): + if hasattr(closable, "_close_handler"): + closable._close_handler() # pylint:disable=protected-access + else: # errors.AMQPConnectionError, compat.TimeoutException + if hasattr(closable, "_close_connection"): + closable._close_connection() # pylint:disable=protected-access + return self._create_eventhub_exception(exception) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py index 49a6110db3cc..90f6b4470feb 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py @@ -30,10 +30,15 @@ PROP_PARTITION_KEY ) -from uamqp import types -from uamqp.message import MessageHeader +try: + from uamqp import types + from uamqp.message import MessageHeader + PROP_PARTITION_KEY_AMQP_SYMBOL = types.AMQPSymbol(PROP_PARTITION_KEY) +except (ImportError, ModuleNotFoundError): + types = None + MessageHeader = None + PROP_PARTITION_KEY_AMQP_SYMBOL = None -PROP_PARTITION_KEY_AMQP_SYMBOL = types.AMQPSymbol(PROP_PARTITION_KEY) if TYPE_CHECKING: # pylint: disable=ungrouped-imports diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/exceptions.py b/sdk/eventhub/azure-eventhub/azure/eventhub/exceptions.py index 1dbcedeab559..8001d97cea6d 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/exceptions.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/exceptions.py @@ -3,7 +3,11 @@ # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- import six -from uamqp import errors, compat +try: + from uamqp import errors, compat +except ImportError: + errors = None + compat = None class EventHubError(Exception): """Represents an error occurred in the client. diff --git a/sdk/eventhub/azure-eventhub/tests/__init__.py b/sdk/eventhub/azure-eventhub/tests/__init__.py new file mode 100644 index 000000000000..34913fb394d7 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/tests/__init__.py @@ -0,0 +1,4 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- diff --git a/sdk/eventhub/azure-eventhub/tests/_test_case.py b/sdk/eventhub/azure-eventhub/tests/_test_case.py new file mode 100644 index 000000000000..2f77bbf23a0c --- /dev/null +++ b/sdk/eventhub/azure-eventhub/tests/_test_case.py @@ -0,0 +1,11 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ + +def get_decorator(): + try: + import uamqp + except (ImportError, ModuleNotFoundError): + return [False] + return [True, False] diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/__init__.py b/sdk/eventhub/azure-eventhub/tests/livetest/__init__.py new file mode 100644 index 000000000000..34913fb394d7 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/tests/livetest/__init__.py @@ -0,0 +1,4 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/__init__.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/__init__.py new file mode 100644 index 000000000000..34913fb394d7 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/__init__.py @@ -0,0 +1,4 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_auth.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_auth.py index f00d034a02ed..e8d9c4dd8341 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_auth.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_auth.py @@ -11,10 +11,12 @@ from azure.eventhub import EventData, EventHubProducerClient, EventHubConsumerClient, EventHubSharedKeyCredential from azure.eventhub._client_base import EventHubSASTokenCredential from azure.core.credentials import AzureSasCredential, AzureNamedKeyCredential +from ..._test_case import get_decorator +uamqp_transport_vals = get_decorator() @pytest.mark.parametrize("uamqp_transport", - [True, False]) + uamqp_transport_vals) @pytest.mark.liveTest def test_client_secret_credential(live_eventhub, uamqp_transport): credential = EnvironmentCredential() @@ -56,7 +58,7 @@ def on_event(partition_context, event): @pytest.mark.parametrize("uamqp_transport", - [True, False]) + uamqp_transport_vals) @pytest.mark.liveTest def test_client_sas_credential(live_eventhub, uamqp_transport): # This should "just work" to validate known-good. @@ -97,7 +99,7 @@ def test_client_sas_credential(live_eventhub, uamqp_transport): @pytest.mark.parametrize("uamqp_transport", - [True, False]) + uamqp_transport_vals) @pytest.mark.liveTest def test_client_azure_sas_credential(live_eventhub, uamqp_transport): # This should "just work" to validate known-good. @@ -126,7 +128,7 @@ def test_client_azure_sas_credential(live_eventhub, uamqp_transport): @pytest.mark.parametrize("uamqp_transport", - [True, False]) + uamqp_transport_vals) @pytest.mark.liveTest def test_client_azure_named_key_credential(live_eventhub, uamqp_transport): credential = AzureNamedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key']) diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_consumer_client.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_consumer_client.py index 062edea4a4e9..e9320c83aed7 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_consumer_client.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_consumer_client.py @@ -6,10 +6,13 @@ from azure.eventhub import EventHubConsumerClient from azure.eventhub._eventprocessor.in_memory_checkpoint_store import InMemoryCheckpointStore from azure.eventhub._constants import ALL_PARTITIONS +from ..._test_case import get_decorator + +uamqp_transport_vals = get_decorator() @pytest.mark.parametrize("uamqp_transport", - [True, False]) + uamqp_transport_vals) @pytest.mark.liveTest def test_receive_no_partition(connstr_senders, uamqp_transport): connection_str, senders = connstr_senders @@ -53,7 +56,7 @@ def on_event(partition_context, event): @pytest.mark.parametrize("uamqp_transport", - [True, False]) + uamqp_transport_vals) @pytest.mark.liveTest def test_receive_partition(connstr_senders, uamqp_transport): connection_str, senders = connstr_senders @@ -85,7 +88,7 @@ def on_event(partition_context, event): @pytest.mark.parametrize("uamqp_transport", - [True, False]) + uamqp_transport_vals) @pytest.mark.liveTest def test_receive_load_balancing(connstr_senders, uamqp_transport): if sys.platform.startswith('darwin'): @@ -121,7 +124,7 @@ def on_event(partition_context, event): @pytest.mark.parametrize("uamqp_transport", - [True, False]) + uamqp_transport_vals) def test_receive_batch_no_max_wait_time(connstr_senders, uamqp_transport): '''Test whether callback is called when max_wait_time is None and max_batch_size has reached ''' @@ -165,13 +168,11 @@ def on_event_batch(partition_context, event_batch): worker.join() -@pytest.mark.parametrize("max_wait_time, sleep_time, expected_result, uamqp_transport", - [(3, 10, [], True), - (3, 2, None, True), - (3, 10, [], False), - (3, 2, None, False), - ]) -def test_receive_batch_empty_with_max_wait_time(connection_str, max_wait_time, sleep_time, expected_result, uamqp_transport): +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) +@pytest.mark.parametrize("max_wait_time, sleep_time, expected_result", + [(3, 10, []), + (3, 2, None)]) +def test_receive_batch_empty_with_max_wait_time(uamqp_transport, connection_str, max_wait_time, sleep_time, expected_result): '''Test whether event handler is called when max_wait_time > 0 and no event is received ''' client = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default', uamqp_transport=uamqp_transport) @@ -190,7 +191,7 @@ def on_event_batch(partition_context, event_batch): @pytest.mark.parametrize("uamqp_transport", - [True, False]) + uamqp_transport_vals) def test_receive_batch_early_callback(connstr_senders, uamqp_transport): ''' Test whether the callback is called once max_batch_size reaches and before max_wait_time reaches. ''' diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_negative.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_negative.py index 303df40070e8..c443a9e4986a 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_negative.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_negative.py @@ -19,11 +19,17 @@ ) from azure.eventhub import EventHubConsumerClient from azure.eventhub import EventHubProducerClient -from azure.eventhub._transport._uamqp_transport import UamqpTransport +try: + from azure.eventhub._transport._uamqp_transport import UamqpTransport +except (ImportError, ModuleNotFoundError): + UamqpTransport = None from azure.eventhub._transport._pyamqp_transport import PyamqpTransport +from ..._test_case import get_decorator +uamqp_transport_vals = get_decorator() -@pytest.mark.parametrize("uamqp_transport", [True, False]) + +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest def test_send_batch_with_invalid_hostname(invalid_hostname, uamqp_transport): amqp_transport = UamqpTransport() if uamqp_transport else PyamqpTransport() @@ -38,7 +44,7 @@ def test_send_batch_with_invalid_hostname(invalid_hostname, uamqp_transport): client.send_batch(batch) -@pytest.mark.parametrize("uamqp_transport", [True, False]) +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest def test_receive_with_invalid_hostname_sync(invalid_hostname, uamqp_transport): def on_event(partition_context, event): @@ -56,7 +62,7 @@ def on_event(partition_context, event): thread.join() -@pytest.mark.parametrize("uamqp_transport", [True, False]) +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest def test_send_batch_with_invalid_key(invalid_key, uamqp_transport): client = EventHubProducerClient.from_connection_string(invalid_key, uamqp_transport=uamqp_transport) @@ -70,7 +76,7 @@ def test_send_batch_with_invalid_key(invalid_key, uamqp_transport): client.close() -@pytest.mark.parametrize("uamqp_transport", [True, False]) +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest def test_send_batch_to_invalid_partitions(connection_str, uamqp_transport): partitions = ["XYZ", "-1", "1000", "-"] @@ -85,7 +91,7 @@ def test_send_batch_to_invalid_partitions(connection_str, uamqp_transport): client.close() -@pytest.mark.parametrize("uamqp_transport", [True, False]) +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest def test_send_batch_too_large_message(connection_str, uamqp_transport): if sys.platform.startswith('darwin'): @@ -100,7 +106,7 @@ def test_send_batch_too_large_message(connection_str, uamqp_transport): client.close() -@pytest.mark.parametrize("uamqp_transport", [True, False]) +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest def test_send_batch_null_body(connection_str, uamqp_transport): client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) @@ -114,7 +120,7 @@ def test_send_batch_null_body(connection_str, uamqp_transport): client.close() -@pytest.mark.parametrize("uamqp_transport", [True, False]) +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest def test_create_batch_with_invalid_hostname_sync(invalid_hostname, uamqp_transport): if sys.platform.startswith('darwin'): @@ -126,7 +132,7 @@ def test_create_batch_with_invalid_hostname_sync(invalid_hostname, uamqp_transpo client.create_batch(max_size_in_bytes=300) -@pytest.mark.parametrize("uamqp_transport", [True, False]) +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest def test_create_batch_with_too_large_size_sync(connection_str, uamqp_transport): client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_properties.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_properties.py index b9349f60b250..678fccabc106 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_properties.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_properties.py @@ -9,9 +9,12 @@ from azure.eventhub import EventHubSharedKeyCredential from azure.eventhub import EventHubConsumerClient from azure.eventhub.exceptions import AuthenticationError, ConnectError, EventHubError +from ..._test_case import get_decorator +uamqp_transport_vals = get_decorator() -@pytest.mark.parametrize("uamqp_transport", [True, False]) + +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest def test_get_properties(live_eventhub, uamqp_transport): client = EventHubConsumerClient(live_eventhub['hostname'], live_eventhub['event_hub'], '$default', @@ -22,7 +25,7 @@ def test_get_properties(live_eventhub, uamqp_transport): properties = client.get_eventhub_properties() assert properties['eventhub_name'] == live_eventhub['event_hub'] and properties['partition_ids'] == ['0', '1'] -@pytest.mark.parametrize("uamqp_transport", [True, False]) +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest def test_get_properties_with_auth_error_sync(live_eventhub, uamqp_transport): client = EventHubConsumerClient(live_eventhub['hostname'], live_eventhub['event_hub'], '$default', @@ -40,7 +43,7 @@ def test_get_properties_with_auth_error_sync(live_eventhub, uamqp_transport): with pytest.raises(AuthenticationError) as e: client.get_eventhub_properties() -@pytest.mark.parametrize("uamqp_transport", [True, False]) +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest def test_get_properties_with_connect_error(live_eventhub, uamqp_transport): client = EventHubConsumerClient(live_eventhub['hostname'], "invalid", '$default', @@ -59,7 +62,7 @@ def test_get_properties_with_connect_error(live_eventhub, uamqp_transport): with pytest.raises(EventHubError) as e: # This can be either ConnectError or ConnectionLostError client.get_eventhub_properties() -@pytest.mark.parametrize("uamqp_transport", [True, False]) +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest def test_get_partition_ids(live_eventhub, uamqp_transport): client = EventHubConsumerClient(live_eventhub['hostname'], live_eventhub['event_hub'], '$default', @@ -71,7 +74,7 @@ def test_get_partition_ids(live_eventhub, uamqp_transport): assert partition_ids == ['0', '1'] -@pytest.mark.parametrize("uamqp_transport", [True, False]) +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest def test_get_partition_properties(live_eventhub, uamqp_transport): client = EventHubConsumerClient(live_eventhub['hostname'], live_eventhub['event_hub'], '$default', diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_receive.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_receive.py index 1a58255f90d8..7b38be683088 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_receive.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_receive.py @@ -12,10 +12,13 @@ from azure.eventhub import EventData, TransportType, EventHubConsumerClient from azure.eventhub.exceptions import EventHubError +from ..._test_case import get_decorator + +uamqp_transport_vals = get_decorator() @pytest.mark.parametrize("uamqp_transport", - [True, False]) + uamqp_transport_vals) @pytest.mark.liveTest def test_receive_end_of_stream(connstr_senders, uamqp_transport): def on_event(partition_context, event): @@ -47,20 +50,15 @@ def on_event(partition_context, event): thread.join() -@pytest.mark.parametrize("position, inclusive, expected_result, uamqp_transport", - [("offset", False, "Exclusive", True), - ("offset", True, "Inclusive", True), - ("sequence", False, "Exclusive", True), - ("sequence", True, "Inclusive", True), - ("enqueued_time", False, "Exclusive", True), - ("offset", False, "Exclusive", False), - ("offset", True, "Inclusive", False), - ("sequence", False, "Exclusive", False), - ("sequence", True, "Inclusive", False), - ("enqueued_time", False, "Exclusive", False) - ]) +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) +@pytest.mark.parametrize("position, inclusive, expected_result", + [("offset", False, "Exclusive"), + ("offset", True, "Inclusive"), + ("sequence", False, "Exclusive"), + ("sequence", True, "Inclusive"), + ("enqueued_time", False, "Exclusive")]) @pytest.mark.liveTest -def test_receive_with_event_position_sync(connstr_senders, position, inclusive, expected_result, uamqp_transport): +def test_receive_with_event_position_sync(uamqp_transport, connstr_senders, position, inclusive, expected_result): def on_event(partition_context, event): assert partition_context.last_enqueued_event_properties.get('sequence_number') == event.sequence_number assert partition_context.last_enqueued_event_properties.get('offset') == event.offset @@ -110,7 +108,7 @@ def on_event(partition_context, event): thread.join() @pytest.mark.parametrize("uamqp_transport", - [True, False]) + uamqp_transport_vals) @pytest.mark.liveTest def test_receive_owner_level(connstr_senders, uamqp_transport): def on_event(partition_context, event): @@ -144,7 +142,7 @@ def on_error(partition_context, error): @pytest.mark.parametrize("uamqp_transport", - [True, False]) + uamqp_transport_vals) @pytest.mark.liveTest def test_receive_over_websocket_sync(connstr_senders, uamqp_transport): app_prop = {"raw_prop": "raw_value"} diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_reconnect.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_reconnect.py index b86b5cc0d0a2..a5be04ad6d91 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_reconnect.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_reconnect.py @@ -11,9 +11,6 @@ from azure.eventhub._pyamqp.client import ReceiveClient from azure.eventhub._pyamqp import error, constants -import uamqp -from uamqp import compat - from azure.eventhub import ( EventData, EventHubSharedKeyCredential, @@ -22,12 +19,19 @@ ) from azure.eventhub.exceptions import OperationTimeoutError from azure.eventhub._utils import transform_outbound_single_message -from azure.eventhub._transport._uamqp_transport import UamqpTransport +try: + import uamqp + from uamqp import compat + from azure.eventhub._transport._uamqp_transport import UamqpTransport +except (ImportError, ModuleNotFoundError): + UamqpTransport = None from azure.eventhub._transport._pyamqp_transport import PyamqpTransport +from ..._test_case import get_decorator +uamqp_transport_vals = get_decorator() @pytest.mark.parametrize("uamqp_transport", - [True, False]) + uamqp_transport_vals) @pytest.mark.liveTest def test_send_with_long_interval_sync(live_eventhub, sleep, uamqp_transport): test_partition = "0" @@ -72,9 +76,8 @@ def test_send_with_long_interval_sync(live_eventhub, sleep, uamqp_transport): assert list(received[0].body)[0] == b"A single event" -# TODO: fix and add pyamqp transport @pytest.mark.parametrize("uamqp_transport", - [True, False]) + uamqp_transport_vals) @pytest.mark.liveTest def test_send_connection_idle_timeout_and_reconnect_sync(connstr_receivers, uamqp_transport): connection_str, receivers = connstr_receivers @@ -136,7 +139,7 @@ def test_send_connection_idle_timeout_and_reconnect_sync(connstr_receivers, uamq @pytest.mark.parametrize("uamqp_transport", - [True, False]) + uamqp_transport_vals) @pytest.mark.liveTest def test_receive_connection_idle_timeout_and_reconnect_sync(connstr_senders, uamqp_transport): connection_str, senders = connstr_senders diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_send.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_send.py index e43e55689d60..b9af8a0073f4 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_send.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_send.py @@ -20,11 +20,17 @@ AmqpAnnotatedMessage, AmqpMessageProperties, ) -from azure.eventhub._transport._uamqp_transport import UamqpTransport +try: + from azure.eventhub._transport._uamqp_transport import UamqpTransport +except (ImportError, ModuleNotFoundError): + UamqpTransport = None from azure.eventhub._transport._pyamqp_transport import PyamqpTransport +from ..._test_case import get_decorator + +uamqp_transport_vals = get_decorator() @pytest.mark.parametrize("uamqp_transport", - [True, False]) + uamqp_transport_vals) @pytest.mark.liveTest def test_send_with_partition_key(connstr_receivers, uamqp_transport): connection_str, receivers = connstr_receivers @@ -54,7 +60,7 @@ def test_send_with_partition_key(connstr_receivers, uamqp_transport): @pytest.mark.parametrize("uamqp_transport", - [True, False]) + uamqp_transport_vals) @pytest.mark.liveTest def test_send_and_receive_large_body_size(connstr_receivers, uamqp_transport): if sys.platform.startswith('darwin'): @@ -76,7 +82,7 @@ def test_send_and_receive_large_body_size(connstr_receivers, uamqp_transport): @pytest.mark.parametrize("uamqp_transport", - [True, False]) + uamqp_transport_vals) @pytest.mark.liveTest def test_send_amqp_annotated_message(connstr_receivers, uamqp_transport): connection_str, receivers = connstr_receivers @@ -195,8 +201,9 @@ def on_event(partition_context, event): assert received_count["normal_msg"] == 2 -@pytest.mark.parametrize("payload, uamqp_transport", - [(b"", True), (b"", False), (b"A single event", True), (b"A single event", False)]) +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) +@pytest.mark.parametrize("payload", + [(b""), (b"A single event")]) @pytest.mark.liveTest def test_send_and_receive_small_body(connstr_receivers, payload, uamqp_transport): connection_str, receivers = connstr_receivers @@ -214,7 +221,7 @@ def test_send_and_receive_small_body(connstr_receivers, payload, uamqp_transport @pytest.mark.parametrize("uamqp_transport", - [True, False]) + uamqp_transport_vals) @pytest.mark.liveTest def test_send_partition(connstr_receivers, uamqp_transport): connection_str, receivers = connstr_receivers @@ -250,7 +257,7 @@ def test_send_partition(connstr_receivers, uamqp_transport): @pytest.mark.parametrize("uamqp_transport", - [True, False]) + uamqp_transport_vals) @pytest.mark.liveTest def test_send_non_ascii(connstr_receivers, uamqp_transport): connection_str, receivers = connstr_receivers @@ -272,7 +279,7 @@ def test_send_non_ascii(connstr_receivers, uamqp_transport): @pytest.mark.parametrize("uamqp_transport", - [True, False]) + uamqp_transport_vals) @pytest.mark.liveTest def test_send_multiple_partitions_with_app_prop(connstr_receivers, uamqp_transport): connection_str, receivers = connstr_receivers @@ -302,7 +309,7 @@ def test_send_multiple_partitions_with_app_prop(connstr_receivers, uamqp_transpo @pytest.mark.parametrize("uamqp_transport", - [True, False]) + uamqp_transport_vals) @pytest.mark.liveTest def test_send_over_websocket_sync(connstr_receivers, uamqp_transport): connection_str, receivers = connstr_receivers @@ -322,7 +329,7 @@ def test_send_over_websocket_sync(connstr_receivers, uamqp_transport): @pytest.mark.parametrize("uamqp_transport", - [True, False]) + uamqp_transport_vals) @pytest.mark.liveTest def test_send_with_create_event_batch_with_app_prop_sync(connstr_receivers, uamqp_transport): connection_str, receivers = connstr_receivers @@ -350,7 +357,7 @@ def test_send_with_create_event_batch_with_app_prop_sync(connstr_receivers, uamq @pytest.mark.parametrize("uamqp_transport", - [True, False]) + uamqp_transport_vals) @pytest.mark.liveTest def test_send_list(connstr_receivers, uamqp_transport): connection_str, receivers = connstr_receivers @@ -367,7 +374,7 @@ def test_send_list(connstr_receivers, uamqp_transport): @pytest.mark.parametrize("uamqp_transport", - [True, False]) + uamqp_transport_vals) @pytest.mark.liveTest def test_send_list_partition(connstr_receivers, uamqp_transport): connection_str, receivers = connstr_receivers @@ -380,12 +387,10 @@ def test_send_list_partition(connstr_receivers, uamqp_transport): assert received.body_as_str() == payload -@pytest.mark.parametrize("to_send, exception_type, uamqp_transport", - [([EventData("A"*1024)]*1100, ValueError, True), - ("any str", AttributeError, True), - ([EventData("A"*1024)]*1100, ValueError, False), - ("any str", AttributeError, False) - ]) +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) +@pytest.mark.parametrize("to_send, exception_type", + [([EventData("A"*1024)]*1100, ValueError), + ("any str", AttributeError)]) @pytest.mark.liveTest def test_send_list_wrong_data(connection_str, to_send, exception_type, uamqp_transport): client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) @@ -394,12 +399,8 @@ def test_send_list_wrong_data(connection_str, to_send, exception_type, uamqp_tra client.send_batch(to_send) -@pytest.mark.parametrize("partition_id, partition_key, uamqp_transport", [ - ("0", None, True), - (None, "pk", True), - ("0", None, False), - (None, "pk", False)] -) +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) +@pytest.mark.parametrize("partition_id, partition_key", [("0", None), (None, "pk")]) def test_send_batch_pid_pk(invalid_hostname, partition_id, partition_key, uamqp_transport): # Use invalid_hostname because this is not a live test. amqp_transport = UamqpTransport() if uamqp_transport else PyamqpTransport() diff --git a/sdk/eventhub/azure-eventhub/tests/unittest/__init__.py b/sdk/eventhub/azure-eventhub/tests/unittest/__init__.py new file mode 100644 index 000000000000..34913fb394d7 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/tests/unittest/__init__.py @@ -0,0 +1,4 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- diff --git a/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py b/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py index 747750caa861..b803a0cef781 100644 --- a/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py +++ b/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py @@ -1,11 +1,25 @@ +# -- coding: utf-8 -- +#------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +#-------------------------------------------------------------------------- + import platform import pytest -import uamqp -from azure.eventhub._transport._uamqp_transport import UamqpTransport +try: + import uamqp + from azure.eventhub._transport._uamqp_transport import UamqpTransport +except ImportError: + UamqpTransport = None + pass from azure.eventhub._transport._pyamqp_transport import PyamqpTransport from azure.eventhub.amqp import AmqpAnnotatedMessage from azure.eventhub import _common from azure.eventhub._pyamqp.message import Message, Properties +from .._test_case import get_decorator + +uamqp_transport_vals = get_decorator() pytestmark = pytest.mark.skipif(platform.python_implementation() == "PyPy", reason="This is ignored for PyPy") @@ -57,41 +71,44 @@ def test_app_properties(): assert event_data.properties["a"] == "b" -# TODO: fix and add uamqp -def test_sys_properties(): - #properties = uamqp.message.MessageProperties() - #properties.message_id = "message_id" - #properties.user_id = "user_id" - #properties.to = "to" - #properties.subject = "subject" - #properties.reply_to = "reply_to" - #properties.correlation_id = "correlation_id" - #properties.content_type = "content_type" - #properties.content_encoding = "content_encoding" - #properties.absolute_expiry_time = 1 - #properties.creation_time = 1 - #properties.group_id = "group_id" - #properties.group_sequence = 1 - #properties.reply_to_group_id = "reply_to_group_id" - #message = uamqp.message.Message(properties=properties) - #message.annotations = {_common.PROP_OFFSET: "@latest"} - properties = Properties( - message_id="message_id", - user_id="user_id", - to="to", - subject="subject", - reply_to="reply_to", - correlation_id="correlation_id", - content_type="content_type", - content_encoding="content_encoding", - absolute_expiry_time=1, - creation_time=1, - group_id="group_id", - group_sequence=1, - reply_to_group_id="reply_to_group_id" - ) - message_annotations = {_common.PROP_OFFSET: "@latest"} - message = Message(properties=properties, message_annotations=message_annotations) +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) +def test_sys_properties(uamqp_transport): + if uamqp_transport: + properties = uamqp.message.MessageProperties() + properties.message_id = "message_id" + properties.user_id = "user_id" + properties.to = "to" + properties.subject = "subject" + properties.reply_to = "reply_to" + properties.correlation_id = "correlation_id" + properties.content_type = "content_type" + properties.content_encoding = "content_encoding" + properties.absolute_expiry_time = 1 + properties.creation_time = 1 + properties.group_id = "group_id" + properties.group_sequence = 1 + properties.reply_to_group_id = "reply_to_group_id" + message = uamqp.message.Message(properties=properties) + message.annotations = {_common.PROP_OFFSET: "@latest"} + else: + properties = Properties( + message_id="message_id", + user_id="user_id", + to="to", + subject="subject", + reply_to="reply_to", + correlation_id="correlation_id", + content_type="content_type", + content_encoding="content_encoding", + absolute_expiry_time=1, + creation_time=1, + group_id="group_id", + group_sequence=1, + reply_to_group_id="reply_to_group_id" + ) + message_annotations = {_common.PROP_OFFSET: "@latest"} + message = Message(properties=properties, message_annotations=message_annotations) ed = EventData._from_message(message) # type: EventData assert ed.system_properties[_common.PROP_OFFSET] == "@latest" @@ -111,9 +128,15 @@ def test_sys_properties(): # TODO: see why pyamqp went from 99 to 87 -@pytest.mark.parametrize("amqp_transport, expected_result", - [(UamqpTransport(), 101), (PyamqpTransport(), 87)]) -def test_event_data_batch(amqp_transport, expected_result): +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) +def test_event_data_batch(uamqp_transport): + if uamqp_transport: + amqp_transport = UamqpTransport() + expected_result = 101 + else: + amqp_transport = PyamqpTransport() + expected_result = 87 batch = EventDataBatch(max_size_in_bytes=110, partition_key="par", amqp_transport=amqp_transport) batch.add(EventData("A")) assert str(batch) == "EventDataBatch(max_size_in_bytes=110, partition_id=None, partition_key='par', event_count=1)" @@ -127,9 +150,14 @@ def test_event_data_batch(amqp_transport, expected_result): batch.add(EventData("A")) -@pytest.mark.parametrize("message, expected_result", - [(uamqp.Message('A'), [b'A']), (Message(data=b'A'), [65])]) -def test_event_data_from_message(message, expected_result): +@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) +def test_event_data_from_message(uamqp_transport): + if uamqp_transport: + message = uamqp.Message('A') + expected_result = [b'A'] + else: + message = Message(data=b'A') + expected_result = [65] event = EventData._from_message(message) assert event.content_type is None assert event.correlation_id is None From 73c98ee3dc3be82eb55542a6d7f9e9bc0c3f89c2 Mon Sep 17 00:00:00 2001 From: swathipil Date: Thu, 28 Jul 2022 14:08:58 -0700 Subject: [PATCH 20/21] add message backcompat --- .../azure-eventhub/azure/eventhub/_common.py | 130 +++++----- .../azure/eventhub/_producer.py | 8 +- .../eventhub/_pyamqp/_message_backcompat.py | 240 ++++++++++++++++++ .../eventhub/_transport/_pyamqp_transport.py | 14 +- .../eventhub/_transport/_uamqp_transport.py | 14 +- .../azure-eventhub/azure/eventhub/_utils.py | 13 +- .../azure/eventhub/amqp/_amqp_message.py | 171 +++---------- .../livetest/synctests/test_reconnect.py | 4 +- .../tests/unittest/test_event_data.py | 88 ++++++- 9 files changed, 454 insertions(+), 228 deletions(-) create mode 100644 sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_message_backcompat.py diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py index 65eaf5511dfa..8ccb726aa55a 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py @@ -2,9 +2,10 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- -from __future__ import unicode_literals +from __future__ import unicode_literals, annotations import json +import datetime import logging import uuid from typing import ( @@ -51,10 +52,11 @@ AmqpMessageHeader, AmqpMessageProperties, ) +from ._pyamqp.message import Message +from ._pyamqp._message_backcompat import LegacyMessage, LegacyBatchMessage +from ._transport._pyamqp_transport import PyamqpTransport if TYPE_CHECKING: - from ._pyamqp.message import Message - try: from uamqp import uamqp_Message except ImportError: @@ -122,20 +124,20 @@ def __init__( self._sys_properties = None # type: Optional[Dict[bytes, Any]] if body is None: raise ValueError("EventData cannot be None.") + self._uamqp_message = None # Internal usage only for transforming AmqpAnnotatedMessage to outgoing EventData self._raw_amqp_message = AmqpAnnotatedMessage( # type: ignore data_body=body, annotations={}, application_properties={} ) - self.message = None # amqp message to be set right before sending + self._message = None # amqp message to be set right before sending self._raw_amqp_message.header = AmqpMessageHeader() self._raw_amqp_message.properties = AmqpMessageProperties() self.message_id = None self.content_type = None self.correlation_id = None - def __repr__(self): - # type: () -> str + def __repr__(self) -> str: # pylint: disable=bare-except try: # TODO: below call won't work b/c pyamqp.message.message doesn't have body_type @@ -165,8 +167,7 @@ def __repr__(self): event_repr += ", enqueued_time=" return f"EventData({event_repr})" - def __str__(self): - # type: () -> str + def __str__(self) -> str: try: body_str = self.body_as_str() except: # pylint: disable=bare-except @@ -189,9 +190,10 @@ def __str__(self): @classmethod def _from_message( - cls, message: Union["uamqp_Message", "Message"], raw_amqp_message=None - ): - # type: (Message, Optional[AmqpAnnotatedMessage]) -> EventData + cls, + message: Union["uamqp_Message", Message], + raw_amqp_message: Optional[AmqpAnnotatedMessage] = None, + ) -> EventData: # pylint:disable=protected-access """Internal use only. @@ -202,7 +204,7 @@ def _from_message( :rtype: ~azure.eventhub.EventData """ event_data = cls(body="") - event_data.message = message + event_data._message = message # pylint: disable=protected-access event_data._raw_amqp_message = ( raw_amqp_message @@ -211,8 +213,7 @@ def _from_message( ) return event_data - def _decode_non_data_body_as_str(self, encoding="UTF-8"): - # type: (str) -> str + def _decode_non_data_body_as_str(self, encoding: str = "UTF-8") -> str: # pylint: disable=protected-access body = self.raw_amqp_message.body if self.body_type == AmqpMessageBodyType.VALUE: @@ -224,14 +225,21 @@ def _decode_non_data_body_as_str(self, encoding="UTF-8"): return str(decode_with_recurse(seq_list, encoding)) @property - def raw_amqp_message(self): - # type: () -> AmqpAnnotatedMessage + def message(self) -> LegacyMessage: + if not self._uamqp_message: + self._uamqp_message = LegacyMessage( + self._raw_amqp_message, + to_outgoing_amqp_message=PyamqpTransport().to_outgoing_amqp_message, + ) + return self._uamqp_message + + @property + def raw_amqp_message(self) -> AmqpAnnotatedMessage: """Advanced usage only. The internal AMQP message payload that is sent or received.""" return self._raw_amqp_message @property - def sequence_number(self): - # type: () -> Optional[int] + def sequence_number(self) -> Optional[int]: """The sequence number of the event. :rtype: int @@ -239,8 +247,7 @@ def sequence_number(self): return self._raw_amqp_message.annotations.get(PROP_SEQ_NUMBER, None) @property - def offset(self): - # type: () -> Optional[str] + def offset(self) -> Optional[str]: """The offset of the event. :rtype: str @@ -251,8 +258,7 @@ def offset(self): return None @property - def enqueued_time(self): - # type: () -> Optional[datetime.datetime] + def enqueued_time(self) -> Optional[datetime.datetime]: """The enqueued timestamp of the event. :rtype: datetime.datetime @@ -263,8 +269,7 @@ def enqueued_time(self): return None @property - def partition_key(self): - # type: () -> Optional[bytes] + def partition_key(self) -> Optional[bytes]: """The partition key of the event. :rtype: bytes @@ -277,8 +282,7 @@ def partition_key(self): return self._raw_amqp_message.annotations.get(PROP_PARTITION_KEY, None) @property - def properties(self): - # type: () -> Dict[Union[str, bytes], Any] + def properties(self) -> Dict[Union[str, bytes], Any]: """Application-defined properties on the event. :rtype: dict @@ -286,7 +290,7 @@ def properties(self): return self._raw_amqp_message.application_properties @properties.setter - def properties(self, value): + def properties(self, value: Dict[Union[str, bytes], Any]): # type: (Dict[Union[str, bytes], Any]) -> None """Application-defined properties on the event. @@ -296,8 +300,7 @@ def properties(self, value): self._raw_amqp_message.application_properties = properties @property - def system_properties(self): - # type: () -> Dict[bytes, Any] + def system_properties(self) -> Dict[bytes, Any]: """Metadata set by the Event Hubs Service associated with the event. An EventData could have some or all of the following meta data depending on the source @@ -335,8 +338,7 @@ def system_properties(self): return self._sys_properties @property - def body(self): - # type: () -> PrimitiveTypes + def body(self) -> PrimitiveTypes: """The body of the Message. The format may vary depending on the body type: For :class:`azure.eventhub.amqp.AmqpMessageBodyType.DATA`, the body could be bytes or Iterable[bytes]. @@ -353,16 +355,14 @@ def body(self): raise ValueError("Event content empty.") @property - def body_type(self): - # type: () -> AmqpMessageBodyType + def body_type(self) -> AmqpMessageBodyType: """The body type of the underlying AMQP message. :rtype: ~azure.eventhub.amqp.AmqpMessageBodyType """ return self._raw_amqp_message.body_type - def body_as_str(self, encoding="UTF-8"): - # type: (str) -> str + def body_as_str(self, encoding: str = "UTF-8") -> str: """The content of the event as a string, if the data is of a compatible type. :param encoding: The encoding to use for decoding event data. @@ -383,8 +383,7 @@ def body_as_str(self, encoding="UTF-8"): except Exception as e: raise TypeError(f"Message data is not compatible with string type: {e}") - def body_as_json(self, encoding="UTF-8"): - # type: (str) -> Dict[str, Any] + def body_as_json(self, encoding: str = "UTF-8") -> Dict[str, Any]: """The content of the event loaded as a JSON object, if the data is compatible. :param encoding: The encoding to use for decoding event data. @@ -398,8 +397,7 @@ def body_as_json(self, encoding="UTF-8"): raise TypeError(f"Event data is not compatible with JSON type: {e}") @property - def content_type(self): - # type: () -> Optional[str] + def content_type(self) -> Optional[str]: """The content type descriptor. Optionally describes the payload of the message, with a descriptor following the format of RFC2045, Section 5, for example "application/json". @@ -413,15 +411,13 @@ def content_type(self): return self._raw_amqp_message.properties.content_type @content_type.setter - def content_type(self, value): - # type: (str) -> None + def content_type(self, value: str) -> None: if not self._raw_amqp_message.properties: self._raw_amqp_message.properties = AmqpMessageProperties() self._raw_amqp_message.properties.content_type = value @property - def correlation_id(self): - # type: () -> Optional[str] + def correlation_id(self) -> Optional[str]: """The correlation identifier. Allows an application to specify a context for the message for the purposes of correlation, for example reflecting the MessageId of a message that is being replied to. @@ -435,15 +431,13 @@ def correlation_id(self): return self._raw_amqp_message.properties.correlation_id @correlation_id.setter - def correlation_id(self, value): - # type: (str) -> None + def correlation_id(self, value: str) -> None: if not self._raw_amqp_message.properties: self._raw_amqp_message.properties = AmqpMessageProperties() self._raw_amqp_message.properties.correlation_id = value @property - def message_id(self): - # type: () -> Optional[str] + def message_id(self) -> Optional[str]: """The id to identify the message. The message identifier is an application-defined value that uniquely identifies the message and its payload. The identifier is a free-form string and can reflect a GUID or an identifier derived from the @@ -459,7 +453,7 @@ def message_id(self): return self._raw_amqp_message.properties.message_id @message_id.setter - def message_id(self, value): + def message_id(self, value: str) -> None: if not self._raw_amqp_message.properties: self._raw_amqp_message.properties = AmqpMessageProperties() self._raw_amqp_message.properties.message_id = value @@ -512,33 +506,37 @@ def __init__( self.max_size_in_bytes = ( max_size_in_bytes or self._amqp_transport.MAX_FRAME_SIZE_BYTES ) - self.message = self._amqp_transport.BATCH_MESSAGE(data=[]) + self._message = self._amqp_transport.BATCH_MESSAGE(data=[]) self._partition_id = partition_id self._partition_key = partition_key - self.message = self._amqp_transport.set_message_partition_key( - self.message, self._partition_key + self._message = self._amqp_transport.set_message_partition_key( + self._message, self._partition_key ) - self._size = self._amqp_transport.get_batch_message_encoded_size(self.message) + self._size = self._amqp_transport.get_batch_message_encoded_size(self._message) self._count = 0 self._internal_events: List[ Union[EventData, AmqpAnnotatedMessage] ] = [] # TODO: only used by uamqp + self._uamqp_message = None - def __repr__(self): - # type: () -> str + def __repr__(self) -> str: batch_repr = ( f"max_size_in_bytes={self.max_size_in_bytes}, partition_id={self._partition_id}, " f"partition_key={self._partition_key!r}, event_count={self._count}" ) return f"EventDataBatch({batch_repr})" - def __len__(self): + def __len__(self) -> int: return self._count @classmethod - def _from_batch(cls, batch_data, amqp_transport, partition_key=None): - # type: (Iterable[EventData], AmqpTransport, Optional[AnyStr]) -> EventDataBatch + def _from_batch( + cls, + batch_data: Iterable[EventData], + amqp_transport: AmqpTransport, + partition_key: Optional[AnyStr] = None, + ) -> EventDataBatch: outgoing_batch_data = [ transform_outbound_single_message( m, EventData, amqp_transport.to_outgoing_amqp_message @@ -565,16 +563,21 @@ def _load_events(self, events): ) @property - def size_in_bytes(self): - # type: () -> int + def size_in_bytes(self) -> int: """The combined size of the events in the batch, in bytes. :rtype: int """ return self._size - def add(self, event_data): - # type: (Union[EventData, AmqpAnnotatedMessage]) -> None + @property + def message(self) -> LegacyBatchMessage: + if not self._uamqp_message: + message = AmqpAnnotatedMessage(message=Message(*self._message)) + self._uamqp_message = LegacyBatchMessage(message) + return self._uamqp_message + + def add(self, event_data: Union[EventData, AmqpAnnotatedMessage]) -> None: """Try to add an EventData to the batch. The total size of an added event is the sum of its body, properties, etc. @@ -601,12 +604,13 @@ def add(self, event_data): ) if not outgoing_event_data.partition_key: self._amqp_transport.set_message_partition_key( - outgoing_event_data.message, self._partition_key + outgoing_event_data._message, # pylint: disable=protected-access + self._partition_key, ) trace_message(outgoing_event_data) event_data_size = self._amqp_transport.get_message_encoded_size( - outgoing_event_data.message + outgoing_event_data._message # pylint: disable=protected-access ) # For a BatchMessage, if the encoded_message_size of event_data is < 256, then the overhead cost to encode that # message into the BatchMessage would be 5 bytes, if >= 256, it would be 8 bytes. diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py index fe914c7eca7b..878995065d93 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py @@ -53,7 +53,7 @@ def _set_partition_key( amqp_transport: "AmqpTransport", ) -> Iterable[EventData]: for ed in iter(event_datas): - amqp_transport.set_message_partition_key(ed.message, partition_key) + amqp_transport.set_message_partition_key(ed._message, partition_key) # pylint: disable=protected-access yield ed @@ -196,7 +196,7 @@ def _wrap_eventdata( ) if partition_key: self._amqp_transport.set_message_partition_key( - outgoing_event_data.message, partition_key + outgoing_event_data._message, partition_key # pylint: disable=protected-access ) wrapper_event_data = outgoing_event_data trace_message(wrapper_event_data, span) @@ -214,7 +214,7 @@ def _wrap_eventdata( ) for ( event - ) in event_data.message.data: # pylint: disable=protected-access + ) in event_data._message.data: # pylint: disable=protected-access trace_message(event, span) wrapper_event_data = event_data # type:ignore else: @@ -264,7 +264,7 @@ def send( wrapper_event_data = self._wrap_eventdata( event_data, child, partition_key ) - self._unsent_events = [wrapper_event_data.message] + self._unsent_events = [wrapper_event_data._message] # pylint: disable=protected-access if child: self._client._add_span_request_attributes( # pylint: disable=protected-access child diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_message_backcompat.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_message_backcompat.py new file mode 100644 index 000000000000..7eb1554931a4 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_message_backcompat.py @@ -0,0 +1,240 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +# pylint: disable=too-many-lines +from typing import Callable +from enum import Enum + +from ._encode import encode_payload +from .utils import get_message_encoded_size +from .error import AMQPError +from .message import Header, Properties +#from uamqp import constants, errors + + +class MessageState(Enum): + WaitingToBeSent = 0 + WaitingForSendAck = 1 + SendComplete = 2 + SendFailed = 3 + ReceivedUnsettled = 4 + ReceivedSettled = 5 + + def __eq__(self, __o: object) -> bool: + try: + return self.value == __o.value + except AttributeError: + return super().__eq__(__o) + + +class MessageAlreadySettled(Exception): + pass + + +DONE_STATES = (MessageState.SendComplete, MessageState.SendFailed) +RECEIVE_STATES = (MessageState.ReceivedSettled, MessageState.ReceivedUnsettled) +PENDING_STATES = (MessageState.WaitingForSendAck, MessageState.WaitingToBeSent) + + +class LegacyMessage(object): + def __init__(self, message, **kwargs): + self._message = message + self.state = MessageState.SendComplete + self.idle_time = 0 + self.retries = 0 + self._settler = kwargs.get('settler') + self._encoding = kwargs.get('encoding') + self.delivery_no = kwargs.get('delivery_no') + self.delivery_tag = kwargs.get('delivery_tag') or None + self.on_send_complete = None + self.properties = LegacyMessageProperties(self._message.properties) + self.application_properties = self._message.application_properties + self.annotations = self._message.annotations + self.header = LegacyMessageHeader(self._message.header) + self.footer = self._message.footer + self.delivery_annotations = self._message.delivery_annotations + if self._settler: + self.state = MessageState.ReceivedUnsettled + elif self.delivery_no: + self.state = MessageState.ReceivedSettled + self._to_outgoing_amqp_message: Callable = kwargs.get('to_outgoing_amqp_message') + + def __str__(self): + return str(self._message) + + def _can_settle_message(self): + if self.state not in RECEIVE_STATES: + raise TypeError("Only received messages can be settled.") + if self.settled: + return False + return True + + @property + def settled(self): + if self.state == MessageState.ReceivedUnsettled: + return False + return True + + def get_message_encoded_size(self): + return get_message_encoded_size(self._to_outgoing_amqp_message(self._message)) + + def encode_message(self): + output = bytearray() + encode_payload(output, self._to_outgoing_amqp_message(self._message)) + return bytes(output) + + def get_data(self): + return self._message.body + + def gather(self): + if self.state in RECEIVE_STATES: + raise TypeError("Only new messages can be gathered.") + if not self._message: + raise ValueError("Message data already consumed.") + if self.state in DONE_STATES: + raise MessageAlreadySettled() + return [self] + + def get_message(self): + return self._to_outgoing_amqp_message(self._message) + + def accept(self): + if self._can_settle_message(): + self._settler.settle_messages(self.delivery_no, 'accepted') + self.state = MessageState.ReceivedSettled + return True + return False + + def reject(self, condition=None, description=None, info=None): + if self._can_settle_message(): + self._settler.settle_messages( + self.delivery_no, + 'rejected', + error=AMQPError( + condition=condition, + description=description, + info=info + ) + ) + self.state = MessageState.ReceivedSettled + return True + return False + + def release(self): + if self._can_settle_message(): + self._settler.settle_messages(self.delivery_no, 'released') + self.state = MessageState.ReceivedSettled + return True + return False + + def modify(self, failed, deliverable, annotations=None): + if self._can_settle_message(): + self._settler.settle_messages( + self.delivery_no, + 'modified', + delivery_failed=failed, + undeliverable_here=deliverable, + message_annotations=annotations, + ) + self.state = MessageState.ReceivedSettled + return True + return False + + +class LegacyBatchMessage(LegacyMessage): + batch_format = 0x80013700 + max_message_length = 1024 * 1024 + size_offset = 0 + + +class LegacyMessageProperties(object): + + def __init__(self, properties): + self.message_id = self._encode_property(properties.message_id) + self.user_id = self._encode_property(properties.user_id) + self.to = self._encode_property(properties.to) + self.subject = self._encode_property(properties.subject) + self.reply_to = self._encode_property(properties.reply_to) + self.correlation_id = self._encode_property(properties.correlation_id) + self.content_type = self._encode_property(properties.content_type) + self.content_encoding = self._encode_property(properties.content_encoding) + self.absolute_expiry_time = properties.absolute_expiry_time + self.creation_time = properties.creation_time + self.group_id = self._encode_property(properties.group_id) + self.group_sequence = properties.group_sequence + self.reply_to_group_id = self._encode_property(properties.reply_to_group_id) + + def __str__(self): + return str( + { + "message_id": self.message_id, + "user_id": self.user_id, + "to": self.to, + "subject": self.subject, + "reply_to": self.reply_to, + "correlation_id": self.correlation_id, + "content_type": self.content_type, + "content_encoding": self.content_encoding, + "absolute_expiry_time": self.absolute_expiry_time, + "creation_time": self.creation_time, + "group_id": self.group_id, + "group_sequence": self.group_sequence, + "reply_to_group_id": self.reply_to_group_id, + } + ) + + def _encode_property(self, value): + try: + return value.encode("UTF-8") + except AttributeError: + return value + + def get_properties_obj(self): + return Properties( + self.message_id, + self.user_id, + self.to, + self.subject, + self.reply_to, + self.correlation_id, + self.content_type, + self.content_encoding, + self.absolute_expiry_time, + self.creation_time, + self.group_id, + self.group_sequence, + self.reply_to_group_id + ) + + +class LegacyMessageHeader(object): + + def __init__(self, header): + self.delivery_count = header.delivery_count # or 0 + self.time_to_live = header.time_to_live + self.first_acquirer = header.first_acquirer + self.durable = header.durable + self.priority = header.priority + + def __str__(self): + return str( + { + "delivery_count": self.delivery_count, + "time_to_live": self.time_to_live, + "first_acquirer": self.first_acquirer, + "durable": self.durable, + "priority": self.priority, + } + ) + + def get_header_obj(self): + return Header( + self.durable, + self.priority, + self.time_to_live, + self.first_acquirer, + self.delivery_count + ) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_pyamqp_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_pyamqp_transport.py index 334a8ed27869..66dce6e383f0 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_pyamqp_transport.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_pyamqp_transport.py @@ -117,16 +117,12 @@ def to_outgoing_amqp_message(self, annotated_message): "application_properties": annotated_message.application_properties, "message_annotations": annotated_message.annotations, "delivery_annotations": annotated_message.delivery_annotations, - "footer": annotated_message.footer, + "data": annotated_message._data_body, # pylint: disable=protected-access + "sequence": annotated_message._sequence_body, # pylint: disable=protected-access + "value": annotated_message._value_body, # pylint: disable=protected-access + "footer": annotated_message.footer } - if annotated_message.body_type == AmqpMessageBodyType.DATA: - message_dict["data"] = annotated_message.body - elif annotated_message.body_type == AmqpMessageBodyType.SEQUENCE: - message_dict["sequence"] = annotated_message.body - else: - message_dict["value"] = annotated_message.body - return Message(**message_dict) def get_batch_message_encoded_size(self, message): @@ -267,7 +263,7 @@ def add_batch(self, batch_message, outgoing_event_data, event_data): # pylint :param event_data: EventData to add to internal batch events. uamqp use only. :rtype: None """ - utils.add_batch(batch_message.message, outgoing_event_data.message) + utils.add_batch(batch_message._message, outgoing_event_data._message) # pylint: disable=protected-access def create_source(self, source, offset, selector): """ diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py index f8a36f536806..8dd8120beef5 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py @@ -139,17 +139,17 @@ def to_outgoing_amqp_message(self, annotated_message): encoding=annotated_message._encoding # pylint: disable=protected-access ) - amqp_body_type = annotated_message.body_type # pylint: disable=protected-access - amqp_body = annotated_message.body + # pylint: disable=protected-access + amqp_body_type = annotated_message.body_type if amqp_body_type == AmqpMessageBodyType.DATA: amqp_body_type = MessageBodyType.Data - amqp_body = list(amqp_body) + amqp_body = list(annotated_message._data_body) elif amqp_body_type == AmqpMessageBodyType.SEQUENCE: amqp_body_type = MessageBodyType.Sequence - amqp_body = list(amqp_body) + amqp_body = list(annotated_message._sequence_body) else: - # amqp_body_type is type of AmqpMessageBodyType.VALUE amqp_body_type = MessageBodyType.Value + amqp_body = annotated_message._value_body return Message( body=amqp_body, @@ -299,8 +299,8 @@ def add_batch(self, batch_message, outgoing_event_data, event_data): :rtype: None """ batch_message._internal_events.append(event_data) - batch_message.message._body_gen.append( - outgoing_event_data + batch_message._message._body_gen.append( # pylint: disable=protected-access + outgoing_event_data._message ) def create_source(self, source, offset, selector): diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py index 90f6b4470feb..024e74042b29 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py @@ -232,23 +232,23 @@ def get_last_enqueued_event_properties(event_data): if event_data._last_enqueued_event_properties: return event_data._last_enqueued_event_properties - if event_data.message.delivery_annotations: - sequence_number = event_data.message.delivery_annotations.get( + if event_data._message.delivery_annotations: + sequence_number = event_data._message.delivery_annotations.get( PROP_LAST_ENQUEUED_SEQUENCE_NUMBER, None ) - enqueued_time_stamp = event_data.message.delivery_annotations.get( + enqueued_time_stamp = event_data._message.delivery_annotations.get( PROP_LAST_ENQUEUED_TIME_UTC, None ) if enqueued_time_stamp: enqueued_time_stamp = utc_from_timestamp(float(enqueued_time_stamp) / 1000) - retrieval_time_stamp = event_data.message.delivery_annotations.get( + retrieval_time_stamp = event_data._message.delivery_annotations.get( PROP_RUNTIME_INFO_RETRIEVAL_TIME_UTC, None ) if retrieval_time_stamp: retrieval_time_stamp = utc_from_timestamp( float(retrieval_time_stamp) / 1000 ) - offset_bytes = event_data.message.delivery_annotations.get( + offset_bytes = event_data._message.delivery_annotations.get( PROP_LAST_ENQUEUED_OFFSET, None ) offset = offset_bytes.decode("UTF-8") if offset_bytes else None @@ -289,7 +289,8 @@ def transform_outbound_single_message(message, message_type, to_outgoing_amqp_me try: # pylint: disable=protected-access # EventData.message stores uamqp/pyamqp.Message during sending - message.message = to_outgoing_amqp_message(message.raw_amqp_message) + # pylint: disable=protected-access + message._message = to_outgoing_amqp_message(message.raw_amqp_message) return message # type: ignore except AttributeError: # pylint: disable=protected-access diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py b/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py index acd515a06be2..79eaa16a603e 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py @@ -4,79 +4,14 @@ # license information. # ------------------------------------------------------------------------- -from typing import Optional, Any, cast, Mapping, Dict -from types import GeneratorType +from __future__ import annotations +from typing import Optional, Any, cast, Mapping, Dict, Union from ._amqp_utils import normalized_data_body, normalized_sequence_body from ._constants import AmqpMessageBodyType from .._mixin import DictMixin -# TODO: remove below if mixin can be imported -#class DictMixin(object): -# def __setitem__(self, key, item): -# # type: (Any, Any) -> None -# self.__dict__[key] = item -# -# def __getitem__(self, key): -# # type: (Any) -> Any -# return self.__dict__[key] -# -# def __repr__(self): -# # type: () -> str -# return str(self) -# -# def __len__(self): -# # type: () -> int -# return len(self.keys()) -# -# def __delitem__(self, key): -# # type: (Any) -> None -# self.__dict__[key] = None -# -# def __eq__(self, other): -# # type: (Any) -> bool -# """Compare objects by comparing all attributes.""" -# if isinstance(other, self.__class__): -# return self.__dict__ == other.__dict__ -# return False -# -# def __ne__(self, other): -# # type: (Any) -> bool -# """Compare objects by comparing all attributes.""" -# return not self.__eq__(other) -# -# def __str__(self): -# # type: () -> str -# return str({k: v for k, v in self.__dict__.items() if not k.startswith("_")}) -# -# def has_key(self, k): -# # type: (Any) -> bool -# return k in self.__dict__ -# -# def update(self, *args, **kwargs): -# # type: (Any, Any) -> None -# return self.__dict__.update(*args, **kwargs) -# -# def keys(self): -# # type: () -> list -# return [k for k in self.__dict__ if not k.startswith("_")] -# -# def values(self): -# # type: () -> list -# return [v for k, v in self.__dict__.items() if not k.startswith("_")] -# -# def items(self): -# # type: () -> list -# return [(k, v) for k, v in self.__dict__.items() if not k.startswith("_")] -# -# def get(self, key, default=None): -# # type: (Any, Optional[Any]) -> Any -# if key in self.__dict__: -# return self.__dict__[key] -# return default - - class AmqpAnnotatedMessage(object): # pylint: disable=too-many-instance-attributes """ @@ -111,6 +46,9 @@ class AmqpAnnotatedMessage(object): def __init__(self, **kwargs): # type: (Any) -> None self._encoding = kwargs.pop("encoding", "UTF-8") + self._data_body = None + self._sequence_body = None + self._value_body = None # internal usage only for Event Hub received message message = kwargs.pop("message", None) @@ -126,16 +64,15 @@ def __init__(self, **kwargs): "or value_body being set as the body of the AmqpAnnotatedMessage." ) - self._body = None self._body_type = None if "data_body" in kwargs: - self._body = normalized_data_body(kwargs.get("data_body")) + self._data_body = normalized_data_body(kwargs.get("data_body")) self._body_type = AmqpMessageBodyType.DATA elif "sequence_body" in kwargs: - self._body = normalized_sequence_body(kwargs.get("sequence_body")) + self._sequence_body = normalized_sequence_body(kwargs.get("sequence_body")) self._body_type = AmqpMessageBodyType.SEQUENCE elif "value_body" in kwargs: - self._body = kwargs.get("value_body") + self._value_body = kwargs.get("value_body") self._body_type = AmqpMessageBodyType.VALUE header_dict = cast(Mapping, kwargs.get("header")) @@ -147,34 +84,16 @@ def __init__(self, **kwargs): self._annotations = kwargs.get("annotations") self._delivery_annotations = kwargs.get("delivery_annotations") - def __str__(self): + def __str__(self) -> str: if self._body_type == AmqpMessageBodyType.DATA: - output_str = "" - for data_section in self.body: - try: - output_str += data_section.decode(self._encoding) - except AttributeError: - output_str += str(data_section) - return output_str + return "".join(d.decode(self._encoding) for d in self._data_body) elif self._body_type == AmqpMessageBodyType.SEQUENCE: - output_str = "" - for sequence_section in self.body: - for d in sequence_section: - try: - output_str += d.decode(self._encoding) - except AttributeError: - output_str += str(d) - return output_str - else: - if not self.body: - return "" - try: - return self.body.decode(self._encoding) - except AttributeError: - return str(self.body) - - def __repr__(self): - # type: () -> str + return str(self._sequence_body) + elif self._body_type == AmqpMessageBodyType.VALUE: + return str(self._value_body) + return "" + + def __repr__(self) -> str: # pylint: disable=bare-except message_repr = "body={}".format( str(self) @@ -226,7 +145,7 @@ def _from_amqp_message(self, message): ) if message.properties else None self._header = AmqpMessageHeader( delivery_count=message.header.delivery_count, - time_to_live=message.header.time_to_live, + time_to_live=message.header.ttl, first_acquirer=message.header.first_acquirer, durable=message.header.durable, priority=message.header.priority @@ -236,17 +155,17 @@ def _from_amqp_message(self, message): self._delivery_annotations = message.delivery_annotations if message.delivery_annotations else {} self._application_properties = message.application_properties if message.application_properties else {} if message.data: - self._body = list(message.data) + self._data_body = list(message.data) self._body_type = AmqpMessageBodyType.DATA elif message.sequence: - self._body = list(message.sequence) + self._sequence_body = list(message.sequence) self._body_type = AmqpMessageBodyType.SEQUENCE else: - self._body = message.value + self._value_body = message.value self._body_type = AmqpMessageBodyType.VALUE @property - def body(self): + def body(self) -> Any: # type: () -> Any """The body of the Message. The format may vary depending on the body type: For ~azure.eventhub.AmqpMessageBodyType.DATA, the body could be bytes or Iterable[bytes] @@ -254,19 +173,23 @@ def body(self): For ~azure.eventhub.AmqpMessageBodyType.VALUE, the body could be any type. :rtype: Any """ - return self._body + if self._body_type == AmqpMessageBodyType.DATA: # pylint:disable=no-else-return + return (i for i in self._data_body) + elif self._body_type == AmqpMessageBodyType.SEQUENCE: + return (i for i in self._sequence_body) + elif self._body_type == AmqpMessageBodyType.VALUE: + return self._value_body + return None @property - def body_type(self): - # type: () -> AmqpMessageBodyType + def body_type(self) -> AmqpMessageBodyType: """The body type of the underlying AMQP message. rtype: ~azure.eventhub.amqp.AmqpMessageBodyType """ return self._body_type @property - def properties(self): - # type: () -> Optional[AmqpMessageProperties] + def properties(self) -> Optional[AmqpMessageProperties]: """ Properties to add to the message. :rtype: Optional[~azure.eventhub.amqp.AmqpMessageProperties] @@ -274,13 +197,11 @@ def properties(self): return self._properties @properties.setter - def properties(self, value): - # type: (AmqpMessageProperties) -> None + def properties(self, value: AmqpMessageProperties) -> None: self._properties = value @property - def application_properties(self): - # type: () -> Optional[Dict] + def application_properties(self) -> Optional[Dict[Union[str, bytes], Any]]: """ Service specific application properties. @@ -289,13 +210,11 @@ def application_properties(self): return self._application_properties @application_properties.setter - def application_properties(self, value): - # type: (Dict) -> None + def application_properties(self, value: Optional[Dict[Union[str, bytes], Any]]) -> None: self._application_properties = value @property - def annotations(self): - # type: () -> Optional[Dict] + def annotations(self) -> Optional[Dict[Union[str, bytes], Any]]: """ Service specific message annotations. @@ -304,13 +223,11 @@ def annotations(self): return self._annotations @annotations.setter - def annotations(self, value): - # type: (Dict) -> None + def annotations(self, value: Optional[Dict[Union[str, bytes], Any]]) -> None: self._annotations = value @property - def delivery_annotations(self): - # type: () -> Optional[Dict] + def delivery_annotations(self) -> Optional[Dict[Union[str, bytes], Any]]: """ Delivery-specific non-standard properties at the head of the message. Delivery annotations convey information from the sending peer to the receiving peer. @@ -320,13 +237,11 @@ def delivery_annotations(self): return self._delivery_annotations @delivery_annotations.setter - def delivery_annotations(self, value): - # type: (Dict) -> None + def delivery_annotations(self, value: Optional[Dict[Union[str, bytes], Any]]) -> None: self._delivery_annotations = value @property - def header(self): - # type: () -> Optional[AmqpMessageHeader] + def header(self) -> Optional[AmqpMessageHeader]: """ The message header. :rtype: Optional[~azure.eventhub.amqp.AmqpMessageHeader] @@ -334,13 +249,11 @@ def header(self): return self._header @header.setter - def header(self, value): - # type: (AmqpMessageHeader) -> None + def header(self, value: AmqpMessageHeader) -> None: self._header = value @property - def footer(self): - # type: () -> Optional[Dict] + def footer(self) -> Optional[Dict[Any, Any]]: """ The message footer. @@ -349,10 +262,8 @@ def footer(self): return self._footer @footer.setter - def footer(self, value): - # type: (Dict) -> None + def footer(self, value: Optional[Dict[Any, Any]]) -> None: self._footer = value - # self._message.footer = value class AmqpMessageHeader(DictMixin): diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_reconnect.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_reconnect.py index a5be04ad6d91..21481b89081d 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_reconnect.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_reconnect.py @@ -98,7 +98,7 @@ def test_send_connection_idle_timeout_and_reconnect_sync(connstr_receivers, uamq sender._open_with_retry() time.sleep(11) ed = transform_outbound_single_message(ed, EventData, amqp_transport.to_outgoing_amqp_message) - sender._unsent_events = [ed.message] + sender._unsent_events = [ed._message] if uamqp_transport: sender._unsent_events[0].on_send_complete = sender._on_outcome with pytest.raises((uamqp.errors.ConnectionClose, @@ -122,7 +122,7 @@ def test_send_connection_idle_timeout_and_reconnect_sync(connstr_receivers, uamq sender._open_with_retry() time.sleep(11) ed = transform_outbound_single_message(ed, EventData, amqp_transport.to_outgoing_amqp_message) - sender._unsent_events = [ed.message] + sender._unsent_events = [ed._message] sender._send_event_data() retry = 0 diff --git a/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py b/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py index b803a0cef781..613b57aa49f3 100644 --- a/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py +++ b/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py @@ -14,9 +14,10 @@ UamqpTransport = None pass from azure.eventhub._transport._pyamqp_transport import PyamqpTransport -from azure.eventhub.amqp import AmqpAnnotatedMessage +from azure.eventhub.amqp import AmqpAnnotatedMessage, AmqpMessageHeader, AmqpMessageProperties from azure.eventhub import _common -from azure.eventhub._pyamqp.message import Message, Properties +from azure.eventhub._pyamqp.message import Message, Properties, Header +from azure.eventhub._utils import transform_outbound_single_message from .._test_case import get_decorator uamqp_transport_vals = get_decorator() @@ -153,11 +154,11 @@ def test_event_data_batch(uamqp_transport): @pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) def test_event_data_from_message(uamqp_transport): if uamqp_transport: - message = uamqp.Message('A') - expected_result = [b'A'] + amqp_transport = UamqpTransport() else: - message = Message(data=b'A') - expected_result = [65] + amqp_transport = PyamqpTransport() + annotated_message = AmqpAnnotatedMessage(data_body=b'A') + message = amqp_transport.to_outgoing_amqp_message(annotated_message) event = EventData._from_message(message) assert event.content_type is None assert event.correlation_id is None @@ -169,7 +170,7 @@ def test_event_data_from_message(uamqp_transport): assert event.content_type == 'content_type' assert event.correlation_id == 'correlation_id' assert event.message_id == 'message_id' - assert list(event.body) == expected_result + assert list(event.body) == [b'A'] def test_amqp_message_str_repr(): @@ -177,3 +178,76 @@ def test_amqp_message_str_repr(): message = AmqpAnnotatedMessage(data_body=data_body) assert str(message) == 'A' assert 'AmqpAnnotatedMessage(body=A, body_type=data' in repr(message) + + +@pytest.mark.parametrize("uamqp_transport", + uamqp_transport_vals) +def test_amqp_message_from_message(uamqp_transport): + if uamqp_transport: + header = uamqp.message.MessageHeader() + header.delivery_count = 1 + header.time_to_live = 10000 + header.first_acquirer = True + header.durable = True + header.priority = 1 + properties = uamqp.message.MessageProperties() + properties.message_id = "message_id" + properties.user_id = "user_id" + properties.to = "to" + properties.subject = "subject" + properties.reply_to = "reply_to" + properties.correlation_id = "correlation_id" + properties.content_type = "content_type" + properties.content_encoding = "content_encoding" + properties.absolute_expiry_time = 1 + properties.creation_time = 1 + properties.group_id = "group_id" + properties.group_sequence = 1 + properties.reply_to_group_id = "reply_to_group_id" + message = uamqp.message.Message(header=header, properties=properties) + message.annotations = {_common.PROP_OFFSET: "@latest"} + else: + header = Header( + delivery_count=1, + ttl=10000, + first_acquirer=True, + durable=True, + priority=1 + ) + properties = Properties( + message_id="message_id", + user_id="user_id", + to="to", + subject="subject", + reply_to="reply_to", + correlation_id="correlation_id", + content_type="content_type", + content_encoding="content_encoding", + absolute_expiry_time=1, + creation_time=1, + group_id="group_id", + group_sequence=1, + reply_to_group_id="reply_to_group_id" + ) + message_annotations = {_common.PROP_OFFSET: "@latest"} + message = Message(properties=properties, header=header, message_annotations=message_annotations) + + amqp_message = AmqpAnnotatedMessage(message=message) + assert amqp_message.properties.message_id == message.properties.message_id + assert amqp_message.properties.user_id == message.properties.user_id + assert amqp_message.properties.to == message.properties.to + assert amqp_message.properties.subject == message.properties.subject + assert amqp_message.properties.reply_to == message.properties.reply_to + assert amqp_message.properties.correlation_id == message.properties.correlation_id + assert amqp_message.properties.content_type == message.properties.content_type + assert amqp_message.properties.absolute_expiry_time == message.properties.absolute_expiry_time + assert amqp_message.properties.creation_time == message.properties.creation_time + assert amqp_message.properties.group_id == message.properties.group_id + assert amqp_message.properties.group_sequence == message.properties.group_sequence + assert amqp_message.properties.reply_to_group_id == message.properties.reply_to_group_id + assert amqp_message.header.time_to_live == message.header.ttl + assert amqp_message.header.delivery_count == message.header.delivery_count + assert amqp_message.header.first_acquirer == message.header.first_acquirer + assert amqp_message.header.durable == message.header.durable + assert amqp_message.header.priority == message.header.priority + assert amqp_message.annotations == message.message_annotations From 8abde03ae27af975a6b919555ffad698cb693a60 Mon Sep 17 00:00:00 2001 From: swathipil Date: Thu, 28 Jul 2022 18:31:03 -0700 Subject: [PATCH 21/21] message backcompat fixes --- .../azure-eventhub/azure/eventhub/_common.py | 6 + .../eventhub/_pyamqp/_message_backcompat.py | 4 +- .../azure-eventhub/dev_requirements.txt | 3 +- sdk/eventhub/azure-eventhub/setup.py | 1 + .../tests/livetest/synctests/test_receive.py | 31 ++ .../tests/unittest/test_event_data.py | 407 ++++++++++++++++++ 6 files changed, 448 insertions(+), 4 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py index 8ccb726aa55a..22bb8ad60470 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py @@ -233,6 +233,12 @@ def message(self) -> LegacyMessage: ) return self._uamqp_message + # TODO: make message property mutable + @message.setter + def message(self, value: Union["uamqp_Message", Message]) -> None: + self._message = value + self._raw_amqp_message = AmqpAnnotatedMessage(message=value) + @property def raw_amqp_message(self) -> AmqpAnnotatedMessage: """Advanced usage only. The internal AMQP message payload that is sent or received.""" diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_message_backcompat.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_message_backcompat.py index 7eb1554931a4..fd00604d282b 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_message_backcompat.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_message_backcompat.py @@ -50,10 +50,10 @@ def __init__(self, message, **kwargs): self.delivery_no = kwargs.get('delivery_no') self.delivery_tag = kwargs.get('delivery_tag') or None self.on_send_complete = None - self.properties = LegacyMessageProperties(self._message.properties) + self.properties = LegacyMessageProperties(self._message.properties) if self._message.properties else None self.application_properties = self._message.application_properties self.annotations = self._message.annotations - self.header = LegacyMessageHeader(self._message.header) + self.header = LegacyMessageHeader(self._message.header) if self._message.header else None self.footer = self._message.footer self.delivery_annotations = self._message.delivery_annotations if self._settler: diff --git a/sdk/eventhub/azure-eventhub/dev_requirements.txt b/sdk/eventhub/azure-eventhub/dev_requirements.txt index 9c91833e14d8..4035baa0ba70 100644 --- a/sdk/eventhub/azure-eventhub/dev_requirements.txt +++ b/sdk/eventhub/azure-eventhub/dev_requirements.txt @@ -5,5 +5,4 @@ azure-mgmt-eventhub==10.0.0 azure-mgmt-resource==20.0.0 aiohttp>=3.0 websocket-client --e ../../../tools/azure-devtools --e ../../servicebus/azure-servicebus \ No newline at end of file +-e ../../../tools/azure-devtools \ No newline at end of file diff --git a/sdk/eventhub/azure-eventhub/setup.py b/sdk/eventhub/azure-eventhub/setup.py index 8730981bc8ea..ee939c0934f8 100644 --- a/sdk/eventhub/azure-eventhub/setup.py +++ b/sdk/eventhub/azure-eventhub/setup.py @@ -70,6 +70,7 @@ packages=find_packages(exclude=exclude_packages), install_requires=[ "azure-core<2.0.0,>=1.14.0", + "uamqp>=1.5.1,<2.0.0", "typing-extensions>=4.0.1", ] ) diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_receive.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_receive.py index 7b38be683088..bf03cae60b37 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_receive.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_receive.py @@ -9,9 +9,11 @@ import pytest import time import datetime +import uamqp from azure.eventhub import EventData, TransportType, EventHubConsumerClient from azure.eventhub.exceptions import EventHubError +from azure.eventhub._pyamqp.message import Properties from ..._test_case import get_decorator uamqp_transport_vals = get_decorator() @@ -107,6 +109,35 @@ def on_event(partition_context, event): thread.join() +# TODO: after fixing message property mutability, test +#@pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) +#@pytest.mark.liveTest +#def test_receive_modify_message_resend_sync(uamqp_transport, connstr_senders): +# received_modified = [False] +# def on_event(partition_context, event): +# message = event.message +# if message.properties.message_id == b'a1': +# message.properties.message_id = 'a2' +# senders[0].send(event) +# elif message.properties.message_id == b'a2': +# received_modified = [True] +# +# connection_str, senders = connstr_senders +# event = EventData("A", message_id='a1') +# senders[0].send(event) +# client = EventHubConsumerClient.from_connection_string( +# connection_str, consumer_group='$default', uamqp_transport=uamqp_transport +# ) +# with client: +# thread = threading.Thread(target=client.receive, args=(on_event,), +# kwargs={"partition_id": "0", "starting_position": "-1"}) +# thread.daemon = True +# thread.start() +# time.sleep(10) +# assert received_modified[0] +# thread.join() + + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_vals) @pytest.mark.liveTest diff --git a/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py b/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py index 613b57aa49f3..3012739a05d6 100644 --- a/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py +++ b/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py @@ -251,3 +251,410 @@ def test_amqp_message_from_message(uamqp_transport): assert amqp_message.header.durable == message.header.durable assert amqp_message.header.priority == message.header.priority assert amqp_message.annotations == message.message_annotations + +# TODO: ADD MESSAGE BACKCOMPAT TESTS +#class EventDataMessageBackcompatTests: +# +# def test_message_backcompat_receive_and_delete_databody(): +# outgoing_event_data = EventData(body="hello") +# outgoing_event_data.application_properties = {'prop': 'test'} +# outgoing_event_data.session_id = "id_session" +# outgoing_event_data.message_id = "id_message" +# outgoing_event_data.time_to_live = timedelta(seconds=30) +# outgoing_event_data.content_type = "content type" +# outgoing_event_data.correlation_id = "correlation" +# outgoing_event_data.subject = "github" +# outgoing_event_data.partition_key = "id_session" +# outgoing_event_data.to = "forward to" +# outgoing_event_data.reply_to = "reply to" +# outgoing_event_data.reply_to_session_id = "reply to session" +# +# # TODO: Attribute shouldn't exist until after message has been sent. +# # with pytest.raises(AttributeError): +# # outgoing_message.message +# +# sb_client = ServiceBusClient.from_connection_string( +# servicebus_namespace_connection_string, logging_enable=True) +# with sb_client.get_queue_sender(queue_name) as sender: +# sender.send_messages(outgoing_message) +# +# assert outgoing_message.message +# with pytest.raises(TypeError): +# outgoing_message.message.accept() +# with pytest.raises(TypeError): +# outgoing_message.message.release() +# with pytest.raises(TypeError): +# outgoing_message.message.reject() +# with pytest.raises(TypeError): +# outgoing_message.message.modify(True, True) +# assert outgoing_message.message.state == uamqp.constants.MessageState.SendComplete +# assert outgoing_message.message.settled +# assert outgoing_message.message.delivery_annotations is None +# assert outgoing_message.message.delivery_no is None +# assert outgoing_message.message.delivery_tag is None +# assert outgoing_message.message.on_send_complete is None +# assert outgoing_message.message.footer is None +# assert outgoing_message.message.retries >= 0 +# assert outgoing_message.message.idle_time >= 0 +# with pytest.raises(Exception): +# outgoing_message.message.gather() +# assert isinstance(outgoing_message.message.encode_message(), bytes) +# assert outgoing_message.message.get_message_encoded_size() == 208 +# assert list(outgoing_message.message.get_data()) == [b'hello'] +# assert outgoing_message.message.application_properties == {'prop': 'test'} +# assert outgoing_message.message.get_message() # C instance. +# assert len(outgoing_message.message.annotations) == 1 +# assert list(outgoing_message.message.annotations.values())[0] == 'id_session' +# assert str(outgoing_message.message.header) == str({'delivery_count': None, 'time_to_live': 30000, 'first_acquirer': None, 'durable': None, 'priority': None}) +# assert outgoing_message.message.header.get_header_obj().delivery_count is None +# assert outgoing_message.message.properties.message_id == b'id_message' +# assert outgoing_message.message.properties.user_id is None +# assert outgoing_message.message.properties.to == b'forward to' +# assert outgoing_message.message.properties.subject == b'github' +# assert outgoing_message.message.properties.reply_to == b'reply to' +# assert outgoing_message.message.properties.correlation_id == b'correlation' +# assert outgoing_message.message.properties.content_type == b'content type' +# assert outgoing_message.message.properties.content_encoding is None +# assert outgoing_message.message.properties.absolute_expiry_time +# assert outgoing_message.message.properties.creation_time +# assert outgoing_message.message.properties.group_id == b'id_session' +# assert outgoing_message.message.properties.group_sequence is None +# assert outgoing_message.message.properties.reply_to_group_id == b'reply to session' +# assert outgoing_message.message.properties.get_properties_obj().message_id +# +# # TODO: Test updating message and resending +# with sb_client.get_queue_receiver(queue_name, +# receive_mode=ServiceBusReceiveMode.RECEIVE_AND_DELETE, +# max_wait_time=10) as receiver: +# batch = receiver.receive_messages() +# incoming_message = batch[0] +# assert incoming_message.message +# assert incoming_message.message.state == uamqp.constants.MessageState.ReceivedSettled +# assert incoming_message.message.settled +# assert incoming_message.message.delivery_annotations == {} +# assert incoming_message.message.delivery_no >= 1 +# assert incoming_message.message.delivery_tag is None +# assert incoming_message.message.on_send_complete is None +# assert incoming_message.message.footer is None +# assert incoming_message.message.retries >= 0 +# assert incoming_message.message.idle_time == 0 +# with pytest.raises(Exception): +# incoming_message.message.gather() +# assert isinstance(incoming_message.message.encode_message(), bytes) +# # TODO: Pyamqp has size at 266 +# # assert incoming_message.message.get_message_encoded_size() == 267 +# assert list(incoming_message.message.get_data()) == [b'hello'] +# assert incoming_message.message.application_properties == {b'prop': b'test'} +# assert incoming_message.message.get_message() # C instance. +# assert len(incoming_message.message.annotations) == 3 +# assert incoming_message.message.annotations[b'x-opt-enqueued-time'] > 0 +# assert incoming_message.message.annotations[b'x-opt-sequence-number'] > 0 +# assert incoming_message.message.annotations[b'x-opt-partition-key'] == b'id_session' +# # TODO: Pyamqp has header {'delivery_count': 0, 'time_to_live': 30000, 'first_acquirer': None, 'durable': None, 'priority': None} +# # assert str(incoming_message.message.header) == str({'delivery_count': 0, 'time_to_live': 30000, 'first_acquirer': True, 'durable': True, 'priority': 4}) +# assert incoming_message.message.header.get_header_obj().delivery_count == 0 +# assert incoming_message.message.properties.message_id == b'id_message' +# assert incoming_message.message.properties.user_id is None +# assert incoming_message.message.properties.to == b'forward to' +# assert incoming_message.message.properties.subject == b'github' +# assert incoming_message.message.properties.reply_to == b'reply to' +# assert incoming_message.message.properties.correlation_id == b'correlation' +# assert incoming_message.message.properties.content_type == b'content type' +# assert incoming_message.message.properties.content_encoding is None +# assert incoming_message.message.properties.absolute_expiry_time +# assert incoming_message.message.properties.creation_time +# assert incoming_message.message.properties.group_id == b'id_session' +# assert incoming_message.message.properties.group_sequence is None +# assert incoming_message.message.properties.reply_to_group_id == b'reply to session' +# assert incoming_message.message.properties.get_properties_obj().message_id +# assert not incoming_message.message.accept() +# assert not incoming_message.message.release() +# assert not incoming_message.message.reject() +# assert not incoming_message.message.modify(True, True) +# +# # TODO: Test updating message and resending +# +# @pytest.mark.liveTest +# @pytest.mark.live_test_only +# @CachedResourceGroupPreparer(name_prefix='servicebustest') +# @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') +# @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) +# def test_message_backcompat_peek_lock_databody(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): +# queue_name = servicebus_queue.name +# outgoing_message = ServiceBusMessage( +# body="hello", +# application_properties={'prop': 'test'}, +# session_id="id_session", +# message_id="id_message", +# time_to_live=timedelta(seconds=30), +# content_type="content type", +# correlation_id="correlation", +# subject="github", +# partition_key="id_session", +# to="forward to", +# reply_to="reply to", +# reply_to_session_id="reply to session" +# ) +# +# # TODO: Attribute shouldn't exist until after message has been sent. +# # with pytest.raises(AttributeError): +# # outgoing_message.message +# +# sb_client = ServiceBusClient.from_connection_string( +# servicebus_namespace_connection_string, logging_enable=True) +# with sb_client.get_queue_sender(queue_name) as sender: +# sender.send_messages(outgoing_message) +# +# assert outgoing_message.message +# with pytest.raises(TypeError): +# outgoing_message.message.accept() +# with pytest.raises(TypeError): +# outgoing_message.message.release() +# with pytest.raises(TypeError): +# outgoing_message.message.reject() +# with pytest.raises(TypeError): +# outgoing_message.message.modify(True, True) +# assert outgoing_message.message.state == uamqp.constants.MessageState.SendComplete +# assert outgoing_message.message.settled +# assert outgoing_message.message.delivery_annotations is None +# assert outgoing_message.message.delivery_no is None +# assert outgoing_message.message.delivery_tag is None +# assert outgoing_message.message.on_send_complete is None +# assert outgoing_message.message.footer is None +# assert outgoing_message.message.retries >= 0 +# assert outgoing_message.message.idle_time >= 0 +# with pytest.raises(Exception): +# outgoing_message.message.gather() +# assert isinstance(outgoing_message.message.encode_message(), bytes) +# assert outgoing_message.message.get_message_encoded_size() == 208 +# assert list(outgoing_message.message.get_data()) == [b'hello'] +# assert outgoing_message.message.application_properties == {'prop': 'test'} +# assert outgoing_message.message.get_message() # C instance. +# assert len(outgoing_message.message.annotations) == 1 +# assert list(outgoing_message.message.annotations.values())[0] == 'id_session' +# assert str(outgoing_message.message.header) == str({'delivery_count': None, 'time_to_live': 30000, 'first_acquirer': None, 'durable': None, 'priority': None}) +# assert outgoing_message.message.header.get_header_obj().delivery_count is None +# assert outgoing_message.message.properties.message_id == b'id_message' +# assert outgoing_message.message.properties.user_id is None +# assert outgoing_message.message.properties.to == b'forward to' +# assert outgoing_message.message.properties.subject == b'github' +# assert outgoing_message.message.properties.reply_to == b'reply to' +# assert outgoing_message.message.properties.correlation_id == b'correlation' +# assert outgoing_message.message.properties.content_type == b'content type' +# assert outgoing_message.message.properties.content_encoding is None +# assert outgoing_message.message.properties.absolute_expiry_time +# assert outgoing_message.message.properties.creation_time +# assert outgoing_message.message.properties.group_id == b'id_session' +# assert outgoing_message.message.properties.group_sequence is None +# assert outgoing_message.message.properties.reply_to_group_id == b'reply to session' +# assert outgoing_message.message.properties.get_properties_obj().message_id +# +# with sb_client.get_queue_receiver(queue_name, +# receive_mode=ServiceBusReceiveMode.PEEK_LOCK, +# max_wait_time=10) as receiver: +# batch = receiver.receive_messages() +# incoming_message = batch[0] +# assert incoming_message.message +# assert incoming_message.message.state == uamqp.constants.MessageState.ReceivedUnsettled +# assert not incoming_message.message.settled +# assert incoming_message.message.delivery_annotations[b'x-opt-lock-token'] +# assert incoming_message.message.delivery_no >= 1 +# assert incoming_message.message.delivery_tag +# assert incoming_message.message.on_send_complete is None +# assert incoming_message.message.footer is None +# assert incoming_message.message.retries >= 0 +# assert incoming_message.message.idle_time == 0 +# with pytest.raises(Exception): +# incoming_message.message.gather() +# assert isinstance(incoming_message.message.encode_message(), bytes) +# # TODO: Pyamqp has size at 336 +# # assert incoming_message.message.get_message_encoded_size() == 334 +# assert list(incoming_message.message.get_data()) == [b'hello'] +# assert incoming_message.message.application_properties == {b'prop': b'test'} +# assert incoming_message.message.get_message() # C instance. +# assert len(incoming_message.message.annotations) == 4 +# assert incoming_message.message.annotations[b'x-opt-enqueued-time'] > 0 +# assert incoming_message.message.annotations[b'x-opt-sequence-number'] > 0 +# assert incoming_message.message.annotations[b'x-opt-partition-key'] == b'id_session' +# assert incoming_message.message.annotations[b'x-opt-locked-until'] +# # TODO: Pyamqp has header {'delivery_count': 0, 'time_to_live': 30000, 'first_acquirer': None, 'durable': None, 'priority': None} +# # assert str(incoming_message.message.header) == str({'delivery_count': 0, 'time_to_live': 30000, 'first_acquirer': True, 'durable': True, 'priority': 4}) +# assert str(incoming_message.message.header) == str({'delivery_count': 0, 'time_to_live': 30000, 'first_acquirer': None, 'durable': None, 'priority': None}) +# assert incoming_message.message.header.get_header_obj().delivery_count == 0 +# assert incoming_message.message.properties.message_id == b'id_message' +# assert incoming_message.message.properties.user_id is None +# assert incoming_message.message.properties.to == b'forward to' +# assert incoming_message.message.properties.subject == b'github' +# assert incoming_message.message.properties.reply_to == b'reply to' +# assert incoming_message.message.properties.correlation_id == b'correlation' +# assert incoming_message.message.properties.content_type == b'content type' +# assert incoming_message.message.properties.content_encoding is None +# assert incoming_message.message.properties.absolute_expiry_time +# assert incoming_message.message.properties.creation_time +# assert incoming_message.message.properties.group_id == b'id_session' +# assert incoming_message.message.properties.group_sequence is None +# assert incoming_message.message.properties.reply_to_group_id == b'reply to session' +# assert incoming_message.message.properties.get_properties_obj().message_id +# assert incoming_message.message.accept() +# # TODO: State isn't updated if settled correctly via the receiver. +# assert incoming_message.message.state == uamqp.constants.MessageState.ReceivedSettled +# assert incoming_message.message.settled +# assert not incoming_message.message.release() +# assert not incoming_message.message.reject() +# assert not incoming_message.message.modify(True, True) +# +# @pytest.mark.liveTest +# @pytest.mark.live_test_only +# @CachedResourceGroupPreparer(name_prefix='servicebustest') +# @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') +# @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) +# def test_message_backcompat_receive_and_delete_valuebody(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): +# queue_name = servicebus_queue.name +# outgoing_message = AmqpAnnotatedMessage(value_body={b"key": b"value"}) +# +# with pytest.raises(AttributeError): +# outgoing_message.message +# +# sb_client = ServiceBusClient.from_connection_string( +# servicebus_namespace_connection_string, logging_enable=False) +# with sb_client.get_queue_sender(queue_name) as sender: +# sender.send_messages(outgoing_message) +# +# with pytest.raises(AttributeError): +# outgoing_message.message +# +# with sb_client.get_queue_receiver(queue_name, +# receive_mode=ServiceBusReceiveMode.RECEIVE_AND_DELETE, +# max_wait_time=10) as receiver: +# batch = receiver.receive_messages() +# incoming_message = batch[0] +# assert incoming_message.message +# assert incoming_message.message.state == uamqp.constants.MessageState.ReceivedSettled +# assert incoming_message.message.settled +# with pytest.raises(Exception): +# incoming_message.message.gather() +# assert incoming_message.message.get_data() == {b"key": b"value"} +# assert not incoming_message.message.accept() +# assert not incoming_message.message.release() +# assert not incoming_message.message.reject() +# assert not incoming_message.message.modify(True, True) +# +# @pytest.mark.liveTest +# @pytest.mark.live_test_only +# @CachedResourceGroupPreparer(name_prefix='servicebustest') +# @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') +# @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) +# def test_message_backcompat_peek_lock_valuebody(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): +# queue_name = servicebus_queue.name +# outgoing_message = AmqpAnnotatedMessage(value_body={b"key": b"value"}) +# +# with pytest.raises(AttributeError): +# outgoing_message.message +# +# sb_client = ServiceBusClient.from_connection_string( +# servicebus_namespace_connection_string, logging_enable=False) +# with sb_client.get_queue_sender(queue_name) as sender: +# sender.send_messages(outgoing_message) +# +# with pytest.raises(AttributeError): +# outgoing_message.message +# +# with sb_client.get_queue_receiver(queue_name, +# receive_mode=ServiceBusReceiveMode.PEEK_LOCK, +# max_wait_time=10) as receiver: +# batch = receiver.receive_messages() +# incoming_message = batch[0] +# assert incoming_message.message +# assert incoming_message.message.state == uamqp.constants.MessageState.ReceivedUnsettled +# assert not incoming_message.message.settled +# assert incoming_message.message.delivery_annotations[b'x-opt-lock-token'] +# assert incoming_message.message.delivery_no >= 1 +# assert incoming_message.message.delivery_tag +# with pytest.raises(Exception): +# incoming_message.message.gather() +# assert incoming_message.message.get_data() == {b"key": b"value"} +# assert incoming_message.message.accept() +# assert incoming_message.message.state == uamqp.constants.MessageState.ReceivedSettled +# assert incoming_message.message.settled +# assert not incoming_message.message.release() +# assert not incoming_message.message.reject() +# assert not incoming_message.message.modify(True, True) +# +# @pytest.mark.liveTest +# @pytest.mark.live_test_only +# @CachedResourceGroupPreparer(name_prefix='servicebustest') +# @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') +# @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) +# def test_message_backcompat_receive_and_delete_sequencebody(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): +# queue_name = servicebus_queue.name +# outgoing_message = AmqpAnnotatedMessage(sequence_body=[1, 2, 3]) +# +# with pytest.raises(AttributeError): +# outgoing_message.message +# +# sb_client = ServiceBusClient.from_connection_string( +# servicebus_namespace_connection_string, logging_enable=False) +# with sb_client.get_queue_sender(queue_name) as sender: +# sender.send_messages(outgoing_message) +# +# with pytest.raises(AttributeError): +# outgoing_message.message +# +# with sb_client.get_queue_receiver(queue_name, +# receive_mode=ServiceBusReceiveMode.RECEIVE_AND_DELETE, +# max_wait_time=10) as receiver: +# batch = receiver.receive_messages() +# incoming_message = batch[0] +# assert incoming_message.message +# assert incoming_message.message.state == uamqp.constants.MessageState.ReceivedSettled +# assert incoming_message.message.settled +# with pytest.raises(Exception): +# incoming_message.message.gather() +# assert list(incoming_message.message.get_data()) == [[1, 2, 3]] +# assert not incoming_message.message.accept() +# assert not incoming_message.message.release() +# assert not incoming_message.message.reject() +# assert not incoming_message.message.modify(True, True) +# +# @pytest.mark.liveTest +# @pytest.mark.live_test_only +# @CachedResourceGroupPreparer(name_prefix='servicebustest') +# @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') +# @ServiceBusQueuePreparer(name_prefix='servicebustest', dead_lettering_on_message_expiration=True) +# def test_message_backcompat_peek_lock_sequencebody(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): +# queue_name = servicebus_queue.name +# outgoing_message = AmqpAnnotatedMessage(sequence_body=[1, 2, 3]) +# +# with pytest.raises(AttributeError): +# outgoing_message.message +# +# sb_client = ServiceBusClient.from_connection_string( +# servicebus_namespace_connection_string, logging_enable=False) +# with sb_client.get_queue_sender(queue_name) as sender: +# sender.send_messages(outgoing_message) +# +# with pytest.raises(AttributeError): +# outgoing_message.message +# +# with sb_client.get_queue_receiver(queue_name, +# receive_mode=ServiceBusReceiveMode.PEEK_LOCK, +# max_wait_time=10) as receiver: +# batch = receiver.receive_messages() +# incoming_message = batch[0] +# assert incoming_message.message +# assert incoming_message.message.state == uamqp.constants.MessageState.ReceivedUnsettled +# assert not incoming_message.message.settled +# assert incoming_message.message.delivery_annotations[b'x-opt-lock-token'] +# assert incoming_message.message.delivery_no >= 1 +# assert incoming_message.message.delivery_tag +# with pytest.raises(Exception): +# incoming_message.message.gather() +# assert list(incoming_message.message.get_data()) == [[1, 2, 3]] +# assert incoming_message.message.accept() +# assert incoming_message.message.state == uamqp.constants.MessageState.ReceivedSettled +# assert incoming_message.message.settled +# assert not incoming_message.message.release() +# assert not incoming_message.message.reject() +# assert not incoming_message.message.modify(True, True) +# +# # TODO: Add batch message backcompat tests