Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add KeepAliveTimeout support to WebSocketMiddleware #57293

Merged
merged 3 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/Http/Http.Features/src/PublicAPI.Unshipped.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#nullable enable
Microsoft.AspNetCore.Http.Features.IHttpMetricsTagsFeature.MetricsDisabled.get -> bool
Microsoft.AspNetCore.Http.Features.IHttpMetricsTagsFeature.MetricsDisabled.set -> void
Microsoft.AspNetCore.Http.WebSocketAcceptContext.KeepAliveTimeout.get -> System.TimeSpan?
Microsoft.AspNetCore.Http.WebSocketAcceptContext.KeepAliveTimeout.set -> void
26 changes: 25 additions & 1 deletion src/Http/Http.Features/src/WebSocketAcceptContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,41 @@ namespace Microsoft.AspNetCore.Http;
public class WebSocketAcceptContext
{
private int _serverMaxWindowBits = 15;
private TimeSpan? _keepAliveTimeout;

/// <summary>
/// Gets or sets the subprotocol being negotiated.
/// </summary>
public virtual string? SubProtocol { get; set; }

/// <summary>
/// The interval to send pong frames. This is a heart-beat that keeps the connection alive.
/// The interval to send keep-alive frames. This is a heart-beat that keeps the connection alive.
/// </summary>
BrennanConroy marked this conversation as resolved.
Show resolved Hide resolved
public virtual TimeSpan? KeepAliveInterval { get; set; }

/// <summary>
/// The time to wait for a Pong frame response after sending a Ping frame. If the time is exceeded the websocket will be aborted.
/// </summary>
/// <remarks>
/// <c>null</c> means use the value from <c>WebSocketOptions.KeepAliveTimeout</c>.
amcasey marked this conversation as resolved.
Show resolved Hide resolved
/// <see cref="Timeout.InfiniteTimeSpan"/> and <see cref="TimeSpan.Zero"/> are valid values and will disable the timeout.
/// </remarks>
/// <exception cref="ArgumentOutOfRangeException">
/// <see cref="TimeSpan"/> is less than <see cref="TimeSpan.Zero"/>.
/// </exception>
public TimeSpan? KeepAliveTimeout
{
get => _keepAliveTimeout;
set
{
if (value is not null && value != Timeout.InfiniteTimeSpan)
amcasey marked this conversation as resolved.
Show resolved Hide resolved
{
ArgumentOutOfRangeException.ThrowIfLessThan(value.Value, TimeSpan.Zero);
}
_keepAliveTimeout = value;
}
}

/// <summary>
/// Enables support for the 'permessage-deflate' WebSocket extension.<para />
/// Be aware that enabling compression over encrypted connections makes the application subject to CRIME/BREACH type attacks.
Expand Down
129 changes: 129 additions & 0 deletions src/Middleware/WebSockets/src/AbortStream.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Net.WebSockets;
using Microsoft.AspNetCore.Http;

namespace Microsoft.AspNetCore.WebSockets;

/// <summary>
/// Used in WebSocketMiddleware to wrap the HttpContext.Request.Body stream
BrennanConroy marked this conversation as resolved.
Show resolved Hide resolved
/// so that we can call HttpContext.Abort when the stream is disposed and the WebSocket is in the Aborted state.
/// The Stream provided by Kestrel (and maybe other servers) noops in Dispose as it doesn't know whether it's a graceful close or not
/// and can result in truncated responses if in the graceful case.
///
/// This handles explicit WebSocket.Abort calls as well as the Keep-Alive timeout setting Aborted and disposing the stream.
/// </summary>
/// <remarks>
/// Workaround for https://github.com/dotnet/runtime/issues/44272
/// </remarks>
internal sealed class AbortStream : Stream
{
private readonly Stream _innerStream;
private readonly HttpContext _httpContext;

public WebSocket? WebSocket { get; set; }

public AbortStream(HttpContext httpContext, Stream innerStream)
{
_innerStream = innerStream;
_httpContext = httpContext;
}

public override bool CanRead => _innerStream.CanRead;

public override bool CanSeek => _innerStream.CanSeek;

public override bool CanWrite => _innerStream.CanWrite;

public override bool CanTimeout => _innerStream.CanTimeout;

public override long Length => _innerStream.Length;

public override long Position { get => _innerStream.Position; set => _innerStream.Position = value; }

public override void Flush()
{
_innerStream.Flush();
}

public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
return _innerStream.ReadAsync(buffer, offset, count, cancellationToken);
}

public override ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
{
return _innerStream.ReadAsync(buffer, cancellationToken);
}

public override int Read(byte[] buffer, int offset, int count)
{
return _innerStream.Read(buffer, offset, count);
}

public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state)
{
return _innerStream.BeginRead(buffer, offset, count, callback, state);
}

public override int EndRead(IAsyncResult asyncResult)
{
return _innerStream.EndRead(asyncResult);
}

public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state)
{
return _innerStream.BeginWrite(buffer, offset, count, callback, state);
}

public override void EndWrite(IAsyncResult asyncResult)
{
_innerStream.EndWrite(asyncResult);
}

public override long Seek(long offset, SeekOrigin origin)
{
return _innerStream.Seek(offset, origin);
}

public override void SetLength(long value)
{
_innerStream.SetLength(value);
}

public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
return _innerStream.WriteAsync(buffer, offset, count, cancellationToken);
}

public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
{
return _innerStream.WriteAsync(buffer, cancellationToken);
}

public override void Write(byte[] buffer, int offset, int count)
{
_innerStream.Write(buffer, offset, count);
}

public override Task FlushAsync(CancellationToken cancellationToken)
{
return _innerStream.FlushAsync(cancellationToken);
}

public override ValueTask DisposeAsync()
{
return _innerStream.DisposeAsync();
}

