diff --git a/packages/smithy-http/pyproject.toml b/packages/smithy-http/pyproject.toml index bd9ed9e66..6a1966bcd 100644 --- a/packages/smithy-http/pyproject.toml +++ b/packages/smithy-http/pyproject.toml @@ -10,7 +10,7 @@ dependencies = [ [project.optional-dependencies] awscrt = [ - "awscrt>=0.23.10", + "awscrt>=0.27.2", ] aiohttp = [ "aiohttp>=3.11.12, <4.0", diff --git a/packages/smithy-http/src/smithy_http/aio/crt.py b/packages/smithy-http/src/smithy_http/aio/crt.py index 9f0b5a418..734565d01 100644 --- a/packages/smithy-http/src/smithy_http/aio/crt.py +++ b/packages/smithy-http/src/smithy_http/aio/crt.py @@ -3,29 +3,23 @@ # pyright: reportMissingTypeStubs=false,reportUnknownMemberType=false # flake8: noqa: F811 import asyncio -from asyncio import Future as AsyncFuture -from collections import deque from collections.abc import AsyncGenerator, AsyncIterable -from concurrent.futures import Future as ConcurrentFuture from copy import deepcopy -from functools import partial -from io import BufferedIOBase, BytesIO +from io import BytesIO from typing import TYPE_CHECKING, Any if TYPE_CHECKING: - # Both of these are types that essentially are "castable to bytes/memoryview" - # Unfortunately they're not exposed anywhere so we have to import them from - # _typeshed. - from _typeshed import ReadableBuffer, WriteableBuffer # pyright doesn't like optional imports. This is reasonable because if we use these # in type hints then they'd result in runtime errors. # TODO: add integ tests that import these without the dependendency installed - from awscrt import http as crt_http + from awscrt import http_asyncio as crt_http + from awscrt import http as crt_http_base from awscrt import io as crt_io try: - from awscrt import http as crt_http + from awscrt import http_asyncio as crt_http + from awscrt import http as crt_http_base from awscrt import io as crt_io HAS_CRT = True @@ -63,11 +57,11 @@ def _initialize_default_loop(self) -> "crt_io.ClientBootstrap": class AWSCRTHTTPResponse(http_aio_interfaces.HTTPResponse): - def __init__(self, *, status: int, fields: Fields, body: "CRTResponseBody") -> None: + def __init__(self, *, status: int, fields: Fields, stream: "crt_http.HttpClientStreamAsync") -> None: _assert_crt() self._status = status self._fields = fields - self._body = body + self._stream = stream @property def status(self) -> int: @@ -89,7 +83,7 @@ def reason(self) -> str | None: async def chunks(self) -> AsyncGenerator[bytes, None]: while True: - chunk = await self._body.next() + chunk = await self._stream.get_next_response_chunk() if chunk: yield chunk else: @@ -103,95 +97,6 @@ def __repr__(self) -> str: ) -class CRTResponseBody: - def __init__(self) -> None: - self._stream: crt_http.HttpClientStream | None = None - self._completion_future: AsyncFuture[int] | None = None - self._chunk_futures: deque[ConcurrentFuture[bytes]] = deque() - - # deque is thread safe and the crt is only going to be writing - # with one thread anyway, so we *shouldn't* need to gate this - # behind a lock. In an ideal world, the CRT would expose - # an interface that better matches python's async. - self._received_chunks: deque[bytes] = deque() - - def set_stream(self, stream: "crt_http.HttpClientStream") -> None: - if self._stream is not None: - raise SmithyHTTPException("Stream already set on AWSCRTHTTPResponse object") - self._stream = stream - concurrent_future: ConcurrentFuture[int] = stream.completion_future - self._completion_future = asyncio.wrap_future(concurrent_future) - self._completion_future.add_done_callback(self._on_complete) - self._stream.activate() - - def on_body(self, chunk: bytes, **kwargs: Any) -> None: # pragma: crt-callback - # TODO: update back pressure window once CRT supports it - if self._chunk_futures: - future = self._chunk_futures.popleft() - future.set_result(chunk) - else: - self._received_chunks.append(chunk) - - async def next(self) -> bytes: - if self._completion_future is None: - raise SmithyHTTPException("Stream not set") - - # TODO: update backpressure window once CRT supports it - if self._received_chunks: - return self._received_chunks.popleft() - elif self._completion_future.done(): - return b"" - else: - future = ConcurrentFuture[bytes]() - self._chunk_futures.append(future) - return await asyncio.wrap_future(future) - - def _on_complete( - self, completion_future: AsyncFuture[int] - ) -> None: # pragma: crt-callback - for future in self._chunk_futures: - future.set_result(b"") - self._chunk_futures.clear() - - -class CRTResponseFactory: - def __init__(self, body: CRTResponseBody) -> None: - self._body = body - self._response_future = ConcurrentFuture[AWSCRTHTTPResponse]() - - def on_response( - self, status_code: int, headers: list[tuple[str, str]], **kwargs: Any - ) -> None: # pragma: crt-callback - fields = Fields() - for header_name, header_val in headers: - try: - fields[header_name].add(header_val) - except KeyError: - fields[header_name] = Field( - name=header_name, - values=[header_val], - kind=FieldPosition.HEADER, - ) - - self._response_future.set_result( - AWSCRTHTTPResponse( - status=status_code, - fields=fields, - body=self._body, - ) - ) - - async def await_response(self) -> AWSCRTHTTPResponse: - return await asyncio.wrap_future(self._response_future) - - def set_done_callback(self, stream: "crt_http.HttpClientStream") -> None: - stream.completion_future.add_done_callback(self._cancel) - - def _cancel(self, completion_future: ConcurrentFuture[int | Exception]) -> None: - if not self._response_future.done(): - self._response_future.cancel() - - ConnectionPoolKey = tuple[str, str, int | None] ConnectionPoolDict = dict[ConnectionPoolKey, "crt_http.HttpClientConnection"] @@ -236,45 +141,72 @@ async def send( :param request: The request including destination URI, fields, payload. :param request_config: Configuration specific to this request. """ - crt_request, crt_body = await self._marshal_request(request) + crt_request = self._marshal_request(request) connection = await self._get_connection(request.destination) - response_body = CRTResponseBody() - response_factory = CRTResponseFactory(response_body) + crt_stream = connection.request( crt_request, - response_factory.on_response, - response_body.on_body, - ) - response_factory.set_done_callback(crt_stream) - response_body.set_stream(crt_stream) - crt_stream.completion_future.add_done_callback( - partial(self._close_input_body, body=crt_body) + manual_write=True # allow manual stream write. ) - response = await response_factory.await_response() - if response.status != 200 and response.status >= 300: - await close(crt_body) + body = request.body + if isinstance(body, bytes | bytearray): + # If the body is already directly in memory, wrap in a BytesIO to hand + # off to CRT. + crt_body = BytesIO(body) + await crt_stream.write_data_async(crt_body, True) + else: + # If the body is async, or potentially very large, start up a task to read + # it into the intermediate object that CRT needs. By using + # asyncio.create_task we'll start the coroutine without having to + # explicitly await it. + + if not isinstance(body, AsyncIterable): + body = AsyncBytesReader(body) - return response + # Start the read task in the background. + read_task = asyncio.create_task( + self._consume_body_async(body, crt_stream)) - def _close_input_body( - self, future: ConcurrentFuture[int], *, body: "BufferableByteStream | BytesIO" - ) -> None: - if future.exception(timeout=0): - body.close() + # Keep track of the read task so that it doesn't get garbage colllected, + # and stop tracking it once it's done. + self._async_reads.add(read_task) + read_task.add_done_callback(self._async_reads.discard) + + return await self._await_response(crt_stream) + + async def _await_response( + self, stream: "crt_http.HttpClientStreamAsync" + ) -> AWSCRTHTTPResponse: + status_code = await stream.get_response_status_code() + headers = await stream.get_response_headers() + fields = Fields() + for header_name, header_val in headers: + try: + fields[header_name].add(header_val) + except KeyError: + fields[header_name] = Field( + name=header_name, + values=[header_val], + kind=FieldPosition.HEADER, + ) + return AWSCRTHTTPResponse( + status=status_code, + fields=fields, + stream=stream, + ) async def _create_connection( self, url: core_interfaces.URI - ) -> "crt_http.HttpClientConnection": + ) -> "crt_http.Http2ClientConnectionAsync": """Builds and validates connection to ``url``""" - connect_future = self._build_new_connection(url) - connection = await asyncio.wrap_future(connect_future) + connection = await self._build_new_connection(url) self._validate_connection(connection) return connection async def _get_connection( self, url: core_interfaces.URI - ) -> "crt_http.HttpClientConnection": + ) -> "crt_http.Http2ClientConnectionAsync": # TODO: Use CRT connection pooling instead of this basic kind connection_key = (url.scheme, url.host, url.port) connection = self._connections.get(connection_key) @@ -286,9 +218,9 @@ async def _get_connection( self._connections[connection_key] = connection return connection - def _build_new_connection( + async def _build_new_connection( self, url: core_interfaces.URI - ) -> ConcurrentFuture["crt_http.HttpClientConnection"]: + ) -> "crt_http.Http2ClientConnectionAsync": if url.scheme == "http": port = self._HTTP_PORT tls_connection_options = None @@ -304,17 +236,14 @@ def _build_new_connection( ) if url.port is not None: port = url.port - - connect_future: ConcurrentFuture[crt_http.HttpClientConnection] = ( - crt_http.HttpClientConnection.new( - bootstrap=self._client_bootstrap, - host_name=url.host, - port=port, - socket_options=self._socket_options, - tls_connection_options=tls_connection_options, - ) + # TODO: support HTTP/1,1 connections + return await crt_http.Http2ClientConnectionAsync.new( + bootstrap=self._client_bootstrap, + host_name=url.host, + port=port, + socket_options=self._socket_options, + tls_connection_options=tls_connection_options, ) - return connect_future def _validate_connection(self, connection: "crt_http.HttpClientConnection") -> None: """Validates an existing connection against the client config. @@ -326,16 +255,17 @@ def _validate_connection(self, connection: "crt_http.HttpClientConnection") -> N if force_http_2 and connection.version is not crt_http.HttpVersion.Http2: connection.close() negotiated = crt_http.HttpVersion(connection.version).name - raise SmithyHTTPException(f"HTTP/2 could not be negotiated: {negotiated}") + raise SmithyHTTPException( + f"HTTP/2 could not be negotiated: {negotiated}") def _render_path(self, url: core_interfaces.URI) -> str: path = url.path if url.path is not None else "/" query = f"?{url.query}" if url.query is not None else "" return f"{path}{query}" - async def _marshal_request( + def _marshal_request( self, request: http_aio_interfaces.HTTPRequest - ) -> tuple["crt_http.HttpRequest", "BufferableByteStream | BytesIO"]: + ) -> "crt_http_base.HttpRequest": """Create :py:class:`awscrt.http.HttpRequest` from :py:class:`smithy_http.aio.HTTPRequest`""" headers_list = [] @@ -355,139 +285,29 @@ async def _marshal_request( headers_list.append((fld.name, val)) path = self._render_path(request.destination) - headers = crt_http.HttpHeaders(headers_list) - - body = request.body - if isinstance(body, bytes | bytearray): - # If the body is already directly in memory, wrap in a BytesIO to hand - # off to CRT. - crt_body = BytesIO(body) - else: - # If the body is async, or potentially very large, start up a task to read - # it into the intermediate object that CRT needs. By using - # asyncio.create_task we'll start the coroutine without having to - # explicitly await it. - crt_body = BufferableByteStream() + headers = crt_http_base.HttpHeaders(headers_list) - if not isinstance(body, AsyncIterable): - body = AsyncBytesReader(body) - - # Start the read task in the background. - read_task = asyncio.create_task(self._consume_body_async(body, crt_body)) - - # Keep track of the read task so that it doesn't get garbage colllected, - # and stop tracking it once it's done. - self._async_reads.add(read_task) - read_task.add_done_callback(self._async_reads.discard) - - crt_request = crt_http.HttpRequest( + crt_request = crt_http_base.HttpRequest( method=request.method, path=path, headers=headers, - body_stream=crt_body, ) - return crt_request, crt_body + return crt_request async def _consume_body_async( - self, source: AsyncIterable[bytes], dest: "BufferableByteStream" + self, source: AsyncIterable[bytes], dest: "crt_http.HttpClientStreamAsync" ) -> None: try: async for chunk in source: - dest.write(chunk) + await dest.write_data_async(BytesIO(chunk), False) except Exception: - dest.close() raise finally: + await dest.write_data_async(BytesIO(b''), True) await close(source) - dest.end_stream() def __deepcopy__(self, memo: Any) -> "AWSCRTHTTPClient": return AWSCRTHTTPClient( eventloop=self._eventloop, client_config=deepcopy(self._config), ) - - -# This is adapted from the transcribe streaming sdk -class BufferableByteStream(BufferedIOBase): - """A non-blocking bytes buffer.""" - - def __init__(self) -> None: - # We're always manipulating the front and back of the buffer, so a deque - # will be much more efficient than a list. - self._chunks: deque[bytes] = deque() - self._closed = False - self._done = False - - def read(self, size: int | None = -1) -> bytes: - if self._closed: - return b"" - - if len(self._chunks) == 0: - if self._done: - self.close() - return b"" - else: - # When the CRT recieves this, it'll try again - raise BlockingIOError("read") - - # We could compile all the chunks here instead of just returning - # the one, BUT the CRT will keep calling read until empty bytes - # are returned. So it's actually better to just return one chunk - # since combining them would have some potentially bad memory - # usage issues. - result = self._chunks.popleft() - if size is not None and size > 0: - remainder = result[size:] - result = result[:size] - if remainder: - self._chunks.appendleft(remainder) - - if self._done and len(self._chunks) == 0: - self.close() - - return result - - def read1(self, size: int = -1) -> bytes: - return self.read(size) - - def readinto(self, buffer: "WriteableBuffer") -> int: - if not isinstance(buffer, memoryview): - buffer = memoryview(buffer).cast("B") - - data = self.read(len(buffer)) # type: ignore - n = len(data) - buffer[:n] = data - return n - - def write(self, buffer: "ReadableBuffer") -> int: - if not isinstance(buffer, bytes): - raise ValueError( - f"Unexpected value written to BufferableByteStream. " - f"Only bytes are support but {type(buffer)} was provided." - ) - - if self._closed: - raise OSError("Stream is completed and doesn't support further writes.") - - if buffer: - self._chunks.append(buffer) - return len(buffer) - - @property - def closed(self) -> bool: - return self._closed - - def close(self) -> None: - self._closed = True - self._done = True - - # Clear out the remaining chunks so that they don't sit around in memory. - self._chunks.clear() - - def end_stream(self) -> None: - """End the stream, letting any remaining chunks be read before it is closed.""" - if len(self._chunks) == 0: - self.close() - else: - self._done = True