diff --git a/src/libraries/System.Net.Quic/src/Resources/Strings.resx b/src/libraries/System.Net.Quic/src/Resources/Strings.resx
index a29352a0578f5..51c31a9db1ca5 100644
--- a/src/libraries/System.Net.Quic/src/Resources/Strings.resx
+++ b/src/libraries/System.Net.Quic/src/Resources/Strings.resx
@@ -150,5 +150,14 @@
Writing is not allowed on stream.
+
+ '{0}' is not supported by System.Net.Quic.
+
+
+ The remote certificate was rejected by the provided RemoteCertificateValidationCallback.
+
+
+ The remote certificate is invalid because of errors in the certificate chain: {0}
+
diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Interop/SafeMsQuicConfigurationHandle.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Interop/SafeMsQuicConfigurationHandle.cs
index df48e0db377ae..aa4c58929f92c 100644
--- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Interop/SafeMsQuicConfigurationHandle.cs
+++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Interop/SafeMsQuicConfigurationHandle.cs
@@ -17,7 +17,7 @@ namespace System.Net.Quic.Implementations.MsQuic.Internal
internal sealed class SafeMsQuicConfigurationHandle : SafeHandle
{
private static readonly FieldInfo _contextCertificate = typeof(SslStreamCertificateContext).GetField("Certificate", BindingFlags.NonPublic | BindingFlags.Instance)!;
- private static readonly FieldInfo _contextChain= typeof(SslStreamCertificateContext).GetField("IntermediateCertificates", BindingFlags.NonPublic | BindingFlags.Instance)!;
+ private static readonly FieldInfo _contextChain = typeof(SslStreamCertificateContext).GetField("IntermediateCertificates", BindingFlags.NonPublic | BindingFlags.Instance)!;
public override bool IsInvalid => handle == IntPtr.Zero;
@@ -33,7 +33,7 @@ protected override bool ReleaseHandle()
}
// TODO: consider moving the static code from here to keep all the handle classes small and simple.
- public static unsafe SafeMsQuicConfigurationHandle Create(QuicClientConnectionOptions options)
+ public static SafeMsQuicConfigurationHandle Create(QuicClientConnectionOptions options)
{
X509Certificate? certificate = null;
if (options.ClientAuthenticationOptions?.ClientCertificates != null)
@@ -56,15 +56,35 @@ public static unsafe SafeMsQuicConfigurationHandle Create(QuicClientConnectionOp
return Create(options, QUIC_CREDENTIAL_FLAGS.CLIENT, certificate: certificate, certificateContext: null, options.ClientAuthenticationOptions?.ApplicationProtocols);
}
- public static unsafe SafeMsQuicConfigurationHandle Create(QuicListenerOptions options)
+ public static SafeMsQuicConfigurationHandle Create(QuicOptions options, SslServerAuthenticationOptions? serverAuthenticationOptions, string? targetHost = null)
{
QUIC_CREDENTIAL_FLAGS flags = QUIC_CREDENTIAL_FLAGS.NONE;
- if (options.ServerAuthenticationOptions != null && options.ServerAuthenticationOptions.ClientCertificateRequired)
+ X509Certificate? certificate = serverAuthenticationOptions?.ServerCertificate;
+
+ if (serverAuthenticationOptions != null)
{
- flags |= QUIC_CREDENTIAL_FLAGS.REQUIRE_CLIENT_AUTHENTICATION | QUIC_CREDENTIAL_FLAGS.INDICATE_CERTIFICATE_RECEIVED | QUIC_CREDENTIAL_FLAGS.NO_CERTIFICATE_VALIDATION;
+ if (serverAuthenticationOptions.CipherSuitesPolicy != null)
+ {
+ throw new PlatformNotSupportedException(SR.Format(SR.net_quic_ssl_option, nameof(serverAuthenticationOptions.CipherSuitesPolicy)));
+ }
+
+ if (serverAuthenticationOptions.EncryptionPolicy == EncryptionPolicy.NoEncryption)
+ {
+ throw new PlatformNotSupportedException(SR.Format(SR.net_quic_ssl_option, nameof(serverAuthenticationOptions.EncryptionPolicy)));
+ }
+
+ if (serverAuthenticationOptions.ClientCertificateRequired)
+ {
+ flags |= QUIC_CREDENTIAL_FLAGS.REQUIRE_CLIENT_AUTHENTICATION | QUIC_CREDENTIAL_FLAGS.INDICATE_CERTIFICATE_RECEIVED | QUIC_CREDENTIAL_FLAGS.NO_CERTIFICATE_VALIDATION;
+ }
+
+ if (certificate == null && serverAuthenticationOptions?.ServerCertificateSelectionCallback != null && targetHost != null)
+ {
+ certificate = serverAuthenticationOptions.ServerCertificateSelectionCallback(options, targetHost);
+ }
}
- return Create(options, flags, options.ServerAuthenticationOptions?.ServerCertificate, options.ServerAuthenticationOptions?.ServerCertificateContext, options.ServerAuthenticationOptions?.ApplicationProtocols);
+ return Create(options, flags, certificate, serverAuthenticationOptions?.ServerCertificateContext, serverAuthenticationOptions?.ApplicationProtocols);
}
// TODO: this is called from MsQuicListener and when it fails it wreaks havoc in MsQuicListener finalizer.
diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs
index 6fb6588958513..f3dcbf2ca706f 100644
--- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs
+++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs
@@ -7,6 +7,7 @@
using System.Net.Sockets;
using System.Runtime.ExceptionServices;
using System.Runtime.InteropServices;
+using System.Security.Authentication;
using System.Security.Cryptography;
using System.Security.Cryptography.X509Certificates;
using System.Threading;
@@ -35,10 +36,6 @@ internal sealed class MsQuicConnection : QuicConnectionProvider
private IPEndPoint? _localEndPoint;
private readonly EndPoint _remoteEndPoint;
private SslApplicationProtocol _negotiatedAlpnProtocol;
- private bool _isServer;
- private bool _remoteCertificateRequired;
- private X509RevocationMode _revocationMode = X509RevocationMode.Offline;
- private RemoteCertificateValidationCallback? _remoteCertificateValidationCallback;
internal sealed class State
{
@@ -50,8 +47,8 @@ internal sealed class State
// These exists to prevent GC of the MsQuicConnection in the middle of an async op (Connect or Shutdown).
public MsQuicConnection? Connection;
- // TODO: only allocate these when there is an outstanding connect/shutdown.
- public readonly TaskCompletionSource ConnectTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ public TaskCompletionSource? ConnectTcs;
+ // TODO: only allocate these when there is an outstanding shutdown.
public readonly TaskCompletionSource ShutdownTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
// Note that there's no such thing as resetable TCS, so we cannot reuse the same instance after we've set the result.
@@ -65,6 +62,13 @@ internal sealed class State
public int StreamCount;
private bool _closing;
+ // Certificate validation properties
+ public bool RemoteCertificateRequired;
+ public X509RevocationMode RevocationMode = X509RevocationMode.Offline;
+ public RemoteCertificateValidationCallback? RemoteCertificateValidationCallback;
+ public bool IsServer;
+ public string? TargetHost;
+
// Queue for accepted streams.
// Backlog limit is managed by MsQuic so it can be unbounded here.
public readonly Channel AcceptQueue = Channel.CreateUnbounded(new UnboundedChannelOptions()
@@ -131,26 +135,17 @@ public void SetClosing()
internal string TraceId() => _state.TraceId;
// constructor for inbound connections
- public MsQuicConnection(IPEndPoint localEndPoint, IPEndPoint remoteEndPoint, SafeMsQuicConnectionHandle handle, bool remoteCertificateRequired = false, X509RevocationMode revocationMode = X509RevocationMode.Offline, RemoteCertificateValidationCallback? remoteCertificateValidationCallback = null)
+ public MsQuicConnection(IPEndPoint localEndPoint, IPEndPoint remoteEndPoint, SafeMsQuicConnectionHandle handle, bool remoteCertificateRequired = false, X509RevocationMode revocationMode = X509RevocationMode.Offline, RemoteCertificateValidationCallback? remoteCertificateValidationCallback = null, ServerCertificateSelectionCallback? serverCertificateSelectionCallback = null)
{
_state.Handle = handle;
_state.StateGCHandle = GCHandle.Alloc(_state);
_state.Connected = true;
- _isServer = true;
+ _state.RemoteCertificateRequired = remoteCertificateRequired;
+ _state.RevocationMode = revocationMode;
+ _state.RemoteCertificateValidationCallback = remoteCertificateValidationCallback;
+ _state.IsServer = true;
_localEndPoint = localEndPoint;
_remoteEndPoint = remoteEndPoint;
- _remoteCertificateRequired = remoteCertificateRequired;
- _revocationMode = revocationMode;
- _remoteCertificateValidationCallback = remoteCertificateValidationCallback;
-
- if (_remoteCertificateRequired)
- {
- // We need to link connection for the validation callback.
- // We need to be able to find the connection in HandleEventPeerCertificateReceived
- // and dispatch it as sender to validation callback.
- // After that Connection will be set back to null.
- _state.Connection = this;
- }
try
{
@@ -177,12 +172,12 @@ public MsQuicConnection(QuicClientConnectionOptions options)
{
_remoteEndPoint = options.RemoteEndPoint!;
_configuration = SafeMsQuicConfigurationHandle.Create(options);
- _isServer = false;
- _remoteCertificateRequired = true;
+ _state.RemoteCertificateRequired = true;
if (options.ClientAuthenticationOptions != null)
{
- _revocationMode = options.ClientAuthenticationOptions.CertificateRevocationCheckMode;
- _remoteCertificateValidationCallback = options.ClientAuthenticationOptions.RemoteCertificateValidationCallback;
+ _state.RevocationMode = options.ClientAuthenticationOptions.CertificateRevocationCheckMode;
+ _state.RemoteCertificateValidationCallback = options.ClientAuthenticationOptions.RemoteCertificateValidationCallback;
+ _state.TargetHost = options.ClientAuthenticationOptions.TargetHost;
}
_state.StateGCHandle = GCHandle.Alloc(_state);
@@ -231,7 +226,7 @@ private static uint HandleEventConnected(State state, ref ConnectionEvent connec
state.Connection = null;
state.Connected = true;
- state.ConnectTcs.SetResult(MsQuicStatusCodes.Success);
+ state.ConnectTcs!.SetResult(MsQuicStatusCodes.Success);
}
return MsQuicStatusCodes.Success;
@@ -239,14 +234,15 @@ private static uint HandleEventConnected(State state, ref ConnectionEvent connec
private static uint HandleEventShutdownInitiatedByTransport(State state, ref ConnectionEvent connectionEvent)
{
- if (!state.Connected)
+ if (!state.Connected && state.ConnectTcs != null)
{
Debug.Assert(state.Connection != null);
state.Connection = null;
uint hresult = connectionEvent.Data.ShutdownInitiatedByTransport.Status;
Exception ex = QuicExceptionHelpers.CreateExceptionForHResult(hresult, "Connection has been shutdown by transport.");
- state.ConnectTcs.SetException(ExceptionDispatchInfo.SetCurrentStackTrace(ex));
+ state.ConnectTcs!.SetException(ExceptionDispatchInfo.SetCurrentStackTrace(ex));
+ state.ConnectTcs = null;
}
state.AcceptQueue.Writer.TryComplete();
@@ -345,17 +341,6 @@ private static uint HandleEventPeerCertificateReceived(State state, ref Connecti
X509Certificate2? certificate = null;
X509Certificate2Collection? additionalCertificates = null;
- MsQuicConnection? connection = state.Connection;
- if (connection == null)
- {
- return MsQuicStatusCodes.InvalidState;
- }
-
- if (connection._isServer)
- {
- state.Connection = null;
- }
-
try
{
if (connectionEvent.Data.PeerCertificateReceived.PlatformCertificateHandle != IntPtr.Zero)
@@ -386,15 +371,15 @@ private static uint HandleEventPeerCertificateReceived(State state, ref Connecti
if (certificate == null)
{
- if (NetEventSource.Log.IsEnabled() && connection._remoteCertificateRequired) NetEventSource.Error(state, $"{state.TraceId} Remote certificate required, but no remote certificate received");
+ if (NetEventSource.Log.IsEnabled() && state.RemoteCertificateRequired) NetEventSource.Error(state, $"{state.TraceId} Remote certificate required, but no remote certificate received");
sslPolicyErrors |= SslPolicyErrors.RemoteCertificateNotAvailable;
}
else
{
chain = new X509Chain();
- chain.ChainPolicy.RevocationMode = connection._revocationMode;
+ chain.ChainPolicy.RevocationMode = state.RevocationMode;
chain.ChainPolicy.RevocationFlag = X509RevocationFlag.ExcludeRoot;
- chain.ChainPolicy.ApplicationPolicy.Add(connection._isServer ? s_clientAuthOid : s_serverAuthOid);
+ chain.ChainPolicy.ApplicationPolicy.Add(state.IsServer ? s_clientAuthOid : s_serverAuthOid);
if (additionalCertificates != null && additionalCertificates.Count > 1)
{
@@ -407,34 +392,46 @@ private static uint HandleEventPeerCertificateReceived(State state, ref Connecti
}
}
- if (!connection._remoteCertificateRequired)
+ if (!state.RemoteCertificateRequired)
{
sslPolicyErrors &= ~SslPolicyErrors.RemoteCertificateNotAvailable;
}
- if (connection._remoteCertificateValidationCallback != null)
+ if (state.RemoteCertificateValidationCallback != null)
{
- bool success = connection._remoteCertificateValidationCallback(connection, certificate, chain, sslPolicyErrors);
+ bool success = state.RemoteCertificateValidationCallback(state, certificate, chain, sslPolicyErrors);
// Unset the callback to prevent multiple invocations of the callback per a single connection.
// Return the same value as the custom callback just did.
- connection._remoteCertificateValidationCallback = (_, _, _, _) => success;
+ state.RemoteCertificateValidationCallback = (_, _, _, _) => success;
if (!success && NetEventSource.Log.IsEnabled())
NetEventSource.Error(state, $"{state.TraceId} Remote certificate rejected by verification callback");
- return success ? MsQuicStatusCodes.Success : MsQuicStatusCodes.HandshakeFailure;
+
+ if (!success)
+ {
+ throw new AuthenticationException(SR.net_quic_cert_custom_validation);
+ }
+
+ return MsQuicStatusCodes.Success;
}
if (NetEventSource.Log.IsEnabled())
NetEventSource.Info(state, $"{state.TraceId} Certificate validation for '${certificate?.Subject}' finished with ${sslPolicyErrors}");
- return (sslPolicyErrors == SslPolicyErrors.None) ? MsQuicStatusCodes.Success : MsQuicStatusCodes.HandshakeFailure;
+// return (sslPolicyErrors == SslPolicyErrors.None) ? MsQuicStatusCodes.Success : MsQuicStatusCodes.HandshakeFailure;
+
+ if (sslPolicyErrors != SslPolicyErrors.None)
+ {
+ throw new AuthenticationException(SR.Format(SR.net_quic_cert_chain_validation, sslPolicyErrors));
+ }
+
+ return MsQuicStatusCodes.Success;
}
catch (Exception ex)
{
if (NetEventSource.Log.IsEnabled()) NetEventSource.Error(state, $"{state.TraceId} Certificate validation failed ${ex.Message}");
+ throw;
}
-
- return MsQuicStatusCodes.InternalError;
}
internal override async ValueTask AcceptStreamAsync(CancellationToken cancellationToken = default)
@@ -544,13 +541,6 @@ internal override ValueTask ConnectAsync(CancellationToken cancellationToken = d
throw new Exception($"{nameof(ConnectAsync)} must not be called on a connection obtained from a listener.");
}
- (string address, int port) = _remoteEndPoint switch
- {
- DnsEndPoint dnsEp => (dnsEp.Host, dnsEp.Port),
- IPEndPoint ipEp => (ipEp.Address.ToString(), ipEp.Port),
- _ => throw new Exception($"Unsupported remote endpoint type '{_remoteEndPoint.GetType()}'.")
- };
-
QUIC_ADDRESS_FAMILY af = _remoteEndPoint.AddressFamily switch
{
AddressFamily.Unspecified => QUIC_ADDRESS_FAMILY.UNSPEC,
@@ -562,13 +552,43 @@ internal override ValueTask ConnectAsync(CancellationToken cancellationToken = d
Debug.Assert(_state.StateGCHandle.IsAllocated);
_state.Connection = this;
+ uint status;
+ string targetHost;
+ int port;
+
+ if (_remoteEndPoint is IPEndPoint)
+ {
+ SOCKADDR_INET address = MsQuicAddressHelpers.IPEndPointToINet((IPEndPoint)_remoteEndPoint);
+ unsafe
+ {
+ status = MsQuicApi.Api.SetParamDelegate(_state.Handle, QUIC_PARAM_LEVEL.CONNECTION, (uint)QUIC_PARAM_CONN.REMOTE_ADDRESS, (uint)sizeof(SOCKADDR_INET), (byte*)&address);
+ QuicExceptionHelpers.ThrowIfFailed(status, "Failed to connect to peer.");
+ }
+
+ targetHost = _state.TargetHost ?? ((IPEndPoint)_remoteEndPoint).Address.ToString();
+ port = ((IPEndPoint)_remoteEndPoint).Port;
+
+ }
+ else if (_remoteEndPoint is DnsEndPoint)
+ {
+ // We don't have way how to set separate SNI and name for connection at this moment.
+ targetHost = ((DnsEndPoint)_remoteEndPoint).Host;
+ port = ((DnsEndPoint)_remoteEndPoint).Port;
+ }
+ else
+ {
+ throw new Exception($"Unsupported remote endpoint type '{_remoteEndPoint.GetType()}'.");
+ }
+
+ _state.ConnectTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+
try
{
- uint status = MsQuicApi.Api.ConnectionStartDelegate(
+ status = MsQuicApi.Api.ConnectionStartDelegate(
_state.Handle,
_configuration,
af,
- address,
+ targetHost,
(ushort)port);
QuicExceptionHelpers.ThrowIfFailed(status, "Failed to connect to peer.");
@@ -665,10 +685,18 @@ private static uint NativeCallbackHandler(
NetEventSource.Error(state, $"{state.TraceId} Exception occurred during handling {connectionEvent.Type} connection callback: {ex}");
}
- Debug.Fail($"{state.TraceId} Exception occurred during handling {connectionEvent.Type} connection callback: {ex}");
+ if (state.ConnectTcs != null)
+ {
+ state.ConnectTcs.SetException(ex);
+ state.ConnectTcs = null;
+ state.Connection = null;
+ }
+ else
+ {
+ Debug.Fail($"{state.TraceId} Exception occurred during handling {connectionEvent.Type} connection callback: {ex}");
+ }
// TODO: trigger an exception on any outstanding async calls.
-
return MsQuicStatusCodes.InternalError;
}
}
@@ -709,7 +737,7 @@ private void Dispose(bool disposing)
return;
}
- if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(_state, $"{TraceId()} Stream disposing {disposing}");
+ if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(_state, $"{TraceId()} Connection disposing {disposing}");
// If we haven't already shutdown gracefully (via a successful CloseAsync call), then force an abortive shutdown.
if (_state.Handle != null)
diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicListener.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicListener.cs
index 391d26a93ddd3..b5daea6e0d0ed 100644
--- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicListener.cs
+++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicListener.cs
@@ -12,6 +12,7 @@
using System.Threading.Channels;
using System.Threading.Tasks;
using static System.Net.Quic.Implementations.MsQuic.Internal.MsQuicNativeMethods;
+using System.Security.Authentication;
namespace System.Net.Quic.Implementations.MsQuic
{
@@ -31,21 +32,39 @@ private sealed class State
public SafeMsQuicListenerHandle Handle = null!;
public string TraceId = null!; // set in ctor.
- public readonly SafeMsQuicConfigurationHandle ConnectionConfiguration;
+ public readonly SafeMsQuicConfigurationHandle? ConnectionConfiguration;
public readonly Channel AcceptConnectionQueue;
- public bool RemoteCertificateRequired;
- public X509RevocationMode RevocationMode = X509RevocationMode.Offline;
- public RemoteCertificateValidationCallback? RemoteCertificateValidationCallback;
+ public QuicOptions ConnectionOptions = new QuicOptions();
+ public SslServerAuthenticationOptions AuthenticationOptions = new SslServerAuthenticationOptions();
public State(QuicListenerOptions options)
{
- ConnectionConfiguration = SafeMsQuicConfigurationHandle.Create(options);
+ ConnectionOptions.IdleTimeout = options.IdleTimeout;
+ ConnectionOptions.MaxBidirectionalStreams = options.MaxBidirectionalStreams;
+ ConnectionOptions.MaxUnidirectionalStreams = options.MaxUnidirectionalStreams;
+
+ bool delayConfiguration = false;
+
if (options.ServerAuthenticationOptions != null)
{
- RemoteCertificateRequired = options.ServerAuthenticationOptions.ClientCertificateRequired;
- RevocationMode = options.ServerAuthenticationOptions.CertificateRevocationCheckMode;
- RemoteCertificateValidationCallback = options.ServerAuthenticationOptions.RemoteCertificateValidationCallback;
+ AuthenticationOptions.ClientCertificateRequired = options.ServerAuthenticationOptions.ClientCertificateRequired;
+ AuthenticationOptions.CertificateRevocationCheckMode = options.ServerAuthenticationOptions.CertificateRevocationCheckMode;
+ AuthenticationOptions.RemoteCertificateValidationCallback = options.ServerAuthenticationOptions.RemoteCertificateValidationCallback;
+ AuthenticationOptions.ServerCertificateSelectionCallback = options.ServerAuthenticationOptions.ServerCertificateSelectionCallback;
+ AuthenticationOptions.ApplicationProtocols = options.ServerAuthenticationOptions.ApplicationProtocols;
+
+ if (options.ServerAuthenticationOptions.ServerCertificate == null && options.ServerAuthenticationOptions.ServerCertificateContext == null &&
+ options.ServerAuthenticationOptions.ServerCertificateSelectionCallback != null)
+ {
+ // We don't have any certificate but we have selection callback so we need to wait for SNI.
+ delayConfiguration = true;
+ }
+ }
+
+ if (!delayConfiguration)
+ {
+ ConnectionConfiguration = SafeMsQuicConfigurationHandle.Create(options, options.ServerAuthenticationOptions);
}
AcceptConnectionQueue = Channel.CreateBounded(new BoundedChannelOptions(options.ListenBacklog)
@@ -211,6 +230,7 @@ private static unsafe uint NativeCallbackHandler(
var state = (State)gcHandle.Target;
SafeMsQuicConnectionHandle? connectionHandle = null;
+ MsQuicConnection? msQuicConnection = null;
try
{
@@ -218,24 +238,53 @@ private static unsafe uint NativeCallbackHandler(
IPEndPoint localEndPoint = MsQuicAddressHelpers.INetToIPEndPoint(ref *(SOCKADDR_INET*)connectionInfo.LocalAddress);
IPEndPoint remoteEndPoint = MsQuicAddressHelpers.INetToIPEndPoint(ref *(SOCKADDR_INET*)connectionInfo.RemoteAddress);
+ string targetHost = string.Empty; // compat with SslStream
+ if (connectionInfo.ServerNameLength > 0 && connectionInfo.ServerName != IntPtr.Zero)
+ {
+ // TBD We should figure out what to do with international names.
+ targetHost = Marshal.PtrToStringAnsi(connectionInfo.ServerName, connectionInfo.ServerNameLength);
+ }
- connectionHandle = new SafeMsQuicConnectionHandle(evt.Data.NewConnection.Connection);
+ SafeMsQuicConfigurationHandle? connectionConfiguration = state.ConnectionConfiguration;
- uint status = MsQuicApi.Api.ConnectionSetConfigurationDelegate(connectionHandle, state.ConnectionConfiguration);
- QuicExceptionHelpers.ThrowIfFailed(status, "ConnectionSetConfiguration failed.");
+ if (connectionConfiguration == null)
+ {
+ Debug.Assert(state.AuthenticationOptions.ServerCertificateSelectionCallback != null);
+ try
+ {
+ // ServerCertificateSelectionCallback is synchronous. We will call it as needed when building configuration
+ connectionConfiguration = SafeMsQuicConfigurationHandle.Create(state.ConnectionOptions, state.AuthenticationOptions, targetHost);
+ }
+ catch (Exception ex)
+ {
+ if (NetEventSource.Log.IsEnabled())
+ {
+ NetEventSource.Error(state, $"[Listener#{state.GetHashCode()}] Exception occurred during creating configuration in connection callback: {ex}");
+ }
+ }
+
+ if (connectionConfiguration == null)
+ {
+ // We don't have safe handle yet so MsQuic will cleanup new connection.
+ return MsQuicStatusCodes.InternalError;
+ }
+ }
- var msQuicConnection = new MsQuicConnection(localEndPoint, remoteEndPoint, connectionHandle, state.RemoteCertificateRequired, state.RevocationMode, state.RemoteCertificateValidationCallback);
- msQuicConnection.SetNegotiatedAlpn(connectionInfo.NegotiatedAlpn, connectionInfo.NegotiatedAlpnLength);
+ connectionHandle = new SafeMsQuicConnectionHandle(evt.Data.NewConnection.Connection);
- if (!state.AcceptConnectionQueue.Writer.TryWrite(msQuicConnection))
+ uint status = MsQuicApi.Api.ConnectionSetConfigurationDelegate(connectionHandle, connectionConfiguration);
+ if (MsQuicStatusHelper.SuccessfulStatusCode(status))
{
- // This handle will be cleaned up by MsQuic.
- connectionHandle.SetHandleAsInvalid();
- msQuicConnection.Dispose();
- return MsQuicStatusCodes.InternalError;
+ msQuicConnection = new MsQuicConnection(localEndPoint, remoteEndPoint, connectionHandle, state.AuthenticationOptions.ClientCertificateRequired, state.AuthenticationOptions.CertificateRevocationCheckMode, state.AuthenticationOptions.RemoteCertificateValidationCallback);
+ msQuicConnection.SetNegotiatedAlpn(connectionInfo.NegotiatedAlpn, connectionInfo.NegotiatedAlpnLength);
+
+ if (state.AcceptConnectionQueue.Writer.TryWrite(msQuicConnection))
+ {
+ return MsQuicStatusCodes.Success;
+ }
}
- return MsQuicStatusCodes.Success;
+ // If we fall-through here something wrong happened.
}
catch (Exception ex)
{
@@ -243,14 +292,12 @@ private static unsafe uint NativeCallbackHandler(
{
NetEventSource.Error(state, $"[Listener#{state.GetHashCode()}] Exception occurred during handling {(QUIC_LISTENER_EVENT)evt.Type} connection callback: {ex}");
}
-
- Debug.Fail($"[Listener#{state.GetHashCode()}] Exception occurred during handling {(QUIC_LISTENER_EVENT)evt.Type} connection callback: {ex}");
-
- // This handle will be cleaned up by MsQuic by returning InternalError.
- connectionHandle?.SetHandleAsInvalid();
- state.AcceptConnectionQueue.Writer.TryComplete(ex);
- return MsQuicStatusCodes.InternalError;
}
+
+ // This handle will be cleaned up by MsQuic by returning InternalError.
+ connectionHandle?.SetHandleAsInvalid();
+ msQuicConnection?.Dispose();
+ return MsQuicStatusCodes.InternalError;
}
private void ThrowIfDisposed()
diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs
index 83846baf6cd6e..e4486c92b8a26 100644
--- a/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs
+++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs
@@ -6,8 +6,11 @@
using System.Diagnostics;
using System.Linq;
using System.Net.Security;
+using System.Net.Sockets;
+using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates;
using System.Text;
+using System.Threading;
using System.Threading.Tasks;
using Xunit;
using Xunit.Abstractions;
@@ -115,6 +118,126 @@ public async Task ConnectWithCertificateChain()
await clientTask;
}
+ [Fact]
+ public async Task CertificateCallbackThrowPropagates()
+ {
+ using CancellationTokenSource cts = new CancellationTokenSource(PassingTestTimeout);
+ X509Certificate? receivedCertificate = null;
+
+ var quicOptions = new QuicListenerOptions();
+ quicOptions.ListenEndPoint = new IPEndPoint( Socket.OSSupportsIPv6 ? IPAddress.IPv6Loopback : IPAddress.Loopback, 0);
+ quicOptions.ServerAuthenticationOptions = GetSslServerAuthenticationOptions();
+
+ using QuicListener listener = new QuicListener(QuicImplementationProviders.MsQuic, quicOptions);
+
+ QuicClientConnectionOptions options = new QuicClientConnectionOptions()
+ {
+ RemoteEndPoint = listener.ListenEndPoint,
+ ClientAuthenticationOptions = GetSslClientAuthenticationOptions(),
+ };
+
+ options.ClientAuthenticationOptions.RemoteCertificateValidationCallback = (sender, cert, chain, errors) =>
+ {
+ receivedCertificate = cert;
+ throw new ArithmeticException("foobar");
+ };
+
+ options.ClientAuthenticationOptions.TargetHost = "foobar1";
+
+ QuicConnection clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, options);
+
+ Task serverTask = listener.AcceptConnectionAsync(cts.Token).AsTask();
+ await Assert.ThrowsAsync(() => clientConnection.ConnectAsync(cts.Token).AsTask());
+ QuicConnection serverConnection = await serverTask;
+
+ Assert.Equal(quicOptions.ServerAuthenticationOptions.ServerCertificate, receivedCertificate);
+
+ clientConnection.Dispose();
+ serverConnection.Dispose();
+ }
+
+ [Fact]
+ public async Task ConnectWithCertificateCallback()
+ {
+ X509Certificate2 c1 = System.Net.Test.Common.Configuration.Certificates.GetServerCertificate();
+ X509Certificate2 c2 = System.Net.Test.Common.Configuration.Certificates.GetClientCertificate(); // This 'wrong' certificate but should be sufficient
+ X509Certificate2 expectedCertificate = c1;
+
+ using CancellationTokenSource cts = new CancellationTokenSource();
+ cts.CancelAfter(PassingTestTimeout);
+ string? receivedHostName = null;
+ X509Certificate? receivedCertificate = null;
+
+ var quicOptions = new QuicListenerOptions();
+ quicOptions.ListenEndPoint = new IPEndPoint( Socket.OSSupportsIPv6 ? IPAddress.IPv6Loopback : IPAddress.Loopback, 0);
+ quicOptions.ServerAuthenticationOptions = GetSslServerAuthenticationOptions();
+ quicOptions.ServerAuthenticationOptions.ServerCertificate = null;
+ quicOptions.ServerAuthenticationOptions.ServerCertificateSelectionCallback = (sender, hostName) =>
+ {
+ receivedHostName = hostName;
+ if (hostName == "foobar1")
+ {
+ return c1;
+ }
+ else if (hostName == "foobar2")
+ {
+ return c2;
+ }
+
+ return null;
+ };
+
+ using QuicListener listener = new QuicListener(QuicImplementationProviders.MsQuic, quicOptions);
+
+ QuicClientConnectionOptions options = new QuicClientConnectionOptions()
+ {
+ RemoteEndPoint = listener.ListenEndPoint,
+ ClientAuthenticationOptions = GetSslClientAuthenticationOptions(),
+ };
+
+ options.ClientAuthenticationOptions.RemoteCertificateValidationCallback = (sender, cert, chain, errors) =>
+ {
+ receivedCertificate = cert;
+ return true;
+ };
+
+ options.ClientAuthenticationOptions.TargetHost = "foobar1";
+
+ QuicConnection clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, options);
+
+ Task serverTask = listener.AcceptConnectionAsync(cts.Token).AsTask();
+ await new Task[] { clientConnection.ConnectAsync().AsTask(), serverTask}.WhenAllOrAnyFailed(PassingTestTimeoutMilliseconds);
+ QuicConnection serverConnection = serverTask.Result;
+
+ Assert.Equal(options.ClientAuthenticationOptions.TargetHost, receivedHostName);
+ Assert.Equal(c1, receivedCertificate);
+ clientConnection.Dispose();
+ serverConnection.Dispose();
+
+ // This should fail when callback return null.
+ options.ClientAuthenticationOptions.TargetHost = "foobar3";
+ clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, options);
+ Task clientTask = clientConnection.ConnectAsync(cts.Token).AsTask();
+
+ await Assert.ThrowsAsync(() => clientTask);
+ Assert.Equal(options.ClientAuthenticationOptions.TargetHost, receivedHostName);
+ clientConnection.Dispose();
+
+ // Do this last to make sure Listener is still functional.
+ options.ClientAuthenticationOptions.TargetHost = "foobar2";
+ expectedCertificate = c2;
+
+ clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, options);
+ serverTask = listener.AcceptConnectionAsync(cts.Token).AsTask();
+ await new Task[] { clientConnection.ConnectAsync().AsTask(), serverTask}.WhenAllOrAnyFailed(PassingTestTimeoutMilliseconds);
+ serverConnection = serverTask.Result;
+
+ Assert.Equal(options.ClientAuthenticationOptions.TargetHost, receivedHostName);
+ Assert.Equal(c2, receivedCertificate);
+ clientConnection.Dispose();
+ serverConnection.Dispose();
+ }
+
[Fact]
[PlatformSpecific(TestPlatforms.Windows)]
[ActiveIssue("https://github.com/microsoft/msquic/pull/1728")]
diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs
index e76591f4e0436..9a6e43a863e5b 100644
--- a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs
+++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs
@@ -28,6 +28,9 @@ public abstract class QuicTestBase
public X509Certificate2 ServerCertificate = System.Net.Test.Common.Configuration.Certificates.GetServerCertificate();
public X509Certificate2 ClientCertificate = System.Net.Test.Common.Configuration.Certificates.GetClientCertificate();
+ public const int PassingTestTimeoutMilliseconds = 4 * 60 * 1000;
+ public static TimeSpan PassingTestTimeout => TimeSpan.FromMilliseconds(PassingTestTimeoutMilliseconds);
+
public bool RemoteCertificateValidationCallback(object sender, X509Certificate? certificate, X509Chain? chain, SslPolicyErrors sslPolicyErrors)
{
Assert.Equal(ServerCertificate.GetCertHash(), certificate?.GetCertHash());
@@ -48,7 +51,8 @@ public SslClientAuthenticationOptions GetSslClientAuthenticationOptions()
return new SslClientAuthenticationOptions()
{
ApplicationProtocols = new List() { ApplicationProtocol },
- RemoteCertificateValidationCallback = RemoteCertificateValidationCallback
+ RemoteCertificateValidationCallback = RemoteCertificateValidationCallback,
+ TargetHost = "localhost"
};
}