protected override void Dispose(bool disposing)
{
// Currently, if ManagedWebSocket sets the Aborted state it calls Stream.Dispose after
if (WebSocket?.State == WebSocketState.Aborted)
{
_httpContext.Abort();
}
_innerStream.Dispose();
}
}
2 changes: 2 additions & 0 deletions src/Middleware/WebSockets/src/PublicAPI.Unshipped.txt
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
#nullable enable
Microsoft.AspNetCore.Builder.WebSocketOptions.KeepAliveTimeout.get -> System.TimeSpan
Microsoft.AspNetCore.Builder.WebSocketOptions.KeepAliveTimeout.set -> void
80 changes: 0 additions & 80 deletions src/Middleware/WebSockets/src/ServerWebSocket.cs

This file was deleted.

9 changes: 7 additions & 2 deletions src/Middleware/WebSockets/src/WebSocketMiddleware.cs
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,15 @@ public async Task<WebSocket> AcceptAsync(WebSocketAcceptContext acceptContext)
bool serverContextTakeover = true;
int serverMaxWindowBits = 15;
TimeSpan keepAliveInterval = _options.KeepAliveInterval;
TimeSpan keepAliveTimeout = _options.KeepAliveTimeout;
if (acceptContext != null)
{
subProtocol = acceptContext.SubProtocol;
enableCompression = acceptContext.DangerousEnableCompression;
serverContextTakeover = !acceptContext.DisableServerContextTakeover;
serverMaxWindowBits = acceptContext.ServerMaxWindowBits;
keepAliveInterval = acceptContext.KeepAliveInterval ?? keepAliveInterval;
keepAliveTimeout = acceptContext.KeepAliveTimeout ?? keepAliveTimeout;
}

#pragma warning disable CS0618 // Type or member is obsolete
Expand Down Expand Up @@ -208,15 +210,18 @@ public async Task<WebSocket> AcceptAsync(WebSocketAcceptContext acceptContext)
// Disable request timeout, if there is one, after the websocket has been accepted
_context.Features.Get<IHttpRequestTimeoutFeature>()?.DisableTimeout();

var wrappedSocket = WebSocket.CreateFromStream(opaqueTransport, new WebSocketCreationOptions()
var abortStream = new AbortStream(_context, opaqueTransport);
var wrappedSocket = WebSocket.CreateFromStream(abortStream, new WebSocketCreationOptions()
{
IsServer = true,
KeepAliveInterval = keepAliveInterval,
KeepAliveTimeout = keepAliveTimeout,
SubProtocol = subProtocol,
DangerousDeflateOptions = deflateOptions
});

return new ServerWebSocket(wrappedSocket, _context);
abortStream.WebSocket = wrappedSocket;
return wrappedSocket;
}

public static bool CheckSupportedWebSocketRequest(string method, IHeaderDictionary requestHeaders)
Expand Down
28 changes: 28 additions & 0 deletions src/Middleware/WebSockets/src/WebSocketOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ namespace Microsoft.AspNetCore.Builder;
/// </summary>
public class WebSocketOptions
{
private TimeSpan _keepAliveTimeout = Timeout.InfiniteTimeSpan;

/// <summary>
/// Constructs the <see cref="WebSocketOptions"/> class with default values.
/// </summary>
Expand All @@ -23,6 +25,32 @@ public WebSocketOptions()
/// </summary>
BrennanConroy marked this conversation as resolved.
Show resolved Hide resolved
public TimeSpan KeepAliveInterval { get; set; }

/// <summary>
/// The time to wait for a Pong frame response after sending a Ping frame. If the time is exceeded the websocket will be aborted.
/// </summary>
/// <remarks>
/// Default value is <see cref="Timeout.InfiniteTimeSpan"/>.
/// <see cref="Timeout.InfiniteTimeSpan"/> and <see cref="TimeSpan.Zero"/> will disable the timeout.
/// </remarks>
/// <exception cref="ArgumentOutOfRangeException">
/// <see cref="TimeSpan"/> is less than <see cref="TimeSpan.Zero"/>.
/// </exception>
public TimeSpan KeepAliveTimeout
{
get
{
return _keepAliveTimeout;
}
set
{
if (value != Timeout.InfiniteTimeSpan)
{
ArgumentOutOfRangeException.ThrowIfLessThan(value, TimeSpan.Zero);
}
_keepAliveTimeout = value;
}
}

/// <summary>
/// Gets or sets the size of the protocol buffer used to receive and parse frames.
/// The default is 4kb.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Licensed to the .NET Foundation under one or more agreements.
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using Microsoft.AspNetCore.Builder;
Expand All @@ -17,14 +17,22 @@ public void AddWebSocketsConfiguresOptions()
serviceCollection.AddWebSockets(o =>
{
o.KeepAliveInterval = TimeSpan.FromSeconds(1000);
o.KeepAliveTimeout = TimeSpan.FromSeconds(1234);
o.AllowedOrigins.Add("someString");
});

var services = serviceCollection.BuildServiceProvider();
var socketOptions = services.GetRequiredService<IOptions<WebSocketOptions>>().Value;

Assert.Equal(TimeSpan.FromSeconds(1000), socketOptions.KeepAliveInterval);
Assert.Equal(TimeSpan.FromSeconds(1234), socketOptions.KeepAliveTimeout);
Assert.Single(socketOptions.AllowedOrigins);
Assert.Equal("someString", socketOptions.AllowedOrigins[0]);
}

[Fact]
public void ThrowsForBadOptions()
{
Assert.Throws<ArgumentOutOfRangeException>(() => new WebSocketOptions() { KeepAliveTimeout = TimeSpan.FromMicroseconds(-1) });
}
halter73 marked this conversation as resolved.
Show resolved Hide resolved
}
Loading
Loading