diff --git a/src/libraries/System.Net.Quic/src/Resources/Strings.resx b/src/libraries/System.Net.Quic/src/Resources/Strings.resx index a29352a0578f5..51c31a9db1ca5 100644 --- a/src/libraries/System.Net.Quic/src/Resources/Strings.resx +++ b/src/libraries/System.Net.Quic/src/Resources/Strings.resx @@ -150,5 +150,14 @@ Writing is not allowed on stream. + + '{0}' is not supported by System.Net.Quic. + + + The remote certificate was rejected by the provided RemoteCertificateValidationCallback. + + + The remote certificate is invalid because of errors in the certificate chain: {0} + diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Interop/SafeMsQuicConfigurationHandle.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Interop/SafeMsQuicConfigurationHandle.cs index df48e0db377ae..aa4c58929f92c 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Interop/SafeMsQuicConfigurationHandle.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Interop/SafeMsQuicConfigurationHandle.cs @@ -17,7 +17,7 @@ namespace System.Net.Quic.Implementations.MsQuic.Internal internal sealed class SafeMsQuicConfigurationHandle : SafeHandle { private static readonly FieldInfo _contextCertificate = typeof(SslStreamCertificateContext).GetField("Certificate", BindingFlags.NonPublic | BindingFlags.Instance)!; - private static readonly FieldInfo _contextChain= typeof(SslStreamCertificateContext).GetField("IntermediateCertificates", BindingFlags.NonPublic | BindingFlags.Instance)!; + private static readonly FieldInfo _contextChain = typeof(SslStreamCertificateContext).GetField("IntermediateCertificates", BindingFlags.NonPublic | BindingFlags.Instance)!; public override bool IsInvalid => handle == IntPtr.Zero; @@ -33,7 +33,7 @@ protected override bool ReleaseHandle() } // TODO: consider moving the static code from here to keep all the handle classes small and simple. - public static unsafe SafeMsQuicConfigurationHandle Create(QuicClientConnectionOptions options) + public static SafeMsQuicConfigurationHandle Create(QuicClientConnectionOptions options) { X509Certificate? certificate = null; if (options.ClientAuthenticationOptions?.ClientCertificates != null) @@ -56,15 +56,35 @@ public static unsafe SafeMsQuicConfigurationHandle Create(QuicClientConnectionOp return Create(options, QUIC_CREDENTIAL_FLAGS.CLIENT, certificate: certificate, certificateContext: null, options.ClientAuthenticationOptions?.ApplicationProtocols); } - public static unsafe SafeMsQuicConfigurationHandle Create(QuicListenerOptions options) + public static SafeMsQuicConfigurationHandle Create(QuicOptions options, SslServerAuthenticationOptions? serverAuthenticationOptions, string? targetHost = null) { QUIC_CREDENTIAL_FLAGS flags = QUIC_CREDENTIAL_FLAGS.NONE; - if (options.ServerAuthenticationOptions != null && options.ServerAuthenticationOptions.ClientCertificateRequired) + X509Certificate? certificate = serverAuthenticationOptions?.ServerCertificate; + + if (serverAuthenticationOptions != null) { - flags |= QUIC_CREDENTIAL_FLAGS.REQUIRE_CLIENT_AUTHENTICATION | QUIC_CREDENTIAL_FLAGS.INDICATE_CERTIFICATE_RECEIVED | QUIC_CREDENTIAL_FLAGS.NO_CERTIFICATE_VALIDATION; + if (serverAuthenticationOptions.CipherSuitesPolicy != null) + { + throw new PlatformNotSupportedException(SR.Format(SR.net_quic_ssl_option, nameof(serverAuthenticationOptions.CipherSuitesPolicy))); + } + + if (serverAuthenticationOptions.EncryptionPolicy == EncryptionPolicy.NoEncryption) + { + throw new PlatformNotSupportedException(SR.Format(SR.net_quic_ssl_option, nameof(serverAuthenticationOptions.EncryptionPolicy))); + } + + if (serverAuthenticationOptions.ClientCertificateRequired) + { + flags |= QUIC_CREDENTIAL_FLAGS.REQUIRE_CLIENT_AUTHENTICATION | QUIC_CREDENTIAL_FLAGS.INDICATE_CERTIFICATE_RECEIVED | QUIC_CREDENTIAL_FLAGS.NO_CERTIFICATE_VALIDATION; + } + + if (certificate == null && serverAuthenticationOptions?.ServerCertificateSelectionCallback != null && targetHost != null) + { + certificate = serverAuthenticationOptions.ServerCertificateSelectionCallback(options, targetHost); + } } - return Create(options, flags, options.ServerAuthenticationOptions?.ServerCertificate, options.ServerAuthenticationOptions?.ServerCertificateContext, options.ServerAuthenticationOptions?.ApplicationProtocols); + return Create(options, flags, certificate, serverAuthenticationOptions?.ServerCertificateContext, serverAuthenticationOptions?.ApplicationProtocols); } // TODO: this is called from MsQuicListener and when it fails it wreaks havoc in MsQuicListener finalizer. diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs index 6fb6588958513..f3dcbf2ca706f 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs @@ -7,6 +7,7 @@ using System.Net.Sockets; using System.Runtime.ExceptionServices; using System.Runtime.InteropServices; +using System.Security.Authentication; using System.Security.Cryptography; using System.Security.Cryptography.X509Certificates; using System.Threading; @@ -35,10 +36,6 @@ internal sealed class MsQuicConnection : QuicConnectionProvider private IPEndPoint? _localEndPoint; private readonly EndPoint _remoteEndPoint; private SslApplicationProtocol _negotiatedAlpnProtocol; - private bool _isServer; - private bool _remoteCertificateRequired; - private X509RevocationMode _revocationMode = X509RevocationMode.Offline; - private RemoteCertificateValidationCallback? _remoteCertificateValidationCallback; internal sealed class State { @@ -50,8 +47,8 @@ internal sealed class State // These exists to prevent GC of the MsQuicConnection in the middle of an async op (Connect or Shutdown). public MsQuicConnection? Connection; - // TODO: only allocate these when there is an outstanding connect/shutdown. - public readonly TaskCompletionSource ConnectTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + public TaskCompletionSource? ConnectTcs; + // TODO: only allocate these when there is an outstanding shutdown. public readonly TaskCompletionSource ShutdownTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); // Note that there's no such thing as resetable TCS, so we cannot reuse the same instance after we've set the result. @@ -65,6 +62,13 @@ internal sealed class State public int StreamCount; private bool _closing; + // Certificate validation properties + public bool RemoteCertificateRequired; + public X509RevocationMode RevocationMode = X509RevocationMode.Offline; + public RemoteCertificateValidationCallback? RemoteCertificateValidationCallback; + public bool IsServer; + public string? TargetHost; + // Queue for accepted streams. // Backlog limit is managed by MsQuic so it can be unbounded here. public readonly Channel AcceptQueue = Channel.CreateUnbounded(new UnboundedChannelOptions() @@ -131,26 +135,17 @@ public void SetClosing() internal string TraceId() => _state.TraceId; // constructor for inbound connections - public MsQuicConnection(IPEndPoint localEndPoint, IPEndPoint remoteEndPoint, SafeMsQuicConnectionHandle handle, bool remoteCertificateRequired = false, X509RevocationMode revocationMode = X509RevocationMode.Offline, RemoteCertificateValidationCallback? remoteCertificateValidationCallback = null) + public MsQuicConnection(IPEndPoint localEndPoint, IPEndPoint remoteEndPoint, SafeMsQuicConnectionHandle handle, bool remoteCertificateRequired = false, X509RevocationMode revocationMode = X509RevocationMode.Offline, RemoteCertificateValidationCallback? remoteCertificateValidationCallback = null, ServerCertificateSelectionCallback? serverCertificateSelectionCallback = null) { _state.Handle = handle; _state.StateGCHandle = GCHandle.Alloc(_state); _state.Connected = true; - _isServer = true; + _state.RemoteCertificateRequired = remoteCertificateRequired; + _state.RevocationMode = revocationMode; + _state.RemoteCertificateValidationCallback = remoteCertificateValidationCallback; + _state.IsServer = true; _localEndPoint = localEndPoint; _remoteEndPoint = remoteEndPoint; - _remoteCertificateRequired = remoteCertificateRequired; - _revocationMode = revocationMode; - _remoteCertificateValidationCallback = remoteCertificateValidationCallback; - - if (_remoteCertificateRequired) - { - // We need to link connection for the validation callback. - // We need to be able to find the connection in HandleEventPeerCertificateReceived - // and dispatch it as sender to validation callback. - // After that Connection will be set back to null. - _state.Connection = this; - } try { @@ -177,12 +172,12 @@ public MsQuicConnection(QuicClientConnectionOptions options) { _remoteEndPoint = options.RemoteEndPoint!; _configuration = SafeMsQuicConfigurationHandle.Create(options); - _isServer = false; - _remoteCertificateRequired = true; + _state.RemoteCertificateRequired = true; if (options.ClientAuthenticationOptions != null) { - _revocationMode = options.ClientAuthenticationOptions.CertificateRevocationCheckMode; - _remoteCertificateValidationCallback = options.ClientAuthenticationOptions.RemoteCertificateValidationCallback; + _state.RevocationMode = options.ClientAuthenticationOptions.CertificateRevocationCheckMode; + _state.RemoteCertificateValidationCallback = options.ClientAuthenticationOptions.RemoteCertificateValidationCallback; + _state.TargetHost = options.ClientAuthenticationOptions.TargetHost; } _state.StateGCHandle = GCHandle.Alloc(_state); @@ -231,7 +226,7 @@ private static uint HandleEventConnected(State state, ref ConnectionEvent connec state.Connection = null; state.Connected = true; - state.ConnectTcs.SetResult(MsQuicStatusCodes.Success); + state.ConnectTcs!.SetResult(MsQuicStatusCodes.Success); } return MsQuicStatusCodes.Success; @@ -239,14 +234,15 @@ private static uint HandleEventConnected(State state, ref ConnectionEvent connec private static uint HandleEventShutdownInitiatedByTransport(State state, ref ConnectionEvent connectionEvent) { - if (!state.Connected) + if (!state.Connected && state.ConnectTcs != null) { Debug.Assert(state.Connection != null); state.Connection = null; uint hresult = connectionEvent.Data.ShutdownInitiatedByTransport.Status; Exception ex = QuicExceptionHelpers.CreateExceptionForHResult(hresult, "Connection has been shutdown by transport."); - state.ConnectTcs.SetException(ExceptionDispatchInfo.SetCurrentStackTrace(ex)); + state.ConnectTcs!.SetException(ExceptionDispatchInfo.SetCurrentStackTrace(ex)); + state.ConnectTcs = null; } state.AcceptQueue.Writer.TryComplete(); @@ -345,17 +341,6 @@ private static uint HandleEventPeerCertificateReceived(State state, ref Connecti X509Certificate2? certificate = null; X509Certificate2Collection? additionalCertificates = null; - MsQuicConnection? connection = state.Connection; - if (connection == null) - { - return MsQuicStatusCodes.InvalidState; - } - - if (connection._isServer) - { - state.Connection = null; - } - try { if (connectionEvent.Data.PeerCertificateReceived.PlatformCertificateHandle != IntPtr.Zero) @@ -386,15 +371,15 @@ private static uint HandleEventPeerCertificateReceived(State state, ref Connecti if (certificate == null) { - if (NetEventSource.Log.IsEnabled() && connection._remoteCertificateRequired) NetEventSource.Error(state, $"{state.TraceId} Remote certificate required, but no remote certificate received"); + if (NetEventSource.Log.IsEnabled() && state.RemoteCertificateRequired) NetEventSource.Error(state, $"{state.TraceId} Remote certificate required, but no remote certificate received"); sslPolicyErrors |= SslPolicyErrors.RemoteCertificateNotAvailable; } else { chain = new X509Chain(); - chain.ChainPolicy.RevocationMode = connection._revocationMode; + chain.ChainPolicy.RevocationMode = state.RevocationMode; chain.ChainPolicy.RevocationFlag = X509RevocationFlag.ExcludeRoot; - chain.ChainPolicy.ApplicationPolicy.Add(connection._isServer ? s_clientAuthOid : s_serverAuthOid); + chain.ChainPolicy.ApplicationPolicy.Add(state.IsServer ? s_clientAuthOid : s_serverAuthOid); if (additionalCertificates != null && additionalCertificates.Count > 1) { @@ -407,34 +392,46 @@ private static uint HandleEventPeerCertificateReceived(State state, ref Connecti } } - if (!connection._remoteCertificateRequired) + if (!state.RemoteCertificateRequired) { sslPolicyErrors &= ~SslPolicyErrors.RemoteCertificateNotAvailable; } - if (connection._remoteCertificateValidationCallback != null) + if (state.RemoteCertificateValidationCallback != null) { - bool success = connection._remoteCertificateValidationCallback(connection, certificate, chain, sslPolicyErrors); + bool success = state.RemoteCertificateValidationCallback(state, certificate, chain, sslPolicyErrors); // Unset the callback to prevent multiple invocations of the callback per a single connection. // Return the same value as the custom callback just did. - connection._remoteCertificateValidationCallback = (_, _, _, _) => success; + state.RemoteCertificateValidationCallback = (_, _, _, _) => success; if (!success && NetEventSource.Log.IsEnabled()) NetEventSource.Error(state, $"{state.TraceId} Remote certificate rejected by verification callback"); - return success ? MsQuicStatusCodes.Success : MsQuicStatusCodes.HandshakeFailure; + + if (!success) + { + throw new AuthenticationException(SR.net_quic_cert_custom_validation); + } + + return MsQuicStatusCodes.Success; } if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(state, $"{state.TraceId} Certificate validation for '${certificate?.Subject}' finished with ${sslPolicyErrors}"); - return (sslPolicyErrors == SslPolicyErrors.None) ? MsQuicStatusCodes.Success : MsQuicStatusCodes.HandshakeFailure; +// return (sslPolicyErrors == SslPolicyErrors.None) ? MsQuicStatusCodes.Success : MsQuicStatusCodes.HandshakeFailure; + + if (sslPolicyErrors != SslPolicyErrors.None) + { + throw new AuthenticationException(SR.Format(SR.net_quic_cert_chain_validation, sslPolicyErrors)); + } + + return MsQuicStatusCodes.Success; } catch (Exception ex) { if (NetEventSource.Log.IsEnabled()) NetEventSource.Error(state, $"{state.TraceId} Certificate validation failed ${ex.Message}"); + throw; } - - return MsQuicStatusCodes.InternalError; } internal override async ValueTask AcceptStreamAsync(CancellationToken cancellationToken = default) @@ -544,13 +541,6 @@ internal override ValueTask ConnectAsync(CancellationToken cancellationToken = d throw new Exception($"{nameof(ConnectAsync)} must not be called on a connection obtained from a listener."); } - (string address, int port) = _remoteEndPoint switch - { - DnsEndPoint dnsEp => (dnsEp.Host, dnsEp.Port), - IPEndPoint ipEp => (ipEp.Address.ToString(), ipEp.Port), - _ => throw new Exception($"Unsupported remote endpoint type '{_remoteEndPoint.GetType()}'.") - }; - QUIC_ADDRESS_FAMILY af = _remoteEndPoint.AddressFamily switch { AddressFamily.Unspecified => QUIC_ADDRESS_FAMILY.UNSPEC, @@ -562,13 +552,43 @@ internal override ValueTask ConnectAsync(CancellationToken cancellationToken = d Debug.Assert(_state.StateGCHandle.IsAllocated); _state.Connection = this; + uint status; + string targetHost; + int port; + + if (_remoteEndPoint is IPEndPoint) + { + SOCKADDR_INET address = MsQuicAddressHelpers.IPEndPointToINet((IPEndPoint)_remoteEndPoint); + unsafe + { + status = MsQuicApi.Api.SetParamDelegate(_state.Handle, QUIC_PARAM_LEVEL.CONNECTION, (uint)QUIC_PARAM_CONN.REMOTE_ADDRESS, (uint)sizeof(SOCKADDR_INET), (byte*)&address); + QuicExceptionHelpers.ThrowIfFailed(status, "Failed to connect to peer."); + } + + targetHost = _state.TargetHost ?? ((IPEndPoint)_remoteEndPoint).Address.ToString(); + port = ((IPEndPoint)_remoteEndPoint).Port; + + } + else if (_remoteEndPoint is DnsEndPoint) + { + // We don't have way how to set separate SNI and name for connection at this moment. + targetHost = ((DnsEndPoint)_remoteEndPoint).Host; + port = ((DnsEndPoint)_remoteEndPoint).Port; + } + else + { + throw new Exception($"Unsupported remote endpoint type '{_remoteEndPoint.GetType()}'."); + } + + _state.ConnectTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + try { - uint status = MsQuicApi.Api.ConnectionStartDelegate( + status = MsQuicApi.Api.ConnectionStartDelegate( _state.Handle, _configuration, af, - address, + targetHost, (ushort)port); QuicExceptionHelpers.ThrowIfFailed(status, "Failed to connect to peer."); @@ -665,10 +685,18 @@ private static uint NativeCallbackHandler( NetEventSource.Error(state, $"{state.TraceId} Exception occurred during handling {connectionEvent.Type} connection callback: {ex}"); } - Debug.Fail($"{state.TraceId} Exception occurred during handling {connectionEvent.Type} connection callback: {ex}"); + if (state.ConnectTcs != null) + { + state.ConnectTcs.SetException(ex); + state.ConnectTcs = null; + state.Connection = null; + } + else + { + Debug.Fail($"{state.TraceId} Exception occurred during handling {connectionEvent.Type} connection callback: {ex}"); + } // TODO: trigger an exception on any outstanding async calls. - return MsQuicStatusCodes.InternalError; } } @@ -709,7 +737,7 @@ private void Dispose(bool disposing) return; } - if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(_state, $"{TraceId()} Stream disposing {disposing}"); + if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(_state, $"{TraceId()} Connection disposing {disposing}"); // If we haven't already shutdown gracefully (via a successful CloseAsync call), then force an abortive shutdown. if (_state.Handle != null) diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicListener.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicListener.cs index 391d26a93ddd3..b5daea6e0d0ed 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicListener.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicListener.cs @@ -12,6 +12,7 @@ using System.Threading.Channels; using System.Threading.Tasks; using static System.Net.Quic.Implementations.MsQuic.Internal.MsQuicNativeMethods; +using System.Security.Authentication; namespace System.Net.Quic.Implementations.MsQuic { @@ -31,21 +32,39 @@ private sealed class State public SafeMsQuicListenerHandle Handle = null!; public string TraceId = null!; // set in ctor. - public readonly SafeMsQuicConfigurationHandle ConnectionConfiguration; + public readonly SafeMsQuicConfigurationHandle? ConnectionConfiguration; public readonly Channel AcceptConnectionQueue; - public bool RemoteCertificateRequired; - public X509RevocationMode RevocationMode = X509RevocationMode.Offline; - public RemoteCertificateValidationCallback? RemoteCertificateValidationCallback; + public QuicOptions ConnectionOptions = new QuicOptions(); + public SslServerAuthenticationOptions AuthenticationOptions = new SslServerAuthenticationOptions(); public State(QuicListenerOptions options) { - ConnectionConfiguration = SafeMsQuicConfigurationHandle.Create(options); + ConnectionOptions.IdleTimeout = options.IdleTimeout; + ConnectionOptions.MaxBidirectionalStreams = options.MaxBidirectionalStreams; + ConnectionOptions.MaxUnidirectionalStreams = options.MaxUnidirectionalStreams; + + bool delayConfiguration = false; + if (options.ServerAuthenticationOptions != null) { - RemoteCertificateRequired = options.ServerAuthenticationOptions.ClientCertificateRequired; - RevocationMode = options.ServerAuthenticationOptions.CertificateRevocationCheckMode; - RemoteCertificateValidationCallback = options.ServerAuthenticationOptions.RemoteCertificateValidationCallback; + AuthenticationOptions.ClientCertificateRequired = options.ServerAuthenticationOptions.ClientCertificateRequired; + AuthenticationOptions.CertificateRevocationCheckMode = options.ServerAuthenticationOptions.CertificateRevocationCheckMode; + AuthenticationOptions.RemoteCertificateValidationCallback = options.ServerAuthenticationOptions.RemoteCertificateValidationCallback; + AuthenticationOptions.ServerCertificateSelectionCallback = options.ServerAuthenticationOptions.ServerCertificateSelectionCallback; + AuthenticationOptions.ApplicationProtocols = options.ServerAuthenticationOptions.ApplicationProtocols; + + if (options.ServerAuthenticationOptions.ServerCertificate == null && options.ServerAuthenticationOptions.ServerCertificateContext == null && + options.ServerAuthenticationOptions.ServerCertificateSelectionCallback != null) + { + // We don't have any certificate but we have selection callback so we need to wait for SNI. + delayConfiguration = true; + } + } + + if (!delayConfiguration) + { + ConnectionConfiguration = SafeMsQuicConfigurationHandle.Create(options, options.ServerAuthenticationOptions); } AcceptConnectionQueue = Channel.CreateBounded(new BoundedChannelOptions(options.ListenBacklog) @@ -211,6 +230,7 @@ private static unsafe uint NativeCallbackHandler( var state = (State)gcHandle.Target; SafeMsQuicConnectionHandle? connectionHandle = null; + MsQuicConnection? msQuicConnection = null; try { @@ -218,24 +238,53 @@ private static unsafe uint NativeCallbackHandler( IPEndPoint localEndPoint = MsQuicAddressHelpers.INetToIPEndPoint(ref *(SOCKADDR_INET*)connectionInfo.LocalAddress); IPEndPoint remoteEndPoint = MsQuicAddressHelpers.INetToIPEndPoint(ref *(SOCKADDR_INET*)connectionInfo.RemoteAddress); + string targetHost = string.Empty; // compat with SslStream + if (connectionInfo.ServerNameLength > 0 && connectionInfo.ServerName != IntPtr.Zero) + { + // TBD We should figure out what to do with international names. + targetHost = Marshal.PtrToStringAnsi(connectionInfo.ServerName, connectionInfo.ServerNameLength); + } - connectionHandle = new SafeMsQuicConnectionHandle(evt.Data.NewConnection.Connection); + SafeMsQuicConfigurationHandle? connectionConfiguration = state.ConnectionConfiguration; - uint status = MsQuicApi.Api.ConnectionSetConfigurationDelegate(connectionHandle, state.ConnectionConfiguration); - QuicExceptionHelpers.ThrowIfFailed(status, "ConnectionSetConfiguration failed."); + if (connectionConfiguration == null) + { + Debug.Assert(state.AuthenticationOptions.ServerCertificateSelectionCallback != null); + try + { + // ServerCertificateSelectionCallback is synchronous. We will call it as needed when building configuration + connectionConfiguration = SafeMsQuicConfigurationHandle.Create(state.ConnectionOptions, state.AuthenticationOptions, targetHost); + } + catch (Exception ex) + { + if (NetEventSource.Log.IsEnabled()) + { + NetEventSource.Error(state, $"[Listener#{state.GetHashCode()}] Exception occurred during creating configuration in connection callback: {ex}"); + } + } + + if (connectionConfiguration == null) + { + // We don't have safe handle yet so MsQuic will cleanup new connection. + return MsQuicStatusCodes.InternalError; + } + } - var msQuicConnection = new MsQuicConnection(localEndPoint, remoteEndPoint, connectionHandle, state.RemoteCertificateRequired, state.RevocationMode, state.RemoteCertificateValidationCallback); - msQuicConnection.SetNegotiatedAlpn(connectionInfo.NegotiatedAlpn, connectionInfo.NegotiatedAlpnLength); + connectionHandle = new SafeMsQuicConnectionHandle(evt.Data.NewConnection.Connection); - if (!state.AcceptConnectionQueue.Writer.TryWrite(msQuicConnection)) + uint status = MsQuicApi.Api.ConnectionSetConfigurationDelegate(connectionHandle, connectionConfiguration); + if (MsQuicStatusHelper.SuccessfulStatusCode(status)) { - // This handle will be cleaned up by MsQuic. - connectionHandle.SetHandleAsInvalid(); - msQuicConnection.Dispose(); - return MsQuicStatusCodes.InternalError; + msQuicConnection = new MsQuicConnection(localEndPoint, remoteEndPoint, connectionHandle, state.AuthenticationOptions.ClientCertificateRequired, state.AuthenticationOptions.CertificateRevocationCheckMode, state.AuthenticationOptions.RemoteCertificateValidationCallback); + msQuicConnection.SetNegotiatedAlpn(connectionInfo.NegotiatedAlpn, connectionInfo.NegotiatedAlpnLength); + + if (state.AcceptConnectionQueue.Writer.TryWrite(msQuicConnection)) + { + return MsQuicStatusCodes.Success; + } } - return MsQuicStatusCodes.Success; + // If we fall-through here something wrong happened. } catch (Exception ex) { @@ -243,14 +292,12 @@ private static unsafe uint NativeCallbackHandler( { NetEventSource.Error(state, $"[Listener#{state.GetHashCode()}] Exception occurred during handling {(QUIC_LISTENER_EVENT)evt.Type} connection callback: {ex}"); } - - Debug.Fail($"[Listener#{state.GetHashCode()}] Exception occurred during handling {(QUIC_LISTENER_EVENT)evt.Type} connection callback: {ex}"); - - // This handle will be cleaned up by MsQuic by returning InternalError. - connectionHandle?.SetHandleAsInvalid(); - state.AcceptConnectionQueue.Writer.TryComplete(ex); - return MsQuicStatusCodes.InternalError; } + + // This handle will be cleaned up by MsQuic by returning InternalError. + connectionHandle?.SetHandleAsInvalid(); + msQuicConnection?.Dispose(); + return MsQuicStatusCodes.InternalError; } private void ThrowIfDisposed() diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs index 83846baf6cd6e..e4486c92b8a26 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs @@ -6,8 +6,11 @@ using System.Diagnostics; using System.Linq; using System.Net.Security; +using System.Net.Sockets; +using System.Security.Authentication; using System.Security.Cryptography.X509Certificates; using System.Text; +using System.Threading; using System.Threading.Tasks; using Xunit; using Xunit.Abstractions; @@ -115,6 +118,126 @@ public async Task ConnectWithCertificateChain() await clientTask; } + [Fact] + public async Task CertificateCallbackThrowPropagates() + { + using CancellationTokenSource cts = new CancellationTokenSource(PassingTestTimeout); + X509Certificate? receivedCertificate = null; + + var quicOptions = new QuicListenerOptions(); + quicOptions.ListenEndPoint = new IPEndPoint( Socket.OSSupportsIPv6 ? IPAddress.IPv6Loopback : IPAddress.Loopback, 0); + quicOptions.ServerAuthenticationOptions = GetSslServerAuthenticationOptions(); + + using QuicListener listener = new QuicListener(QuicImplementationProviders.MsQuic, quicOptions); + + QuicClientConnectionOptions options = new QuicClientConnectionOptions() + { + RemoteEndPoint = listener.ListenEndPoint, + ClientAuthenticationOptions = GetSslClientAuthenticationOptions(), + }; + + options.ClientAuthenticationOptions.RemoteCertificateValidationCallback = (sender, cert, chain, errors) => + { + receivedCertificate = cert; + throw new ArithmeticException("foobar"); + }; + + options.ClientAuthenticationOptions.TargetHost = "foobar1"; + + QuicConnection clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, options); + + Task serverTask = listener.AcceptConnectionAsync(cts.Token).AsTask(); + await Assert.ThrowsAsync(() => clientConnection.ConnectAsync(cts.Token).AsTask()); + QuicConnection serverConnection = await serverTask; + + Assert.Equal(quicOptions.ServerAuthenticationOptions.ServerCertificate, receivedCertificate); + + clientConnection.Dispose(); + serverConnection.Dispose(); + } + + [Fact] + public async Task ConnectWithCertificateCallback() + { + X509Certificate2 c1 = System.Net.Test.Common.Configuration.Certificates.GetServerCertificate(); + X509Certificate2 c2 = System.Net.Test.Common.Configuration.Certificates.GetClientCertificate(); // This 'wrong' certificate but should be sufficient + X509Certificate2 expectedCertificate = c1; + + using CancellationTokenSource cts = new CancellationTokenSource(); + cts.CancelAfter(PassingTestTimeout); + string? receivedHostName = null; + X509Certificate? receivedCertificate = null; + + var quicOptions = new QuicListenerOptions(); + quicOptions.ListenEndPoint = new IPEndPoint( Socket.OSSupportsIPv6 ? IPAddress.IPv6Loopback : IPAddress.Loopback, 0); + quicOptions.ServerAuthenticationOptions = GetSslServerAuthenticationOptions(); + quicOptions.ServerAuthenticationOptions.ServerCertificate = null; + quicOptions.ServerAuthenticationOptions.ServerCertificateSelectionCallback = (sender, hostName) => + { + receivedHostName = hostName; + if (hostName == "foobar1") + { + return c1; + } + else if (hostName == "foobar2") + { + return c2; + } + + return null; + }; + + using QuicListener listener = new QuicListener(QuicImplementationProviders.MsQuic, quicOptions); + + QuicClientConnectionOptions options = new QuicClientConnectionOptions() + { + RemoteEndPoint = listener.ListenEndPoint, + ClientAuthenticationOptions = GetSslClientAuthenticationOptions(), + }; + + options.ClientAuthenticationOptions.RemoteCertificateValidationCallback = (sender, cert, chain, errors) => + { + receivedCertificate = cert; + return true; + }; + + options.ClientAuthenticationOptions.TargetHost = "foobar1"; + + QuicConnection clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, options); + + Task serverTask = listener.AcceptConnectionAsync(cts.Token).AsTask(); + await new Task[] { clientConnection.ConnectAsync().AsTask(), serverTask}.WhenAllOrAnyFailed(PassingTestTimeoutMilliseconds); + QuicConnection serverConnection = serverTask.Result; + + Assert.Equal(options.ClientAuthenticationOptions.TargetHost, receivedHostName); + Assert.Equal(c1, receivedCertificate); + clientConnection.Dispose(); + serverConnection.Dispose(); + + // This should fail when callback return null. + options.ClientAuthenticationOptions.TargetHost = "foobar3"; + clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, options); + Task clientTask = clientConnection.ConnectAsync(cts.Token).AsTask(); + + await Assert.ThrowsAsync(() => clientTask); + Assert.Equal(options.ClientAuthenticationOptions.TargetHost, receivedHostName); + clientConnection.Dispose(); + + // Do this last to make sure Listener is still functional. + options.ClientAuthenticationOptions.TargetHost = "foobar2"; + expectedCertificate = c2; + + clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, options); + serverTask = listener.AcceptConnectionAsync(cts.Token).AsTask(); + await new Task[] { clientConnection.ConnectAsync().AsTask(), serverTask}.WhenAllOrAnyFailed(PassingTestTimeoutMilliseconds); + serverConnection = serverTask.Result; + + Assert.Equal(options.ClientAuthenticationOptions.TargetHost, receivedHostName); + Assert.Equal(c2, receivedCertificate); + clientConnection.Dispose(); + serverConnection.Dispose(); + } + [Fact] [PlatformSpecific(TestPlatforms.Windows)] [ActiveIssue("https://github.com/microsoft/msquic/pull/1728")] diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs index e76591f4e0436..9a6e43a863e5b 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs @@ -28,6 +28,9 @@ public abstract class QuicTestBase public X509Certificate2 ServerCertificate = System.Net.Test.Common.Configuration.Certificates.GetServerCertificate(); public X509Certificate2 ClientCertificate = System.Net.Test.Common.Configuration.Certificates.GetClientCertificate(); + public const int PassingTestTimeoutMilliseconds = 4 * 60 * 1000; + public static TimeSpan PassingTestTimeout => TimeSpan.FromMilliseconds(PassingTestTimeoutMilliseconds); + public bool RemoteCertificateValidationCallback(object sender, X509Certificate? certificate, X509Chain? chain, SslPolicyErrors sslPolicyErrors) { Assert.Equal(ServerCertificate.GetCertHash(), certificate?.GetCertHash()); @@ -48,7 +51,8 @@ public SslClientAuthenticationOptions GetSslClientAuthenticationOptions() return new SslClientAuthenticationOptions() { ApplicationProtocols = new List() { ApplicationProtocol }, - RemoteCertificateValidationCallback = RemoteCertificateValidationCallback + RemoteCertificateValidationCallback = RemoteCertificateValidationCallback, + TargetHost = "localhost" }; }