diff --git a/src/libraries/Common/src/System/Net/WebSockets/WebSocketDefaults.cs b/src/libraries/Common/src/System/Net/WebSockets/WebSocketDefaults.cs new file mode 100644 index 0000000000000..8f2c768fd4bc9 --- /dev/null +++ b/src/libraries/Common/src/System/Net/WebSockets/WebSocketDefaults.cs @@ -0,0 +1,19 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Threading; + +namespace System.Net.WebSockets +{ + /// + /// Central repository for default values used in WebSocket settings. Not all settings are relevant + /// to or configurable by all WebSocket implementations. + /// + internal static partial class WebSocketDefaults + { + public static readonly TimeSpan DefaultKeepAliveInterval = TimeSpan.Zero; + public static readonly TimeSpan DefaultClientKeepAliveInterval = TimeSpan.FromSeconds(30); + + public static readonly TimeSpan DefaultKeepAliveTimeout = Timeout.InfiniteTimeSpan; + } +} diff --git a/src/libraries/Common/src/System/Net/WebSockets/WebSocketValidate.cs b/src/libraries/Common/src/System/Net/WebSockets/WebSocketValidate.cs index e087677be4608..c11524bdeef9b 100644 --- a/src/libraries/Common/src/System/Net/WebSockets/WebSocketValidate.cs +++ b/src/libraries/Common/src/System/Net/WebSockets/WebSocketValidate.cs @@ -38,27 +38,26 @@ internal static partial class WebSocketValidate SearchValues.Create("!#$%&'*+-.0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ^_`abcdefghijklmnopqrstuvwxyz|~"); internal static void ThrowIfInvalidState(WebSocketState currentState, bool isDisposed, WebSocketState[] validStates) + => ThrowIfInvalidState(currentState, isDisposed, innerException: null, validStates ?? []); + + internal static void ThrowIfInvalidState(WebSocketState currentState, bool isDisposed, Exception? innerException, WebSocketState[]? validStates = null) { - string validStatesText = string.Empty; + if (validStates is not null && Array.IndexOf(validStates, currentState) == -1) + { + string invalidStateMessage = SR.Format( + SR.net_WebSockets_InvalidState, currentState, string.Join(", ", validStates)); + + throw new WebSocketException(WebSocketError.InvalidState, invalidStateMessage, innerException); + } - if (validStates != null && validStates.Length > 0) + if (innerException is not null) { - foreach (WebSocketState validState in validStates) - { - if (currentState == validState) - { - // Ordering is important to maintain .NET 4.5 WebSocket implementation exception behavior. - ObjectDisposedException.ThrowIf(isDisposed, typeof(WebSocket)); - return; - } - } - - validStatesText = string.Join(", ", validStates); + Debug.Assert(currentState == WebSocketState.Aborted); + throw new OperationCanceledException(nameof(WebSocketState.Aborted), innerException); } - throw new WebSocketException( - WebSocketError.InvalidState, - SR.Format(SR.net_WebSockets_InvalidState, currentState, validStatesText)); + // Ordering is important to maintain .NET 4.5 WebSocket implementation exception behavior. + ObjectDisposedException.ThrowIf(isDisposed, typeof(WebSocket)); } internal static void ValidateSubprotocol(string subProtocol) diff --git a/src/libraries/Common/tests/TestUtilities/TestEventListener.cs b/src/libraries/Common/tests/TestUtilities/TestEventListener.cs index 8cb70ee3cbd8c..b83a165906164 100644 --- a/src/libraries/Common/tests/TestUtilities/TestEventListener.cs +++ b/src/libraries/Common/tests/TestUtilities/TestEventListener.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Diagnostics.Tracing; using System.IO; using System.Text; @@ -31,6 +32,7 @@ public sealed class TestEventListener : EventListener "Private.InternalDiagnostics.System.Net.Sockets", "Private.InternalDiagnostics.System.Net.Security", "Private.InternalDiagnostics.System.Net.Quic", + "Private.InternalDiagnostics.System.Net.WebSockets", "Private.InternalDiagnostics.System.Net.Http.WinHttpHandler", "Private.InternalDiagnostics.System.Net.HttpListener", "Private.InternalDiagnostics.System.Net.Mail", @@ -41,19 +43,24 @@ public sealed class TestEventListener : EventListener private readonly Action _writeFunc; private readonly HashSet _sourceNames; + private readonly bool _enableActivityId; // Until https://github.com/dotnet/runtime/issues/63979 is solved. private List _eventSources = new List(); public TestEventListener(TextWriter output, params string[] sourceNames) - : this(str => output.WriteLine(str), sourceNames) + : this(output.WriteLine, sourceNames) { } public TestEventListener(ITestOutputHelper output, params string[] sourceNames) - : this(str => output.WriteLine(str), sourceNames) + : this(output.WriteLine, sourceNames) { } public TestEventListener(Action writeFunc, params string[] sourceNames) + : this(writeFunc, enableActivityId: false, sourceNames) + { } + + public TestEventListener(Action writeFunc, bool enableActivityId, params string[] sourceNames) { List eventSources = _eventSources; @@ -61,16 +68,14 @@ public TestEventListener(Action writeFunc, params string[] sourceNames) { _writeFunc = writeFunc; _sourceNames = new HashSet(sourceNames); + _enableActivityId = enableActivityId; _eventSources = null; } // eventSources were populated in the base ctor and are now owned by this thread, enable them now. foreach (EventSource eventSource in eventSources) { - if (_sourceNames.Contains(eventSource.Name)) - { - EnableEvents(eventSource, EventLevel.LogAlways); - } + EnableEventSource(eventSource); } } @@ -90,20 +95,42 @@ protected override void OnEventSourceCreated(EventSource eventSource) } // Second pass called after our ctor, allow logging for specified source names. + EnableEventSource(eventSource); + } + + private void EnableEventSource(EventSource eventSource) + { if (_sourceNames.Contains(eventSource.Name)) { EnableEvents(eventSource, EventLevel.LogAlways); } + else if (_enableActivityId && eventSource.Name == "System.Threading.Tasks.TplEventSource") + { + EnableEvents(eventSource, EventLevel.LogAlways, (EventKeywords)0x80 /* TasksFlowActivityIds */); + } } protected override void OnEventWritten(EventWrittenEventArgs eventData) { - StringBuilder sb = new StringBuilder(). + StringBuilder sb = new StringBuilder(); + #if NET || NETSTANDARD2_1_OR_GREATER - Append($"{eventData.TimeStamp:HH:mm:ss.fffffff}[{eventData.EventName}] "); -#else - Append($"[{eventData.EventName}] "); + sb.Append($"{eventData.TimeStamp:HH:mm:ss.fffffff}"); + if (_enableActivityId) + { + if (eventData.ActivityId != Guid.Empty) + { + string activityId = ActivityHelpers.ActivityPathString(eventData.ActivityId); + sb.Append($" {activityId} {new string('-', activityId.Length / 2 - 1 )} "); + } + else + { + sb.Append(" / "); + } + } #endif + sb.Append($"[{eventData.EventName}] "); + for (int i = 0; i < eventData.Payload?.Count; i++) { if (i > 0) @@ -116,4 +143,103 @@ protected override void OnEventWritten(EventWrittenEventArgs eventData) } catch { } } + + // From https://gist.github.com/MihaZupan/cc63ee68b4146892f2e5b640ed57bc09 + private static class ActivityHelpers + { + private enum NumberListCodes : byte + { + End = 0x0, + LastImmediateValue = 0xA, + PrefixCode = 0xB, + MultiByte1 = 0xC, + } + + public static unsafe bool IsActivityPath(Guid guid) + { + uint* uintPtr = (uint*)&guid; + uint sum = uintPtr[0] + uintPtr[1] + uintPtr[2] + 0x599D99AD; + return ((sum & 0xFFF00000) == (uintPtr[3] & 0xFFF00000)); + } + + public static unsafe string ActivityPathString(Guid guid) + => IsActivityPath(guid) ? CreateActivityPathString(guid) : guid.ToString(); + + internal static unsafe string CreateActivityPathString(Guid guid) + { + Debug.Assert(IsActivityPath(guid)); + + StringBuilder sb = new StringBuilder(); + + byte* bytePtr = (byte*)&guid; + byte* endPtr = bytePtr + 12; + char separator = '/'; + while (bytePtr < endPtr) + { + uint nibble = (uint)(*bytePtr >> 4); + bool secondNibble = false; + NextNibble: + if (nibble == (uint)NumberListCodes.End) + { + break; + } + if (nibble <= (uint)NumberListCodes.LastImmediateValue) + { + sb.Append('/').Append(nibble); + if (!secondNibble) + { + nibble = (uint)(*bytePtr & 0xF); + secondNibble = true; + goto NextNibble; + } + bytePtr++; + continue; + } + else if (nibble == (uint)NumberListCodes.PrefixCode) + { + if (!secondNibble) + { + nibble = (uint)(*bytePtr & 0xF); + } + else + { + bytePtr++; + if (endPtr <= bytePtr) + { + break; + } + nibble = (uint)(*bytePtr >> 4); + } + if (nibble < (uint)NumberListCodes.MultiByte1) + { + return guid.ToString(); + } + separator = '$'; + } + Debug.Assert((uint)NumberListCodes.MultiByte1 <= nibble); + uint numBytes = nibble - (uint)NumberListCodes.MultiByte1; + uint value = 0; + if (!secondNibble) + { + value = (uint)(*bytePtr & 0xF); + } + bytePtr++; + numBytes++; + if (endPtr < bytePtr + numBytes) + { + break; + } + for (int i = (int)numBytes - 1; 0 <= i; --i) + { + value = (value << 8) + bytePtr[i]; + } + sb.Append(separator).Append(value); + + bytePtr += numBytes; + } + + sb.Append('/'); + return sb.ToString(); + } + } } diff --git a/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs b/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs index edb4eb043bcb0..88bf41f23fbce 100644 --- a/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs +++ b/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs @@ -42,6 +42,8 @@ internal ClientWebSocketOptions() { } [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")] public System.TimeSpan KeepAliveInterval { get { throw null; } set { } } [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")] + public System.TimeSpan KeepAliveTimeout { get { throw null; } set { } } + [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")] public System.Net.WebSockets.WebSocketDeflateOptions? DangerousDeflateOptions { get { throw null; } set { } } [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")] public System.Net.IWebProxy? Proxy { get { throw null; } set { } } diff --git a/src/libraries/System.Net.WebSockets.Client/src/System.Net.WebSockets.Client.csproj b/src/libraries/System.Net.WebSockets.Client/src/System.Net.WebSockets.Client.csproj index a309737d6917d..8265edd7e9369 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System.Net.WebSockets.Client.csproj +++ b/src/libraries/System.Net.WebSockets.Client/src/System.Net.WebSockets.Client.csproj @@ -27,6 +27,7 @@ + @@ -47,6 +48,7 @@ + @@ -57,7 +59,6 @@ - diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/BrowserWebSockets/ClientWebSocketOptions.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/BrowserWebSockets/ClientWebSocketOptions.cs index 59096fc864d3a..aa8164d1099c2 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/BrowserWebSockets/ClientWebSocketOptions.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/BrowserWebSockets/ClientWebSocketOptions.cs @@ -122,6 +122,13 @@ public TimeSpan KeepAliveInterval set => throw new PlatformNotSupportedException(); } + [UnsupportedOSPlatform("browser")] + public TimeSpan KeepAliveTimeout + { + get => throw new PlatformNotSupportedException(); + set => throw new PlatformNotSupportedException(); + } + [UnsupportedOSPlatform("browser")] public WebSocketDeflateOptions? DangerousDeflateOptions { diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs index f78882f8b005d..dc7155e11cf58 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs @@ -14,7 +14,8 @@ namespace System.Net.WebSockets public sealed class ClientWebSocketOptions { private bool _isReadOnly; // After ConnectAsync is called the options cannot be modified. - private TimeSpan _keepAliveInterval = WebSocket.DefaultKeepAliveInterval; + private TimeSpan _keepAliveInterval = WebSocketDefaults.DefaultClientKeepAliveInterval; + private TimeSpan _keepAliveTimeout = WebSocketDefaults.DefaultKeepAliveTimeout; private bool _useDefaultCredentials; private ICredentials? _credentials; private IWebProxy? _proxy; @@ -171,6 +172,12 @@ public void AddSubProtocol(string subProtocol) subprotocols.Add(subProtocol); } + /// + /// The keep-alive interval to use, or or to disable keep-alives. + /// If is set, then PING messages are sent and peer's PONG responses are expected, otherwise, + /// unsolicited PONG messages are used as a keep-alive heartbeat. + /// The default is (typically 30 seconds). + /// [UnsupportedOSPlatform("browser")] public TimeSpan KeepAliveInterval { @@ -188,6 +195,28 @@ public TimeSpan KeepAliveInterval } } + /// + /// The timeout to use when waiting for the peer's PONG in response to us sending a PING; or or + /// to disable waiting for peer's response, and use an unsolicited PONG as a Keep-Alive heartbeat instead. + /// The default is . + /// + [UnsupportedOSPlatform("browser")] + public TimeSpan KeepAliveTimeout + { + get => _keepAliveTimeout; + set + { + ThrowIfReadOnly(); + if (value != Timeout.InfiniteTimeSpan && value < TimeSpan.Zero) + { + throw new ArgumentOutOfRangeException(nameof(value), value, + SR.Format(SR.net_WebSockets_ArgumentOutOfRange_TooSmall, + Timeout.InfiniteTimeSpan.ToString())); + } + _keepAliveTimeout = value; + } + } + /// /// Gets or sets the options for the per-message-deflate extension. /// When present, the options are sent to the server during the handshake phase. If the server diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs index 3301bfead64c7..96e091663ff79 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs @@ -198,6 +198,7 @@ public async Task ConnectAsync(Uri uri, HttpMessageInvoker? invoker, Cancellatio IsServer = false, SubProtocol = subprotocol, KeepAliveInterval = options.KeepAliveInterval, + KeepAliveTimeout = options.KeepAliveTimeout, DangerousDeflateOptions = negotiatedDeflateOptions }); _negotiatedDeflateOptions = negotiatedDeflateOptions; diff --git a/src/libraries/System.Net.WebSockets.Client/tests/AbortTest.Loopback.cs b/src/libraries/System.Net.WebSockets.Client/tests/AbortTest.Loopback.cs index 0aa83697a9de7..8d0a89b320d61 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/AbortTest.Loopback.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/AbortTest.Loopback.cs @@ -2,6 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Collections.Generic; +using System.IO; +using System.Net.Sockets; +using System.Net.Test.Common; using System.Threading; using System.Threading.Tasks; using Xunit; @@ -17,6 +20,8 @@ public AbortTest_Loopback(ITestOutputHelper output) : base(output) { } protected virtual Version HttpVersion => Net.HttpVersion.Version11; + public static object[][] AbortClient_MemberData = ToMemberData(Enum.GetValues(), UseSsl_Values, /* verifySendReceive */ Bool_Values); + [Theory] [MemberData(nameof(AbortClient_MemberData))] public Task AbortClient_ServerGetsCorrectException(AbortType abortType, bool useSsl, bool verifySendReceive) @@ -64,6 +69,8 @@ public Task AbortClient_ServerGetsCorrectException(AbortType abortType, bool use timeoutCts.Token); } + public static object[][] ServerPrematureEos_MemberData = ToMemberData(Enum.GetValues(), UseSsl_Values); + [Theory] [MemberData(nameof(ServerPrematureEos_MemberData))] public Task ServerPrematureEos_ClientGetsCorrectException(ServerEosType serverEosType, bool useSsl) @@ -146,34 +153,6 @@ await SendServerResponseAndEosAsync( protected virtual Task SendServerResponseAndEosAsync(WebSocketRequestData requestData, ServerEosType serverEosType, Func serverFunc, CancellationToken cancellationToken) => WebSocketHandshakeHelper.SendHttp11ServerResponseAndEosAsync(requestData, serverFunc, cancellationToken); // override for HTTP/2 - private static readonly bool[] Bool_Values = new[] { false, true }; - private static readonly bool[] UseSsl_Values = PlatformDetection.SupportsAlpn ? Bool_Values : new[] { false }; - - public static IEnumerable AbortClient_MemberData() - { - foreach (var abortType in Enum.GetValues()) - { - foreach (var useSsl in UseSsl_Values) - { - foreach (var verifySendReceive in Bool_Values) - { - yield return new object[] { abortType, useSsl, verifySendReceive }; - } - } - } - } - - public static IEnumerable ServerPrematureEos_MemberData() - { - foreach (var serverEosType in Enum.GetValues()) - { - foreach (var useSsl in UseSsl_Values) - { - yield return new object[] { serverEosType, useSsl }; - } - } - } - public enum AbortType { Abort, @@ -187,7 +166,7 @@ public enum ServerEosType AfterSomeData } - private static async Task VerifySendReceiveAsync(WebSocket ws, byte[] localMsg, byte[] remoteMsg, + protected static async Task VerifySendReceiveAsync(WebSocket ws, byte[] localMsg, byte[] remoteMsg, TaskCompletionSource localAckTcs, Task remoteAck, CancellationToken cancellationToken) { var sendTask = ws.SendAsync(localMsg, WebSocketMessageType.Binary, endOfMessage: true, cancellationToken); diff --git a/src/libraries/System.Net.WebSockets.Client/tests/ClientWebSocketOptionsTests.cs b/src/libraries/System.Net.WebSockets.Client/tests/ClientWebSocketOptionsTests.cs index b2137a7faa7a2..7a39f2423cad8 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/ClientWebSocketOptionsTests.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/ClientWebSocketOptionsTests.cs @@ -142,6 +142,25 @@ public static void KeepAliveInterval_Roundtrips() AssertExtensions.Throws("value", () => cws.Options.KeepAliveInterval = TimeSpan.MinValue); } + [ConditionalFact(nameof(WebSocketsSupported))] + [SkipOnPlatform(TestPlatforms.Browser, "KeepAlive not supported on browser")] + public static void KeepAliveTimeout_Roundtrips() + { + var cws = new ClientWebSocket(); + Assert.True(cws.Options.KeepAliveTimeout == Timeout.InfiniteTimeSpan); + + cws.Options.KeepAliveTimeout = TimeSpan.Zero; + Assert.Equal(TimeSpan.Zero, cws.Options.KeepAliveTimeout); + + cws.Options.KeepAliveTimeout = TimeSpan.MaxValue; + Assert.Equal(TimeSpan.MaxValue, cws.Options.KeepAliveTimeout); + + cws.Options.KeepAliveTimeout = Timeout.InfiniteTimeSpan; + Assert.Equal(Timeout.InfiniteTimeSpan, cws.Options.KeepAliveTimeout); + + AssertExtensions.Throws("value", () => cws.Options.KeepAliveTimeout = TimeSpan.MinValue); + } + [ConditionalFact(nameof(WebSocketsSupported))] [SkipOnPlatform(TestPlatforms.Browser, "Certificates not supported on browser")] public void RemoteCertificateValidationCallback_Roundtrips() diff --git a/src/libraries/System.Net.WebSockets.Client/tests/ClientWebSocketTestBase.cs b/src/libraries/System.Net.WebSockets.Client/tests/ClientWebSocketTestBase.cs index 0dc1775b57345..4e4fb4b3d87c7 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/ClientWebSocketTestBase.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/ClientWebSocketTestBase.cs @@ -2,14 +2,16 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Collections.Generic; +using System.Diagnostics; using System.Linq; +using System.Net.Http; +using System.Reflection; using System.Threading; using System.Threading.Tasks; +using TestUtilities; using Xunit; using Xunit.Abstractions; -using System.Net.Http; -using System.Diagnostics; namespace System.Net.WebSockets.Client.Tests { @@ -23,6 +25,19 @@ public class ClientWebSocketTestBase new object[] { o[0], true } }).ToArray(); + public static readonly bool[] Bool_Values = new[] { false, true }; + public static readonly bool[] UseSsl_Values = PlatformDetection.SupportsAlpn ? Bool_Values : new[] { false }; + public static readonly object[][] UseSsl_MemberData = ToMemberData(UseSsl_Values); + + public static object[][] ToMemberData(IEnumerable data) + => data.Select(a => new object[] { a }).ToArray(); + + public static object[][] ToMemberData(IEnumerable dataA, IEnumerable dataB) + => dataA.SelectMany(a => dataB.Select(b => new object[] { a, b })).ToArray(); + + public static object[][] ToMemberData(IEnumerable dataA, IEnumerable dataB, IEnumerable dataC) + => dataA.SelectMany(a => dataB.SelectMany(b => dataC.Select(c => new object[] { a, b, c }))).ToArray(); + public const int TimeOutMilliseconds = 30000; public const int CloseDescriptionMaxLength = 123; public readonly ITestOutputHelper _output; diff --git a/src/libraries/System.Net.WebSockets.Client/tests/CloseTest.cs b/src/libraries/System.Net.WebSockets.Client/tests/CloseTest.cs index fb73485fc7fe1..14affae6bd39e 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/CloseTest.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/CloseTest.cs @@ -264,7 +264,7 @@ public async Task CloseOutputAsync_ClientInitiated_CanReceive_CanClose(Uri serve [ActiveIssue("https://github.com/dotnet/runtime/issues/28957", typeof(PlatformDetection), nameof(PlatformDetection.IsNotBrowser))] [OuterLoop("Uses external servers", typeof(PlatformDetection), nameof(PlatformDetection.LocalEchoServerIsNotAvailable))] - [ConditionalTheory(nameof(WebSocketsSupported)), MemberData(nameof(EchoServersWithSwitch))] + [ConditionalTheory(nameof(WebSocketsSupported)), MemberData(nameof(EchoServersAndBoolean))] public async Task CloseOutputAsync_ServerInitiated_CanReceive(Uri server, bool delayReceiving) { var expectedCloseStatus = WebSocketCloseStatus.NormalClosure; @@ -367,15 +367,8 @@ await cws.SendAsync( } } - public static IEnumerable EchoServersWithSwitch => - EchoServers.SelectMany(server => new List - { - new object[] { server[0], true }, - new object[] { server[0], false } - }); - [ActiveIssue("https://github.com/dotnet/runtime/issues/28957", typeof(PlatformDetection), nameof(PlatformDetection.IsNotBrowser))] - [ConditionalTheory(nameof(WebSocketsSupported)), MemberData(nameof(EchoServersWithSwitch))] + [ConditionalTheory(nameof(WebSocketsSupported)), MemberData(nameof(EchoServersAndBoolean))] public async Task CloseOutputAsync_ServerInitiated_CanReceiveAfterClose(Uri server, bool syncState) { using (ClientWebSocket cws = await GetConnectedWebSocket(server, TimeOutMilliseconds, _output)) @@ -495,11 +488,11 @@ await LoopbackServer.CreateClientAndServerAsync(async uri => try { using (var cws = new ClientWebSocket()) - using (var cts = new CancellationTokenSource(TimeOutMilliseconds)) + using (var testTimeoutCts = new CancellationTokenSource(TimeOutMilliseconds)) { - await ConnectAsync(cws, uri, cts.Token); + await ConnectAsync(cws, uri, testTimeoutCts.Token); - Task receiveTask = cws.ReceiveAsync(new byte[1], CancellationToken.None); + Task receiveTask = cws.ReceiveAsync(new byte[1], testTimeoutCts.Token); var cancelCloseCts = new CancellationTokenSource(); await Assert.ThrowsAnyAsync(async () => @@ -509,7 +502,12 @@ await Assert.ThrowsAnyAsync(async () => await t; }); + Assert.True(cancelCloseCts.Token.IsCancellationRequested); + Assert.False(testTimeoutCts.Token.IsCancellationRequested); + await Assert.ThrowsAnyAsync(() => receiveTask); + + Assert.False(testTimeoutCts.Token.IsCancellationRequested); } } finally diff --git a/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.Loopback.cs b/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.Loopback.cs new file mode 100644 index 0000000000000..08306c0804ee4 --- /dev/null +++ b/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.Loopback.cs @@ -0,0 +1,144 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Net.Test.Common; +using System.Threading; +using System.Threading.Tasks; +using System.Threading.Channels; + +using Xunit; +using Xunit.Abstractions; + +namespace System.Net.WebSockets.Client.Tests +{ + [SkipOnPlatform(TestPlatforms.Browser, "KeepAlive not supported on browser")] + public abstract class KeepAliveTest_Loopback : ClientWebSocketTestBase + { + public KeepAliveTest_Loopback(ITestOutputHelper output) : base(output) { } + + protected virtual Version HttpVersion => Net.HttpVersion.Version11; + + [OuterLoop("Uses Task.Delay")] + [Theory] + [MemberData(nameof(UseSsl_MemberData))] + public Task KeepAlive_LongDelayBetweenSendReceives_Succeeds(bool useSsl) + { + var clientMsg = new byte[] { 1, 2, 3, 4, 5, 6 }; + var serverMsg = new byte[] { 42 }; + var clientAckTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var serverAckTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var longDelayByServerTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + TimeSpan LongDelay = TimeSpan.FromSeconds(10); + + var timeoutCts = new CancellationTokenSource(TimeOutMilliseconds); + + var options = new LoopbackWebSocketServer.Options(HttpVersion, useSsl, GetInvoker()) + { + DisposeServerWebSocket = true, + DisposeClientWebSocket = true, + ConfigureClientOptions = clientOptions => + { + clientOptions.KeepAliveInterval = TimeSpan.FromMilliseconds(100); + clientOptions.KeepAliveTimeout = TimeSpan.FromSeconds(1); + } + }; + + return LoopbackWebSocketServer.RunAsync( + async (cws, token) => + { + ReadAheadWebSocket clientWebSocket = new(cws); + + await VerifySendReceiveAsync(clientWebSocket, clientMsg, serverMsg, clientAckTcs, serverAckTcs.Task, token).ConfigureAwait(false); + + await longDelayByServerTcs.Task.WaitAsync(token).ConfigureAwait(false); + Assert.Equal(WebSocketState.Open, clientWebSocket.State); + + await VerifySendReceiveAsync(clientWebSocket, clientMsg, serverMsg, clientAckTcs, serverAckTcs.Task, token).ConfigureAwait(false); + Assert.Equal(WebSocketState.Open, clientWebSocket.State); + + await clientWebSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "", token).ConfigureAwait(false); + Assert.Equal(WebSocketState.Closed, clientWebSocket.State); + }, + async (sws, token) => + { + ReadAheadWebSocket serverWebSocket = new(sws); + + await VerifySendReceiveAsync(serverWebSocket, serverMsg, clientMsg, serverAckTcs, clientAckTcs.Task, token).ConfigureAwait(false); + Assert.Equal(WebSocketState.Open, serverWebSocket.State); + + await Task.Delay(LongDelay); + Assert.Equal(WebSocketState.Open, serverWebSocket.State); + + // recreate already-completed TCS for another round + clientAckTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + serverAckTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + longDelayByServerTcs.SetResult(); + + await VerifySendReceiveAsync(serverWebSocket, serverMsg, clientMsg, serverAckTcs, clientAckTcs.Task, token).ConfigureAwait(false); + + var closeFrame = await serverWebSocket.ReceiveAsync(Array.Empty(), token).ConfigureAwait(false); + Assert.Equal(WebSocketMessageType.Close, closeFrame.MessageType); + Assert.Equal(WebSocketState.CloseReceived, serverWebSocket.State); + + await serverWebSocket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", token).ConfigureAwait(false); + Assert.Equal(WebSocketState.Closed, serverWebSocket.State); + }, + options, + timeoutCts.Token); + } + + private static async Task VerifySendReceiveAsync(WebSocket ws, byte[] localMsg, byte[] remoteMsg, + TaskCompletionSource localAckTcs, Task remoteAck, CancellationToken cancellationToken) + { + var sendTask = ws.SendAsync(localMsg, WebSocketMessageType.Binary, endOfMessage: true, cancellationToken); + + var recvBuf = new byte[remoteMsg.Length * 2]; + var recvResult = await ws.ReceiveAsync(recvBuf, cancellationToken).ConfigureAwait(false); + + Assert.Equal(WebSocketMessageType.Binary, recvResult.MessageType); + Assert.Equal(remoteMsg.Length, recvResult.Count); + Assert.True(recvResult.EndOfMessage); + Assert.Equal(remoteMsg, recvBuf[..recvResult.Count]); + + localAckTcs.SetResult(); + + await sendTask.ConfigureAwait(false); + await remoteAck.WaitAsync(cancellationToken).ConfigureAwait(false); + } + } + + // --- HTTP/1.1 WebSocket loopback tests --- + + public class KeepAliveTest_Invoker_Loopback : KeepAliveTest_Loopback + { + public KeepAliveTest_Invoker_Loopback(ITestOutputHelper output) : base(output) { } + protected override bool UseCustomInvoker => true; + } + + public class KeepAliveTest_HttpClient_Loopback : KeepAliveTest_Loopback + { + public KeepAliveTest_HttpClient_Loopback(ITestOutputHelper output) : base(output) { } + protected override bool UseHttpClient => true; + } + + public class KeepAliveTest_SharedHandler_Loopback : KeepAliveTest_Loopback + { + public KeepAliveTest_SharedHandler_Loopback(ITestOutputHelper output) : base(output) { } + } + + // --- HTTP/2 WebSocket loopback tests --- + + public class KeepAliveTest_Invoker_Http2 : KeepAliveTest_Invoker_Loopback + { + public KeepAliveTest_Invoker_Http2(ITestOutputHelper output) : base(output) { } + protected override Version HttpVersion => Net.HttpVersion.Version20; + } + + public class KeepAliveTest_HttpClient_Http2 : KeepAliveTest_HttpClient_Loopback + { + public KeepAliveTest_HttpClient_Http2(ITestOutputHelper output) : base(output) { } + protected override Version HttpVersion => Net.HttpVersion.Version20; + } +} diff --git a/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.cs b/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.cs index 56acccdc05590..5ff9c51e56a6a 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/KeepAliveTest.cs @@ -9,6 +9,8 @@ using Xunit; using Xunit.Abstractions; +using static System.Net.Test.Common.Configuration.WebSockets; + namespace System.Net.WebSockets.Client.Tests { [SkipOnPlatform(TestPlatforms.Browser, "KeepAlive not supported on browser")] @@ -17,10 +19,10 @@ public class KeepAliveTest : ClientWebSocketTestBase public KeepAliveTest(ITestOutputHelper output) : base(output) { } [ConditionalFact(nameof(WebSocketsSupported))] - [OuterLoop] // involves long delay + [OuterLoop("Uses Task.Delay")] public async Task KeepAlive_LongDelayBetweenSendReceives_Succeeds() { - using (ClientWebSocket cws = await WebSocketHelper.GetConnectedWebSocket(System.Net.Test.Common.Configuration.WebSockets.RemoteEchoServer, TimeOutMilliseconds, _output, TimeSpan.FromSeconds(1))) + using (ClientWebSocket cws = await WebSocketHelper.GetConnectedWebSocket(RemoteEchoServer, TimeOutMilliseconds, _output, TimeSpan.FromSeconds(1))) { await cws.SendAsync(new ArraySegment(new byte[1] { 42 }), WebSocketMessageType.Binary, true, CancellationToken.None); @@ -33,5 +35,35 @@ public async Task KeepAlive_LongDelayBetweenSendReceives_Succeeds() await cws.CloseAsync(WebSocketCloseStatus.NormalClosure, "KeepAlive_LongDelayBetweenSendReceives_Succeeds", CancellationToken.None); } } + + [ConditionalTheory(nameof(WebSocketsSupported))] + [OuterLoop("Uses Task.Delay")] + [InlineData(1, 0)] // unsolicited pong + [InlineData(1, 2)] // ping/pong + public async Task KeepAlive_LongDelayBetweenReceiveSends_Succeeds(int keepAliveIntervalSec, int keepAliveTimeoutSec) + { + using (ClientWebSocket cws = await WebSocketHelper.GetConnectedWebSocket( + RemoteEchoServer, + TimeOutMilliseconds, + _output, + options => + { + options.KeepAliveInterval = TimeSpan.FromSeconds(keepAliveIntervalSec); + options.KeepAliveTimeout = TimeSpan.FromSeconds(keepAliveTimeoutSec); + })) + { + byte[] receiveBuffer = new byte[1]; + var receiveTask = cws.ReceiveAsync(new ArraySegment(receiveBuffer), CancellationToken.None); // this will wait until we trigger the echo server by sending a message + + await Task.Delay(TimeSpan.FromSeconds(10)); + + await cws.SendAsync(new ArraySegment(new byte[1] { 42 }), WebSocketMessageType.Binary, true, CancellationToken.None); + + Assert.Equal(1, (await receiveTask).Count); + Assert.Equal(42, receiveBuffer[0]); + + await cws.CloseAsync(WebSocketCloseStatus.NormalClosure, "KeepAlive_LongDelayBetweenSendReceives_Succeeds", CancellationToken.None); + } + } } } diff --git a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/Http2LoopbackStream.cs b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/Http2LoopbackStream.cs index 1b3b51840ec99..b841eead6ea24 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/Http2LoopbackStream.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/Http2LoopbackStream.cs @@ -15,6 +15,7 @@ public class Http2LoopbackStream : Stream private readonly int _streamId; private bool _readEnded; private ReadOnlyMemory _leftoverReadData; + private bool _sendResetOnDispose; public override bool CanRead => true; public override bool CanSeek => false; @@ -23,10 +24,11 @@ public class Http2LoopbackStream : Stream public Http2LoopbackConnection Connection => _connection; public int StreamId => _streamId; - public Http2LoopbackStream(Http2LoopbackConnection connection, int streamId) + public Http2LoopbackStream(Http2LoopbackConnection connection, int streamId, bool sendResetOnDispose = true) { _connection = connection; _streamId = streamId; + _sendResetOnDispose = sendResetOnDispose; } public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) @@ -71,7 +73,7 @@ public override async ValueTask DisposeAsync() { await _connection.SendResponseDataAsync(_streamId, Memory.Empty, endStream: true).ConfigureAwait(false); - if (!_readEnded) + if (!_readEnded && _sendResetOnDispose) { var rstFrame = new RstStreamFrame(FrameFlags.None, (int)ProtocolErrors.NO_ERROR, _streamId); await _connection.WriteFrameAsync(rstFrame).ConfigureAwait(false); diff --git a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/LoopbackWebSocketServer.cs b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/LoopbackWebSocketServer.cs index b24e2e20d40df..ec53020184802 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/LoopbackWebSocketServer.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/LoopbackWebSocketServer.cs @@ -73,7 +73,7 @@ await server.AcceptConnectionAsync(async connection => await loopbackServerFunc(requestData, cancellationToken).ConfigureAwait(false); - await http2Connection.DisposeAsync().ConfigureAwait(false); + await http2Connection.ShutdownIgnoringErrorsAsync(http2StreamId).ConfigureAwait(false); }, new Http2Options { WebSocketEndpoint = true, UseSsl = options.UseSsl }); } @@ -132,6 +132,8 @@ public static async Task GetConnectedClientAsync(Uri uri, Optio clientWebSocket.Options.RemoteCertificateValidationCallback = delegate { return true; }; } + options.ConfigureClientOptions?.Invoke(clientWebSocket.Options); + await clientWebSocket.ConnectAsync(uri, options.HttpInvoker, cancellationToken).ConfigureAwait(false); return clientWebSocket; @@ -143,6 +145,7 @@ public record class Options(Version HttpVersion, bool UseSsl, HttpMessageInvoker public bool DisposeClientWebSocket { get; set; } public bool DisposeHttpInvoker { get; set; } public bool ManualServerHandshakeResponse { get; set; } + public Action? ConfigureClientOptions { get; set; } } } } diff --git a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/ReadAheadWebSocket.cs b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/ReadAheadWebSocket.cs new file mode 100644 index 0000000000000..af98d76580cf2 --- /dev/null +++ b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/ReadAheadWebSocket.cs @@ -0,0 +1,122 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using System.Collections.Generic; +using System.Net.Test.Common; +using System.Threading; +using System.Threading.Tasks; +using System.Threading.Channels; + +using Xunit; +using Xunit.Abstractions; + +namespace System.Net.WebSockets.Client.Tests; + +internal class ReadAheadWebSocket : WebSocket +{ + private const int ReadAheadBufferSize = 64 * 1024 * 1024; + + private record struct DataFrame(ValueWebSocketReceiveResult Metadata, Memory Memory, byte[] _rented); + + private Channel _incomingFrames = Channel.CreateUnbounded( + new UnboundedChannelOptions { SingleReader = true, SingleWriter = true }); + private DataFrame? _currentFrame; + + private SemaphoreSlim receiveMutex = new SemaphoreSlim(1, 1); + private readonly WebSocket _innerWebSocket; + + public ReadAheadWebSocket(WebSocket innerWebSocket) + { + _innerWebSocket = innerWebSocket; + _ = ProcessIncomingFrames(); + } + + private async Task ProcessIncomingFrames() + { + var buffer = new byte[ReadAheadBufferSize]; + while (true) + { + try + { + ValueWebSocketReceiveResult result = await _innerWebSocket.ReceiveAsync((Memory)buffer, default).ConfigureAwait(false); + + byte[] rented = result.Count > 0 ? ArrayPool.Shared.Rent(result.Count) : Array.Empty(); + Memory message = rented.AsMemory(0, result.Count); + buffer.AsMemory(0, result.Count).CopyTo(message); + + await _incomingFrames.Writer.WriteAsync(new DataFrame(result, message, rented), default).ConfigureAwait(false); + + if (result.MessageType == WebSocketMessageType.Close) + { + _incomingFrames.Writer.Complete(); + break; + } + } + catch (Exception e) + { + _incomingFrames.Writer.Complete(e); + break; + } + } + } + + public override async ValueTask ReceiveAsync(Memory buffer, CancellationToken cancellationToken) + { + await receiveMutex.WaitAsync(cancellationToken).ConfigureAwait(false); + + try + { + _currentFrame ??= await _incomingFrames.Reader.ReadAsync(cancellationToken).ConfigureAwait(false); + + var (result, message, rented) = _currentFrame.Value; + + if (buffer.Length < result.Count) + { + message.Slice(0, buffer.Length).CopyTo(buffer); + var remaining = message.Slice(buffer.Length); + _currentFrame = _currentFrame.Value with { Metadata = new (remaining.Length, result.MessageType, result.EndOfMessage), Memory = remaining }; + + return new (buffer.Length, result.MessageType, endOfMessage: false); + } + else + { + message.CopyTo(buffer); + if (rented.Length > 0) + { + ArrayPool.Shared.Return(rented); + } + _currentFrame = null; + return result; + } + } + finally + { + receiveMutex.Release(); + } + } + + public override async Task ReceiveAsync(ArraySegment buffer, CancellationToken cancellationToken) + { + ValueWebSocketReceiveResult valueResult = await ReceiveAsync((Memory)buffer, cancellationToken).ConfigureAwait(false); + var result = new WebSocketReceiveResult( + valueResult.Count, + valueResult.MessageType, + valueResult.EndOfMessage, + valueResult.MessageType == WebSocketMessageType.Close ? CloseStatus : null, + valueResult.MessageType == WebSocketMessageType.Close ? CloseStatusDescription : null); + return result; + } + + public override WebSocketCloseStatus? CloseStatus => _innerWebSocket.CloseStatus; + public override string? CloseStatusDescription => _innerWebSocket.CloseStatusDescription; + public override string? SubProtocol => _innerWebSocket.SubProtocol; + public override WebSocketState State => _innerWebSocket.State; + public override void Abort() => _innerWebSocket.Abort(); + public override void Dispose() => _innerWebSocket.Dispose(); + public override Task CloseAsync(WebSocketCloseStatus closeStatus, string? statusDescription, CancellationToken cancellationToken) => _innerWebSocket.CloseAsync(closeStatus, statusDescription, cancellationToken); + public override Task CloseOutputAsync(WebSocketCloseStatus closeStatus, string? statusDescription, CancellationToken cancellationToken) => _innerWebSocket.CloseOutputAsync(closeStatus, statusDescription, cancellationToken); + public override Task SendAsync(ArraySegment buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) => _innerWebSocket.SendAsync(buffer, messageType, endOfMessage, cancellationToken); + public override ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) => _innerWebSocket.SendAsync(buffer, messageType, endOfMessage, cancellationToken); + public override ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, WebSocketMessageFlags messageFlags, CancellationToken cancellationToken) => _innerWebSocket.SendAsync(buffer, messageType, messageFlags, cancellationToken); +} diff --git a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/WebSocketHandshakeHelper.cs b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/WebSocketHandshakeHelper.cs index 2a8c84e7de8ea..06e62d4a17e48 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/WebSocketHandshakeHelper.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/WebSocketHandshakeHelper.cs @@ -83,7 +83,7 @@ public static async Task ProcessHttp2RequestAsync(Http2Loo await SendHttp2ServerResponseAsync(connection, streamId, cancellationToken: cancellationToken).ConfigureAwait(false); } - data.WebSocketStream = new Http2LoopbackStream(connection, streamId); + data.WebSocketStream = new Http2LoopbackStream(connection, streamId, sendResetOnDispose: false); return data; } diff --git a/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj b/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj index 0c07922eb10ec..af3355cd701c6 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj +++ b/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj @@ -58,9 +58,11 @@ + + diff --git a/src/libraries/System.Net.WebSockets.Client/tests/WebSocketHelper.cs b/src/libraries/System.Net.WebSockets.Client/tests/WebSocketHelper.cs index d409007f9995d..df29e843590e9 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/WebSocketHelper.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/WebSocketHelper.cs @@ -69,19 +69,35 @@ public static Task GetConnectedWebSocket( ITestOutputHelper output, TimeSpan keepAliveInterval = default, IWebProxy proxy = null, + HttpMessageInvoker? invoker = null) => + GetConnectedWebSocket( + server, + timeOutMilliseconds, + output, + options => + { + if (proxy != null) + { + options.Proxy = proxy; + } + if (keepAliveInterval.TotalSeconds > 0) + { + options.KeepAliveInterval = keepAliveInterval; + } + }, + invoker + ); + + public static Task GetConnectedWebSocket( + Uri server, + int timeOutMilliseconds, + ITestOutputHelper output, + Action configureOptions, HttpMessageInvoker? invoker = null) => Retry(output, async () => { var cws = new ClientWebSocket(); - if (proxy != null) - { - cws.Options.Proxy = proxy; - } - - if (keepAliveInterval.TotalSeconds > 0) - { - cws.Options.KeepAliveInterval = keepAliveInterval; - } + configureOptions(cws.Options); using (var cts = new CancellationTokenSource(timeOutMilliseconds)) { diff --git a/src/libraries/System.Net.WebSockets/ref/System.Net.WebSockets.cs b/src/libraries/System.Net.WebSockets/ref/System.Net.WebSockets.cs index e3d230708b3f2..ae5337ec05385 100644 --- a/src/libraries/System.Net.WebSockets/ref/System.Net.WebSockets.cs +++ b/src/libraries/System.Net.WebSockets/ref/System.Net.WebSockets.cs @@ -141,6 +141,7 @@ public sealed partial class WebSocketCreationOptions public bool IsServer { get { throw null; } set { } } public string? SubProtocol { get { throw null; } set { } } public System.TimeSpan KeepAliveInterval { get { throw null; } set { } } + public System.TimeSpan KeepAliveTimeout { get { throw null; } set { } } public System.Net.WebSockets.WebSocketDeflateOptions? DangerousDeflateOptions { get { throw null; } set { } } } public sealed partial class WebSocketDeflateOptions diff --git a/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx b/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx index 8e01fce49ad88..a57e81b239a92 100644 --- a/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx +++ b/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx @@ -138,6 +138,9 @@ The argument must be a value between {0} and {1}. + + The WebSocket didn't recieve a Pong frame in response to a Ping frame within the configured KeepAliveTimeout. + The WebSocket received a continuation frame with Per-Message Compressed flag set. diff --git a/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj b/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj index 98ace5cfbf038..177e95dacee0a 100644 --- a/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj +++ b/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj @@ -17,6 +17,8 @@ + + @@ -29,6 +31,8 @@ + + + @@ -57,6 +65,7 @@ + diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/AsyncMutex.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/AsyncMutex.cs index 4191466dd4efa..abf7e5e56a276 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/AsyncMutex.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/AsyncMutex.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Diagnostics; +using System.Net; using System.Threading.Tasks; namespace System.Threading @@ -65,6 +66,8 @@ public Task EnterAsync(CancellationToken cancellationToken) Task Contended(CancellationToken cancellationToken) { + if (NetEventSource.Log.IsEnabled()) NetEventSource.MutexContended(this, _gate); + var w = new Waiter(this); // We need to register for cancellation before storing the waiter into the list. @@ -185,6 +188,8 @@ public void Exit() void Contended() { + if (NetEventSource.Log.IsEnabled()) NetEventSource.MutexContended(this, _gate); + Waiter? w; lock (SyncObj) { diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.KeepAlive.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.KeepAlive.cs new file mode 100644 index 0000000000000..c9ff393cb7180 --- /dev/null +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.KeepAlive.cs @@ -0,0 +1,306 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using System.Buffers.Binary; +using System.Diagnostics; +using System.Runtime.ExceptionServices; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Net.WebSockets +{ + internal sealed partial class ManagedWebSocket : WebSocket + { + private bool IsUnsolicitedPongKeepAlive => _keepAlivePingState is null; + private static bool IsValidSendState(WebSocketState state) => Array.IndexOf(s_validSendStates, state) != -1; + private static bool IsValidReceiveState(WebSocketState state) => Array.IndexOf(s_validReceiveStates, state) != -1; + + private void HeartBeat() + { + if (IsUnsolicitedPongKeepAlive) + { + UnsolicitedPongHeartBeat(); + } + else + { + KeepAlivePingHeartBeat(); + } + } + + private void UnsolicitedPongHeartBeat() + { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this); + + Observe( + TrySendKeepAliveFrameAsync(MessageOpcode.Pong)); + } + + private ValueTask TrySendKeepAliveFrameAsync(MessageOpcode opcode, ReadOnlyMemory? payload = null) + { + Debug.Assert(opcode is MessageOpcode.Pong || !IsUnsolicitedPongKeepAlive && opcode is MessageOpcode.Ping); + + if (!IsValidSendState(_state)) + { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"Cannot send keep-alive frame in {nameof(_state)}={_state}"); + + // we can't send any frames, but no need to throw as we are not observing errors anyway + return ValueTask.CompletedTask; + } + + payload ??= ReadOnlyMemory.Empty; + + return SendFrameAsync(opcode, endOfMessage: true, disableCompression: true, payload.Value, CancellationToken.None); + } + + private void KeepAlivePingHeartBeat() + { + Debug.Assert(_keepAlivePingState != null); + + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this); + + bool shouldSendPing = false; + long pingPayload = -1; + + try + { + lock (StateUpdateLock) + { + if (_keepAlivePingState.Exception is not null) + { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"KeepAlive already faulted, skipping... (exception: {_keepAlivePingState.Exception.Message})"); + return; + } + + long now = Environment.TickCount64; + + if (_keepAlivePingState.PingSent) + { + if (Environment.TickCount64 > _keepAlivePingState.PingTimeoutTimestamp) + { + if (NetEventSource.Log.IsEnabled()) + { + NetEventSource.Trace(this, $"Keep-alive ping timed out after {_keepAlivePingState.TimeoutMs}ms. Expected pong with payload {_keepAlivePingState.PingPayload}"); + } + + Exception exc = ExceptionDispatchInfo.SetCurrentStackTrace( + new WebSocketException(WebSocketError.Faulted, SR.net_Websockets_KeepAlivePingTimeout)); + + _keepAlivePingState.OnKeepAliveFaultedCore(exc); // we are holding the lock + return; + } + } + else + { + if (Environment.TickCount64 > _keepAlivePingState.NextPingRequestTimestamp) + { + _keepAlivePingState.OnNextPingRequestCore(); // we are holding the lock + shouldSendPing = true; + pingPayload = _keepAlivePingState.PingPayload; + } + } + } + + if (shouldSendPing) + { + Observe( + SendPingAsync(pingPayload)); + } + } + catch (Exception e) + { + if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceException(this, e); + + _keepAlivePingState.OnKeepAliveFaulted(e); + } + } + + private async ValueTask SendPingAsync(long pingPayload) + { + Debug.Assert(_keepAlivePingState != null); + + byte[] pingPayloadBuffer = ArrayPool.Shared.Rent(sizeof(long)); + BinaryPrimitives.WriteInt64BigEndian(pingPayloadBuffer, pingPayload); + try + { + await TrySendKeepAliveFrameAsync(MessageOpcode.Ping, pingPayloadBuffer.AsMemory(0, sizeof(long))).ConfigureAwait(false); + + if (NetEventSource.Log.IsEnabled()) NetEventSource.KeepAlivePingSent(this, pingPayload); + } + finally + { + ArrayPool.Shared.Return(pingPayloadBuffer); + } + } + + // "Observe" either a ValueTask result, or any exception, ignoring it + // to prevent the unobserved exception event from being raised. + private void Observe(ValueTask t) + { + if (t.IsCompletedSuccessfully) + { + t.GetAwaiter().GetResult(); + } + else + { + Observe(t.AsTask()); + } + } + + // "Observe" any exception, ignoring it to prevent the unobserved task + // exception event from being raised. + private void Observe(Task t) + { + if (t.IsCompleted) + { + if (t.IsFaulted) + { + LogFaulted(t, this); + } + } + else + { + t.ContinueWith( + LogFaulted, + this, + CancellationToken.None, + TaskContinuationOptions.OnlyOnFaulted | TaskContinuationOptions.ExecuteSynchronously, + TaskScheduler.Default); + } + + static void LogFaulted(Task task, object? thisObj) + { + Debug.Assert(task.IsFaulted); + + // accessing exception to observe it regardless of whether the tracing is enabled + Exception e = task.Exception!.InnerException!; + + if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceException(thisObj, e); + } + } + + private sealed class KeepAlivePingState + { + internal const int PingPayloadSize = sizeof(long); + private const int MinIntervalMs = 1; + + private readonly ManagedWebSocket _parent; + private object StateUpdateLock => _parent.StateUpdateLock; + + internal int DelayMs { get; } + internal int TimeoutMs { get; } + internal int HeartBeatIntervalMs => Math.Max(Math.Min(DelayMs, TimeoutMs) / 4, MinIntervalMs); + + internal long PingPayload { get; private set; } + internal bool PingSent { get; private set; } + internal long PingTimeoutTimestamp { get; private set; } + internal long NextPingRequestTimestamp { get; private set; } + internal Exception? Exception { get; private set; } + + public KeepAlivePingState(TimeSpan keepAliveInterval, TimeSpan keepAliveTimeout, ManagedWebSocket parent) + { + DelayMs = TimeSpanToMs(keepAliveInterval); + TimeoutMs = TimeSpanToMs(keepAliveTimeout); + NextPingRequestTimestamp = Environment.TickCount64 + DelayMs; + PingTimeoutTimestamp = Timeout.Infinite; + _parent = parent; + + static int TimeSpanToMs(TimeSpan value) => (int)Math.Clamp((long)value.TotalMilliseconds, MinIntervalMs, int.MaxValue); + } + + internal void OnDataReceived() + { + lock (StateUpdateLock) + { + NextPingRequestTimestamp = Environment.TickCount64 + DelayMs; + } + } + + internal void OnPongResponseReceived(long pongPayload) + { + lock (StateUpdateLock) + { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"pongPayload={pongPayload}"); + + if (!PingSent) + { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"Not waiting for Pong. Skipping."); + return; + } + + if (pongPayload == PingPayload) + { + if (NetEventSource.Log.IsEnabled()) NetEventSource.PongResponseReceived(this, pongPayload); + + PingTimeoutTimestamp = long.MaxValue; + PingSent = false; + } + else + { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"Expected payload {PingPayload}. Skipping."); + } + } + } + + internal void OnNextPingRequestCore() + { + Debug.Assert(Monitor.IsEntered(StateUpdateLock)); + + PingSent = true; + PingTimeoutTimestamp = Environment.TickCount64 + TimeoutMs; + ++PingPayload; + } + + internal void OnKeepAliveFaulted(Exception exc) + { + lock (StateUpdateLock) + { + OnKeepAliveFaultedCore(exc); + } + } + + internal void OnKeepAliveFaultedCore(Exception exc) + { + Debug.Assert(Monitor.IsEntered(StateUpdateLock)); + + if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceErrorMsg(this, exc); + + if (_parent._disposed) + { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"WebSocket already disposed, skipping..."); + return; + } + + if (_parent.State is WebSocketState.Closed) + { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"WebSocket is already closed, skipping..."); + // We've transferred into the Closed state, but didn't dispose yet + // This can happen in e.g. HandleReceivedCloseAsync where we first change the state + // but then still do some operations with the stream. + // No need to do anything as we've already completed the Closing Handshake + return; + } + + if (_parent.State is WebSocketState.Aborted) + { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"WebSocket is already aborted, skipping..."); + // Something else already aborted the websocket, but didn't dispose it (yet?)? + // This can happen either + // (1) in the Abort() method, e.g. on cancellation, if we interjected between the state + // change and the Dispose() call; or + // (2) in the catch block of ReceiveAsyncPrivate (which doesn't do the dispose after??). + // This most possibly happens if we've hit a premature EOF from the server. + // Websocket is not usable in the Aborted state anyway, so let's free the resources while we're at it? + _parent.Dispose(); + return; + } + + // we were the ones who triggered the abort, let's save the exception + Exception = exc; + + _parent.OnAbortedCore(); + _parent.DisposeCore(); + } + } + } +} diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs index 3ee864e71cc8c..8a26a4c29e2eb 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -8,6 +8,7 @@ using System.Net.WebSockets.Compression; using System.Numerics; using System.Runtime.CompilerServices; +using System.Runtime.ExceptionServices; using System.Runtime.InteropServices; using System.Security.Cryptography; using System.Text; @@ -137,23 +138,34 @@ internal sealed partial class ManagedWebSocket : WebSocket private readonly WebSocketInflater? _inflater; private readonly WebSocketDeflater? _deflater; + private readonly KeepAlivePingState? _keepAlivePingState; + /// Initializes the websocket. /// The connected Stream. /// true if this is the server-side of the connection; false if this is the client-side of the connection. /// The agreed upon subprotocol for the connection. /// The interval to use for keep-alive pings. - internal ManagedWebSocket(Stream stream, bool isServer, string? subprotocol, TimeSpan keepAliveInterval) + /// The timeout to use when waiting for keep-alive pong response. + internal ManagedWebSocket(Stream stream, bool isServer, string? subprotocol, TimeSpan keepAliveInterval, TimeSpan keepAliveTimeout) { Debug.Assert(StateUpdateLock != null, $"Expected {nameof(StateUpdateLock)} to be non-null"); Debug.Assert(stream != null, $"Expected non-null {nameof(stream)}"); Debug.Assert(stream.CanRead, $"Expected readable {nameof(stream)}"); Debug.Assert(stream.CanWrite, $"Expected writeable {nameof(stream)}"); Debug.Assert(keepAliveInterval == Timeout.InfiniteTimeSpan || keepAliveInterval >= TimeSpan.Zero, $"Invalid {nameof(keepAliveInterval)}: {keepAliveInterval}"); + Debug.Assert(keepAliveTimeout == Timeout.InfiniteTimeSpan || keepAliveTimeout >= TimeSpan.Zero, $"Invalid {nameof(keepAliveTimeout)}: {keepAliveTimeout}"); _stream = stream; _isServer = isServer; _subprotocol = subprotocol; + if (NetEventSource.Log.IsEnabled()) + { + NetEventSource.Associate(this, stream); + NetEventSource.Associate(this, _sendMutex); + NetEventSource.Associate(this, _receiveMutex); + } + // Create a buffer just large enough to handle received packet headers (at most 14 bytes) and // control payloads (at most 125 bytes). Message payloads are read directly into the buffer // supplied to ReceiveAsync. @@ -165,14 +177,33 @@ internal ManagedWebSocket(Stream stream, bool isServer, string? subprotocol, Tim // that could keep the web socket rooted in erroneous cases. if (keepAliveInterval > TimeSpan.Zero) { + long heartBeatIntervalMs = (long)keepAliveInterval.TotalMilliseconds; + if (keepAliveTimeout > TimeSpan.Zero) + { + _keepAlivePingState = new KeepAlivePingState(keepAliveInterval, keepAliveTimeout, this); + heartBeatIntervalMs = _keepAlivePingState.HeartBeatIntervalMs; + + if (NetEventSource.Log.IsEnabled()) + { + NetEventSource.Associate(this, _keepAlivePingState); + + NetEventSource.Trace(this, + $"Enabling Ping/Pong Keep-Alive strategy: ping delay={_keepAlivePingState.DelayMs}ms, timeout={_keepAlivePingState.TimeoutMs}ms, heartbeat={heartBeatIntervalMs}ms"); + } + } + else if (NetEventSource.Log.IsEnabled()) + { + NetEventSource.Trace(this, $"Enabling Unsolicited Pong Keep-Alive strategy: heartbeat={heartBeatIntervalMs}ms"); + } + _keepAliveTimer = new Timer(static s => { var wr = (WeakReference)s!; if (wr.TryGetTarget(out ManagedWebSocket? thisRef)) { - thisRef.SendKeepAliveFrameAsync(); + thisRef.HeartBeat(); } - }, new WeakReference(this), keepAliveInterval, keepAliveInterval); + }, new WeakReference(this), heartBeatIntervalMs, heartBeatIntervalMs); } } @@ -180,7 +211,7 @@ internal ManagedWebSocket(Stream stream, bool isServer, string? subprotocol, Tim /// The connected Stream. /// The options with which the websocket must be created. internal ManagedWebSocket(Stream stream, WebSocketCreationOptions options) - : this(stream, options.IsServer, options.SubProtocol, options.KeepAliveInterval) + : this(stream, options.IsServer, options.SubProtocol, options.KeepAliveInterval, options.KeepAliveTimeout) { var deflateOptions = options.DangerousDeflateOptions; @@ -201,6 +232,8 @@ internal ManagedWebSocket(Stream stream, WebSocketCreationOptions options) public override void Dispose() { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this); + lock (StateUpdateLock) { DisposeCore(); @@ -210,17 +243,23 @@ public override void Dispose() private void DisposeCore() { Debug.Assert(Monitor.IsEntered(StateUpdateLock), $"Expected {nameof(StateUpdateLock)} to be held"); + + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"{nameof(_disposed)}={_disposed}"); + if (!_disposed) { _disposed = true; _keepAliveTimer?.Dispose(); _stream.Dispose(); - if (_state < WebSocketState.Aborted) + WebSocketState state = _state; + if (state < WebSocketState.Aborted) { _state = WebSocketState.Closed; } + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"State transition from {state} to {_state}"); + DisposeSafe(_inflater, _receiveMutex); DisposeSafe(_deflater, _sendMutex); } @@ -234,15 +273,23 @@ private static void DisposeSafe(IDisposable? resource, AsyncMutex mutex) if (lockTask.IsCompleted) { + if (NetEventSource.Log.IsEnabled()) NetEventSource.MutexEntered(mutex); + resource.Dispose(); mutex.Exit(); + + if (NetEventSource.Log.IsEnabled()) NetEventSource.MutexExited(mutex); } else { lockTask.GetAwaiter().UnsafeOnCompleted(() => { + if (NetEventSource.Log.IsEnabled()) NetEventSource.MutexEntered(mutex); + resource.Dispose(); mutex.Exit(); + + if (NetEventSource.Log.IsEnabled()) NetEventSource.MutexExited(mutex); }); } } @@ -258,6 +305,8 @@ private static void DisposeSafe(IDisposable? resource, AsyncMutex mutex) public override Task SendAsync(ArraySegment buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this); + if (messageType != WebSocketMessageType.Text && messageType != WebSocketMessageType.Binary) { throw new ArgumentException(SR.Format( @@ -276,6 +325,8 @@ public override ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessag public override ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, WebSocketMessageFlags messageFlags, CancellationToken cancellationToken) { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this); + if (messageType != WebSocketMessageType.Text && messageType != WebSocketMessageType.Binary) { throw new ArgumentException(SR.Format( @@ -286,10 +337,11 @@ public override ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessag try { - WebSocketValidate.ThrowIfInvalidState(_state, _disposed, s_validSendStates); + ThrowIfInvalidState(s_validSendStates); } catch (Exception exc) { + if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceException(this, exc); return ValueTask.FromException(exc); } @@ -319,44 +371,53 @@ public override ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessag public override Task ReceiveAsync(ArraySegment buffer, CancellationToken cancellationToken) { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this); + WebSocketValidate.ValidateArraySegment(buffer, nameof(buffer)); try { - WebSocketValidate.ThrowIfInvalidState(_state, _disposed, s_validReceiveStates); + ThrowIfInvalidState(s_validReceiveStates); return ReceiveAsyncPrivate(buffer, cancellationToken).AsTask(); } catch (Exception exc) { + if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceException(this, exc); return Task.FromException(exc); } } public override ValueTask ReceiveAsync(Memory buffer, CancellationToken cancellationToken) { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this); + try { - WebSocketValidate.ThrowIfInvalidState(_state, _disposed, s_validReceiveStates); + ThrowIfInvalidState(s_validReceiveStates); return ReceiveAsyncPrivate(buffer, cancellationToken); } catch (Exception exc) { + if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceException(this, exc); return ValueTask.FromException(exc); } } public override Task CloseAsync(WebSocketCloseStatus closeStatus, string? statusDescription, CancellationToken cancellationToken) { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this); + WebSocketValidate.ValidateCloseStatus(closeStatus, statusDescription); try { - WebSocketValidate.ThrowIfInvalidState(_state, _disposed, s_validCloseStates); + ThrowIfInvalidState(s_validCloseStates); } catch (Exception exc) { + if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceException(this, exc); return Task.FromException(exc); } @@ -365,13 +426,17 @@ public override Task CloseAsync(WebSocketCloseStatus closeStatus, string? status public override Task CloseOutputAsync(WebSocketCloseStatus closeStatus, string? statusDescription, CancellationToken cancellationToken) { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this); + WebSocketValidate.ValidateCloseStatus(closeStatus, statusDescription); return CloseOutputAsyncCore(closeStatus, statusDescription, cancellationToken); } private async Task CloseOutputAsyncCore(WebSocketCloseStatus closeStatus, string? statusDescription, CancellationToken cancellationToken) { - WebSocketValidate.ThrowIfInvalidState(_state, _disposed, s_validCloseOutputStates); + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this); + + ThrowIfInvalidState(s_validCloseOutputStates); await SendCloseFrameAsync(closeStatus, statusDescription, cancellationToken).ConfigureAwait(false); @@ -388,22 +453,35 @@ private async Task CloseOutputAsyncCore(WebSocketCloseStatus closeStatus, string public override void Abort() { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this); + OnAborted(); Dispose(); // forcibly tear down connection } private void OnAborted() { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this); + lock (StateUpdateLock) { - WebSocketState state = _state; - if (state != WebSocketState.Closed && state != WebSocketState.Aborted) - { - _state = state != WebSocketState.None && state != WebSocketState.Connecting ? - WebSocketState.Aborted : - WebSocketState.Closed; - } + OnAbortedCore(); + } + } + + private void OnAbortedCore() + { + Debug.Assert(Monitor.IsEntered(StateUpdateLock), $"Expected {nameof(StateUpdateLock)} to be held"); + + WebSocketState state = _state; + if (state is not WebSocketState.Closed and not WebSocketState.Aborted) + { + _state = state is not WebSocketState.None and not WebSocketState.Connecting ? + WebSocketState.Aborted : + WebSocketState.Closed; } + + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"State transition from {state} to {_state}"); } /// Sends a websocket frame to the network. @@ -414,6 +492,8 @@ private void OnAborted() /// The CancellationToken to use to cancel the websocket. private ValueTask SendFrameAsync(MessageOpcode opcode, bool endOfMessage, bool disableCompression, ReadOnlyMemory payloadBuffer, CancellationToken cancellationToken) { + if (NetEventSource.Log.IsEnabled()) NetEventSource.SendFrameAsyncStarted(this, opcode.ToString(), payloadBuffer.Length); + // If a cancelable cancellation token was provided, that would require registering with it, which means more state we have to // pass around (the CancellationTokenRegistration), so if it is cancelable, just immediately go to the fallback path. // Similarly, it should be rare that there are multiple outstanding calls to SendFrameAsync, but if there are, again @@ -433,6 +513,8 @@ private ValueTask SendFrameLockAcquiredNonCancelableAsync(MessageOpcode opcode, { Debug.Assert(_sendMutex.IsHeld, $"Caller should hold the {nameof(_sendMutex)}"); + if (NetEventSource.Log.IsEnabled()) NetEventSource.MutexEntered(_sendMutex); + // If we get here, the cancellation token is not cancelable so we don't have to worry about it, // and we own the semaphore, so we don't need to asynchronously wait for it. ValueTask writeTask = default; @@ -468,6 +550,8 @@ private ValueTask SendFrameLockAcquiredNonCancelableAsync(MessageOpcode opcode, } catch (Exception exc) { + if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceException(this, exc); + return ValueTask.FromException( exc is OperationCanceledException ? exc : _state == WebSocketState.Aborted ? CreateOperationCanceledException(exc) : @@ -479,6 +563,12 @@ private ValueTask SendFrameLockAcquiredNonCancelableAsync(MessageOpcode opcode, { ReleaseSendBuffer(); _sendMutex.Exit(); + + if (NetEventSource.Log.IsEnabled()) + { + NetEventSource.MutexExited(_sendMutex); + NetEventSource.SendFrameAsyncCompleted(this); + } } } @@ -495,8 +585,15 @@ private async ValueTask WaitForWriteTaskAsync(ValueTask writeTask, bool shouldFl await _stream.FlushAsync().ConfigureAwait(false); } } - catch (Exception exc) when (exc is not OperationCanceledException) + catch (Exception exc) { + if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceException(this, exc); + + if (exc is OperationCanceledException) + { + throw; + } + throw _state == WebSocketState.Aborted ? CreateOperationCanceledException(exc) : new WebSocketException(WebSocketError.ConnectionClosedPrematurely, exc); @@ -505,12 +602,20 @@ private async ValueTask WaitForWriteTaskAsync(ValueTask writeTask, bool shouldFl { ReleaseSendBuffer(); _sendMutex.Exit(); + + if (NetEventSource.Log.IsEnabled()) + { + NetEventSource.MutexExited(_sendMutex); + NetEventSource.SendFrameAsyncCompleted(this); + } } } private async ValueTask SendFrameFallbackAsync(MessageOpcode opcode, bool endOfMessage, bool disableCompression, ReadOnlyMemory payloadBuffer, Task lockTask, CancellationToken cancellationToken) { await lockTask.ConfigureAwait(false); + if (NetEventSource.Log.IsEnabled()) NetEventSource.MutexEntered(_sendMutex); + try { int sendBytes = WriteFrameToSendBuffer(opcode, endOfMessage, disableCompression, payloadBuffer.Span); @@ -520,8 +625,15 @@ private async ValueTask SendFrameFallbackAsync(MessageOpcode opcode, bool endOfM await _stream.FlushAsync(cancellationToken).ConfigureAwait(false); } } - catch (Exception exc) when (exc is not OperationCanceledException) + catch (Exception exc) { + if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceException(this, exc); + + if (exc is OperationCanceledException) + { + throw; + } + throw _state == WebSocketState.Aborted ? CreateOperationCanceledException(exc, cancellationToken) : new WebSocketException(WebSocketError.ConnectionClosedPrematurely, exc); @@ -530,13 +642,19 @@ private async ValueTask SendFrameFallbackAsync(MessageOpcode opcode, bool endOfM { ReleaseSendBuffer(); _sendMutex.Exit(); + + if (NetEventSource.Log.IsEnabled()) + { + NetEventSource.MutexExited(_sendMutex); + NetEventSource.SendFrameAsyncCompleted(this); + } } } /// Writes a frame into the send buffer, which can then be sent over the network. private int WriteFrameToSendBuffer(MessageOpcode opcode, bool endOfMessage, bool disableCompression, ReadOnlySpan payloadBuffer) { - ObjectDisposedException.ThrowIf(_disposed, typeof(WebSocket)); + ThrowIfDisposed(); if (_deflater is not null && !disableCompression) { @@ -585,26 +703,6 @@ private int WriteFrameToSendBuffer(MessageOpcode opcode, bool endOfMessage, bool return headerLength + payloadLength; } - private void SendKeepAliveFrameAsync() - { - // This exists purely to keep the connection alive; don't wait for the result, and ignore any failures. - // The call will handle releasing the lock. We send a pong rather than ping, since it's allowed by - // the RFC as a unidirectional heartbeat and we're not interested in waiting for a response. - ValueTask t = SendFrameAsync(MessageOpcode.Pong, endOfMessage: true, disableCompression: true, ReadOnlyMemory.Empty, CancellationToken.None); - if (t.IsCompletedSuccessfully) - { - t.GetAwaiter().GetResult(); - } - else - { - // "Observe" any exception, ignoring it to prevent the unobserved exception event from being raised. - t.AsTask().ContinueWith(static p => { _ = p.Exception; }, - CancellationToken.None, - TaskContinuationOptions.OnlyOnFaulted | TaskContinuationOptions.ExecuteSynchronously, - TaskScheduler.Default); - } - } - private static int WriteHeader(MessageOpcode opcode, byte[] sendBuffer, ReadOnlySpan payload, bool endOfMessage, bool useMask, bool compressed) { // Client header format: @@ -697,13 +795,19 @@ private async ValueTask ReceiveAsyncPrivate(Memory paylo // those to be much less frequent (e.g. we should only get one close per websocket), and thus we can afford to pay // a bit more for readability and maintainability. - CancellationTokenRegistration registration = cancellationToken.Register(static s => ((ManagedWebSocket)s!).Abort(), this); + if (NetEventSource.Log.IsEnabled()) NetEventSource.ReceiveAsyncPrivateStarted(this, payloadBuffer.Length); + + CancellationTokenRegistration registration = default; try { + registration = cancellationToken.Register(static s => ((ManagedWebSocket)s!).Abort(), this); + await _receiveMutex.EnterAsync(cancellationToken).ConfigureAwait(false); + if (NetEventSource.Log.IsEnabled()) NetEventSource.MutexEntered(_receiveMutex); + try { - ObjectDisposedException.ThrowIf(_disposed, typeof(WebSocket)); + ThrowIfDisposed(); while (true) // in case we get control frames that should be ignored from the user's perspective { @@ -715,6 +819,8 @@ private async ValueTask ReceiveAsyncPrivate(Memory paylo MessageHeader header = _lastReceiveHeader; if (header.Processed) { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, "Reading the next frame header"); + if (_receiveBufferCount < (_isServer ? MaxMessageHeaderLength : (MaxMessageHeaderLength - MaskLength))) { // Make sure we have the first two bytes, which includes the start of the payload length. @@ -758,6 +864,11 @@ private async ValueTask ReceiveAsyncPrivate(Memory paylo } _receivedMaskOffsetOffset = 0; + if (NetEventSource.Log.IsEnabled()) + { + NetEventSource.Trace(this, $"Next frame opcode={header.Opcode}, fin={header.Fin}, compressed={header.Compressed}, payloadLength={header.PayloadLength}"); + } + if (header.PayloadLength == 0 && header.Compressed) { // In the rare case where we receive a compressed message with no payload @@ -841,10 +952,14 @@ private async ValueTask ReceiveAsyncPrivate(Memory paylo int numBytesRead = await _stream.ReadAtLeastAsync( readBuffer, bytesToRead, throwOnEndOfStream: false, cancellationToken).ConfigureAwait(false); + + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"bytesRead={numBytesRead}"); + if (numBytesRead < bytesToRead) { ThrowEOFUnexpected(); } + _keepAlivePingState?.OnDataReceived(); totalBytesReceived += numBytesRead; } @@ -882,6 +997,11 @@ private async ValueTask ReceiveAsyncPrivate(Memory paylo await CloseWithReceiveErrorAndThrowAsync(WebSocketCloseStatus.InvalidPayloadData, WebSocketError.Faulted).ConfigureAwait(false); } + if (header.Processed) + { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, "Data frame fully processed"); + } + _lastReceiveHeader = header; return GetReceiveResult( totalBytesReceived, @@ -892,13 +1012,30 @@ private async ValueTask ReceiveAsyncPrivate(Memory paylo finally { _receiveMutex.Exit(); + if (NetEventSource.Log.IsEnabled()) NetEventSource.MutexExited(_receiveMutex); } } - catch (Exception exc) when (exc is not OperationCanceledException) + catch (Exception exc) { + if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceException(this, exc); + + if (exc is OperationCanceledException) + { + throw; + } + if (_state == WebSocketState.Aborted) { - throw new OperationCanceledException(nameof(WebSocketState.Aborted), exc); + Exception inner = exc; + if (_keepAlivePingState?.Exception is not null) + { + // exception was most likely caused by us aborting the connection due to + // keep-alive timeout; but let's surface both just in case + inner = ExceptionDispatchInfo.SetCurrentStackTrace( + new AggregateException(_keepAlivePingState.Exception, exc)); + } + + throw new OperationCanceledException(nameof(WebSocketState.Aborted), inner); } OnAborted(); @@ -912,6 +1049,7 @@ private async ValueTask ReceiveAsyncPrivate(Memory paylo finally { registration.Dispose(); + if (NetEventSource.Log.IsEnabled()) NetEventSource.ReceiveAsyncPrivateCompleted(this); } } @@ -941,14 +1079,17 @@ private async ValueTask HandleReceivedCloseAsync(MessageHeader header, Cancellat lock (StateUpdateLock) { _receivedCloseFrame = true; - if (_sentCloseFrame && _state < WebSocketState.Closed) + WebSocketState state = _state; + if (_sentCloseFrame && state < WebSocketState.Closed) { _state = WebSocketState.Closed; } - else if (_state < WebSocketState.CloseReceived) + else if (state < WebSocketState.CloseReceived) { _state = WebSocketState.CloseReceived; } + + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"State transition from {state} to {_state}"); } WebSocketCloseStatus closeStatus = WebSocketCloseStatus.NormalClosure; @@ -1005,6 +1146,8 @@ private async ValueTask HandleReceivedCloseAsync(MessageHeader header, Cancellat /// Issues a read on the stream to wait for EOF. private async ValueTask WaitForServerToCloseConnectionAsync(CancellationToken cancellationToken) { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this); + // Per RFC 6455 7.1.1, try to let the server close the connection. We give it up to a second. // We simply issue a read and don't care what we get back; we could validate that we don't get // additional data, but at this point we're about to close the connection and we're just stalling @@ -1036,19 +1179,30 @@ private async ValueTask WaitForServerToCloseConnectionAsync(CancellationToken ca /// The CancellationToken used to cancel the websocket operation. private async ValueTask HandleReceivedPingPongAsync(MessageHeader header, CancellationToken cancellationToken) { + Debug.Assert(_receiveMutex.IsHeld, $"Caller should hold the {nameof(_receiveMutex)}"); + + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this); + // Consume any (optional) payload associated with the ping/pong. if (header.PayloadLength > 0 && _receiveBufferCount < header.PayloadLength) { await EnsureBufferContainsAsync((int)header.PayloadLength, cancellationToken).ConfigureAwait(false); } + bool processPing = header.Opcode == MessageOpcode.Ping; + + bool processPong = header.Opcode == MessageOpcode.Pong && _keepAlivePingState is not null + && header.PayloadLength == KeepAlivePingState.PingPayloadSize; + + if ((processPing || processPong) && _isServer) + { + ApplyMask(_receiveBuffer.Span.Slice(_receiveBufferOffset, (int)header.PayloadLength), header.Mask, 0); + } + // If this was a ping, send back a pong response. - if (header.Opcode == MessageOpcode.Ping) + if (processPing) { - if (_isServer) - { - ApplyMask(_receiveBuffer.Span.Slice(_receiveBufferOffset, (int)header.PayloadLength), header.Mask, 0); - } + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, "Processing incoming Ping"); await SendFrameAsync( MessageOpcode.Pong, @@ -1057,6 +1211,20 @@ await SendFrameAsync( _receiveBuffer.Slice(_receiveBufferOffset, (int)header.PayloadLength), cancellationToken).ConfigureAwait(false); } + else if (processPong) + { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, "Processing incoming Pong"); + + long pongPayload = BinaryPrimitives.ReadInt64BigEndian(_receiveBuffer.Span.Slice(_receiveBufferOffset, (int)header.PayloadLength)); + lock (StateUpdateLock) + { + _keepAlivePingState!.OnPongResponseReceived(pongPayload); + } + } + else + { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, "Received Unsolicited Pong. Skipping."); + } // Regardless of whether it was a ping or pong, we no longer need the payload. if (header.PayloadLength > 0) @@ -1115,6 +1283,8 @@ private static bool IsValidCloseStatus(WebSocketCloseStatus closeStatus) private async ValueTask CloseWithReceiveErrorAndThrowAsync( WebSocketCloseStatus closeStatus, WebSocketError error, string? errorMessage = null, Exception? innerException = null) { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, errorMessage); + // Close the connection if it hasn't already been closed if (!_sentCloseFrame) { @@ -1258,79 +1428,98 @@ private async ValueTask CloseWithReceiveErrorAndThrowAsync( /// The CancellationToken to use to cancel the websocket. private async Task CloseAsyncPrivate(WebSocketCloseStatus closeStatus, string? statusDescription, CancellationToken cancellationToken) { - // Send the close message. Skip sending a close frame if we're currently in a CloseSent state, - // for example having just done a CloseOutputAsync. - if (!_sentCloseFrame) + if (NetEventSource.Log.IsEnabled()) NetEventSource.CloseAsyncPrivateStarted(this); + try { - await SendCloseFrameAsync(closeStatus, statusDescription, cancellationToken).ConfigureAwait(false); - } - - // We should now either be in a CloseSent case (because we just sent one), or in a Closed state, in case - // there was a concurrent receive that ended up handling an immediate close frame response from the server. - // Of course it could also be Aborted if something happened concurrently to cause things to blow up. - Debug.Assert( - State == WebSocketState.CloseSent || - State == WebSocketState.Closed || - State == WebSocketState.Aborted, - $"Unexpected state {State}."); + // Send the close message. Skip sending a close frame if we're currently in a CloseSent state, + // for example having just done a CloseOutputAsync. + if (!_sentCloseFrame) + { + await SendCloseFrameAsync(closeStatus, statusDescription, cancellationToken).ConfigureAwait(false); + } - // We only need to wait for a received close frame if we are in the CloseSent State. If we are in the Closed - // State then it means we already received a close frame. If we are in the Aborted State, then we should not - // wait for a close frame as per RFC 6455 Section 7.1.7 "Fail the WebSocket Connection". - if (State == WebSocketState.CloseSent) - { - // Wait until we've received a close response - byte[] closeBuffer = ArrayPool.Shared.Rent(MaxMessageHeaderLength + MaxControlPayloadLength); - try + // We should now either be in a CloseSent case (because we just sent one), or in a Closed state, in case + // there was a concurrent receive that ended up handling an immediate close frame response from the server. + // Of course it could also be Aborted if something happened concurrently to cause things to blow up. + Debug.Assert( + State == WebSocketState.CloseSent || + State == WebSocketState.Closed || + State == WebSocketState.Aborted, + $"Unexpected state {State}."); + + // We only need to wait for a received close frame if we are in the CloseSent State. If we are in the Closed + // State then it means we already received a close frame. If we are in the Aborted State, then we should not + // wait for a close frame as per RFC 6455 Section 7.1.7 "Fail the WebSocket Connection". + if (State == WebSocketState.CloseSent) { - // Loop until we've received a close frame. - while (!_receivedCloseFrame) + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, "Waiting for a close frame"); + + // Wait until we've received a close response + byte[] closeBuffer = ArrayPool.Shared.Rent(MaxMessageHeaderLength + MaxControlPayloadLength); + try { - // Enter the receive lock in order to get a consistent view of whether we've received a close - // frame. If we haven't, issue a receive. Since that receive will try to take the same - // non-entrant receive lock, we then exit the lock before waiting for the receive to complete, - // as it will always complete asynchronously and only after we've exited the lock. - ValueTask receiveTask = default; - try + // Loop until we've received a close frame. + while (!_receivedCloseFrame) { - await _receiveMutex.EnterAsync(cancellationToken).ConfigureAwait(false); + // Enter the receive lock in order to get a consistent view of whether we've received a close + // frame. If we haven't, issue a receive. Since that receive will try to take the same + // non-entrant receive lock, we then exit the lock before waiting for the receive to complete, + // as it will always complete asynchronously and only after we've exited the lock. + ValueTask receiveTask = default; try { - if (!_receivedCloseFrame) + await _receiveMutex.EnterAsync(cancellationToken).ConfigureAwait(false); + if (NetEventSource.Log.IsEnabled()) NetEventSource.MutexEntered(_receiveMutex); + + try { - receiveTask = ReceiveAsyncPrivate(closeBuffer, cancellationToken); + if (!_receivedCloseFrame) + { + receiveTask = ReceiveAsyncPrivate(closeBuffer, cancellationToken); + } + } + finally + { + _receiveMutex.Exit(); + if (NetEventSource.Log.IsEnabled()) NetEventSource.MutexExited(_receiveMutex); } } - finally + catch (OperationCanceledException) { - _receiveMutex.Exit(); + // If waiting on the receive lock was canceled, abort the connection, as we would do + // as part of the receive itself. + Abort(); + throw; } - } - catch (OperationCanceledException) - { - // If waiting on the receive lock was canceled, abort the connection, as we would do - // as part of the receive itself. - Abort(); - throw; - } - // Wait for the receive to complete if we issued one. - await receiveTask.ConfigureAwait(false); + // Wait for the receive to complete if we issued one. + await receiveTask.ConfigureAwait(false); + } + } + finally + { + ArrayPool.Shared.Return(closeBuffer); } } - finally + + // We're closed. Close the connection and update the status. + lock (StateUpdateLock) { - ArrayPool.Shared.Return(closeBuffer); + DisposeCore(); } } - - // We're closed. Close the connection and update the status. - lock (StateUpdateLock) + catch (Exception exc) { - DisposeCore(); + if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceException(this, exc); + throw; + } + finally + { + if (NetEventSource.Log.IsEnabled()) NetEventSource.CloseAsyncPrivateCompleted(this); } } + /// Sends a close message to the server. /// The close status to send. /// The close status description to send. @@ -1370,14 +1559,17 @@ private async ValueTask SendCloseFrameAsync(WebSocketCloseStatus closeStatus, st lock (StateUpdateLock) { _sentCloseFrame = true; - if (_receivedCloseFrame && _state < WebSocketState.Closed) + WebSocketState state = _state; + if (_receivedCloseFrame && state < WebSocketState.Closed) { _state = WebSocketState.Closed; } - else if (_state < WebSocketState.CloseSent) + else if (state < WebSocketState.CloseSent) { _state = WebSocketState.CloseSent; } + + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"State transition from {state} to {_state}"); } if (!_isServer && _receivedCloseFrame) @@ -1417,10 +1609,13 @@ private async ValueTask EnsureBufferContainsAsync(int minimumRequiredBytes, Canc _receiveBuffer.Slice(_receiveBufferCount), bytesToRead, throwOnEndOfStream: false, cancellationToken).ConfigureAwait(false); _receiveBufferCount += numRead; + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"bytesRead={numRead}"); + if (numRead < bytesToRead) { ThrowEOFUnexpected(); } + _keepAlivePingState?.OnDataReceived(); } } } @@ -1430,7 +1625,7 @@ private void ThrowEOFUnexpected() // The connection closed before we were able to read everything we needed. // If it was due to us being disposed, fail with the correct exception. // Otherwise, it was due to the connection being closed and it wasn't expected. - ObjectDisposedException.ThrowIf(_disposed, typeof(WebSocket)); + ThrowIfDisposed(); throw new WebSocketException(WebSocketError.ConnectionClosedPrematurely); } @@ -1542,6 +1737,30 @@ private void ThrowIfOperationInProgress(bool operationCompleted, [CallerMemberNa cancellationToken); } + private void ThrowIfDisposed() => ThrowIfInvalidState(); + + private void ThrowIfInvalidState(WebSocketState[]? validStates = null) + { + bool disposed = _disposed; + WebSocketState state = _state; + Exception? keepAliveException = null; + + if (_keepAlivePingState is not null) + { + // we need to take a lock to maintain consistency + lock (StateUpdateLock) + { + disposed = _disposed; + state = _state; + keepAliveException = _keepAlivePingState.Exception; + } + } + + if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"_state={state}, _disposed={disposed}, _keepAlivePingState.Exception={keepAliveException}"); + + WebSocketValidate.ThrowIfInvalidState(state, disposed, keepAliveException, validStates); + } + // From https://github.com/aspnet/WebSockets/blob/aa63e27fce2e9202698053620679a9a1059b501e/src/Microsoft.AspNetCore.WebSockets.Protocol/Utilities.cs#L75 // Performs a stateful validation of UTF-8 bytes. // It checks for valid formatting, overlong encodings, surrogates, and value ranges. diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/NetEventSource.WebSockets.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/NetEventSource.WebSockets.cs new file mode 100644 index 0000000000000..d0977fb767060 --- /dev/null +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/NetEventSource.WebSockets.cs @@ -0,0 +1,286 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Diagnostics.Tracing; +using System.Runtime.CompilerServices; + +namespace System.Net +{ + [EventSource(Name = "Private.InternalDiagnostics.System.Net.WebSockets")] + internal sealed partial class NetEventSource + { + // NOTE + // - The 'Start' and 'Stop' suffixes on the following event names have special meaning in EventSource. They + // enable creating 'activities'. + // For more information, take a look at the following blog post: + // https://blogs.msdn.microsoft.com/vancem/2015/09/14/exploring-eventsource-activity-correlation-and-causation-features/ + // - A stop event's event id must be next one after its start event. + + private const int KeepAliveSentId = NextAvailableEventId; + private const int KeepAliveAckedId = KeepAliveSentId + 1; + + private const int WsTraceId = KeepAliveAckedId + 1; + + private const int CloseStartId = WsTraceId + 1; + private const int CloseStopId = CloseStartId + 1; + + private const int ReceiveStartId = CloseStopId + 1; + private const int ReceiveStopId = ReceiveStartId + 1; + + private const int SendStartId = ReceiveStopId + 1; + private const int SendStopId = SendStartId + 1; + + private const int MutexEnterId = SendStopId + 1; + private const int MutexExitId = MutexEnterId + 1; + private const int MutexContendedId = MutexExitId + 1; + + // + // Keep-Alive + // + + private const string Ping = "Ping"; + private const string Pong = "Pong"; + + [Event(KeepAliveSentId, Keywords = Keywords.Debug, Level = EventLevel.Informational)] + private void KeepAliveSent(string objName, string opcode, long payload) => + WriteEvent(KeepAliveSentId, objName, opcode, payload); + + [Event(KeepAliveAckedId, Keywords = Keywords.Debug, Level = EventLevel.Informational)] + private void KeepAliveAcked(string objName, long payload) => + WriteEvent(KeepAliveAckedId, objName, payload); + + [NonEvent] + public static void KeepAlivePingSent(object? obj, long payload) + { + Debug.Assert(Log.IsEnabled()); + Log.KeepAliveSent(IdOf(obj), Ping, payload); + } + + [NonEvent] + public static void UnsolicitedPongSent(object? obj) + { + Debug.Assert(Log.IsEnabled()); + Log.KeepAliveSent(IdOf(obj), Pong, 0); + } + + [NonEvent] + public static void PongResponseReceived(object? obj, long payload) + { + Debug.Assert(Log.IsEnabled()); + Log.KeepAliveAcked(IdOf(obj), payload); + } + + // + // Debug Messages + // + + [Event(WsTraceId, Keywords = Keywords.Debug, Level = EventLevel.Verbose)] + private void WsTrace(string objName, string memberName, string message) => + WriteEvent(WsTraceId, objName, memberName, message); + + [NonEvent] + public static void TraceErrorMsg(object? obj, Exception exception, [CallerMemberName] string? memberName = null) + => Trace(obj, $"{exception.GetType().Name}: {exception.Message}", memberName); + + [NonEvent] + public static void TraceException(object? obj, Exception exception, [CallerMemberName] string? memberName = null) + => Trace(obj, exception.ToString(), memberName); + + [NonEvent] + public static void Trace(object? obj, string? message = null, [CallerMemberName] string? memberName = null) + { + Debug.Assert(Log.IsEnabled()); + Log.WsTrace(IdOf(obj), memberName ?? MissingMember, message ?? memberName ?? string.Empty); + } + + // + // Close + // + + [Event(CloseStartId, Keywords = Keywords.Debug, Level = EventLevel.Verbose)] + private void CloseStart(string objName, string memberName) => + WriteEvent(CloseStartId, objName, memberName); + + [Event(CloseStopId, Keywords = Keywords.Debug, Level = EventLevel.Verbose)] + private void CloseStop(string objName, string memberName) => + WriteEvent(CloseStopId, objName, memberName); + + [NonEvent] + public static void CloseAsyncPrivateStarted(object? obj, [CallerMemberName] string? memberName = null) + { + Debug.Assert(Log.IsEnabled()); + Log.CloseStart(IdOf(obj), memberName ?? MissingMember); + } + + [NonEvent] + public static void CloseAsyncPrivateCompleted(object? obj, [CallerMemberName] string? memberName = null) + { + Debug.Assert(Log.IsEnabled()); + Log.CloseStop(IdOf(obj), memberName ?? MissingMember); + } + + // + // ReceiveAsyncPrivate + // + + [Event(ReceiveStartId, Keywords = Keywords.Debug, Level = EventLevel.Informational)] + private void ReceiveStart(string objName, string memberName, int bufferLength) => + WriteEvent(ReceiveStartId, objName, memberName, bufferLength); + + [Event(ReceiveStopId, Keywords = Keywords.Debug, Level = EventLevel.Informational)] + private void ReceiveStop(string objName, string memberName) => + WriteEvent(ReceiveStopId, objName, memberName); + + [NonEvent] + public static void ReceiveAsyncPrivateStarted(object? obj, int bufferLength, [CallerMemberName] string? memberName = null) + { + Debug.Assert(Log.IsEnabled()); + Log.ReceiveStart(IdOf(obj), memberName ?? MissingMember, bufferLength); + } + + [NonEvent] + public static void ReceiveAsyncPrivateCompleted(object? obj, [CallerMemberName] string? memberName = null) + { + Debug.Assert(Log.IsEnabled()); + Log.ReceiveStop(IdOf(obj), memberName ?? MissingMember); + } + + // + // SendFrameAsync + // + + [Event(SendStartId, Keywords = Keywords.Debug, Level = EventLevel.Verbose)] + private void SendStart(string objName, string memberName, string opcode, int bufferLength) => + WriteEvent(SendStartId, objName, memberName, opcode, bufferLength); + + [Event(SendStopId, Keywords = Keywords.Debug, Level = EventLevel.Verbose)] + private void SendStop(string objName, string memberName) => + WriteEvent(SendStopId, objName, memberName); + + [NonEvent] + public static void SendFrameAsyncStarted(object? obj, string opcode, int bufferLength, [CallerMemberName] string? memberName = null) + { + Debug.Assert(Log.IsEnabled()); + Log.SendStart(IdOf(obj), memberName ?? MissingMember, opcode, bufferLength); + } + + [NonEvent] + public static void SendFrameAsyncCompleted(object? obj, [CallerMemberName] string? memberName = null) + { + Debug.Assert(Log.IsEnabled()); + Log.SendStop(IdOf(obj), memberName ?? MissingMember); + } + + // + // AsyncMutex + // + + [Event(MutexEnterId, Keywords = Keywords.Debug, Level = EventLevel.Verbose)] + private void MutexEnter(string objName, string memberName) => + WriteEvent(MutexEnterId, objName, memberName); + + [Event(MutexExitId, Keywords = Keywords.Debug, Level = EventLevel.Verbose)] + private void MutexExit(string objName, string memberName) => + WriteEvent(MutexExitId, objName, memberName); + + [Event(MutexContendedId, Keywords = Keywords.Debug, Level = EventLevel.Verbose)] + private void MutexContended(string objName, string memberName, int queueLength) => + WriteEvent(MutexContendedId, objName, memberName, queueLength); + + [NonEvent] + public static void MutexEntered(object? obj, [CallerMemberName] string? memberName = null) + { + Debug.Assert(Log.IsEnabled()); + Log.MutexEnter(IdOf(obj), memberName ?? MissingMember); + } + + [NonEvent] + public static void MutexExited(object? obj, [CallerMemberName] string? memberName = null) + { + Debug.Assert(Log.IsEnabled()); + Log.MutexExit(IdOf(obj), memberName ?? MissingMember); + } + + [NonEvent] + public static void MutexContended(object? obj, int gateValue, [CallerMemberName] string? memberName = null) + { + Debug.Assert(Log.IsEnabled()); + Log.MutexContended(IdOf(obj), memberName ?? MissingMember, -gateValue); + } + + // + // WriteEvent overloads + // + + [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026:UnrecognizedReflectionPattern", + Justification = EventSourceSuppressMessage)] + [NonEvent] + private unsafe void WriteEvent(int eventId, string arg1, string arg2, long arg3) + { + fixed (char* arg1Ptr = arg1) + fixed (char* arg2Ptr = arg2) + { + const int NumEventDatas = 3; + EventData* descrs = stackalloc EventData[NumEventDatas]; + + descrs[0] = new EventData + { + DataPointer = (IntPtr)(arg1Ptr), + Size = (arg1.Length + 1) * sizeof(char) + }; + descrs[1] = new EventData + { + DataPointer = (IntPtr)(arg2Ptr), + Size = (arg2.Length + 1) * sizeof(char) + }; + descrs[2] = new EventData + { + DataPointer = (IntPtr)(&arg3), + Size = sizeof(long) + }; + + WriteEventCore(eventId, NumEventDatas, descrs); + } + } + + [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026:UnrecognizedReflectionPattern", + Justification = EventSourceSuppressMessage)] + [NonEvent] + private unsafe void WriteEvent(int eventId, string arg1, string arg2, string arg3, int arg4) + { + fixed (char* arg1Ptr = arg1) + fixed (char* arg2Ptr = arg2) + fixed (char* arg3Ptr = arg3) + { + const int NumEventDatas = 4; + EventData* descrs = stackalloc EventData[NumEventDatas]; + + descrs[0] = new EventData + { + DataPointer = (IntPtr)(arg1Ptr), + Size = (arg1.Length + 1) * sizeof(char) + }; + descrs[1] = new EventData + { + DataPointer = (IntPtr)(arg2Ptr), + Size = (arg2.Length + 1) * sizeof(char) + }; + descrs[2] = new EventData + { + DataPointer = (IntPtr)(arg3Ptr), + Size = (arg3.Length + 1) * sizeof(char) + }; + descrs[3] = new EventData + { + DataPointer = (IntPtr)(&arg4), + Size = sizeof(int) + }; + + WriteEventCore(eventId, NumEventDatas, descrs); + } + } + + } +} diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs index fc4436926a6f0..ce47b894d32c7 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs @@ -84,7 +84,7 @@ private async ValueTask SendWithArrayPoolAsync( public static TimeSpan DefaultKeepAliveInterval { // In the .NET Framework, this pulls the value from a P/Invoke. Here we just hardcode it to a reasonable default. - get { return TimeSpan.FromSeconds(30); } + get { return WebSocketDefaults.DefaultClientKeepAliveInterval; } } protected static void ThrowOnInvalidState(WebSocketState state, params WebSocketState[] validStates) @@ -150,7 +150,7 @@ public static WebSocket CreateFromStream(Stream stream, bool isServer, string? s 0)); } - return new ManagedWebSocket(stream, isServer, subProtocol, keepAliveInterval); + return new ManagedWebSocket(stream, isServer, subProtocol, keepAliveInterval, WebSocketDefaults.DefaultKeepAliveTimeout); } /// Creates a that operates on a representing a web socket connection. @@ -209,7 +209,7 @@ public static WebSocket CreateClientWebSocket(Stream innerStream, // Ignore useZeroMaskingKey. ManagedWebSocket doesn't currently support that debugging option. // Ignore internalBuffer. ManagedWebSocket uses its own small buffer for headers/control messages. - return new ManagedWebSocket(innerStream, false, subProtocol, keepAliveInterval); + return new ManagedWebSocket(innerStream, false, subProtocol, keepAliveInterval, WebSocketDefaults.DefaultKeepAliveTimeout); } } } diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketCreationOptions.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketCreationOptions.cs index d042583da5444..dfc74241379f8 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketCreationOptions.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketCreationOptions.cs @@ -11,7 +11,8 @@ namespace System.Net.WebSockets public sealed class WebSocketCreationOptions { private string? _subProtocol; - private TimeSpan _keepAliveInterval; + private TimeSpan _keepAliveInterval = WebSocketDefaults.DefaultKeepAliveInterval; + private TimeSpan _keepAliveTimeout = WebSocketDefaults.DefaultKeepAliveTimeout; /// /// Defines if this websocket is the server-side of the connection. The default value is false. @@ -36,6 +37,8 @@ public string? SubProtocol /// /// The keep-alive interval to use, or or to disable keep-alives. + /// If is set, then PING messages are sent and peer's PONG responses are expected, otherwise, + /// unsolicited PONG messages are used as a keep-alive heartbeat. /// The default is . /// public TimeSpan KeepAliveInterval @@ -52,6 +55,25 @@ public TimeSpan KeepAliveInterval } } + /// + /// The timeout to use when waiting for the peer's PONG in response to us sending a PING; or or + /// to disable waiting for peer's response, and use an unsolicited PONG as a Keep-Alive heartbeat instead. + /// The default is . + /// + public TimeSpan KeepAliveTimeout + { + get => _keepAliveTimeout; + set + { + if (value != Timeout.InfiniteTimeSpan && value < TimeSpan.Zero) + { + throw new ArgumentOutOfRangeException(nameof(KeepAliveTimeout), value, + SR.Format(SR.net_WebSockets_ArgumentOutOfRange_TooSmall, 0)); + } + _keepAliveTimeout = value; + } + } + /// /// The agreed upon options for per message deflate. /// Be aware that enabling compression makes the application subject to CRIME/BREACH type of attacks. diff --git a/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj b/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj index 807da709ea755..a7f09ff31db29 100644 --- a/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj +++ b/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj @@ -1,6 +1,7 @@ $(NetCoreAppCurrent) + ../src/Resources/Strings.resx @@ -9,6 +10,7 @@ + diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketCloseTests.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketCloseTests.cs index 86d1dfb2cd530..423c8d40ed5d5 100644 --- a/src/libraries/System.Net.WebSockets/tests/WebSocketCloseTests.cs +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketCloseTests.cs @@ -85,5 +85,30 @@ public async Task ReceiveAsync_ValidCloseStatus_Success(WebSocketCloseStatus clo Assert.Equal(closeStatusDescription, closing.CloseStatusDescription); } } + + [Fact] + public async Task CloseAsync_CancelableEvenWhenPendingReceive_Throws() + { + using var stream = new WebSocketTestStream(); + using var websocket = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions()); + + Task receiveTask = websocket.ReceiveAsync(new byte[1], CancellationToken); + await Task.Delay(100); // give the receive task a chance to aquire the lock + var cancelCloseCts = new CancellationTokenSource(); + await Assert.ThrowsAnyAsync(async () => + { + Task t = websocket.CloseAsync(WebSocketCloseStatus.NormalClosure, null, cancelCloseCts.Token); + await Task.Delay(100); // give the close task time to get in the queue waiting for the lock + cancelCloseCts.Cancel(); + await t; + }); + + Assert.True(cancelCloseCts.Token.IsCancellationRequested); + Assert.False(CancellationToken.IsCancellationRequested); + + await Assert.ThrowsAnyAsync(() => receiveTask); + + Assert.False(CancellationToken.IsCancellationRequested); + } } } diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketKeepAliveTests.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketKeepAliveTests.cs new file mode 100644 index 0000000000000..11dd28117ac5d --- /dev/null +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketKeepAliveTests.cs @@ -0,0 +1,280 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers.Binary; +using System.Diagnostics; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Net.WebSockets.Tests +{ + public class WebSocketKeepAliveTests + { + public static readonly TimeSpan TestTimeout = TimeSpan.FromSeconds(10); + public static readonly TimeSpan KeepAliveInterval = TimeSpan.FromMilliseconds(100); + public static readonly TimeSpan KeepAliveTimeout = TimeSpan.FromSeconds(1); + public const int FramesToTestCount = 5; + +#region Frame format helper constants + + public const int MinHeaderLength = 2; + public const int MaskLength = 4; + public const int SingleInt64PayloadLength = sizeof(long); + public const int PingPayloadLength = SingleInt64PayloadLength; + + // 0b_1_***_**** -- fin=true + public const byte FirstByteBits_FinFlag = 0b_1_000_0000; + + // 0b_*_***_0010 -- opcode=BINARY (0x02) + public const byte FirstByteBits_OpcodeBinary = 0b_0_000_0010; + + // 0b_*_***_1001 -- opcode=PING (0x09) + public const byte FirstByteBits_OpcodePing = 0b_0_000_1001; + + // 0b_*_***_1010 -- opcode=PONG (0x10) + public const byte FirstByteBits_OpcodePong = 0b_0_000_1010; + + // 0b_1_******* -- mask=true + public const byte SecondByteBits_MaskFlag = 0b_1_0000000; + + // 0b_*_0001000 -- length=8 + public const byte SecondByteBits_PayloadLength8 = SingleInt64PayloadLength; + + public const byte FirstByte_PingFrame = FirstByteBits_FinFlag | FirstByteBits_OpcodePing; + public const byte FirstByte_PongFrame = FirstByteBits_FinFlag | FirstByteBits_OpcodePong; + public const byte FirstByte_DataFrame = FirstByteBits_FinFlag | FirstByteBits_OpcodeBinary; + + public const byte SecondByte_Server_NoPayload = 0; + public const byte SecondByte_Client_NoPayload = SecondByteBits_MaskFlag; + + public const byte SecondByte_Server_8bPayload = SecondByteBits_PayloadLength8; + public const byte SecondByte_Client_8bPayload = SecondByteBits_MaskFlag | SecondByteBits_PayloadLength8; + + public const int Server_FrameHeaderLength = MinHeaderLength; + public const int Client_FrameHeaderLength = MinHeaderLength + MaskLength; + +#endregion + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task WebSocket_NoUserReadOrWrite_SendsUnsolicitedPong(bool isServer) + { + var cancellationToken = new CancellationTokenSource(TestTimeout).Token; + + using WebSocketTestStream testStream = new(); + Stream localEndpointStream = testStream; + Stream remoteEndpointStream = testStream.Remote; + + using WebSocket webSocket = WebSocket.CreateFromStream(localEndpointStream, new WebSocketCreationOptions + { + IsServer = isServer, + KeepAliveInterval = KeepAliveInterval + }); + + // --- "remote endpoint" side --- + + int pongFrameLength = isServer ? Server_FrameHeaderLength : Client_FrameHeaderLength; + var pongBuffer = new byte[pongFrameLength]; + for (int i = 0; i < FramesToTestCount; i++) // WS should be sending pongs "indefinitely", let's check a few + { + await remoteEndpointStream.ReadExactlyAsync(pongBuffer, cancellationToken); + + Assert.Equal(FirstByte_PongFrame, pongBuffer[0]); + Assert.Equal( + isServer ? SecondByte_Server_NoPayload : SecondByte_Client_NoPayload, + pongBuffer[1]); + } + } + + [Fact] + public async Task WebSocketServer_NoUserReadOrWrite_SendsPingAndReadsPongResponse() + { + var cancellationToken = new CancellationTokenSource(TestTimeout).Token; + + using WebSocketTestStream testStream = new(); + Stream serverStream = testStream; + Stream clientStream = testStream.Remote; + + using WebSocket webSocketServer = WebSocket.CreateFromStream(serverStream, new WebSocketCreationOptions + { + IsServer = true, + KeepAliveInterval = KeepAliveInterval, + KeepAliveTimeout = TestTimeout // we don't care about the actual timeout here + }); + + // we need an outstanding read to ensure the client receives pongs + var readCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + var serverReadTask = webSocketServer.ReceiveAsync(Memory.Empty, readCts.Token); + + // --- "client" side --- + + var buffer = new byte[Client_FrameHeaderLength + PingPayloadLength]; // client frame is bigger because of masking + + for (int i = 0; i < FramesToTestCount; i++) // WS should be sending pings "indefinitely", let's check a few + { + Assert.Equal(WebSocketState.Open, webSocketServer.State); + Assert.False(serverReadTask.IsCompleted); + + buffer.AsSpan().Clear(); + await clientStream.ReadExactlyAsync( + buffer.AsMemory(0, Server_FrameHeaderLength + PingPayloadLength), + cancellationToken); + + Assert.Equal(FirstByte_PingFrame, buffer[0]); + + // implementation detail: payload is a long counter starting from 1 + Assert.Equal(SecondByte_Server_8bPayload, buffer[1]); + + var payloadBytes = buffer.AsSpan().Slice(Server_FrameHeaderLength, PingPayloadLength); + long pingCounter = BinaryPrimitives.ReadInt64BigEndian(payloadBytes); + + Assert.Equal(i+1, pingCounter); + + // --- sending pong back --- + + buffer[0] = FirstByte_PongFrame; + buffer[1] = SecondByte_Client_8bPayload; + + // using zeroes as a "mask" -- applying such a mask is a no-op + Array.Clear(buffer, MinHeaderLength, MaskLength); + + // sending the same payload back + BinaryPrimitives.WriteInt64BigEndian(buffer.AsSpan().Slice(Client_FrameHeaderLength), pingCounter); + + await clientStream.WriteAsync(buffer, cancellationToken); + } + + Assert.Equal(WebSocketState.Open, webSocketServer.State); + Assert.False(serverReadTask.IsCompleted); + + readCts.Cancel(); + + await Assert.ThrowsAsync(() => serverReadTask.AsTask()); + Assert.Equal(WebSocketState.Aborted, webSocketServer.State); + } + + [Fact] + public async Task WebSocketClient_NoServerDataSent_SendsPingAndReadsPongResponse() + { + var cancellationToken = new CancellationTokenSource(TestTimeout).Token; + + using WebSocketTestStream testStream = new(); + Stream clientStream = testStream; + Stream serverStream = testStream.Remote; + + using WebSocket webSocketClient = WebSocket.CreateFromStream(clientStream, new WebSocketCreationOptions + { + IsServer = false, + KeepAliveInterval = KeepAliveInterval, + KeepAliveTimeout = TestTimeout // we don't care about the actual timeout here + }); + + // we need an outstanding read to ensure the client receives pongs + var readCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + var clientReadTask = webSocketClient.ReceiveAsync(Memory.Empty, readCts.Token); + + + // --- "server" side --- + + var buffer = new byte[Client_FrameHeaderLength + PingPayloadLength]; // client frame is bigger because of masking + + for (int i = 0; i < FramesToTestCount; i++) // WS should be sending pings "indefinitely", let's check a few + { + Assert.Equal(WebSocketState.Open, webSocketClient.State); + Assert.False(clientReadTask.IsCompleted); + + buffer.AsSpan().Clear(); + await serverStream.ReadExactlyAsync(buffer, cancellationToken); + + Assert.Equal(FirstByte_PingFrame, buffer[0]); + + // implementation detail: payload is a long counter starting from 1 + Assert.Equal(SecondByte_Client_8bPayload, buffer[1]); + + var payloadBytes = buffer.AsSpan().Slice(Client_FrameHeaderLength, PingPayloadLength); + ApplyMask(payloadBytes, buffer.AsSpan().Slice(Client_FrameHeaderLength - MaskLength, MaskLength)); + long pingCounter = BinaryPrimitives.ReadInt64BigEndian(payloadBytes); + Assert.Equal(i+1, pingCounter); + + // --- sending pong back --- + + buffer[0] = FirstByte_PongFrame; + buffer[1] = SecondByte_Server_8bPayload; + + // sending the same payload back + BinaryPrimitives.WriteInt64BigEndian(buffer.AsSpan().Slice(Server_FrameHeaderLength), pingCounter); + + await serverStream.WriteAsync( + buffer.AsMemory(0, Server_FrameHeaderLength + PingPayloadLength), + cancellationToken); + } + + Assert.Equal(WebSocketState.Open, webSocketClient.State); + Assert.False(clientReadTask.IsCompleted); + + readCts.Cancel(); + + await Assert.ThrowsAsync(() => clientReadTask.AsTask()); + Assert.Equal(WebSocketState.Aborted, webSocketClient.State); + + // Octet i of the transformed data ("transformed-octet-i") is the XOR of + // octet i of the original data ("original-octet-i") with octet at index + // i modulo 4 of the masking key ("masking-key-octet-j"): + // + // j = i MOD 4 + // transformed-octet-i = original-octet-i XOR masking-key-octet-j + // + static void ApplyMask(Span buffer, Span mask) + { + + for (int i = 0; i < buffer.Length; i++) + { + buffer[i] ^= mask[i % MaskLength]; + } + } + } + + [OuterLoop("Uses Task.Delay")] + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task WebSocket_NoPongResponseWithinTimeout_Aborted(bool outstandingUserRead) + { + var cancellationToken = new CancellationTokenSource(TestTimeout).Token; + + using WebSocketTestStream testStream = new(); + Stream localEndpointStream = testStream; + Stream remoteEndpointStream = testStream.Remote; + + using WebSocket webSocket = WebSocket.CreateFromStream(localEndpointStream, new WebSocketCreationOptions + { + IsServer = true, + KeepAliveInterval = KeepAliveInterval, + KeepAliveTimeout = KeepAliveTimeout + }); + + Debug.Assert(webSocket.State == WebSocketState.Open); + + ValueTask userReadTask = default; + if (outstandingUserRead) + { + userReadTask = webSocket.ReceiveAsync(Memory.Empty, cancellationToken); + } + + await Task.Delay(2 * (KeepAliveTimeout + KeepAliveInterval), cancellationToken); + + Assert.Equal(WebSocketState.Aborted, webSocket.State); + + Exception readException = outstandingUserRead + ? await Assert.ThrowsAsync(() => userReadTask.AsTask()) + : await Assert.ThrowsAsync(() => webSocket.ReceiveAsync(Memory.Empty, cancellationToken).AsTask()); + + var wse = Assert.IsType(readException.InnerException); + Assert.Equal(WebSocketError.Faulted, wse.WebSocketErrorCode); + Assert.Equal(SR.net_Websockets_KeepAlivePingTimeout, wse.Message); + } + } +} diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketTests.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketTests.cs index 4cf9c279ba5f3..73e84998a9419 100644 --- a/src/libraries/System.Net.WebSockets/tests/WebSocketTests.cs +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketTests.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.IO; +using System.Threading; using System.Threading.Tasks; using Xunit; @@ -200,6 +201,26 @@ public async Task ReceiveAsync_WhenDisposedInParallel_DoesNotGetStuck() await Assert.ThrowsAsync(() => r3.WaitAsync(TimeSpan.FromSeconds(1))); } + [Fact] + public async Task ReceiveAsync_AfterCancellationDoReceiveAsync_ThrowsWebSocketException() + { + using var stream = new WebSocketTestStream(); + using var websocket = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions()); + var recvBuffer = new byte[100]; + var segment = new ArraySegment(recvBuffer); + var cts = new CancellationTokenSource(); + + Task receive = websocket.ReceiveAsync(segment, cts.Token); + cts.Cancel(); + await Assert.ThrowsAnyAsync(() => receive); + + WebSocketException ex = await Assert.ThrowsAsync(() => + websocket.ReceiveAsync(segment, CancellationToken.None)); + Assert.Equal( + SR.Format(SR.net_WebSockets_InvalidState, "Aborted", "Open, CloseSent"), + ex.Message); + } + public abstract class ExposeProtectedWebSocket : WebSocket { public static new bool IsStateTerminal(WebSocketState state) =>