diff --git a/src/libraries/Common/src/System/Net/WebSockets/WebSocketDefaults.cs b/src/libraries/Common/src/System/Net/WebSockets/WebSocketDefaults.cs
new file mode 100644
index 0000000000000..8f2c768fd4bc9
--- /dev/null
+++ b/src/libraries/Common/src/System/Net/WebSockets/WebSocketDefaults.cs
@@ -0,0 +1,19 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Threading;
+
+namespace System.Net.WebSockets
+{
+ ///
+ /// Central repository for default values used in WebSocket settings. Not all settings are relevant
+ /// to or configurable by all WebSocket implementations.
+ ///
+ internal static partial class WebSocketDefaults
+ {
+ public static readonly TimeSpan DefaultKeepAliveInterval = TimeSpan.Zero;
+ public static readonly TimeSpan DefaultClientKeepAliveInterval = TimeSpan.FromSeconds(30);
+
+ public static readonly TimeSpan DefaultKeepAliveTimeout = Timeout.InfiniteTimeSpan;
+ }
+}
diff --git a/src/libraries/Common/src/System/Net/WebSockets/WebSocketValidate.cs b/src/libraries/Common/src/System/Net/WebSockets/WebSocketValidate.cs
index e087677be4608..c11524bdeef9b 100644
--- a/src/libraries/Common/src/System/Net/WebSockets/WebSocketValidate.cs
+++ b/src/libraries/Common/src/System/Net/WebSockets/WebSocketValidate.cs
@@ -38,27 +38,26 @@ internal static partial class WebSocketValidate
SearchValues.Create("!#$%&'*+-.0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ^_`abcdefghijklmnopqrstuvwxyz|~");
internal static void ThrowIfInvalidState(WebSocketState currentState, bool isDisposed, WebSocketState[] validStates)
+ => ThrowIfInvalidState(currentState, isDisposed, innerException: null, validStates ?? []);
+
+ internal static void ThrowIfInvalidState(WebSocketState currentState, bool isDisposed, Exception? innerException, WebSocketState[]? validStates = null)
{
- string validStatesText = string.Empty;
+ if (validStates is not null && Array.IndexOf(validStates, currentState) == -1)
+ {
+ string invalidStateMessage = SR.Format(
+ SR.net_WebSockets_InvalidState, currentState, string.Join(", ", validStates));
+
+ throw new WebSocketException(WebSocketError.InvalidState, invalidStateMessage, innerException);
+ }
- if (validStates != null && validStates.Length > 0)
+ if (innerException is not null)
{
- foreach (WebSocketState validState in validStates)
- {
- if (currentState == validState)
- {
- // Ordering is important to maintain .NET 4.5 WebSocket implementation exception behavior.
- ObjectDisposedException.ThrowIf(isDisposed, typeof(WebSocket));
- return;
- }
- }
-
- validStatesText = string.Join(", ", validStates);
+ Debug.Assert(currentState == WebSocketState.Aborted);
+ throw new OperationCanceledException(nameof(WebSocketState.Aborted), innerException);
}
- throw new WebSocketException(
- WebSocketError.InvalidState,
- SR.Format(SR.net_WebSockets_InvalidState, currentState, validStatesText));
+ // Ordering is important to maintain .NET 4.5 WebSocket implementation exception behavior.
+ ObjectDisposedException.ThrowIf(isDisposed, typeof(WebSocket));
}
internal static void ValidateSubprotocol(string subProtocol)
diff --git a/src/libraries/Common/tests/TestUtilities/TestEventListener.cs b/src/libraries/Common/tests/TestUtilities/TestEventListener.cs
index 8cb70ee3cbd8c..b83a165906164 100644
--- a/src/libraries/Common/tests/TestUtilities/TestEventListener.cs
+++ b/src/libraries/Common/tests/TestUtilities/TestEventListener.cs
@@ -3,6 +3,7 @@
using System;
using System.Collections.Generic;
+using System.Diagnostics;
using System.Diagnostics.Tracing;
using System.IO;
using System.Text;
@@ -31,6 +32,7 @@ public sealed class TestEventListener : EventListener
"Private.InternalDiagnostics.System.Net.Sockets",
"Private.InternalDiagnostics.System.Net.Security",
"Private.InternalDiagnostics.System.Net.Quic",
+ "Private.InternalDiagnostics.System.Net.WebSockets",
"Private.InternalDiagnostics.System.Net.Http.WinHttpHandler",
"Private.InternalDiagnostics.System.Net.HttpListener",
"Private.InternalDiagnostics.System.Net.Mail",
@@ -41,19 +43,24 @@ public sealed class TestEventListener : EventListener
private readonly Action _writeFunc;
private readonly HashSet _sourceNames;
+ private readonly bool _enableActivityId;
// Until https://github.com/dotnet/runtime/issues/63979 is solved.
private List _eventSources = new List();
public TestEventListener(TextWriter output, params string[] sourceNames)
- : this(str => output.WriteLine(str), sourceNames)
+ : this(output.WriteLine, sourceNames)
{ }
public TestEventListener(ITestOutputHelper output, params string[] sourceNames)
- : this(str => output.WriteLine(str), sourceNames)
+ : this(output.WriteLine, sourceNames)
{ }
public TestEventListener(Action writeFunc, params string[] sourceNames)
+ : this(writeFunc, enableActivityId: false, sourceNames)
+ { }
+
+ public TestEventListener(Action writeFunc, bool enableActivityId, params string[] sourceNames)
{
List eventSources = _eventSources;
@@ -61,16 +68,14 @@ public TestEventListener(Action writeFunc, params string[] sourceNames)
{
_writeFunc = writeFunc;
_sourceNames = new HashSet(sourceNames);
+ _enableActivityId = enableActivityId;
_eventSources = null;
}
// eventSources were populated in the base ctor and are now owned by this thread, enable them now.
foreach (EventSource eventSource in eventSources)
{
- if (_sourceNames.Contains(eventSource.Name))
- {
- EnableEvents(eventSource, EventLevel.LogAlways);
- }
+ EnableEventSource(eventSource);
}
}
@@ -90,20 +95,42 @@ protected override void OnEventSourceCreated(EventSource eventSource)
}
// Second pass called after our ctor, allow logging for specified source names.
+ EnableEventSource(eventSource);
+ }
+
+ private void EnableEventSource(EventSource eventSource)
+ {
if (_sourceNames.Contains(eventSource.Name))
{
EnableEvents(eventSource, EventLevel.LogAlways);
}
+ else if (_enableActivityId && eventSource.Name == "System.Threading.Tasks.TplEventSource")
+ {
+ EnableEvents(eventSource, EventLevel.LogAlways, (EventKeywords)0x80 /* TasksFlowActivityIds */);
+ }
}
protected override void OnEventWritten(EventWrittenEventArgs eventData)
{
- StringBuilder sb = new StringBuilder().
+ StringBuilder sb = new StringBuilder();
+
#if NET || NETSTANDARD2_1_OR_GREATER
- Append($"{eventData.TimeStamp:HH:mm:ss.fffffff}[{eventData.EventName}] ");
-#else
- Append($"[{eventData.EventName}] ");
+ sb.Append($"{eventData.TimeStamp:HH:mm:ss.fffffff}");
+ if (_enableActivityId)
+ {
+ if (eventData.ActivityId != Guid.Empty)
+ {
+ string activityId = ActivityHelpers.ActivityPathString(eventData.ActivityId);
+ sb.Append($" {activityId} {new string('-', activityId.Length / 2 - 1 )} ");
+ }
+ else
+ {
+ sb.Append(" / ");
+ }
+ }
#endif
+ sb.Append($"[{eventData.EventName}] ");
+
for (int i = 0; i < eventData.Payload?.Count; i++)
{
if (i > 0)
@@ -116,4 +143,103 @@ protected override void OnEventWritten(EventWrittenEventArgs eventData)
}
catch { }
}
+
+ // From https://gist.github.com/MihaZupan/cc63ee68b4146892f2e5b640ed57bc09
+ private static class ActivityHelpers
+ {
+ private enum NumberListCodes : byte
+ {
+ End = 0x0,
+ LastImmediateValue = 0xA,
+ PrefixCode = 0xB,
+ MultiByte1 = 0xC,
+ }
+
+ public static unsafe bool IsActivityPath(Guid guid)
+ {
+ uint* uintPtr = (uint*)&guid;
+ uint sum = uintPtr[0] + uintPtr[1] + uintPtr[2] + 0x599D99AD;
+ return ((sum & 0xFFF00000) == (uintPtr[3] & 0xFFF00000));
+ }
+
+ public static unsafe string ActivityPathString(Guid guid)
+ => IsActivityPath(guid) ? CreateActivityPathString(guid) : guid.ToString();
+
+ internal static unsafe string CreateActivityPathString(Guid guid)
+ {
+ Debug.Assert(IsActivityPath(guid));
+
+ StringBuilder sb = new StringBuilder();
+
+ byte* bytePtr = (byte*)&guid;
+ byte* endPtr = bytePtr + 12;
+ char separator = '/';
+ while (bytePtr < endPtr)
+ {
+ uint nibble = (uint)(*bytePtr >> 4);
+ bool secondNibble = false;
+ NextNibble:
+ if (nibble == (uint)NumberListCodes.End)
+ {
+ break;
+ }
+ if (nibble <= (uint)NumberListCodes.LastImmediateValue)
+ {
+ sb.Append('/').Append(nibble);
+ if (!secondNibble)
+ {
+ nibble = (uint)(*bytePtr & 0xF);
+ secondNibble = true;
+ goto NextNibble;
+ }
+ bytePtr++;
+ continue;
+ }
+ else if (nibble == (uint)NumberListCodes.PrefixCode)
+ {
+ if (!secondNibble)
+ {
+ nibble = (uint)(*bytePtr & 0xF);
+ }
+ else
+ {
+ bytePtr++;
+ if (endPtr <= bytePtr)
+ {
+ break;
+ }
+ nibble = (uint)(*bytePtr >> 4);
+ }
+ if (nibble < (uint)NumberListCodes.MultiByte1)
+ {
+ return guid.ToString();
+ }
+ separator = '$';
+ }
+ Debug.Assert((uint)NumberListCodes.MultiByte1 <= nibble);
+ uint numBytes = nibble - (uint)NumberListCodes.MultiByte1;
+ uint value = 0;
+ if (!secondNibble)
+ {
+ value = (uint)(*bytePtr & 0xF);
+ }
+ bytePtr++;
+ numBytes++;
+ if (endPtr < bytePtr + numBytes)
+ {
+ break;
+ }
+ for (int i = (int)numBytes - 1; 0 <= i; --i)
+ {
+ value = (value << 8) + bytePtr[i];
+ }
+ sb.Append(separator).Append(value);
+
+ bytePtr += numBytes;
+ }
+
+ sb.Append('/');
+ return sb.ToString();
+ }
+ }
}
diff --git a/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs b/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs
index edb4eb043bcb0..88bf41f23fbce 100644
--- a/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs
+++ b/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs
@@ -42,6 +42,8 @@ internal ClientWebSocketOptions() { }
[System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")]
public System.TimeSpan KeepAliveInterval { get { throw null; } set { } }
[System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")]
+ public System.TimeSpan KeepAliveTimeout { get { throw null; } set { } }
+ [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")]
public System.Net.WebSockets.WebSocketDeflateOptions? DangerousDeflateOptions { get { throw null; } set { } }
[System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")]
public System.Net.IWebProxy? Proxy { get { throw null; } set { } }
diff --git a/src/libraries/System.Net.WebSockets.Client/src/System.Net.WebSockets.Client.csproj b/src/libraries/System.Net.WebSockets.Client/src/System.Net.WebSockets.Client.csproj
index a309737d6917d..8265edd7e9369 100644
--- a/src/libraries/System.Net.WebSockets.Client/src/System.Net.WebSockets.Client.csproj
+++ b/src/libraries/System.Net.WebSockets.Client/src/System.Net.WebSockets.Client.csproj
@@ -27,6 +27,7 @@
+
@@ -47,6 +48,7 @@
+
@@ -57,7 +59,6 @@
-
diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/BrowserWebSockets/ClientWebSocketOptions.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/BrowserWebSockets/ClientWebSocketOptions.cs
index 59096fc864d3a..aa8164d1099c2 100644
--- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/BrowserWebSockets/ClientWebSocketOptions.cs
+++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/BrowserWebSockets/ClientWebSocketOptions.cs
@@ -122,6 +122,13 @@ public TimeSpan KeepAliveInterval
set => throw new PlatformNotSupportedException();
}
+ [UnsupportedOSPlatform("browser")]
+ public TimeSpan KeepAliveTimeout
+ {
+ get => throw new PlatformNotSupportedException();
+ set => throw new PlatformNotSupportedException();
+ }
+
[UnsupportedOSPlatform("browser")]
public WebSocketDeflateOptions? DangerousDeflateOptions
{
diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs
index f78882f8b005d..dc7155e11cf58 100644
--- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs
+++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs
@@ -14,7 +14,8 @@ namespace System.Net.WebSockets
public sealed class ClientWebSocketOptions
{
private bool _isReadOnly; // After ConnectAsync is called the options cannot be modified.
- private TimeSpan _keepAliveInterval = WebSocket.DefaultKeepAliveInterval;
+ private TimeSpan _keepAliveInterval = WebSocketDefaults.DefaultClientKeepAliveInterval;
+ private TimeSpan _keepAliveTimeout = WebSocketDefaults.DefaultKeepAliveTimeout;
private bool _useDefaultCredentials;
private ICredentials? _credentials;
private IWebProxy? _proxy;
@@ -171,6 +172,12 @@ public void AddSubProtocol(string subProtocol)
subprotocols.Add(subProtocol);
}
+ ///
+ /// The keep-alive interval to use, or or to disable keep-alives.
+ /// If is set, then PING messages are sent and peer's PONG responses are expected, otherwise,
+ /// unsolicited PONG messages are used as a keep-alive heartbeat.
+ /// The default is (typically 30 seconds).
+ ///
[UnsupportedOSPlatform("browser")]
public TimeSpan KeepAliveInterval
{
@@ -188,6 +195,28 @@ public TimeSpan KeepAliveInterval
}
}
+ ///
+ /// The timeout to use when waiting for the peer's PONG in response to us sending a PING; or or
+ /// to disable waiting for peer's response, and use an unsolicited PONG as a Keep-Alive heartbeat instead.
+ /// The default is .
+ ///
+ [UnsupportedOSPlatform("browser")]
+ public TimeSpan KeepAliveTimeout
+ {
+ get => _keepAliveTimeout;
+ set
+ {
+ ThrowIfReadOnly();
+ if (value != Timeout.InfiniteTimeSpan && value < TimeSpan.Zero)
+ {
+ throw new ArgumentOutOfRangeException(nameof(value), value,
+ SR.Format(SR.net_WebSockets_ArgumentOutOfRange_TooSmall,
+ Timeout.InfiniteTimeSpan.ToString()));
+ }
+ _keepAliveTimeout = value;
+ }
+ }
+
///
/// Gets or sets the options for the per-message-deflate extension.
/// When present, the options are sent to the server during the handshake phase. If the server
diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs
index 3301bfead64c7..96e091663ff79 100644
--- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs
+++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs
@@ -198,6 +198,7 @@ public async Task ConnectAsync(Uri uri, HttpMessageInvoker? invoker, Cancellatio
IsServer = false,
SubProtocol = subprotocol,
KeepAliveInterval = options.KeepAliveInterval,
+ KeepAliveTimeout = options.KeepAliveTimeout,
DangerousDeflateOptions = negotiatedDeflateOptions
});
_negotiatedDeflateOptions = negotiatedDeflateOptions;
diff --git a/src/libraries/System.Net.WebSockets.Client/tests/AbortTest.Loopback.cs b/src/libraries/System.Net.WebSockets.Client/tests/AbortTest.Loopback.cs
index 0aa83697a9de7..8d0a89b320d61 100644
--- a/src/libraries/System.Net.WebSockets.Client/tests/AbortTest.Loopback.cs
+++ b/src/libraries/System.Net.WebSockets.Client/tests/AbortTest.Loopback.cs
@@ -2,6 +2,9 @@
// The .NET Foundation licenses this file to you under the MIT license.
using System.Collections.Generic;
+using System.IO;
+using System.Net.Sockets;
+using System.Net.Test.Common;
using System.Threading;
using System.Threading.Tasks;
using Xunit;
@@ -17,6 +20,8 @@ public AbortTest_Loopback(ITestOutputHelper output) : base(output) { }
protected virtual Version HttpVersion => Net.HttpVersion.Version11;
+ public static object[][] AbortClient_MemberData = ToMemberData(Enum.GetValues(), UseSsl_Values, /* verifySendReceive */ Bool_Values);
+
[Theory]
[MemberData(nameof(AbortClient_MemberData))]
public Task AbortClient_ServerGetsCorrectException(AbortType abortType, bool useSsl, bool verifySendReceive)
@@ -64,6 +69,8 @@ public Task AbortClient_ServerGetsCorrectException(AbortType abortType, bool use
timeoutCts.Token);
}
+ public static object[][] ServerPrematureEos_MemberData = ToMemberData(Enum.GetValues(), UseSsl_Values);
+
[Theory]
[MemberData(nameof(ServerPrematureEos_MemberData))]
public Task ServerPrematureEos_ClientGetsCorrectException(ServerEosType serverEosType, bool useSsl)
@@ -146,34 +153,6 @@ await SendServerResponseAndEosAsync(
protected virtual Task SendServerResponseAndEosAsync(WebSocketRequestData requestData, ServerEosType serverEosType, Func serverFunc, CancellationToken cancellationToken)
=> WebSocketHandshakeHelper.SendHttp11ServerResponseAndEosAsync(requestData, serverFunc, cancellationToken); // override for HTTP/2
- private static readonly bool[] Bool_Values = new[] { false, true };
- private static readonly bool[] UseSsl_Values = PlatformDetection.SupportsAlpn ? Bool_Values : new[] { false };
-
- public static IEnumerable
@@ -57,6 +65,7 @@
+
diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/AsyncMutex.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/AsyncMutex.cs
index 4191466dd4efa..abf7e5e56a276 100644
--- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/AsyncMutex.cs
+++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/AsyncMutex.cs
@@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
using System.Diagnostics;
+using System.Net;
using System.Threading.Tasks;
namespace System.Threading
@@ -65,6 +66,8 @@ public Task EnterAsync(CancellationToken cancellationToken)
Task Contended(CancellationToken cancellationToken)
{
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.MutexContended(this, _gate);
+
var w = new Waiter(this);
// We need to register for cancellation before storing the waiter into the list.
@@ -185,6 +188,8 @@ public void Exit()
void Contended()
{
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.MutexContended(this, _gate);
+
Waiter? w;
lock (SyncObj)
{
diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.KeepAlive.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.KeepAlive.cs
new file mode 100644
index 0000000000000..c9ff393cb7180
--- /dev/null
+++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.KeepAlive.cs
@@ -0,0 +1,306 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Buffers;
+using System.Buffers.Binary;
+using System.Diagnostics;
+using System.Runtime.ExceptionServices;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace System.Net.WebSockets
+{
+ internal sealed partial class ManagedWebSocket : WebSocket
+ {
+ private bool IsUnsolicitedPongKeepAlive => _keepAlivePingState is null;
+ private static bool IsValidSendState(WebSocketState state) => Array.IndexOf(s_validSendStates, state) != -1;
+ private static bool IsValidReceiveState(WebSocketState state) => Array.IndexOf(s_validReceiveStates, state) != -1;
+
+ private void HeartBeat()
+ {
+ if (IsUnsolicitedPongKeepAlive)
+ {
+ UnsolicitedPongHeartBeat();
+ }
+ else
+ {
+ KeepAlivePingHeartBeat();
+ }
+ }
+
+ private void UnsolicitedPongHeartBeat()
+ {
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this);
+
+ Observe(
+ TrySendKeepAliveFrameAsync(MessageOpcode.Pong));
+ }
+
+ private ValueTask TrySendKeepAliveFrameAsync(MessageOpcode opcode, ReadOnlyMemory? payload = null)
+ {
+ Debug.Assert(opcode is MessageOpcode.Pong || !IsUnsolicitedPongKeepAlive && opcode is MessageOpcode.Ping);
+
+ if (!IsValidSendState(_state))
+ {
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"Cannot send keep-alive frame in {nameof(_state)}={_state}");
+
+ // we can't send any frames, but no need to throw as we are not observing errors anyway
+ return ValueTask.CompletedTask;
+ }
+
+ payload ??= ReadOnlyMemory.Empty;
+
+ return SendFrameAsync(opcode, endOfMessage: true, disableCompression: true, payload.Value, CancellationToken.None);
+ }
+
+ private void KeepAlivePingHeartBeat()
+ {
+ Debug.Assert(_keepAlivePingState != null);
+
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this);
+
+ bool shouldSendPing = false;
+ long pingPayload = -1;
+
+ try
+ {
+ lock (StateUpdateLock)
+ {
+ if (_keepAlivePingState.Exception is not null)
+ {
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"KeepAlive already faulted, skipping... (exception: {_keepAlivePingState.Exception.Message})");
+ return;
+ }
+
+ long now = Environment.TickCount64;
+
+ if (_keepAlivePingState.PingSent)
+ {
+ if (Environment.TickCount64 > _keepAlivePingState.PingTimeoutTimestamp)
+ {
+ if (NetEventSource.Log.IsEnabled())
+ {
+ NetEventSource.Trace(this, $"Keep-alive ping timed out after {_keepAlivePingState.TimeoutMs}ms. Expected pong with payload {_keepAlivePingState.PingPayload}");
+ }
+
+ Exception exc = ExceptionDispatchInfo.SetCurrentStackTrace(
+ new WebSocketException(WebSocketError.Faulted, SR.net_Websockets_KeepAlivePingTimeout));
+
+ _keepAlivePingState.OnKeepAliveFaultedCore(exc); // we are holding the lock
+ return;
+ }
+ }
+ else
+ {
+ if (Environment.TickCount64 > _keepAlivePingState.NextPingRequestTimestamp)
+ {
+ _keepAlivePingState.OnNextPingRequestCore(); // we are holding the lock
+ shouldSendPing = true;
+ pingPayload = _keepAlivePingState.PingPayload;
+ }
+ }
+ }
+
+ if (shouldSendPing)
+ {
+ Observe(
+ SendPingAsync(pingPayload));
+ }
+ }
+ catch (Exception e)
+ {
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceException(this, e);
+
+ _keepAlivePingState.OnKeepAliveFaulted(e);
+ }
+ }
+
+ private async ValueTask SendPingAsync(long pingPayload)
+ {
+ Debug.Assert(_keepAlivePingState != null);
+
+ byte[] pingPayloadBuffer = ArrayPool.Shared.Rent(sizeof(long));
+ BinaryPrimitives.WriteInt64BigEndian(pingPayloadBuffer, pingPayload);
+ try
+ {
+ await TrySendKeepAliveFrameAsync(MessageOpcode.Ping, pingPayloadBuffer.AsMemory(0, sizeof(long))).ConfigureAwait(false);
+
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.KeepAlivePingSent(this, pingPayload);
+ }
+ finally
+ {
+ ArrayPool.Shared.Return(pingPayloadBuffer);
+ }
+ }
+
+ // "Observe" either a ValueTask result, or any exception, ignoring it
+ // to prevent the unobserved exception event from being raised.
+ private void Observe(ValueTask t)
+ {
+ if (t.IsCompletedSuccessfully)
+ {
+ t.GetAwaiter().GetResult();
+ }
+ else
+ {
+ Observe(t.AsTask());
+ }
+ }
+
+ // "Observe" any exception, ignoring it to prevent the unobserved task
+ // exception event from being raised.
+ private void Observe(Task t)
+ {
+ if (t.IsCompleted)
+ {
+ if (t.IsFaulted)
+ {
+ LogFaulted(t, this);
+ }
+ }
+ else
+ {
+ t.ContinueWith(
+ LogFaulted,
+ this,
+ CancellationToken.None,
+ TaskContinuationOptions.OnlyOnFaulted | TaskContinuationOptions.ExecuteSynchronously,
+ TaskScheduler.Default);
+ }
+
+ static void LogFaulted(Task task, object? thisObj)
+ {
+ Debug.Assert(task.IsFaulted);
+
+ // accessing exception to observe it regardless of whether the tracing is enabled
+ Exception e = task.Exception!.InnerException!;
+
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceException(thisObj, e);
+ }
+ }
+
+ private sealed class KeepAlivePingState
+ {
+ internal const int PingPayloadSize = sizeof(long);
+ private const int MinIntervalMs = 1;
+
+ private readonly ManagedWebSocket _parent;
+ private object StateUpdateLock => _parent.StateUpdateLock;
+
+ internal int DelayMs { get; }
+ internal int TimeoutMs { get; }
+ internal int HeartBeatIntervalMs => Math.Max(Math.Min(DelayMs, TimeoutMs) / 4, MinIntervalMs);
+
+ internal long PingPayload { get; private set; }
+ internal bool PingSent { get; private set; }
+ internal long PingTimeoutTimestamp { get; private set; }
+ internal long NextPingRequestTimestamp { get; private set; }
+ internal Exception? Exception { get; private set; }
+
+ public KeepAlivePingState(TimeSpan keepAliveInterval, TimeSpan keepAliveTimeout, ManagedWebSocket parent)
+ {
+ DelayMs = TimeSpanToMs(keepAliveInterval);
+ TimeoutMs = TimeSpanToMs(keepAliveTimeout);
+ NextPingRequestTimestamp = Environment.TickCount64 + DelayMs;
+ PingTimeoutTimestamp = Timeout.Infinite;
+ _parent = parent;
+
+ static int TimeSpanToMs(TimeSpan value) => (int)Math.Clamp((long)value.TotalMilliseconds, MinIntervalMs, int.MaxValue);
+ }
+
+ internal void OnDataReceived()
+ {
+ lock (StateUpdateLock)
+ {
+ NextPingRequestTimestamp = Environment.TickCount64 + DelayMs;
+ }
+ }
+
+ internal void OnPongResponseReceived(long pongPayload)
+ {
+ lock (StateUpdateLock)
+ {
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"pongPayload={pongPayload}");
+
+ if (!PingSent)
+ {
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"Not waiting for Pong. Skipping.");
+ return;
+ }
+
+ if (pongPayload == PingPayload)
+ {
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.PongResponseReceived(this, pongPayload);
+
+ PingTimeoutTimestamp = long.MaxValue;
+ PingSent = false;
+ }
+ else
+ {
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"Expected payload {PingPayload}. Skipping.");
+ }
+ }
+ }
+
+ internal void OnNextPingRequestCore()
+ {
+ Debug.Assert(Monitor.IsEntered(StateUpdateLock));
+
+ PingSent = true;
+ PingTimeoutTimestamp = Environment.TickCount64 + TimeoutMs;
+ ++PingPayload;
+ }
+
+ internal void OnKeepAliveFaulted(Exception exc)
+ {
+ lock (StateUpdateLock)
+ {
+ OnKeepAliveFaultedCore(exc);
+ }
+ }
+
+ internal void OnKeepAliveFaultedCore(Exception exc)
+ {
+ Debug.Assert(Monitor.IsEntered(StateUpdateLock));
+
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceErrorMsg(this, exc);
+
+ if (_parent._disposed)
+ {
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"WebSocket already disposed, skipping...");
+ return;
+ }
+
+ if (_parent.State is WebSocketState.Closed)
+ {
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"WebSocket is already closed, skipping...");
+ // We've transferred into the Closed state, but didn't dispose yet
+ // This can happen in e.g. HandleReceivedCloseAsync where we first change the state
+ // but then still do some operations with the stream.
+ // No need to do anything as we've already completed the Closing Handshake
+ return;
+ }
+
+ if (_parent.State is WebSocketState.Aborted)
+ {
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"WebSocket is already aborted, skipping...");
+ // Something else already aborted the websocket, but didn't dispose it (yet?)?
+ // This can happen either
+ // (1) in the Abort() method, e.g. on cancellation, if we interjected between the state
+ // change and the Dispose() call; or
+ // (2) in the catch block of ReceiveAsyncPrivate (which doesn't do the dispose after??).
+ // This most possibly happens if we've hit a premature EOF from the server.
+ // Websocket is not usable in the Aborted state anyway, so let's free the resources while we're at it?
+ _parent.Dispose();
+ return;
+ }
+
+ // we were the ones who triggered the abort, let's save the exception
+ Exception = exc;
+
+ _parent.OnAbortedCore();
+ _parent.DisposeCore();
+ }
+ }
+ }
+}
diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs
index 3ee864e71cc8c..8a26a4c29e2eb 100644
--- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs
+++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs
@@ -8,6 +8,7 @@
using System.Net.WebSockets.Compression;
using System.Numerics;
using System.Runtime.CompilerServices;
+using System.Runtime.ExceptionServices;
using System.Runtime.InteropServices;
using System.Security.Cryptography;
using System.Text;
@@ -137,23 +138,34 @@ internal sealed partial class ManagedWebSocket : WebSocket
private readonly WebSocketInflater? _inflater;
private readonly WebSocketDeflater? _deflater;
+ private readonly KeepAlivePingState? _keepAlivePingState;
+
/// Initializes the websocket.
/// The connected Stream.
/// true if this is the server-side of the connection; false if this is the client-side of the connection.
/// The agreed upon subprotocol for the connection.
/// The interval to use for keep-alive pings.
- internal ManagedWebSocket(Stream stream, bool isServer, string? subprotocol, TimeSpan keepAliveInterval)
+ /// The timeout to use when waiting for keep-alive pong response.
+ internal ManagedWebSocket(Stream stream, bool isServer, string? subprotocol, TimeSpan keepAliveInterval, TimeSpan keepAliveTimeout)
{
Debug.Assert(StateUpdateLock != null, $"Expected {nameof(StateUpdateLock)} to be non-null");
Debug.Assert(stream != null, $"Expected non-null {nameof(stream)}");
Debug.Assert(stream.CanRead, $"Expected readable {nameof(stream)}");
Debug.Assert(stream.CanWrite, $"Expected writeable {nameof(stream)}");
Debug.Assert(keepAliveInterval == Timeout.InfiniteTimeSpan || keepAliveInterval >= TimeSpan.Zero, $"Invalid {nameof(keepAliveInterval)}: {keepAliveInterval}");
+ Debug.Assert(keepAliveTimeout == Timeout.InfiniteTimeSpan || keepAliveTimeout >= TimeSpan.Zero, $"Invalid {nameof(keepAliveTimeout)}: {keepAliveTimeout}");
_stream = stream;
_isServer = isServer;
_subprotocol = subprotocol;
+ if (NetEventSource.Log.IsEnabled())
+ {
+ NetEventSource.Associate(this, stream);
+ NetEventSource.Associate(this, _sendMutex);
+ NetEventSource.Associate(this, _receiveMutex);
+ }
+
// Create a buffer just large enough to handle received packet headers (at most 14 bytes) and
// control payloads (at most 125 bytes). Message payloads are read directly into the buffer
// supplied to ReceiveAsync.
@@ -165,14 +177,33 @@ internal ManagedWebSocket(Stream stream, bool isServer, string? subprotocol, Tim
// that could keep the web socket rooted in erroneous cases.
if (keepAliveInterval > TimeSpan.Zero)
{
+ long heartBeatIntervalMs = (long)keepAliveInterval.TotalMilliseconds;
+ if (keepAliveTimeout > TimeSpan.Zero)
+ {
+ _keepAlivePingState = new KeepAlivePingState(keepAliveInterval, keepAliveTimeout, this);
+ heartBeatIntervalMs = _keepAlivePingState.HeartBeatIntervalMs;
+
+ if (NetEventSource.Log.IsEnabled())
+ {
+ NetEventSource.Associate(this, _keepAlivePingState);
+
+ NetEventSource.Trace(this,
+ $"Enabling Ping/Pong Keep-Alive strategy: ping delay={_keepAlivePingState.DelayMs}ms, timeout={_keepAlivePingState.TimeoutMs}ms, heartbeat={heartBeatIntervalMs}ms");
+ }
+ }
+ else if (NetEventSource.Log.IsEnabled())
+ {
+ NetEventSource.Trace(this, $"Enabling Unsolicited Pong Keep-Alive strategy: heartbeat={heartBeatIntervalMs}ms");
+ }
+
_keepAliveTimer = new Timer(static s =>
{
var wr = (WeakReference)s!;
if (wr.TryGetTarget(out ManagedWebSocket? thisRef))
{
- thisRef.SendKeepAliveFrameAsync();
+ thisRef.HeartBeat();
}
- }, new WeakReference(this), keepAliveInterval, keepAliveInterval);
+ }, new WeakReference(this), heartBeatIntervalMs, heartBeatIntervalMs);
}
}
@@ -180,7 +211,7 @@ internal ManagedWebSocket(Stream stream, bool isServer, string? subprotocol, Tim
/// The connected Stream.
/// The options with which the websocket must be created.
internal ManagedWebSocket(Stream stream, WebSocketCreationOptions options)
- : this(stream, options.IsServer, options.SubProtocol, options.KeepAliveInterval)
+ : this(stream, options.IsServer, options.SubProtocol, options.KeepAliveInterval, options.KeepAliveTimeout)
{
var deflateOptions = options.DangerousDeflateOptions;
@@ -201,6 +232,8 @@ internal ManagedWebSocket(Stream stream, WebSocketCreationOptions options)
public override void Dispose()
{
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this);
+
lock (StateUpdateLock)
{
DisposeCore();
@@ -210,17 +243,23 @@ public override void Dispose()
private void DisposeCore()
{
Debug.Assert(Monitor.IsEntered(StateUpdateLock), $"Expected {nameof(StateUpdateLock)} to be held");
+
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"{nameof(_disposed)}={_disposed}");
+
if (!_disposed)
{
_disposed = true;
_keepAliveTimer?.Dispose();
_stream.Dispose();
- if (_state < WebSocketState.Aborted)
+ WebSocketState state = _state;
+ if (state < WebSocketState.Aborted)
{
_state = WebSocketState.Closed;
}
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"State transition from {state} to {_state}");
+
DisposeSafe(_inflater, _receiveMutex);
DisposeSafe(_deflater, _sendMutex);
}
@@ -234,15 +273,23 @@ private static void DisposeSafe(IDisposable? resource, AsyncMutex mutex)
if (lockTask.IsCompleted)
{
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.MutexEntered(mutex);
+
resource.Dispose();
mutex.Exit();
+
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.MutexExited(mutex);
}
else
{
lockTask.GetAwaiter().UnsafeOnCompleted(() =>
{
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.MutexEntered(mutex);
+
resource.Dispose();
mutex.Exit();
+
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.MutexExited(mutex);
});
}
}
@@ -258,6 +305,8 @@ private static void DisposeSafe(IDisposable? resource, AsyncMutex mutex)
public override Task SendAsync(ArraySegment buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken)
{
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this);
+
if (messageType != WebSocketMessageType.Text && messageType != WebSocketMessageType.Binary)
{
throw new ArgumentException(SR.Format(
@@ -276,6 +325,8 @@ public override ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessag
public override ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, WebSocketMessageFlags messageFlags, CancellationToken cancellationToken)
{
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this);
+
if (messageType != WebSocketMessageType.Text && messageType != WebSocketMessageType.Binary)
{
throw new ArgumentException(SR.Format(
@@ -286,10 +337,11 @@ public override ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessag
try
{
- WebSocketValidate.ThrowIfInvalidState(_state, _disposed, s_validSendStates);
+ ThrowIfInvalidState(s_validSendStates);
}
catch (Exception exc)
{
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceException(this, exc);
return ValueTask.FromException(exc);
}
@@ -319,44 +371,53 @@ public override ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessag
public override Task ReceiveAsync(ArraySegment buffer, CancellationToken cancellationToken)
{
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this);
+
WebSocketValidate.ValidateArraySegment(buffer, nameof(buffer));
try
{
- WebSocketValidate.ThrowIfInvalidState(_state, _disposed, s_validReceiveStates);
+ ThrowIfInvalidState(s_validReceiveStates);
return ReceiveAsyncPrivate(buffer, cancellationToken).AsTask();
}
catch (Exception exc)
{
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceException(this, exc);
return Task.FromException(exc);
}
}
public override ValueTask ReceiveAsync(Memory buffer, CancellationToken cancellationToken)
{
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this);
+
try
{
- WebSocketValidate.ThrowIfInvalidState(_state, _disposed, s_validReceiveStates);
+ ThrowIfInvalidState(s_validReceiveStates);
return ReceiveAsyncPrivate(buffer, cancellationToken);
}
catch (Exception exc)
{
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceException(this, exc);
return ValueTask.FromException(exc);
}
}
public override Task CloseAsync(WebSocketCloseStatus closeStatus, string? statusDescription, CancellationToken cancellationToken)
{
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this);
+
WebSocketValidate.ValidateCloseStatus(closeStatus, statusDescription);
try
{
- WebSocketValidate.ThrowIfInvalidState(_state, _disposed, s_validCloseStates);
+ ThrowIfInvalidState(s_validCloseStates);
}
catch (Exception exc)
{
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceException(this, exc);
return Task.FromException(exc);
}
@@ -365,13 +426,17 @@ public override Task CloseAsync(WebSocketCloseStatus closeStatus, string? status
public override Task CloseOutputAsync(WebSocketCloseStatus closeStatus, string? statusDescription, CancellationToken cancellationToken)
{
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this);
+
WebSocketValidate.ValidateCloseStatus(closeStatus, statusDescription);
return CloseOutputAsyncCore(closeStatus, statusDescription, cancellationToken);
}
private async Task CloseOutputAsyncCore(WebSocketCloseStatus closeStatus, string? statusDescription, CancellationToken cancellationToken)
{
- WebSocketValidate.ThrowIfInvalidState(_state, _disposed, s_validCloseOutputStates);
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this);
+
+ ThrowIfInvalidState(s_validCloseOutputStates);
await SendCloseFrameAsync(closeStatus, statusDescription, cancellationToken).ConfigureAwait(false);
@@ -388,22 +453,35 @@ private async Task CloseOutputAsyncCore(WebSocketCloseStatus closeStatus, string
public override void Abort()
{
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this);
+
OnAborted();
Dispose(); // forcibly tear down connection
}
private void OnAborted()
{
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this);
+
lock (StateUpdateLock)
{
- WebSocketState state = _state;
- if (state != WebSocketState.Closed && state != WebSocketState.Aborted)
- {
- _state = state != WebSocketState.None && state != WebSocketState.Connecting ?
- WebSocketState.Aborted :
- WebSocketState.Closed;
- }
+ OnAbortedCore();
+ }
+ }
+
+ private void OnAbortedCore()
+ {
+ Debug.Assert(Monitor.IsEntered(StateUpdateLock), $"Expected {nameof(StateUpdateLock)} to be held");
+
+ WebSocketState state = _state;
+ if (state is not WebSocketState.Closed and not WebSocketState.Aborted)
+ {
+ _state = state is not WebSocketState.None and not WebSocketState.Connecting ?
+ WebSocketState.Aborted :
+ WebSocketState.Closed;
}
+
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"State transition from {state} to {_state}");
}
/// Sends a websocket frame to the network.
@@ -414,6 +492,8 @@ private void OnAborted()
/// The CancellationToken to use to cancel the websocket.
private ValueTask SendFrameAsync(MessageOpcode opcode, bool endOfMessage, bool disableCompression, ReadOnlyMemory payloadBuffer, CancellationToken cancellationToken)
{
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.SendFrameAsyncStarted(this, opcode.ToString(), payloadBuffer.Length);
+
// If a cancelable cancellation token was provided, that would require registering with it, which means more state we have to
// pass around (the CancellationTokenRegistration), so if it is cancelable, just immediately go to the fallback path.
// Similarly, it should be rare that there are multiple outstanding calls to SendFrameAsync, but if there are, again
@@ -433,6 +513,8 @@ private ValueTask SendFrameLockAcquiredNonCancelableAsync(MessageOpcode opcode,
{
Debug.Assert(_sendMutex.IsHeld, $"Caller should hold the {nameof(_sendMutex)}");
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.MutexEntered(_sendMutex);
+
// If we get here, the cancellation token is not cancelable so we don't have to worry about it,
// and we own the semaphore, so we don't need to asynchronously wait for it.
ValueTask writeTask = default;
@@ -468,6 +550,8 @@ private ValueTask SendFrameLockAcquiredNonCancelableAsync(MessageOpcode opcode,
}
catch (Exception exc)
{
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceException(this, exc);
+
return ValueTask.FromException(
exc is OperationCanceledException ? exc :
_state == WebSocketState.Aborted ? CreateOperationCanceledException(exc) :
@@ -479,6 +563,12 @@ private ValueTask SendFrameLockAcquiredNonCancelableAsync(MessageOpcode opcode,
{
ReleaseSendBuffer();
_sendMutex.Exit();
+
+ if (NetEventSource.Log.IsEnabled())
+ {
+ NetEventSource.MutexExited(_sendMutex);
+ NetEventSource.SendFrameAsyncCompleted(this);
+ }
}
}
@@ -495,8 +585,15 @@ private async ValueTask WaitForWriteTaskAsync(ValueTask writeTask, bool shouldFl
await _stream.FlushAsync().ConfigureAwait(false);
}
}
- catch (Exception exc) when (exc is not OperationCanceledException)
+ catch (Exception exc)
{
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceException(this, exc);
+
+ if (exc is OperationCanceledException)
+ {
+ throw;
+ }
+
throw _state == WebSocketState.Aborted ?
CreateOperationCanceledException(exc) :
new WebSocketException(WebSocketError.ConnectionClosedPrematurely, exc);
@@ -505,12 +602,20 @@ private async ValueTask WaitForWriteTaskAsync(ValueTask writeTask, bool shouldFl
{
ReleaseSendBuffer();
_sendMutex.Exit();
+
+ if (NetEventSource.Log.IsEnabled())
+ {
+ NetEventSource.MutexExited(_sendMutex);
+ NetEventSource.SendFrameAsyncCompleted(this);
+ }
}
}
private async ValueTask SendFrameFallbackAsync(MessageOpcode opcode, bool endOfMessage, bool disableCompression, ReadOnlyMemory payloadBuffer, Task lockTask, CancellationToken cancellationToken)
{
await lockTask.ConfigureAwait(false);
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.MutexEntered(_sendMutex);
+
try
{
int sendBytes = WriteFrameToSendBuffer(opcode, endOfMessage, disableCompression, payloadBuffer.Span);
@@ -520,8 +625,15 @@ private async ValueTask SendFrameFallbackAsync(MessageOpcode opcode, bool endOfM
await _stream.FlushAsync(cancellationToken).ConfigureAwait(false);
}
}
- catch (Exception exc) when (exc is not OperationCanceledException)
+ catch (Exception exc)
{
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceException(this, exc);
+
+ if (exc is OperationCanceledException)
+ {
+ throw;
+ }
+
throw _state == WebSocketState.Aborted ?
CreateOperationCanceledException(exc, cancellationToken) :
new WebSocketException(WebSocketError.ConnectionClosedPrematurely, exc);
@@ -530,13 +642,19 @@ private async ValueTask SendFrameFallbackAsync(MessageOpcode opcode, bool endOfM
{
ReleaseSendBuffer();
_sendMutex.Exit();
+
+ if (NetEventSource.Log.IsEnabled())
+ {
+ NetEventSource.MutexExited(_sendMutex);
+ NetEventSource.SendFrameAsyncCompleted(this);
+ }
}
}
/// Writes a frame into the send buffer, which can then be sent over the network.
private int WriteFrameToSendBuffer(MessageOpcode opcode, bool endOfMessage, bool disableCompression, ReadOnlySpan payloadBuffer)
{
- ObjectDisposedException.ThrowIf(_disposed, typeof(WebSocket));
+ ThrowIfDisposed();
if (_deflater is not null && !disableCompression)
{
@@ -585,26 +703,6 @@ private int WriteFrameToSendBuffer(MessageOpcode opcode, bool endOfMessage, bool
return headerLength + payloadLength;
}
- private void SendKeepAliveFrameAsync()
- {
- // This exists purely to keep the connection alive; don't wait for the result, and ignore any failures.
- // The call will handle releasing the lock. We send a pong rather than ping, since it's allowed by
- // the RFC as a unidirectional heartbeat and we're not interested in waiting for a response.
- ValueTask t = SendFrameAsync(MessageOpcode.Pong, endOfMessage: true, disableCompression: true, ReadOnlyMemory.Empty, CancellationToken.None);
- if (t.IsCompletedSuccessfully)
- {
- t.GetAwaiter().GetResult();
- }
- else
- {
- // "Observe" any exception, ignoring it to prevent the unobserved exception event from being raised.
- t.AsTask().ContinueWith(static p => { _ = p.Exception; },
- CancellationToken.None,
- TaskContinuationOptions.OnlyOnFaulted | TaskContinuationOptions.ExecuteSynchronously,
- TaskScheduler.Default);
- }
- }
-
private static int WriteHeader(MessageOpcode opcode, byte[] sendBuffer, ReadOnlySpan payload, bool endOfMessage, bool useMask, bool compressed)
{
// Client header format:
@@ -697,13 +795,19 @@ private async ValueTask ReceiveAsyncPrivate(Memory paylo
// those to be much less frequent (e.g. we should only get one close per websocket), and thus we can afford to pay
// a bit more for readability and maintainability.
- CancellationTokenRegistration registration = cancellationToken.Register(static s => ((ManagedWebSocket)s!).Abort(), this);
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.ReceiveAsyncPrivateStarted(this, payloadBuffer.Length);
+
+ CancellationTokenRegistration registration = default;
try
{
+ registration = cancellationToken.Register(static s => ((ManagedWebSocket)s!).Abort(), this);
+
await _receiveMutex.EnterAsync(cancellationToken).ConfigureAwait(false);
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.MutexEntered(_receiveMutex);
+
try
{
- ObjectDisposedException.ThrowIf(_disposed, typeof(WebSocket));
+ ThrowIfDisposed();
while (true) // in case we get control frames that should be ignored from the user's perspective
{
@@ -715,6 +819,8 @@ private async ValueTask ReceiveAsyncPrivate(Memory paylo
MessageHeader header = _lastReceiveHeader;
if (header.Processed)
{
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, "Reading the next frame header");
+
if (_receiveBufferCount < (_isServer ? MaxMessageHeaderLength : (MaxMessageHeaderLength - MaskLength)))
{
// Make sure we have the first two bytes, which includes the start of the payload length.
@@ -758,6 +864,11 @@ private async ValueTask ReceiveAsyncPrivate(Memory paylo
}
_receivedMaskOffsetOffset = 0;
+ if (NetEventSource.Log.IsEnabled())
+ {
+ NetEventSource.Trace(this, $"Next frame opcode={header.Opcode}, fin={header.Fin}, compressed={header.Compressed}, payloadLength={header.PayloadLength}");
+ }
+
if (header.PayloadLength == 0 && header.Compressed)
{
// In the rare case where we receive a compressed message with no payload
@@ -841,10 +952,14 @@ private async ValueTask ReceiveAsyncPrivate(Memory paylo
int numBytesRead = await _stream.ReadAtLeastAsync(
readBuffer, bytesToRead, throwOnEndOfStream: false, cancellationToken).ConfigureAwait(false);
+
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"bytesRead={numBytesRead}");
+
if (numBytesRead < bytesToRead)
{
ThrowEOFUnexpected();
}
+ _keepAlivePingState?.OnDataReceived();
totalBytesReceived += numBytesRead;
}
@@ -882,6 +997,11 @@ private async ValueTask ReceiveAsyncPrivate(Memory paylo
await CloseWithReceiveErrorAndThrowAsync(WebSocketCloseStatus.InvalidPayloadData, WebSocketError.Faulted).ConfigureAwait(false);
}
+ if (header.Processed)
+ {
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, "Data frame fully processed");
+ }
+
_lastReceiveHeader = header;
return GetReceiveResult(
totalBytesReceived,
@@ -892,13 +1012,30 @@ private async ValueTask ReceiveAsyncPrivate(Memory paylo
finally
{
_receiveMutex.Exit();
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.MutexExited(_receiveMutex);
}
}
- catch (Exception exc) when (exc is not OperationCanceledException)
+ catch (Exception exc)
{
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceException(this, exc);
+
+ if (exc is OperationCanceledException)
+ {
+ throw;
+ }
+
if (_state == WebSocketState.Aborted)
{
- throw new OperationCanceledException(nameof(WebSocketState.Aborted), exc);
+ Exception inner = exc;
+ if (_keepAlivePingState?.Exception is not null)
+ {
+ // exception was most likely caused by us aborting the connection due to
+ // keep-alive timeout; but let's surface both just in case
+ inner = ExceptionDispatchInfo.SetCurrentStackTrace(
+ new AggregateException(_keepAlivePingState.Exception, exc));
+ }
+
+ throw new OperationCanceledException(nameof(WebSocketState.Aborted), inner);
}
OnAborted();
@@ -912,6 +1049,7 @@ private async ValueTask ReceiveAsyncPrivate(Memory paylo
finally
{
registration.Dispose();
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.ReceiveAsyncPrivateCompleted(this);
}
}
@@ -941,14 +1079,17 @@ private async ValueTask HandleReceivedCloseAsync(MessageHeader header, Cancellat
lock (StateUpdateLock)
{
_receivedCloseFrame = true;
- if (_sentCloseFrame && _state < WebSocketState.Closed)
+ WebSocketState state = _state;
+ if (_sentCloseFrame && state < WebSocketState.Closed)
{
_state = WebSocketState.Closed;
}
- else if (_state < WebSocketState.CloseReceived)
+ else if (state < WebSocketState.CloseReceived)
{
_state = WebSocketState.CloseReceived;
}
+
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"State transition from {state} to {_state}");
}
WebSocketCloseStatus closeStatus = WebSocketCloseStatus.NormalClosure;
@@ -1005,6 +1146,8 @@ private async ValueTask HandleReceivedCloseAsync(MessageHeader header, Cancellat
/// Issues a read on the stream to wait for EOF.
private async ValueTask WaitForServerToCloseConnectionAsync(CancellationToken cancellationToken)
{
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this);
+
// Per RFC 6455 7.1.1, try to let the server close the connection. We give it up to a second.
// We simply issue a read and don't care what we get back; we could validate that we don't get
// additional data, but at this point we're about to close the connection and we're just stalling
@@ -1036,19 +1179,30 @@ private async ValueTask WaitForServerToCloseConnectionAsync(CancellationToken ca
/// The CancellationToken used to cancel the websocket operation.
private async ValueTask HandleReceivedPingPongAsync(MessageHeader header, CancellationToken cancellationToken)
{
+ Debug.Assert(_receiveMutex.IsHeld, $"Caller should hold the {nameof(_receiveMutex)}");
+
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this);
+
// Consume any (optional) payload associated with the ping/pong.
if (header.PayloadLength > 0 && _receiveBufferCount < header.PayloadLength)
{
await EnsureBufferContainsAsync((int)header.PayloadLength, cancellationToken).ConfigureAwait(false);
}
+ bool processPing = header.Opcode == MessageOpcode.Ping;
+
+ bool processPong = header.Opcode == MessageOpcode.Pong && _keepAlivePingState is not null
+ && header.PayloadLength == KeepAlivePingState.PingPayloadSize;
+
+ if ((processPing || processPong) && _isServer)
+ {
+ ApplyMask(_receiveBuffer.Span.Slice(_receiveBufferOffset, (int)header.PayloadLength), header.Mask, 0);
+ }
+
// If this was a ping, send back a pong response.
- if (header.Opcode == MessageOpcode.Ping)
+ if (processPing)
{
- if (_isServer)
- {
- ApplyMask(_receiveBuffer.Span.Slice(_receiveBufferOffset, (int)header.PayloadLength), header.Mask, 0);
- }
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, "Processing incoming Ping");
await SendFrameAsync(
MessageOpcode.Pong,
@@ -1057,6 +1211,20 @@ await SendFrameAsync(
_receiveBuffer.Slice(_receiveBufferOffset, (int)header.PayloadLength),
cancellationToken).ConfigureAwait(false);
}
+ else if (processPong)
+ {
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, "Processing incoming Pong");
+
+ long pongPayload = BinaryPrimitives.ReadInt64BigEndian(_receiveBuffer.Span.Slice(_receiveBufferOffset, (int)header.PayloadLength));
+ lock (StateUpdateLock)
+ {
+ _keepAlivePingState!.OnPongResponseReceived(pongPayload);
+ }
+ }
+ else
+ {
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, "Received Unsolicited Pong. Skipping.");
+ }
// Regardless of whether it was a ping or pong, we no longer need the payload.
if (header.PayloadLength > 0)
@@ -1115,6 +1283,8 @@ private static bool IsValidCloseStatus(WebSocketCloseStatus closeStatus)
private async ValueTask CloseWithReceiveErrorAndThrowAsync(
WebSocketCloseStatus closeStatus, WebSocketError error, string? errorMessage = null, Exception? innerException = null)
{
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, errorMessage);
+
// Close the connection if it hasn't already been closed
if (!_sentCloseFrame)
{
@@ -1258,79 +1428,98 @@ private async ValueTask CloseWithReceiveErrorAndThrowAsync(
/// The CancellationToken to use to cancel the websocket.
private async Task CloseAsyncPrivate(WebSocketCloseStatus closeStatus, string? statusDescription, CancellationToken cancellationToken)
{
- // Send the close message. Skip sending a close frame if we're currently in a CloseSent state,
- // for example having just done a CloseOutputAsync.
- if (!_sentCloseFrame)
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.CloseAsyncPrivateStarted(this);
+ try
{
- await SendCloseFrameAsync(closeStatus, statusDescription, cancellationToken).ConfigureAwait(false);
- }
-
- // We should now either be in a CloseSent case (because we just sent one), or in a Closed state, in case
- // there was a concurrent receive that ended up handling an immediate close frame response from the server.
- // Of course it could also be Aborted if something happened concurrently to cause things to blow up.
- Debug.Assert(
- State == WebSocketState.CloseSent ||
- State == WebSocketState.Closed ||
- State == WebSocketState.Aborted,
- $"Unexpected state {State}.");
+ // Send the close message. Skip sending a close frame if we're currently in a CloseSent state,
+ // for example having just done a CloseOutputAsync.
+ if (!_sentCloseFrame)
+ {
+ await SendCloseFrameAsync(closeStatus, statusDescription, cancellationToken).ConfigureAwait(false);
+ }
- // We only need to wait for a received close frame if we are in the CloseSent State. If we are in the Closed
- // State then it means we already received a close frame. If we are in the Aborted State, then we should not
- // wait for a close frame as per RFC 6455 Section 7.1.7 "Fail the WebSocket Connection".
- if (State == WebSocketState.CloseSent)
- {
- // Wait until we've received a close response
- byte[] closeBuffer = ArrayPool.Shared.Rent(MaxMessageHeaderLength + MaxControlPayloadLength);
- try
+ // We should now either be in a CloseSent case (because we just sent one), or in a Closed state, in case
+ // there was a concurrent receive that ended up handling an immediate close frame response from the server.
+ // Of course it could also be Aborted if something happened concurrently to cause things to blow up.
+ Debug.Assert(
+ State == WebSocketState.CloseSent ||
+ State == WebSocketState.Closed ||
+ State == WebSocketState.Aborted,
+ $"Unexpected state {State}.");
+
+ // We only need to wait for a received close frame if we are in the CloseSent State. If we are in the Closed
+ // State then it means we already received a close frame. If we are in the Aborted State, then we should not
+ // wait for a close frame as per RFC 6455 Section 7.1.7 "Fail the WebSocket Connection".
+ if (State == WebSocketState.CloseSent)
{
- // Loop until we've received a close frame.
- while (!_receivedCloseFrame)
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, "Waiting for a close frame");
+
+ // Wait until we've received a close response
+ byte[] closeBuffer = ArrayPool.Shared.Rent(MaxMessageHeaderLength + MaxControlPayloadLength);
+ try
{
- // Enter the receive lock in order to get a consistent view of whether we've received a close
- // frame. If we haven't, issue a receive. Since that receive will try to take the same
- // non-entrant receive lock, we then exit the lock before waiting for the receive to complete,
- // as it will always complete asynchronously and only after we've exited the lock.
- ValueTask receiveTask = default;
- try
+ // Loop until we've received a close frame.
+ while (!_receivedCloseFrame)
{
- await _receiveMutex.EnterAsync(cancellationToken).ConfigureAwait(false);
+ // Enter the receive lock in order to get a consistent view of whether we've received a close
+ // frame. If we haven't, issue a receive. Since that receive will try to take the same
+ // non-entrant receive lock, we then exit the lock before waiting for the receive to complete,
+ // as it will always complete asynchronously and only after we've exited the lock.
+ ValueTask receiveTask = default;
try
{
- if (!_receivedCloseFrame)
+ await _receiveMutex.EnterAsync(cancellationToken).ConfigureAwait(false);
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.MutexEntered(_receiveMutex);
+
+ try
{
- receiveTask = ReceiveAsyncPrivate(closeBuffer, cancellationToken);
+ if (!_receivedCloseFrame)
+ {
+ receiveTask = ReceiveAsyncPrivate(closeBuffer, cancellationToken);
+ }
+ }
+ finally
+ {
+ _receiveMutex.Exit();
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.MutexExited(_receiveMutex);
}
}
- finally
+ catch (OperationCanceledException)
{
- _receiveMutex.Exit();
+ // If waiting on the receive lock was canceled, abort the connection, as we would do
+ // as part of the receive itself.
+ Abort();
+ throw;
}
- }
- catch (OperationCanceledException)
- {
- // If waiting on the receive lock was canceled, abort the connection, as we would do
- // as part of the receive itself.
- Abort();
- throw;
- }
- // Wait for the receive to complete if we issued one.
- await receiveTask.ConfigureAwait(false);
+ // Wait for the receive to complete if we issued one.
+ await receiveTask.ConfigureAwait(false);
+ }
+ }
+ finally
+ {
+ ArrayPool.Shared.Return(closeBuffer);
}
}
- finally
+
+ // We're closed. Close the connection and update the status.
+ lock (StateUpdateLock)
{
- ArrayPool.Shared.Return(closeBuffer);
+ DisposeCore();
}
}
-
- // We're closed. Close the connection and update the status.
- lock (StateUpdateLock)
+ catch (Exception exc)
{
- DisposeCore();
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceException(this, exc);
+ throw;
+ }
+ finally
+ {
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.CloseAsyncPrivateCompleted(this);
}
}
+
/// Sends a close message to the server.
/// The close status to send.
/// The close status description to send.
@@ -1370,14 +1559,17 @@ private async ValueTask SendCloseFrameAsync(WebSocketCloseStatus closeStatus, st
lock (StateUpdateLock)
{
_sentCloseFrame = true;
- if (_receivedCloseFrame && _state < WebSocketState.Closed)
+ WebSocketState state = _state;
+ if (_receivedCloseFrame && state < WebSocketState.Closed)
{
_state = WebSocketState.Closed;
}
- else if (_state < WebSocketState.CloseSent)
+ else if (state < WebSocketState.CloseSent)
{
_state = WebSocketState.CloseSent;
}
+
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"State transition from {state} to {_state}");
}
if (!_isServer && _receivedCloseFrame)
@@ -1417,10 +1609,13 @@ private async ValueTask EnsureBufferContainsAsync(int minimumRequiredBytes, Canc
_receiveBuffer.Slice(_receiveBufferCount), bytesToRead, throwOnEndOfStream: false, cancellationToken).ConfigureAwait(false);
_receiveBufferCount += numRead;
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"bytesRead={numRead}");
+
if (numRead < bytesToRead)
{
ThrowEOFUnexpected();
}
+ _keepAlivePingState?.OnDataReceived();
}
}
}
@@ -1430,7 +1625,7 @@ private void ThrowEOFUnexpected()
// The connection closed before we were able to read everything we needed.
// If it was due to us being disposed, fail with the correct exception.
// Otherwise, it was due to the connection being closed and it wasn't expected.
- ObjectDisposedException.ThrowIf(_disposed, typeof(WebSocket));
+ ThrowIfDisposed();
throw new WebSocketException(WebSocketError.ConnectionClosedPrematurely);
}
@@ -1542,6 +1737,30 @@ private void ThrowIfOperationInProgress(bool operationCompleted, [CallerMemberNa
cancellationToken);
}
+ private void ThrowIfDisposed() => ThrowIfInvalidState();
+
+ private void ThrowIfInvalidState(WebSocketState[]? validStates = null)
+ {
+ bool disposed = _disposed;
+ WebSocketState state = _state;
+ Exception? keepAliveException = null;
+
+ if (_keepAlivePingState is not null)
+ {
+ // we need to take a lock to maintain consistency
+ lock (StateUpdateLock)
+ {
+ disposed = _disposed;
+ state = _state;
+ keepAliveException = _keepAlivePingState.Exception;
+ }
+ }
+
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"_state={state}, _disposed={disposed}, _keepAlivePingState.Exception={keepAliveException}");
+
+ WebSocketValidate.ThrowIfInvalidState(state, disposed, keepAliveException, validStates);
+ }
+
// From https://github.com/aspnet/WebSockets/blob/aa63e27fce2e9202698053620679a9a1059b501e/src/Microsoft.AspNetCore.WebSockets.Protocol/Utilities.cs#L75
// Performs a stateful validation of UTF-8 bytes.
// It checks for valid formatting, overlong encodings, surrogates, and value ranges.
diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/NetEventSource.WebSockets.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/NetEventSource.WebSockets.cs
new file mode 100644
index 0000000000000..d0977fb767060
--- /dev/null
+++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/NetEventSource.WebSockets.cs
@@ -0,0 +1,286 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Diagnostics;
+using System.Diagnostics.CodeAnalysis;
+using System.Diagnostics.Tracing;
+using System.Runtime.CompilerServices;
+
+namespace System.Net
+{
+ [EventSource(Name = "Private.InternalDiagnostics.System.Net.WebSockets")]
+ internal sealed partial class NetEventSource
+ {
+ // NOTE
+ // - The 'Start' and 'Stop' suffixes on the following event names have special meaning in EventSource. They
+ // enable creating 'activities'.
+ // For more information, take a look at the following blog post:
+ // https://blogs.msdn.microsoft.com/vancem/2015/09/14/exploring-eventsource-activity-correlation-and-causation-features/
+ // - A stop event's event id must be next one after its start event.
+
+ private const int KeepAliveSentId = NextAvailableEventId;
+ private const int KeepAliveAckedId = KeepAliveSentId + 1;
+
+ private const int WsTraceId = KeepAliveAckedId + 1;
+
+ private const int CloseStartId = WsTraceId + 1;
+ private const int CloseStopId = CloseStartId + 1;
+
+ private const int ReceiveStartId = CloseStopId + 1;
+ private const int ReceiveStopId = ReceiveStartId + 1;
+
+ private const int SendStartId = ReceiveStopId + 1;
+ private const int SendStopId = SendStartId + 1;
+
+ private const int MutexEnterId = SendStopId + 1;
+ private const int MutexExitId = MutexEnterId + 1;
+ private const int MutexContendedId = MutexExitId + 1;
+
+ //
+ // Keep-Alive
+ //
+
+ private const string Ping = "Ping";
+ private const string Pong = "Pong";
+
+ [Event(KeepAliveSentId, Keywords = Keywords.Debug, Level = EventLevel.Informational)]
+ private void KeepAliveSent(string objName, string opcode, long payload) =>
+ WriteEvent(KeepAliveSentId, objName, opcode, payload);
+
+ [Event(KeepAliveAckedId, Keywords = Keywords.Debug, Level = EventLevel.Informational)]
+ private void KeepAliveAcked(string objName, long payload) =>
+ WriteEvent(KeepAliveAckedId, objName, payload);
+
+ [NonEvent]
+ public static void KeepAlivePingSent(object? obj, long payload)
+ {
+ Debug.Assert(Log.IsEnabled());
+ Log.KeepAliveSent(IdOf(obj), Ping, payload);
+ }
+
+ [NonEvent]
+ public static void UnsolicitedPongSent(object? obj)
+ {
+ Debug.Assert(Log.IsEnabled());
+ Log.KeepAliveSent(IdOf(obj), Pong, 0);
+ }
+
+ [NonEvent]
+ public static void PongResponseReceived(object? obj, long payload)
+ {
+ Debug.Assert(Log.IsEnabled());
+ Log.KeepAliveAcked(IdOf(obj), payload);
+ }
+
+ //
+ // Debug Messages
+ //
+
+ [Event(WsTraceId, Keywords = Keywords.Debug, Level = EventLevel.Verbose)]
+ private void WsTrace(string objName, string memberName, string message) =>
+ WriteEvent(WsTraceId, objName, memberName, message);
+
+ [NonEvent]
+ public static void TraceErrorMsg(object? obj, Exception exception, [CallerMemberName] string? memberName = null)
+ => Trace(obj, $"{exception.GetType().Name}: {exception.Message}", memberName);
+
+ [NonEvent]
+ public static void TraceException(object? obj, Exception exception, [CallerMemberName] string? memberName = null)
+ => Trace(obj, exception.ToString(), memberName);
+
+ [NonEvent]
+ public static void Trace(object? obj, string? message = null, [CallerMemberName] string? memberName = null)
+ {
+ Debug.Assert(Log.IsEnabled());
+ Log.WsTrace(IdOf(obj), memberName ?? MissingMember, message ?? memberName ?? string.Empty);
+ }
+
+ //
+ // Close
+ //
+
+ [Event(CloseStartId, Keywords = Keywords.Debug, Level = EventLevel.Verbose)]
+ private void CloseStart(string objName, string memberName) =>
+ WriteEvent(CloseStartId, objName, memberName);
+
+ [Event(CloseStopId, Keywords = Keywords.Debug, Level = EventLevel.Verbose)]
+ private void CloseStop(string objName, string memberName) =>
+ WriteEvent(CloseStopId, objName, memberName);
+
+ [NonEvent]
+ public static void CloseAsyncPrivateStarted(object? obj, [CallerMemberName] string? memberName = null)
+ {
+ Debug.Assert(Log.IsEnabled());
+ Log.CloseStart(IdOf(obj), memberName ?? MissingMember);
+ }
+
+ [NonEvent]
+ public static void CloseAsyncPrivateCompleted(object? obj, [CallerMemberName] string? memberName = null)
+ {
+ Debug.Assert(Log.IsEnabled());
+ Log.CloseStop(IdOf(obj), memberName ?? MissingMember);
+ }
+
+ //
+ // ReceiveAsyncPrivate
+ //
+
+ [Event(ReceiveStartId, Keywords = Keywords.Debug, Level = EventLevel.Informational)]
+ private void ReceiveStart(string objName, string memberName, int bufferLength) =>
+ WriteEvent(ReceiveStartId, objName, memberName, bufferLength);
+
+ [Event(ReceiveStopId, Keywords = Keywords.Debug, Level = EventLevel.Informational)]
+ private void ReceiveStop(string objName, string memberName) =>
+ WriteEvent(ReceiveStopId, objName, memberName);
+
+ [NonEvent]
+ public static void ReceiveAsyncPrivateStarted(object? obj, int bufferLength, [CallerMemberName] string? memberName = null)
+ {
+ Debug.Assert(Log.IsEnabled());
+ Log.ReceiveStart(IdOf(obj), memberName ?? MissingMember, bufferLength);
+ }
+
+ [NonEvent]
+ public static void ReceiveAsyncPrivateCompleted(object? obj, [CallerMemberName] string? memberName = null)
+ {
+ Debug.Assert(Log.IsEnabled());
+ Log.ReceiveStop(IdOf(obj), memberName ?? MissingMember);
+ }
+
+ //
+ // SendFrameAsync
+ //
+
+ [Event(SendStartId, Keywords = Keywords.Debug, Level = EventLevel.Verbose)]
+ private void SendStart(string objName, string memberName, string opcode, int bufferLength) =>
+ WriteEvent(SendStartId, objName, memberName, opcode, bufferLength);
+
+ [Event(SendStopId, Keywords = Keywords.Debug, Level = EventLevel.Verbose)]
+ private void SendStop(string objName, string memberName) =>
+ WriteEvent(SendStopId, objName, memberName);
+
+ [NonEvent]
+ public static void SendFrameAsyncStarted(object? obj, string opcode, int bufferLength, [CallerMemberName] string? memberName = null)
+ {
+ Debug.Assert(Log.IsEnabled());
+ Log.SendStart(IdOf(obj), memberName ?? MissingMember, opcode, bufferLength);
+ }
+
+ [NonEvent]
+ public static void SendFrameAsyncCompleted(object? obj, [CallerMemberName] string? memberName = null)
+ {
+ Debug.Assert(Log.IsEnabled());
+ Log.SendStop(IdOf(obj), memberName ?? MissingMember);
+ }
+
+ //
+ // AsyncMutex
+ //
+
+ [Event(MutexEnterId, Keywords = Keywords.Debug, Level = EventLevel.Verbose)]
+ private void MutexEnter(string objName, string memberName) =>
+ WriteEvent(MutexEnterId, objName, memberName);
+
+ [Event(MutexExitId, Keywords = Keywords.Debug, Level = EventLevel.Verbose)]
+ private void MutexExit(string objName, string memberName) =>
+ WriteEvent(MutexExitId, objName, memberName);
+
+ [Event(MutexContendedId, Keywords = Keywords.Debug, Level = EventLevel.Verbose)]
+ private void MutexContended(string objName, string memberName, int queueLength) =>
+ WriteEvent(MutexContendedId, objName, memberName, queueLength);
+
+ [NonEvent]
+ public static void MutexEntered(object? obj, [CallerMemberName] string? memberName = null)
+ {
+ Debug.Assert(Log.IsEnabled());
+ Log.MutexEnter(IdOf(obj), memberName ?? MissingMember);
+ }
+
+ [NonEvent]
+ public static void MutexExited(object? obj, [CallerMemberName] string? memberName = null)
+ {
+ Debug.Assert(Log.IsEnabled());
+ Log.MutexExit(IdOf(obj), memberName ?? MissingMember);
+ }
+
+ [NonEvent]
+ public static void MutexContended(object? obj, int gateValue, [CallerMemberName] string? memberName = null)
+ {
+ Debug.Assert(Log.IsEnabled());
+ Log.MutexContended(IdOf(obj), memberName ?? MissingMember, -gateValue);
+ }
+
+ //
+ // WriteEvent overloads
+ //
+
+ [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026:UnrecognizedReflectionPattern",
+ Justification = EventSourceSuppressMessage)]
+ [NonEvent]
+ private unsafe void WriteEvent(int eventId, string arg1, string arg2, long arg3)
+ {
+ fixed (char* arg1Ptr = arg1)
+ fixed (char* arg2Ptr = arg2)
+ {
+ const int NumEventDatas = 3;
+ EventData* descrs = stackalloc EventData[NumEventDatas];
+
+ descrs[0] = new EventData
+ {
+ DataPointer = (IntPtr)(arg1Ptr),
+ Size = (arg1.Length + 1) * sizeof(char)
+ };
+ descrs[1] = new EventData
+ {
+ DataPointer = (IntPtr)(arg2Ptr),
+ Size = (arg2.Length + 1) * sizeof(char)
+ };
+ descrs[2] = new EventData
+ {
+ DataPointer = (IntPtr)(&arg3),
+ Size = sizeof(long)
+ };
+
+ WriteEventCore(eventId, NumEventDatas, descrs);
+ }
+ }
+
+ [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026:UnrecognizedReflectionPattern",
+ Justification = EventSourceSuppressMessage)]
+ [NonEvent]
+ private unsafe void WriteEvent(int eventId, string arg1, string arg2, string arg3, int arg4)
+ {
+ fixed (char* arg1Ptr = arg1)
+ fixed (char* arg2Ptr = arg2)
+ fixed (char* arg3Ptr = arg3)
+ {
+ const int NumEventDatas = 4;
+ EventData* descrs = stackalloc EventData[NumEventDatas];
+
+ descrs[0] = new EventData
+ {
+ DataPointer = (IntPtr)(arg1Ptr),
+ Size = (arg1.Length + 1) * sizeof(char)
+ };
+ descrs[1] = new EventData
+ {
+ DataPointer = (IntPtr)(arg2Ptr),
+ Size = (arg2.Length + 1) * sizeof(char)
+ };
+ descrs[2] = new EventData
+ {
+ DataPointer = (IntPtr)(arg3Ptr),
+ Size = (arg3.Length + 1) * sizeof(char)
+ };
+ descrs[3] = new EventData
+ {
+ DataPointer = (IntPtr)(&arg4),
+ Size = sizeof(int)
+ };
+
+ WriteEventCore(eventId, NumEventDatas, descrs);
+ }
+ }
+
+ }
+}
diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs
index fc4436926a6f0..ce47b894d32c7 100644
--- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs
+++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs
@@ -84,7 +84,7 @@ private async ValueTask SendWithArrayPoolAsync(
public static TimeSpan DefaultKeepAliveInterval
{
// In the .NET Framework, this pulls the value from a P/Invoke. Here we just hardcode it to a reasonable default.
- get { return TimeSpan.FromSeconds(30); }
+ get { return WebSocketDefaults.DefaultClientKeepAliveInterval; }
}
protected static void ThrowOnInvalidState(WebSocketState state, params WebSocketState[] validStates)
@@ -150,7 +150,7 @@ public static WebSocket CreateFromStream(Stream stream, bool isServer, string? s
0));
}
- return new ManagedWebSocket(stream, isServer, subProtocol, keepAliveInterval);
+ return new ManagedWebSocket(stream, isServer, subProtocol, keepAliveInterval, WebSocketDefaults.DefaultKeepAliveTimeout);
}
/// Creates a that operates on a representing a web socket connection.
@@ -209,7 +209,7 @@ public static WebSocket CreateClientWebSocket(Stream innerStream,
// Ignore useZeroMaskingKey. ManagedWebSocket doesn't currently support that debugging option.
// Ignore internalBuffer. ManagedWebSocket uses its own small buffer for headers/control messages.
- return new ManagedWebSocket(innerStream, false, subProtocol, keepAliveInterval);
+ return new ManagedWebSocket(innerStream, false, subProtocol, keepAliveInterval, WebSocketDefaults.DefaultKeepAliveTimeout);
}
}
}
diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketCreationOptions.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketCreationOptions.cs
index d042583da5444..dfc74241379f8 100644
--- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketCreationOptions.cs
+++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketCreationOptions.cs
@@ -11,7 +11,8 @@ namespace System.Net.WebSockets
public sealed class WebSocketCreationOptions
{
private string? _subProtocol;
- private TimeSpan _keepAliveInterval;
+ private TimeSpan _keepAliveInterval = WebSocketDefaults.DefaultKeepAliveInterval;
+ private TimeSpan _keepAliveTimeout = WebSocketDefaults.DefaultKeepAliveTimeout;
///
/// Defines if this websocket is the server-side of the connection. The default value is false.
@@ -36,6 +37,8 @@ public string? SubProtocol
///
/// The keep-alive interval to use, or or to disable keep-alives.
+ /// If is set, then PING messages are sent and peer's PONG responses are expected, otherwise,
+ /// unsolicited PONG messages are used as a keep-alive heartbeat.
/// The default is .
///
public TimeSpan KeepAliveInterval
@@ -52,6 +55,25 @@ public TimeSpan KeepAliveInterval
}
}
+ ///
+ /// The timeout to use when waiting for the peer's PONG in response to us sending a PING; or or
+ /// to disable waiting for peer's response, and use an unsolicited PONG as a Keep-Alive heartbeat instead.
+ /// The default is .
+ ///
+ public TimeSpan KeepAliveTimeout
+ {
+ get => _keepAliveTimeout;
+ set
+ {
+ if (value != Timeout.InfiniteTimeSpan && value < TimeSpan.Zero)
+ {
+ throw new ArgumentOutOfRangeException(nameof(KeepAliveTimeout), value,
+ SR.Format(SR.net_WebSockets_ArgumentOutOfRange_TooSmall, 0));
+ }
+ _keepAliveTimeout = value;
+ }
+ }
+
///
/// The agreed upon options for per message deflate.
/// Be aware that enabling compression makes the application subject to CRIME/BREACH type of attacks.
diff --git a/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj b/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj
index 807da709ea755..a7f09ff31db29 100644
--- a/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj
+++ b/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj
@@ -1,6 +1,7 @@
$(NetCoreAppCurrent)
+ ../src/Resources/Strings.resx
@@ -9,6 +10,7 @@
+
diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketCloseTests.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketCloseTests.cs
index 86d1dfb2cd530..423c8d40ed5d5 100644
--- a/src/libraries/System.Net.WebSockets/tests/WebSocketCloseTests.cs
+++ b/src/libraries/System.Net.WebSockets/tests/WebSocketCloseTests.cs
@@ -85,5 +85,30 @@ public async Task ReceiveAsync_ValidCloseStatus_Success(WebSocketCloseStatus clo
Assert.Equal(closeStatusDescription, closing.CloseStatusDescription);
}
}
+
+ [Fact]
+ public async Task CloseAsync_CancelableEvenWhenPendingReceive_Throws()
+ {
+ using var stream = new WebSocketTestStream();
+ using var websocket = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions());
+
+ Task receiveTask = websocket.ReceiveAsync(new byte[1], CancellationToken);
+ await Task.Delay(100); // give the receive task a chance to aquire the lock
+ var cancelCloseCts = new CancellationTokenSource();
+ await Assert.ThrowsAnyAsync(async () =>
+ {
+ Task t = websocket.CloseAsync(WebSocketCloseStatus.NormalClosure, null, cancelCloseCts.Token);
+ await Task.Delay(100); // give the close task time to get in the queue waiting for the lock
+ cancelCloseCts.Cancel();
+ await t;
+ });
+
+ Assert.True(cancelCloseCts.Token.IsCancellationRequested);
+ Assert.False(CancellationToken.IsCancellationRequested);
+
+ await Assert.ThrowsAnyAsync(() => receiveTask);
+
+ Assert.False(CancellationToken.IsCancellationRequested);
+ }
}
}
diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketKeepAliveTests.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketKeepAliveTests.cs
new file mode 100644
index 0000000000000..11dd28117ac5d
--- /dev/null
+++ b/src/libraries/System.Net.WebSockets/tests/WebSocketKeepAliveTests.cs
@@ -0,0 +1,280 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Buffers.Binary;
+using System.Diagnostics;
+using System.IO;
+using System.Threading;
+using System.Threading.Tasks;
+using Xunit;
+
+namespace System.Net.WebSockets.Tests
+{
+ public class WebSocketKeepAliveTests
+ {
+ public static readonly TimeSpan TestTimeout = TimeSpan.FromSeconds(10);
+ public static readonly TimeSpan KeepAliveInterval = TimeSpan.FromMilliseconds(100);
+ public static readonly TimeSpan KeepAliveTimeout = TimeSpan.FromSeconds(1);
+ public const int FramesToTestCount = 5;
+
+#region Frame format helper constants
+
+ public const int MinHeaderLength = 2;
+ public const int MaskLength = 4;
+ public const int SingleInt64PayloadLength = sizeof(long);
+ public const int PingPayloadLength = SingleInt64PayloadLength;
+
+ // 0b_1_***_**** -- fin=true
+ public const byte FirstByteBits_FinFlag = 0b_1_000_0000;
+
+ // 0b_*_***_0010 -- opcode=BINARY (0x02)
+ public const byte FirstByteBits_OpcodeBinary = 0b_0_000_0010;
+
+ // 0b_*_***_1001 -- opcode=PING (0x09)
+ public const byte FirstByteBits_OpcodePing = 0b_0_000_1001;
+
+ // 0b_*_***_1010 -- opcode=PONG (0x10)
+ public const byte FirstByteBits_OpcodePong = 0b_0_000_1010;
+
+ // 0b_1_******* -- mask=true
+ public const byte SecondByteBits_MaskFlag = 0b_1_0000000;
+
+ // 0b_*_0001000 -- length=8
+ public const byte SecondByteBits_PayloadLength8 = SingleInt64PayloadLength;
+
+ public const byte FirstByte_PingFrame = FirstByteBits_FinFlag | FirstByteBits_OpcodePing;
+ public const byte FirstByte_PongFrame = FirstByteBits_FinFlag | FirstByteBits_OpcodePong;
+ public const byte FirstByte_DataFrame = FirstByteBits_FinFlag | FirstByteBits_OpcodeBinary;
+
+ public const byte SecondByte_Server_NoPayload = 0;
+ public const byte SecondByte_Client_NoPayload = SecondByteBits_MaskFlag;
+
+ public const byte SecondByte_Server_8bPayload = SecondByteBits_PayloadLength8;
+ public const byte SecondByte_Client_8bPayload = SecondByteBits_MaskFlag | SecondByteBits_PayloadLength8;
+
+ public const int Server_FrameHeaderLength = MinHeaderLength;
+ public const int Client_FrameHeaderLength = MinHeaderLength + MaskLength;
+
+#endregion
+
+ [Theory]
+ [InlineData(true)]
+ [InlineData(false)]
+ public async Task WebSocket_NoUserReadOrWrite_SendsUnsolicitedPong(bool isServer)
+ {
+ var cancellationToken = new CancellationTokenSource(TestTimeout).Token;
+
+ using WebSocketTestStream testStream = new();
+ Stream localEndpointStream = testStream;
+ Stream remoteEndpointStream = testStream.Remote;
+
+ using WebSocket webSocket = WebSocket.CreateFromStream(localEndpointStream, new WebSocketCreationOptions
+ {
+ IsServer = isServer,
+ KeepAliveInterval = KeepAliveInterval
+ });
+
+ // --- "remote endpoint" side ---
+
+ int pongFrameLength = isServer ? Server_FrameHeaderLength : Client_FrameHeaderLength;
+ var pongBuffer = new byte[pongFrameLength];
+ for (int i = 0; i < FramesToTestCount; i++) // WS should be sending pongs "indefinitely", let's check a few
+ {
+ await remoteEndpointStream.ReadExactlyAsync(pongBuffer, cancellationToken);
+
+ Assert.Equal(FirstByte_PongFrame, pongBuffer[0]);
+ Assert.Equal(
+ isServer ? SecondByte_Server_NoPayload : SecondByte_Client_NoPayload,
+ pongBuffer[1]);
+ }
+ }
+
+ [Fact]
+ public async Task WebSocketServer_NoUserReadOrWrite_SendsPingAndReadsPongResponse()
+ {
+ var cancellationToken = new CancellationTokenSource(TestTimeout).Token;
+
+ using WebSocketTestStream testStream = new();
+ Stream serverStream = testStream;
+ Stream clientStream = testStream.Remote;
+
+ using WebSocket webSocketServer = WebSocket.CreateFromStream(serverStream, new WebSocketCreationOptions
+ {
+ IsServer = true,
+ KeepAliveInterval = KeepAliveInterval,
+ KeepAliveTimeout = TestTimeout // we don't care about the actual timeout here
+ });
+
+ // we need an outstanding read to ensure the client receives pongs
+ var readCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
+ var serverReadTask = webSocketServer.ReceiveAsync(Memory.Empty, readCts.Token);
+
+ // --- "client" side ---
+
+ var buffer = new byte[Client_FrameHeaderLength + PingPayloadLength]; // client frame is bigger because of masking
+
+ for (int i = 0; i < FramesToTestCount; i++) // WS should be sending pings "indefinitely", let's check a few
+ {
+ Assert.Equal(WebSocketState.Open, webSocketServer.State);
+ Assert.False(serverReadTask.IsCompleted);
+
+ buffer.AsSpan().Clear();
+ await clientStream.ReadExactlyAsync(
+ buffer.AsMemory(0, Server_FrameHeaderLength + PingPayloadLength),
+ cancellationToken);
+
+ Assert.Equal(FirstByte_PingFrame, buffer[0]);
+
+ // implementation detail: payload is a long counter starting from 1
+ Assert.Equal(SecondByte_Server_8bPayload, buffer[1]);
+
+ var payloadBytes = buffer.AsSpan().Slice(Server_FrameHeaderLength, PingPayloadLength);
+ long pingCounter = BinaryPrimitives.ReadInt64BigEndian(payloadBytes);
+
+ Assert.Equal(i+1, pingCounter);
+
+ // --- sending pong back ---
+
+ buffer[0] = FirstByte_PongFrame;
+ buffer[1] = SecondByte_Client_8bPayload;
+
+ // using zeroes as a "mask" -- applying such a mask is a no-op
+ Array.Clear(buffer, MinHeaderLength, MaskLength);
+
+ // sending the same payload back
+ BinaryPrimitives.WriteInt64BigEndian(buffer.AsSpan().Slice(Client_FrameHeaderLength), pingCounter);
+
+ await clientStream.WriteAsync(buffer, cancellationToken);
+ }
+
+ Assert.Equal(WebSocketState.Open, webSocketServer.State);
+ Assert.False(serverReadTask.IsCompleted);
+
+ readCts.Cancel();
+
+ await Assert.ThrowsAsync(() => serverReadTask.AsTask());
+ Assert.Equal(WebSocketState.Aborted, webSocketServer.State);
+ }
+
+ [Fact]
+ public async Task WebSocketClient_NoServerDataSent_SendsPingAndReadsPongResponse()
+ {
+ var cancellationToken = new CancellationTokenSource(TestTimeout).Token;
+
+ using WebSocketTestStream testStream = new();
+ Stream clientStream = testStream;
+ Stream serverStream = testStream.Remote;
+
+ using WebSocket webSocketClient = WebSocket.CreateFromStream(clientStream, new WebSocketCreationOptions
+ {
+ IsServer = false,
+ KeepAliveInterval = KeepAliveInterval,
+ KeepAliveTimeout = TestTimeout // we don't care about the actual timeout here
+ });
+
+ // we need an outstanding read to ensure the client receives pongs
+ var readCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
+ var clientReadTask = webSocketClient.ReceiveAsync(Memory.Empty, readCts.Token);
+
+
+ // --- "server" side ---
+
+ var buffer = new byte[Client_FrameHeaderLength + PingPayloadLength]; // client frame is bigger because of masking
+
+ for (int i = 0; i < FramesToTestCount; i++) // WS should be sending pings "indefinitely", let's check a few
+ {
+ Assert.Equal(WebSocketState.Open, webSocketClient.State);
+ Assert.False(clientReadTask.IsCompleted);
+
+ buffer.AsSpan().Clear();
+ await serverStream.ReadExactlyAsync(buffer, cancellationToken);
+
+ Assert.Equal(FirstByte_PingFrame, buffer[0]);
+
+ // implementation detail: payload is a long counter starting from 1
+ Assert.Equal(SecondByte_Client_8bPayload, buffer[1]);
+
+ var payloadBytes = buffer.AsSpan().Slice(Client_FrameHeaderLength, PingPayloadLength);
+ ApplyMask(payloadBytes, buffer.AsSpan().Slice(Client_FrameHeaderLength - MaskLength, MaskLength));
+ long pingCounter = BinaryPrimitives.ReadInt64BigEndian(payloadBytes);
+ Assert.Equal(i+1, pingCounter);
+
+ // --- sending pong back ---
+
+ buffer[0] = FirstByte_PongFrame;
+ buffer[1] = SecondByte_Server_8bPayload;
+
+ // sending the same payload back
+ BinaryPrimitives.WriteInt64BigEndian(buffer.AsSpan().Slice(Server_FrameHeaderLength), pingCounter);
+
+ await serverStream.WriteAsync(
+ buffer.AsMemory(0, Server_FrameHeaderLength + PingPayloadLength),
+ cancellationToken);
+ }
+
+ Assert.Equal(WebSocketState.Open, webSocketClient.State);
+ Assert.False(clientReadTask.IsCompleted);
+
+ readCts.Cancel();
+
+ await Assert.ThrowsAsync(() => clientReadTask.AsTask());
+ Assert.Equal(WebSocketState.Aborted, webSocketClient.State);
+
+ // Octet i of the transformed data ("transformed-octet-i") is the XOR of
+ // octet i of the original data ("original-octet-i") with octet at index
+ // i modulo 4 of the masking key ("masking-key-octet-j"):
+ //
+ // j = i MOD 4
+ // transformed-octet-i = original-octet-i XOR masking-key-octet-j
+ //
+ static void ApplyMask(Span buffer, Span mask)
+ {
+
+ for (int i = 0; i < buffer.Length; i++)
+ {
+ buffer[i] ^= mask[i % MaskLength];
+ }
+ }
+ }
+
+ [OuterLoop("Uses Task.Delay")]
+ [Theory]
+ [InlineData(true)]
+ [InlineData(false)]
+ public async Task WebSocket_NoPongResponseWithinTimeout_Aborted(bool outstandingUserRead)
+ {
+ var cancellationToken = new CancellationTokenSource(TestTimeout).Token;
+
+ using WebSocketTestStream testStream = new();
+ Stream localEndpointStream = testStream;
+ Stream remoteEndpointStream = testStream.Remote;
+
+ using WebSocket webSocket = WebSocket.CreateFromStream(localEndpointStream, new WebSocketCreationOptions
+ {
+ IsServer = true,
+ KeepAliveInterval = KeepAliveInterval,
+ KeepAliveTimeout = KeepAliveTimeout
+ });
+
+ Debug.Assert(webSocket.State == WebSocketState.Open);
+
+ ValueTask userReadTask = default;
+ if (outstandingUserRead)
+ {
+ userReadTask = webSocket.ReceiveAsync(Memory.Empty, cancellationToken);
+ }
+
+ await Task.Delay(2 * (KeepAliveTimeout + KeepAliveInterval), cancellationToken);
+
+ Assert.Equal(WebSocketState.Aborted, webSocket.State);
+
+ Exception readException = outstandingUserRead
+ ? await Assert.ThrowsAsync(() => userReadTask.AsTask())
+ : await Assert.ThrowsAsync(() => webSocket.ReceiveAsync(Memory.Empty, cancellationToken).AsTask());
+
+ var wse = Assert.IsType(readException.InnerException);
+ Assert.Equal(WebSocketError.Faulted, wse.WebSocketErrorCode);
+ Assert.Equal(SR.net_Websockets_KeepAlivePingTimeout, wse.Message);
+ }
+ }
+}
diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketTests.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketTests.cs
index 4cf9c279ba5f3..73e84998a9419 100644
--- a/src/libraries/System.Net.WebSockets/tests/WebSocketTests.cs
+++ b/src/libraries/System.Net.WebSockets/tests/WebSocketTests.cs
@@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
using System.IO;
+using System.Threading;
using System.Threading.Tasks;
using Xunit;
@@ -200,6 +201,26 @@ public async Task ReceiveAsync_WhenDisposedInParallel_DoesNotGetStuck()
await Assert.ThrowsAsync(() => r3.WaitAsync(TimeSpan.FromSeconds(1)));
}
+ [Fact]
+ public async Task ReceiveAsync_AfterCancellationDoReceiveAsync_ThrowsWebSocketException()
+ {
+ using var stream = new WebSocketTestStream();
+ using var websocket = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions());
+ var recvBuffer = new byte[100];
+ var segment = new ArraySegment(recvBuffer);
+ var cts = new CancellationTokenSource();
+
+ Task receive = websocket.ReceiveAsync(segment, cts.Token);
+ cts.Cancel();
+ await Assert.ThrowsAnyAsync(() => receive);
+
+ WebSocketException ex = await Assert.ThrowsAsync(() =>
+ websocket.ReceiveAsync(segment, CancellationToken.None));
+ Assert.Equal(
+ SR.Format(SR.net_WebSockets_InvalidState, "Aborted", "Open, CloseSent"),
+ ex.Message);
+ }
+
public abstract class ExposeProtectedWebSocket : WebSocket
{
public static new bool IsStateTerminal(WebSocketState state) =>