Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix SNI handling in quic #55468

Merged
merged 8 commits into from
Jul 23, 2021
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace System.Net.Quic.Implementations.MsQuic.Internal
internal sealed class SafeMsQuicConfigurationHandle : SafeHandle
{
private static readonly FieldInfo _contextCertificate = typeof(SslStreamCertificateContext).GetField("Certificate", BindingFlags.NonPublic | BindingFlags.Instance)!;
private static readonly FieldInfo _contextChain= typeof(SslStreamCertificateContext).GetField("IntermediateCertificates", BindingFlags.NonPublic | BindingFlags.Instance)!;
private static readonly FieldInfo _contextChain = typeof(SslStreamCertificateContext).GetField("IntermediateCertificates", BindingFlags.NonPublic | BindingFlags.Instance)!;

public override bool IsInvalid => handle == IntPtr.Zero;

Expand All @@ -33,7 +33,7 @@ protected override bool ReleaseHandle()
}

// TODO: consider moving the static code from here to keep all the handle classes small and simple.
public static unsafe SafeMsQuicConfigurationHandle Create(QuicClientConnectionOptions options)
public static SafeMsQuicConfigurationHandle Create(QuicClientConnectionOptions options)
{
X509Certificate? certificate = null;
if (options.ClientAuthenticationOptions?.ClientCertificates != null)
Expand All @@ -56,7 +56,7 @@ public static unsafe SafeMsQuicConfigurationHandle Create(QuicClientConnectionOp
return Create(options, QUIC_CREDENTIAL_FLAGS.CLIENT, certificate: certificate, certificateContext: null, options.ClientAuthenticationOptions?.ApplicationProtocols);
}

public static unsafe SafeMsQuicConfigurationHandle Create(QuicListenerOptions options)
public static SafeMsQuicConfigurationHandle Create(QuicListenerOptions options)
{
QUIC_CREDENTIAL_FLAGS flags = QUIC_CREDENTIAL_FLAGS.NONE;
if (options.ServerAuthenticationOptions != null && options.ServerAuthenticationOptions.ClientCertificateRequired)
Expand All @@ -67,6 +67,23 @@ public static unsafe SafeMsQuicConfigurationHandle Create(QuicListenerOptions op
return Create(options, flags, options.ServerAuthenticationOptions?.ServerCertificate, options.ServerAuthenticationOptions?.ServerCertificateContext, options.ServerAuthenticationOptions?.ApplicationProtocols);
}

public static SafeMsQuicConfigurationHandle Create(QuicOptions options, SslServerAuthenticationOptions? serverAuthenticationOptions, string targetHost)
wfurt marked this conversation as resolved.
Show resolved Hide resolved
{
QUIC_CREDENTIAL_FLAGS flags = QUIC_CREDENTIAL_FLAGS.NONE;
X509Certificate? certificate = serverAuthenticationOptions?.ServerCertificate;
if (serverAuthenticationOptions != null && serverAuthenticationOptions.ClientCertificateRequired)
{
flags |= QUIC_CREDENTIAL_FLAGS.REQUIRE_CLIENT_AUTHENTICATION | QUIC_CREDENTIAL_FLAGS.INDICATE_CERTIFICATE_RECEIVED | QUIC_CREDENTIAL_FLAGS.NO_CERTIFICATE_VALIDATION;
}

if (certificate == null && serverAuthenticationOptions?.ServerCertificateSelectionCallback != null)
{
certificate = serverAuthenticationOptions.ServerCertificateSelectionCallback(options, targetHost);
}

return Create(options, flags, certificate, serverAuthenticationOptions?.ServerCertificateContext, serverAuthenticationOptions?.ApplicationProtocols);
}

// TODO: this is called from MsQuicListener and when it fails it wreaks havoc in MsQuicListener finalizer.
// Consider moving bigger logic like this outside of constructor call chains.
private static unsafe SafeMsQuicConfigurationHandle Create(QuicOptions options, QUIC_CREDENTIAL_FLAGS flags, X509Certificate? certificate, SslStreamCertificateContext? certificateContext, List<SslApplicationProtocol>? alpnProtocols)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ internal sealed class MsQuicConnection : QuicConnectionProvider
private bool _remoteCertificateRequired;
private X509RevocationMode _revocationMode = X509RevocationMode.Offline;
private RemoteCertificateValidationCallback? _remoteCertificateValidationCallback;
private string? _targetHost;

internal sealed class State
{
Expand Down Expand Up @@ -131,7 +132,7 @@ public void SetClosing()
internal string TraceId() => _state.TraceId;

// constructor for inbound connections
public MsQuicConnection(IPEndPoint localEndPoint, IPEndPoint remoteEndPoint, SafeMsQuicConnectionHandle handle, bool remoteCertificateRequired = false, X509RevocationMode revocationMode = X509RevocationMode.Offline, RemoteCertificateValidationCallback? remoteCertificateValidationCallback = null)
public MsQuicConnection(IPEndPoint localEndPoint, IPEndPoint remoteEndPoint, SafeMsQuicConnectionHandle handle, bool remoteCertificateRequired = false, X509RevocationMode revocationMode = X509RevocationMode.Offline, RemoteCertificateValidationCallback? remoteCertificateValidationCallback = null, ServerCertificateSelectionCallback? serverCertificateSelectionCallback = null)
{
_state.Handle = handle;
_state.StateGCHandle = GCHandle.Alloc(_state);
Expand Down Expand Up @@ -183,6 +184,7 @@ public MsQuicConnection(QuicClientConnectionOptions options)
{
_revocationMode = options.ClientAuthenticationOptions.CertificateRevocationCheckMode;
_remoteCertificateValidationCallback = options.ClientAuthenticationOptions.RemoteCertificateValidationCallback;
_targetHost = options.ClientAuthenticationOptions.TargetHost;
}

_state.StateGCHandle = GCHandle.Alloc(_state);
Expand Down Expand Up @@ -547,10 +549,11 @@ internal override ValueTask ConnectAsync(CancellationToken cancellationToken = d
throw new Exception($"{nameof(ConnectAsync)} must not be called on a connection obtained from a listener.");
}

(string address, int port) = _remoteEndPoint switch
string? targetHost = _targetHost;
wfurt marked this conversation as resolved.
Show resolved Hide resolved
int port = _remoteEndPoint switch
{
DnsEndPoint dnsEp => (dnsEp.Host, dnsEp.Port),
IPEndPoint ipEp => (ipEp.Address.ToString(), ipEp.Port),
DnsEndPoint dnsEp => dnsEp.Port,
IPEndPoint ipEp => ipEp.Port,
_ => throw new Exception($"Unsupported remote endpoint type '{_remoteEndPoint.GetType()}'.")
};
wfurt marked this conversation as resolved.
Show resolved Hide resolved

Expand All @@ -565,13 +568,34 @@ internal override ValueTask ConnectAsync(CancellationToken cancellationToken = d
Debug.Assert(_state.StateGCHandle.IsAllocated);

_state.Connection = this;
uint status;
if (_remoteEndPoint is IPEndPoint)
{
SOCKADDR_INET address = MsQuicAddressHelpers.IPEndPointToINet((IPEndPoint)_remoteEndPoint);
unsafe
{
status = MsQuicApi.Api.SetParamDelegate(_state.Handle, QUIC_PARAM_LEVEL.CONNECTION, (uint)QUIC_PARAM_CONN.REMOTE_ADDRESS, (uint)sizeof(SOCKADDR_INET), (byte*)&address);
QuicExceptionHelpers.ThrowIfFailed(status, "Failed to connect to peer.");
}

if (targetHost == null)
{
targetHost = ((IPEndPoint)_remoteEndPoint).Address.ToString();
}
wfurt marked this conversation as resolved.
Show resolved Hide resolved
}
else
{
// We don't have way how to set separate SNI and name for connection at this moment.
targetHost = ((DnsEndPoint)_remoteEndPoint).Host;
}

try
{
uint status = MsQuicApi.Api.ConnectionStartDelegate(
status = MsQuicApi.Api.ConnectionStartDelegate(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
status = MsQuicApi.Api.ConnectionStartDelegate(
uint status = MsQuicApi.Api.ConnectionStartDelegate(

It shouldn't clash if you put the declaration above next to the SetParamDelegate as well. They should be within different scopes.

_state.Handle,
_configuration,
af,
address,
targetHost,
(ushort)port);

QuicExceptionHelpers.ThrowIfFailed(status, "Failed to connect to peer.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
using System.Threading.Channels;
using System.Threading.Tasks;
using static System.Net.Quic.Implementations.MsQuic.Internal.MsQuicNativeMethods;
using System.Security.Authentication;

namespace System.Net.Quic.Implementations.MsQuic
{
Expand All @@ -31,21 +32,39 @@ private sealed class State
public SafeMsQuicListenerHandle Handle = null!;
public string TraceId = null!; // set in ctor.

public readonly SafeMsQuicConfigurationHandle ConnectionConfiguration;
public readonly SafeMsQuicConfigurationHandle? ConnectionConfiguration;
public readonly Channel<MsQuicConnection> AcceptConnectionQueue;

public bool RemoteCertificateRequired;
public X509RevocationMode RevocationMode = X509RevocationMode.Offline;
public RemoteCertificateValidationCallback? RemoteCertificateValidationCallback;
public QuicOptions ConnectionOptions = new QuicOptions();
public SslServerAuthenticationOptions AuthenticationOptions = new SslServerAuthenticationOptions();

public State(QuicListenerOptions options)
{
ConnectionConfiguration = SafeMsQuicConfigurationHandle.Create(options);
ConnectionOptions.IdleTimeout = options.IdleTimeout;
ConnectionOptions.MaxBidirectionalStreams = options.MaxBidirectionalStreams;
ConnectionOptions.MaxUnidirectionalStreams = options.MaxUnidirectionalStreams;

bool delayConfiguration = false;

if (options.ServerAuthenticationOptions != null)
{
RemoteCertificateRequired = options.ServerAuthenticationOptions.ClientCertificateRequired;
RevocationMode = options.ServerAuthenticationOptions.CertificateRevocationCheckMode;
RemoteCertificateValidationCallback = options.ServerAuthenticationOptions.RemoteCertificateValidationCallback;
AuthenticationOptions.ClientCertificateRequired = options.ServerAuthenticationOptions.ClientCertificateRequired;
AuthenticationOptions.CertificateRevocationCheckMode = options.ServerAuthenticationOptions.CertificateRevocationCheckMode;
AuthenticationOptions.RemoteCertificateValidationCallback = options.ServerAuthenticationOptions.RemoteCertificateValidationCallback;
AuthenticationOptions.ServerCertificateSelectionCallback = options.ServerAuthenticationOptions.ServerCertificateSelectionCallback;
AuthenticationOptions.ApplicationProtocols = options.ServerAuthenticationOptions.ApplicationProtocols;

if (options.ServerAuthenticationOptions.ServerCertificate == null && options.ServerAuthenticationOptions.ServerCertificateContext == null &&
options.ServerAuthenticationOptions.ServerCertificateSelectionCallback != null)
{
// We don't have any certificate but we have selection callback so we need to wait for SNI.
delayConfiguration = true;
}
}

if (!delayConfiguration)
{
ConnectionConfiguration = SafeMsQuicConfigurationHandle.Create(options);
}

AcceptConnectionQueue = Channel.CreateBounded<MsQuicConnection>(new BoundedChannelOptions(options.ListenBacklog)
Expand Down Expand Up @@ -211,46 +230,74 @@ private static unsafe uint NativeCallbackHandler(
var state = (State)gcHandle.Target;

SafeMsQuicConnectionHandle? connectionHandle = null;
MsQuicConnection? msQuicConnection = null;

try
{
ref NewConnectionInfo connectionInfo = ref *evt.Data.NewConnection.Info;

IPEndPoint localEndPoint = MsQuicAddressHelpers.INetToIPEndPoint(ref *(SOCKADDR_INET*)connectionInfo.LocalAddress);
IPEndPoint remoteEndPoint = MsQuicAddressHelpers.INetToIPEndPoint(ref *(SOCKADDR_INET*)connectionInfo.RemoteAddress);
string targetHost = string.Empty; // compat with SslStream
if (connectionInfo.ServerNameLength > 0 && connectionInfo.ServerName != IntPtr.Zero)
{
// TBD We should figure out what to do with international names.
wfurt marked this conversation as resolved.
Show resolved Hide resolved
targetHost = Marshal.PtrToStringAnsi(connectionInfo.ServerName, connectionInfo.ServerNameLength);
}

connectionHandle = new SafeMsQuicConnectionHandle(evt.Data.NewConnection.Connection);
SafeMsQuicConfigurationHandle? connectionConfiguration = state.ConnectionConfiguration;

uint status = MsQuicApi.Api.ConnectionSetConfigurationDelegate(connectionHandle, state.ConnectionConfiguration);
QuicExceptionHelpers.ThrowIfFailed(status, "ConnectionSetConfiguration failed.");
if (connectionConfiguration == null)
{
Debug.Assert(state.AuthenticationOptions.ServerCertificateSelectionCallback != null);
try
{
// ServerCertificateSelectionCallback is synchronous. We will call it as needed when building configuration
connectionConfiguration = SafeMsQuicConfigurationHandle.Create(state.ConnectionOptions, state.AuthenticationOptions, targetHost);
}
catch (Exception ex)
{
if (NetEventSource.Log.IsEnabled())
{
NetEventSource.Error(state, $"[Listener#{state.GetHashCode()}] Exception occurred during creating configuration in connection callback: {ex}");
}
}

if (connectionConfiguration == null)
{
// We don't have safe handle yet so MsQuic will cleanup new connection.
return MsQuicStatusCodes.InternalError;
}
}

var msQuicConnection = new MsQuicConnection(localEndPoint, remoteEndPoint, connectionHandle, state.RemoteCertificateRequired, state.RevocationMode, state.RemoteCertificateValidationCallback);
msQuicConnection.SetNegotiatedAlpn(connectionInfo.NegotiatedAlpn, connectionInfo.NegotiatedAlpnLength);
connectionHandle = new SafeMsQuicConnectionHandle(evt.Data.NewConnection.Connection);

if (!state.AcceptConnectionQueue.Writer.TryWrite(msQuicConnection))
uint status = MsQuicApi.Api.ConnectionSetConfigurationDelegate(connectionHandle, connectionConfiguration);
if (MsQuicStatusHelper.SuccessfulStatusCode(status))
wfurt marked this conversation as resolved.
Show resolved Hide resolved
{
// This handle will be cleaned up by MsQuic.
connectionHandle.SetHandleAsInvalid();
msQuicConnection.Dispose();
return MsQuicStatusCodes.InternalError;
msQuicConnection = new MsQuicConnection(localEndPoint, remoteEndPoint, connectionHandle, state.AuthenticationOptions.ClientCertificateRequired, state.AuthenticationOptions.CertificateRevocationCheckMode, state.AuthenticationOptions.RemoteCertificateValidationCallback);
msQuicConnection.SetNegotiatedAlpn(connectionInfo.NegotiatedAlpn, connectionInfo.NegotiatedAlpnLength);

if (state.AcceptConnectionQueue.Writer.TryWrite(msQuicConnection))
{
return MsQuicStatusCodes.Success;
}
}

return MsQuicStatusCodes.Success;
// If we fall-through here something wrong happened.
}
catch (Exception ex)
{
if (NetEventSource.Log.IsEnabled())
{
NetEventSource.Error(state, $"[Listener#{state.GetHashCode()}] Exception occurred during handling {(QUIC_LISTENER_EVENT)evt.Type} connection callback: {ex}");
}

Debug.Fail($"[Listener#{state.GetHashCode()}] Exception occurred during handling {(QUIC_LISTENER_EVENT)evt.Type} connection callback: {ex}");

// This handle will be cleaned up by MsQuic by returning InternalError.
connectionHandle?.SetHandleAsInvalid();
state.AcceptConnectionQueue.Writer.TryComplete(ex);
return MsQuicStatusCodes.InternalError;
}

// This handle will be cleaned up by MsQuic by returning InternalError.
connectionHandle?.SetHandleAsInvalid();
msQuicConnection?.Dispose();
return MsQuicStatusCodes.InternalError;
}

private void ThrowIfDisposed()
Expand Down
Loading