diff --git a/src/Http/Http.Features/src/PublicAPI.Unshipped.txt b/src/Http/Http.Features/src/PublicAPI.Unshipped.txt index 556a81c6fed3..998915f4d291 100644 --- a/src/Http/Http.Features/src/PublicAPI.Unshipped.txt +++ b/src/Http/Http.Features/src/PublicAPI.Unshipped.txt @@ -1,3 +1,5 @@ #nullable enable Microsoft.AspNetCore.Http.Features.IHttpMetricsTagsFeature.MetricsDisabled.get -> bool Microsoft.AspNetCore.Http.Features.IHttpMetricsTagsFeature.MetricsDisabled.set -> void +Microsoft.AspNetCore.Http.WebSocketAcceptContext.KeepAliveTimeout.get -> System.TimeSpan? +Microsoft.AspNetCore.Http.WebSocketAcceptContext.KeepAliveTimeout.set -> void diff --git a/src/Http/Http.Features/src/WebSocketAcceptContext.cs b/src/Http/Http.Features/src/WebSocketAcceptContext.cs index 400ae09732c5..acb6602c4522 100644 --- a/src/Http/Http.Features/src/WebSocketAcceptContext.cs +++ b/src/Http/Http.Features/src/WebSocketAcceptContext.cs @@ -11,6 +11,7 @@ namespace Microsoft.AspNetCore.Http; public class WebSocketAcceptContext { private int _serverMaxWindowBits = 15; + private TimeSpan? _keepAliveTimeout; /// /// Gets or sets the subprotocol being negotiated. @@ -18,10 +19,36 @@ public class WebSocketAcceptContext public virtual string? SubProtocol { get; set; } /// - /// The interval to send pong frames. This is a heart-beat that keeps the connection alive. + /// The interval to send keep-alive frames. This is a heart-beat that keeps the connection alive. /// + /// + /// May be either a Ping or a Pong frame, depending on if is set. + /// public virtual TimeSpan? KeepAliveInterval { get; set; } + /// + /// The time to wait for a Pong frame response after sending a Ping frame. If the time is exceeded the websocket will be aborted. + /// + /// + /// null means use the value from WebSocketOptions.KeepAliveTimeout. + /// and are valid values and will disable the timeout. + /// + /// + /// is less than . + /// + public TimeSpan? KeepAliveTimeout + { + get => _keepAliveTimeout; + set + { + if (value is not null && value != Timeout.InfiniteTimeSpan) + { + ArgumentOutOfRangeException.ThrowIfLessThan(value.Value, TimeSpan.Zero); + } + _keepAliveTimeout = value; + } + } + /// /// Enables support for the 'permessage-deflate' WebSocket extension. /// Be aware that enabling compression over encrypted connections makes the application subject to CRIME/BREACH type attacks. diff --git a/src/Middleware/WebSockets/src/AbortStream.cs b/src/Middleware/WebSockets/src/AbortStream.cs new file mode 100644 index 000000000000..1fbcce4b8a83 --- /dev/null +++ b/src/Middleware/WebSockets/src/AbortStream.cs @@ -0,0 +1,129 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net.WebSockets; +using Microsoft.AspNetCore.Http; + +namespace Microsoft.AspNetCore.WebSockets; + +/// +/// Used in to wrap the .Request.Body stream +/// so that we can call when the stream is disposed and the WebSocket is in the state. +/// The Stream provided by Kestrel (and maybe other servers) noops in Dispose as it doesn't know whether it's a graceful close or not +/// and can result in truncated responses if in the graceful case. +/// +/// This handles explicit calls as well as the Keep-Alive timeout setting and disposing the stream. +/// +/// +/// Workaround for https://github.com/dotnet/runtime/issues/44272 +/// +internal sealed class AbortStream : Stream +{ + private readonly Stream _innerStream; + private readonly HttpContext _httpContext; + + public WebSocket? WebSocket { get; set; } + + public AbortStream(HttpContext httpContext, Stream innerStream) + { + _innerStream = innerStream; + _httpContext = httpContext; + } + + public override bool CanRead => _innerStream.CanRead; + + public override bool CanSeek => _innerStream.CanSeek; + + public override bool CanWrite => _innerStream.CanWrite; + + public override bool CanTimeout => _innerStream.CanTimeout; + + public override long Length => _innerStream.Length; + + public override long Position { get => _innerStream.Position; set => _innerStream.Position = value; } + + public override void Flush() + { + _innerStream.Flush(); + } + + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return _innerStream.ReadAsync(buffer, offset, count, cancellationToken); + } + + public override ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + { + return _innerStream.ReadAsync(buffer, cancellationToken); + } + + public override int Read(byte[] buffer, int offset, int count) + { + return _innerStream.Read(buffer, offset, count); + } + + public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state) + { + return _innerStream.BeginRead(buffer, offset, count, callback, state); + } + + public override int EndRead(IAsyncResult asyncResult) + { + return _innerStream.EndRead(asyncResult); + } + + public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state) + { + return _innerStream.BeginWrite(buffer, offset, count, callback, state); + } + + public override void EndWrite(IAsyncResult asyncResult) + { + _innerStream.EndWrite(asyncResult); + } + + public override long Seek(long offset, SeekOrigin origin) + { + return _innerStream.Seek(offset, origin); + } + + public override void SetLength(long value) + { + _innerStream.SetLength(value); + } + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return _innerStream.WriteAsync(buffer, offset, count, cancellationToken); + } + + public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) + { + return _innerStream.WriteAsync(buffer, cancellationToken); + } + + public override void Write(byte[] buffer, int offset, int count) + { + _innerStream.Write(buffer, offset, count); + } + + public override Task FlushAsync(CancellationToken cancellationToken) + { + return _innerStream.FlushAsync(cancellationToken); + } + + public override ValueTask DisposeAsync() + { + return _innerStream.DisposeAsync(); + } + + protected override void Dispose(bool disposing) + { + // Currently, if ManagedWebSocket sets the Aborted state it calls Stream.Dispose after + if (WebSocket?.State == WebSocketState.Aborted) + { + _httpContext.Abort(); + } + _innerStream.Dispose(); + } +} diff --git a/src/Middleware/WebSockets/src/PublicAPI.Unshipped.txt b/src/Middleware/WebSockets/src/PublicAPI.Unshipped.txt index 7dc5c58110bf..913e2d3dc9f3 100644 --- a/src/Middleware/WebSockets/src/PublicAPI.Unshipped.txt +++ b/src/Middleware/WebSockets/src/PublicAPI.Unshipped.txt @@ -1 +1,3 @@ #nullable enable +Microsoft.AspNetCore.Builder.WebSocketOptions.KeepAliveTimeout.get -> System.TimeSpan +Microsoft.AspNetCore.Builder.WebSocketOptions.KeepAliveTimeout.set -> void diff --git a/src/Middleware/WebSockets/src/ServerWebSocket.cs b/src/Middleware/WebSockets/src/ServerWebSocket.cs deleted file mode 100644 index 70be31cb0459..000000000000 --- a/src/Middleware/WebSockets/src/ServerWebSocket.cs +++ /dev/null @@ -1,80 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Net.WebSockets; -using Microsoft.AspNetCore.Http; - -namespace Microsoft.AspNetCore.WebSockets; - -/// -/// Used in ASP.NET Core to wrap a WebSocket with its associated HttpContext so that when the WebSocket is aborted -/// the underlying HttpContext is aborted. All other methods are delegated to the underlying WebSocket. -/// -internal sealed class ServerWebSocket : WebSocket -{ - private readonly WebSocket _wrappedSocket; - private readonly HttpContext _context; - - internal ServerWebSocket(WebSocket wrappedSocket, HttpContext context) - { - ArgumentNullException.ThrowIfNull(wrappedSocket); - ArgumentNullException.ThrowIfNull(context); - - _wrappedSocket = wrappedSocket; - _context = context; - } - - public override WebSocketCloseStatus? CloseStatus => _wrappedSocket.CloseStatus; - - public override string? CloseStatusDescription => _wrappedSocket.CloseStatusDescription; - - public override WebSocketState State => _wrappedSocket.State; - - public override string? SubProtocol => _wrappedSocket.SubProtocol; - - public override void Abort() - { - _wrappedSocket.Abort(); - _context.Abort(); - } - - public override Task CloseAsync(WebSocketCloseStatus closeStatus, string? statusDescription, CancellationToken cancellationToken) - { - return _wrappedSocket.CloseAsync(closeStatus, statusDescription, cancellationToken); - } - - public override Task CloseOutputAsync(WebSocketCloseStatus closeStatus, string? statusDescription, CancellationToken cancellationToken) - { - return _wrappedSocket.CloseOutputAsync(closeStatus, statusDescription, cancellationToken); - } - - public override void Dispose() - { - _wrappedSocket.Dispose(); - } - - public override Task ReceiveAsync(ArraySegment buffer, CancellationToken cancellationToken) - { - return _wrappedSocket.ReceiveAsync(buffer, cancellationToken); - } - - public override ValueTask ReceiveAsync(Memory buffer, CancellationToken cancellationToken) - { - return _wrappedSocket.ReceiveAsync(buffer, cancellationToken); - } - - public override ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) - { - return _wrappedSocket.SendAsync(buffer, messageType, endOfMessage, cancellationToken); - } - - public override ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, WebSocketMessageFlags messageFlags, CancellationToken cancellationToken) - { - return _wrappedSocket.SendAsync(buffer, messageType, messageFlags, cancellationToken); - } - - public override Task SendAsync(ArraySegment buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) - { - return _wrappedSocket.SendAsync(buffer, messageType, endOfMessage, cancellationToken); - } -} diff --git a/src/Middleware/WebSockets/src/WebSocketMiddleware.cs b/src/Middleware/WebSockets/src/WebSocketMiddleware.cs index 10a28cb0c667..a1dedab5beb0 100644 --- a/src/Middleware/WebSockets/src/WebSocketMiddleware.cs +++ b/src/Middleware/WebSockets/src/WebSocketMiddleware.cs @@ -141,6 +141,7 @@ public async Task AcceptAsync(WebSocketAcceptContext acceptContext) bool serverContextTakeover = true; int serverMaxWindowBits = 15; TimeSpan keepAliveInterval = _options.KeepAliveInterval; + TimeSpan keepAliveTimeout = _options.KeepAliveTimeout; if (acceptContext != null) { subProtocol = acceptContext.SubProtocol; @@ -148,6 +149,7 @@ public async Task AcceptAsync(WebSocketAcceptContext acceptContext) serverContextTakeover = !acceptContext.DisableServerContextTakeover; serverMaxWindowBits = acceptContext.ServerMaxWindowBits; keepAliveInterval = acceptContext.KeepAliveInterval ?? keepAliveInterval; + keepAliveTimeout = acceptContext.KeepAliveTimeout ?? keepAliveTimeout; } #pragma warning disable CS0618 // Type or member is obsolete @@ -208,15 +210,18 @@ public async Task AcceptAsync(WebSocketAcceptContext acceptContext) // Disable request timeout, if there is one, after the websocket has been accepted _context.Features.Get()?.DisableTimeout(); - var wrappedSocket = WebSocket.CreateFromStream(opaqueTransport, new WebSocketCreationOptions() + var abortStream = new AbortStream(_context, opaqueTransport); + var wrappedSocket = WebSocket.CreateFromStream(abortStream, new WebSocketCreationOptions() { IsServer = true, KeepAliveInterval = keepAliveInterval, + KeepAliveTimeout = keepAliveTimeout, SubProtocol = subProtocol, DangerousDeflateOptions = deflateOptions }); - return new ServerWebSocket(wrappedSocket, _context); + abortStream.WebSocket = wrappedSocket; + return wrappedSocket; } public static bool CheckSupportedWebSocketRequest(string method, IHeaderDictionary requestHeaders) diff --git a/src/Middleware/WebSockets/src/WebSocketOptions.cs b/src/Middleware/WebSockets/src/WebSocketOptions.cs index 314184b7c953..1ea7a187be52 100644 --- a/src/Middleware/WebSockets/src/WebSocketOptions.cs +++ b/src/Middleware/WebSockets/src/WebSocketOptions.cs @@ -8,6 +8,8 @@ namespace Microsoft.AspNetCore.Builder; /// public class WebSocketOptions { + private TimeSpan _keepAliveTimeout = Timeout.InfiniteTimeSpan; + /// /// Constructs the class with default values. /// @@ -18,11 +20,40 @@ public WebSocketOptions() } /// - /// Gets or sets the frequency at which to send Ping/Pong keep-alive control frames. + /// The interval to send keep-alive frames. This is a heart-beat that keeps the connection alive. /// The default is two minutes. /// + /// + /// May be either a Ping or a Pong frame, depending on if is set. + /// public TimeSpan KeepAliveInterval { get; set; } + /// + /// The time to wait for a Pong frame response after sending a Ping frame. If the time is exceeded the websocket will be aborted. + /// + /// + /// Default value is . + /// and will disable the timeout. + /// + /// + /// is less than . + /// + public TimeSpan KeepAliveTimeout + { + get + { + return _keepAliveTimeout; + } + set + { + if (value != Timeout.InfiniteTimeSpan) + { + ArgumentOutOfRangeException.ThrowIfLessThan(value, TimeSpan.Zero); + } + _keepAliveTimeout = value; + } + } + /// /// Gets or sets the size of the protocol buffer used to receive and parse frames. /// The default is 4kb. diff --git a/src/Middleware/WebSockets/test/UnitTests/AddWebSocketsTests.cs b/src/Middleware/WebSockets/test/UnitTests/AddWebSocketsTests.cs index f505505cf0da..1c92c44ab580 100644 --- a/src/Middleware/WebSockets/test/UnitTests/AddWebSocketsTests.cs +++ b/src/Middleware/WebSockets/test/UnitTests/AddWebSocketsTests.cs @@ -1,4 +1,4 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. using Microsoft.AspNetCore.Builder; @@ -17,6 +17,7 @@ public void AddWebSocketsConfiguresOptions() serviceCollection.AddWebSockets(o => { o.KeepAliveInterval = TimeSpan.FromSeconds(1000); + o.KeepAliveTimeout = TimeSpan.FromSeconds(1234); o.AllowedOrigins.Add("someString"); }); @@ -24,7 +25,14 @@ public void AddWebSocketsConfiguresOptions() var socketOptions = services.GetRequiredService>().Value; Assert.Equal(TimeSpan.FromSeconds(1000), socketOptions.KeepAliveInterval); + Assert.Equal(TimeSpan.FromSeconds(1234), socketOptions.KeepAliveTimeout); Assert.Single(socketOptions.AllowedOrigins); Assert.Equal("someString", socketOptions.AllowedOrigins[0]); } + + [Fact] + public void ThrowsForBadOptions() + { + Assert.Throws(() => new WebSocketOptions() { KeepAliveTimeout = TimeSpan.FromMicroseconds(-1) }); + } } diff --git a/src/Middleware/WebSockets/test/UnitTests/WebSocketMiddlewareTests.cs b/src/Middleware/WebSockets/test/UnitTests/WebSocketMiddlewareTests.cs index 39e86eb6b579..81d93afb897d 100644 --- a/src/Middleware/WebSockets/test/UnitTests/WebSocketMiddlewareTests.cs +++ b/src/Middleware/WebSockets/test/UnitTests/WebSocketMiddlewareTests.cs @@ -795,6 +795,57 @@ public async Task AcceptingWebSocketRequestDisablesTimeout() } } + [Fact] + public async Task PingTimeoutCancelsReceiveAsync() + { + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => + { + try + { + Assert.True(context.WebSockets.IsWebSocketRequest); + var webSocket = await context.WebSockets.AcceptWebSocketAsync(); + await webSocket.ReceiveAsync(new byte[1], cancellationToken: default); + } + catch (Exception ex) + { + tcs.TrySetException(ex); + } + finally + { + tcs.TrySetResult(); + } + }, + o => + { + o.KeepAliveInterval = TimeSpan.FromMilliseconds(1); + o.KeepAliveTimeout = TimeSpan.FromMilliseconds(1); + })) + { + using (var client = new HttpClient()) + { + var uri = new UriBuilder(new Uri($"ws://127.0.0.1:{port}/")); + uri.Scheme = "http"; + + // Craft a valid WebSocket Upgrade request + using (var request = new HttpRequestMessage(HttpMethod.Get, uri.ToString())) + { + request.Headers.Connection.Clear(); + request.Headers.Connection.Add("Upgrade"); + request.Headers.Upgrade.Add(new System.Net.Http.Headers.ProductHeaderValue("websocket")); + request.Headers.Add(HeaderNames.SecWebSocketVersion, "13"); + // SecWebSocketKey required to be 16 bytes + request.Headers.Add(HeaderNames.SecWebSocketKey, Convert.ToBase64String(new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 }, Base64FormattingOptions.None)); + + var response = await client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead); + + var ex = await Assert.ThrowsAnyAsync(() => tcs.Task); + Assert.True(ex is ConnectionAbortedException or WebSocketException, ex.GetType().FullName); + } + } + } + } + internal sealed class HttpRequestTimeoutFeature : IHttpRequestTimeoutFeature { public bool Enabled { get; private set; } = true; diff --git a/src/Servers/Kestrel/test/Interop.FunctionalTests/Http2WebSocketInteropTests.cs b/src/Servers/Kestrel/test/Interop.FunctionalTests/Http2WebSocketInteropTests.cs index aaa385c64375..834c17ff3612 100644 --- a/src/Servers/Kestrel/test/Interop.FunctionalTests/Http2WebSocketInteropTests.cs +++ b/src/Servers/Kestrel/test/Interop.FunctionalTests/Http2WebSocketInteropTests.cs @@ -94,6 +94,75 @@ public async Task HttpVersionNegotationWorks(string scheme, string clientVersion await wsClient.CloseAsync(WebSocketCloseStatus.NormalClosure, "Client closed", default); } + [Fact] + public async Task PingTimeoutCancelsReceiveAsync() + { + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var hostBuilder = new HostBuilder() + .ConfigureWebHost(webHostBuilder => + { + ConfigureKestrel(webHostBuilder, "https", HttpProtocols.Http2); + webHostBuilder.ConfigureServices(AddTestLogging) + .Configure(app => + { + app.UseWebSockets(new WebSocketOptions() + { + KeepAliveInterval = TimeSpan.FromMilliseconds(1), + KeepAliveTimeout = TimeSpan.FromMilliseconds(1), + }); + app.Run(async context => + { + Assert.True(context.WebSockets.IsWebSocketRequest); + var ws = await context.WebSockets.AcceptWebSocketAsync(); + var bytes = new byte[1024]; + + try + { + var result = await ws.ReceiveAsync(bytes, default); + } + catch (Exception ex) + { + tcs.SetException(ex); + } + finally + { + tcs.TrySetResult(); + } + }); + }); + }); + using var host = await hostBuilder.StartAsync().DefaultTimeout(); + + var url = host.MakeUrl("wss"); + + var handler = new HttpClientHandler(); + handler.ServerCertificateCustomValidationCallback = HttpClientHandler.DangerousAcceptAnyServerCertificateValidator; + var pauseSendHandler = new PauseSendHandler(handler); + using var client = new HttpClient(pauseSendHandler); + + var wsClient = new ClientWebSocket(); + wsClient.Options.HttpVersion = Version.Parse("2.0"); + wsClient.Options.HttpVersionPolicy = HttpVersionPolicy.RequestVersionExact; + wsClient.Options.CollectHttpResponseDetails = true; + await wsClient.ConnectAsync(new Uri(url), client, default); + Assert.Equal(HttpStatusCode.OK, wsClient.HttpStatusCode); + + // Prevent Pong replies so we can test the server timing out + // It's fine if some Pongs were already sent before this is set + pauseSendHandler.PauseSend = true; + + var ex = await Assert.ThrowsAnyAsync(() => tcs.Task); + Assert.True(ex is WebSocketException || ex is TaskCanceledException, ex.GetType().FullName); + + // Unblock Send + pauseSendHandler.PauseSend = false; + + // Call any websocket method that tries networking so we have something to await to check that the client connection closed. + await Assert.ThrowsAnyAsync(() => wsClient.ReceiveAsync(new byte[1], default)); + + Assert.Equal(WebSocketState.Aborted, wsClient.State); + } + private static HttpClient CreateClient() { var handler = new HttpClientHandler(); @@ -116,4 +185,22 @@ private static void ConfigureKestrel(IWebHostBuilder webHostBuilder, string sche }); }); } + + public sealed class PauseSendHandler : DelegatingHandler + { + public bool PauseSend { get; set; } + + public PauseSendHandler(HttpClientHandler handler) : base(handler) + { + } + + protected override async Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) + { + while (PauseSend) + { + await Task.Delay(1); + } + return await base.SendAsync(request, cancellationToken); + } + } }