diff --git a/src/libraries/Common/tests/System/Net/Sockets/SocketTestServer.cs b/src/libraries/Common/tests/System/Net/Sockets/SocketTestServer.cs index 9950c7de9643a..fc92fbd8ac272 100644 --- a/src/libraries/Common/tests/System/Net/Sockets/SocketTestServer.cs +++ b/src/libraries/Common/tests/System/Net/Sockets/SocketTestServer.cs @@ -12,6 +12,7 @@ public abstract partial class SocketTestServer : IDisposable protected abstract int Port { get; } public abstract EndPoint EndPoint { get; } + public event Action Accepted; public static SocketTestServer SocketTestServerFactory(SocketImplementationType type, EndPoint endpoint, ProtocolType protocolType = ProtocolType.Tcp) { @@ -23,6 +24,9 @@ public static SocketTestServer SocketTestServerFactory(SocketImplementationType return SocketTestServerFactory(type, DefaultNumConnections, DefaultReceiveBufferSize, address, out port); } + public static SocketTestServer SocketTestServerFactory(SocketImplementationType type, IPAddress address) + => SocketTestServerFactory(type, address, out _); + public static SocketTestServer SocketTestServerFactory( SocketImplementationType type, int numConnections, @@ -60,5 +64,7 @@ public void Dispose() } protected abstract void Dispose(bool disposing); + + protected void NotifyAccepted(Socket socket) => Accepted?.Invoke(socket); } } diff --git a/src/libraries/Common/tests/System/Net/Sockets/SocketTestServerAPM.cs b/src/libraries/Common/tests/System/Net/Sockets/SocketTestServerAPM.cs index d8d484f47a7ac..50b747ecf1432 100644 --- a/src/libraries/Common/tests/System/Net/Sockets/SocketTestServerAPM.cs +++ b/src/libraries/Common/tests/System/Net/Sockets/SocketTestServerAPM.cs @@ -67,6 +67,7 @@ private void OnAccept(IAsyncResult result) return; } + NotifyAccepted(client); ServerSocketState state = new ServerSocketState(client, _receiveBufferSize); try { diff --git a/src/libraries/Common/tests/System/Net/Sockets/SocketTestServerAsync.cs b/src/libraries/Common/tests/System/Net/Sockets/SocketTestServerAsync.cs index 5db3fa0edd29f..e42d4b55cd810 100644 --- a/src/libraries/Common/tests/System/Net/Sockets/SocketTestServerAsync.cs +++ b/src/libraries/Common/tests/System/Net/Sockets/SocketTestServerAsync.cs @@ -200,6 +200,7 @@ private void ProcessAccept(SocketAsyncEventArgs e) // Get the socket for the accepted client connection and put it into the ReadEventArg object user token. SocketAsyncEventArgs readEventArgs = _readWritePool.Pop(); + NotifyAccepted(e.AcceptSocket); ((AsyncUserToken)readEventArgs.UserToken).Socket = e.AcceptSocket; // As soon as the client is connected, post a receive to the connection. diff --git a/src/libraries/Common/tests/TestUtilities/System/PlatformDetection.cs b/src/libraries/Common/tests/TestUtilities/System/PlatformDetection.cs index 3db016a56c9cf..59d21ac14c5f4 100644 --- a/src/libraries/Common/tests/TestUtilities/System/PlatformDetection.cs +++ b/src/libraries/Common/tests/TestUtilities/System/PlatformDetection.cs @@ -20,6 +20,7 @@ public static partial class PlatformDetection public static bool IsNetCore => Environment.Version.Major >= 5 || RuntimeInformation.FrameworkDescription.StartsWith(".NET Core", StringComparison.OrdinalIgnoreCase); public static bool IsMonoRuntime => Type.GetType("Mono.RuntimeStructs") != null; + public static bool IsNotMonoRuntime => !IsMonoRuntime; public static bool IsMonoInterpreter => GetIsRunningOnMonoInterpreter(); public static bool IsFreeBSD => RuntimeInformation.IsOSPlatform(OSPlatform.Create("FREEBSD")); public static bool IsNetBSD => RuntimeInformation.IsOSPlatform(OSPlatform.Create("NETBSD")); diff --git a/src/libraries/System.Net.Connections/ref/System.Net.Connections.cs b/src/libraries/System.Net.Connections/ref/System.Net.Connections.cs index 6eaa384fdc7ee..ae1b0027967c8 100644 --- a/src/libraries/System.Net.Connections/ref/System.Net.Connections.cs +++ b/src/libraries/System.Net.Connections/ref/System.Net.Connections.cs @@ -71,4 +71,11 @@ public partial interface IConnectionProperties { bool TryGet(System.Type propertyKey, [System.Diagnostics.CodeAnalysis.NotNullWhenAttribute(true)] out object? property); } + public partial class SocketsConnectionFactory : System.Net.Connections.ConnectionFactory + { + public SocketsConnectionFactory(System.Net.Sockets.AddressFamily addressFamily, System.Net.Sockets.SocketType socketType, System.Net.Sockets.ProtocolType protocolType) { } + public SocketsConnectionFactory(System.Net.Sockets.SocketType socketType, System.Net.Sockets.ProtocolType protocolType) { } + public override System.Threading.Tasks.ValueTask ConnectAsync(System.Net.EndPoint? endPoint, System.Net.Connections.IConnectionProperties? options = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + protected virtual System.Net.Sockets.Socket CreateSocket(System.Net.Sockets.AddressFamily addressFamily, System.Net.Sockets.SocketType socketType, System.Net.Sockets.ProtocolType protocolType, System.Net.EndPoint? endPoint, System.Net.Connections.IConnectionProperties? options) { throw null; } + } } diff --git a/src/libraries/System.Net.Connections/ref/System.Net.Connections.csproj b/src/libraries/System.Net.Connections/ref/System.Net.Connections.csproj index b239637616b4f..fa3a6d8bd7437 100644 --- a/src/libraries/System.Net.Connections/ref/System.Net.Connections.csproj +++ b/src/libraries/System.Net.Connections/ref/System.Net.Connections.csproj @@ -9,6 +9,7 @@ + diff --git a/src/libraries/System.Net.Connections/src/System.Net.Connections.csproj b/src/libraries/System.Net.Connections/src/System.Net.Connections.csproj index c53e55947afb4..7ccd9e6f600ab 100644 --- a/src/libraries/System.Net.Connections/src/System.Net.Connections.csproj +++ b/src/libraries/System.Net.Connections/src/System.Net.Connections.csproj @@ -4,6 +4,7 @@ enable + @@ -13,14 +14,20 @@ + + + + + + diff --git a/src/libraries/System.Net.Connections/src/System/Net/Connections/Connection.cs b/src/libraries/System.Net.Connections/src/System/Net/Connections/Connection.cs index 8f7d479ce2a1a..f68bac27de6c5 100644 --- a/src/libraries/System.Net.Connections/src/System/Net/Connections/Connection.cs +++ b/src/libraries/System.Net.Connections/src/System/Net/Connections/Connection.cs @@ -90,22 +90,6 @@ protected virtual IDuplexPipe CreatePipe() } } - private sealed class DuplexStreamPipe : IDuplexPipe - { - private static readonly StreamPipeReaderOptions s_readerOpts = new StreamPipeReaderOptions(leaveOpen: true); - private static readonly StreamPipeWriterOptions s_writerOpts = new StreamPipeWriterOptions(leaveOpen: true); - - public DuplexStreamPipe(Stream stream) - { - Input = PipeReader.Create(stream, s_readerOpts); - Output = PipeWriter.Create(stream, s_writerOpts); - } - - public PipeReader Input { get; } - - public PipeWriter Output { get; } - } - /// /// Creates a connection for a . /// diff --git a/src/libraries/System.Net.Connections/src/System/Net/Connections/DuplexStreamPipe.cs b/src/libraries/System.Net.Connections/src/System/Net/Connections/DuplexStreamPipe.cs new file mode 100644 index 0000000000000..7b620c2b8bf15 --- /dev/null +++ b/src/libraries/System.Net.Connections/src/System/Net/Connections/DuplexStreamPipe.cs @@ -0,0 +1,24 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.IO; +using System.IO.Pipelines; + +namespace System.Net.Connections +{ + internal sealed class DuplexStreamPipe : IDuplexPipe + { + private static readonly StreamPipeReaderOptions s_readerOpts = new StreamPipeReaderOptions(leaveOpen: true); + private static readonly StreamPipeWriterOptions s_writerOpts = new StreamPipeWriterOptions(leaveOpen: true); + + public DuplexStreamPipe(Stream stream) + { + Input = PipeReader.Create(stream, s_readerOpts); + Output = PipeWriter.Create(stream, s_writerOpts); + } + + public PipeReader Input { get; } + + public PipeWriter Output { get; } + } +} diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Connections/SocketConnection.cs b/src/libraries/System.Net.Connections/src/System/Net/Connections/Sockets/SocketConnection.cs similarity index 75% rename from src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Connections/SocketConnection.cs rename to src/libraries/System.Net.Connections/src/System/Net/Connections/Sockets/SocketConnection.cs index 130cf634ec299..2eb3e6bca9063 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Connections/SocketConnection.cs +++ b/src/libraries/System.Net.Connections/src/System/Net/Connections/Sockets/SocketConnection.cs @@ -3,7 +3,7 @@ using System.Diagnostics.CodeAnalysis; using System.IO; -using System.Net.Http; +using System.IO.Pipelines; using System.Net.Sockets; using System.Runtime.ExceptionServices; using System.Threading; @@ -13,15 +13,16 @@ namespace System.Net.Connections { internal sealed class SocketConnection : Connection, IConnectionProperties { - private readonly NetworkStream _stream; + private readonly Socket _socket; + private Stream? _stream; - public override EndPoint? RemoteEndPoint => _stream.Socket.RemoteEndPoint; - public override EndPoint? LocalEndPoint => _stream.Socket.LocalEndPoint; + public override EndPoint? RemoteEndPoint => _socket.RemoteEndPoint; + public override EndPoint? LocalEndPoint => _socket.LocalEndPoint; public override IConnectionProperties ConnectionProperties => this; public SocketConnection(Socket socket) { - _stream = new NetworkStream(socket, ownsSocket: true); + _socket = socket; } protected override ValueTask CloseAsyncCore(ConnectionCloseMethod method, CancellationToken cancellationToken) @@ -37,10 +38,11 @@ protected override ValueTask CloseAsyncCore(ConnectionCloseMethod method, Cancel { // Dispose must be called first in order to cause a connection reset, // as NetworkStream.Dispose() will call Shutdown(Both). - _stream.Socket.Dispose(); + _socket.Dispose(); } - _stream.Dispose(); + // Since CreatePipe() calls CreateStream(), so _stream should be present even in the pipe case: + _stream?.Dispose(); } catch (SocketException socketException) { @@ -54,18 +56,18 @@ protected override ValueTask CloseAsyncCore(ConnectionCloseMethod method, Cancel return default; } - protected override Stream CreateStream() => _stream; - bool IConnectionProperties.TryGet(Type propertyKey, [NotNullWhen(true)] out object? property) { if (propertyKey == typeof(Socket)) { - property = _stream.Socket; + property = _socket; return true; } property = null; return false; } + + protected override Stream CreateStream() => _stream ??= new NetworkStream(_socket, true); } } diff --git a/src/libraries/System.Net.Connections/src/System/Net/Connections/Sockets/SocketsConnectionFactory.cs b/src/libraries/System.Net.Connections/src/System/Net/Connections/Sockets/SocketsConnectionFactory.cs new file mode 100644 index 0000000000000..2936a04e0484c --- /dev/null +++ b/src/libraries/System.Net.Connections/src/System/Net/Connections/Sockets/SocketsConnectionFactory.cs @@ -0,0 +1,137 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.IO; +using System.IO.Pipelines; +using System.Net.Sockets; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Net.Connections +{ + /// + /// A to establish socket-based connections. + /// + /// + /// When constructed with , this factory will create connections with enabled. + /// In case of IPv6 sockets is also enabled. + /// + public class SocketsConnectionFactory : ConnectionFactory + { + private readonly AddressFamily _addressFamily; + private readonly SocketType _socketType; + private readonly ProtocolType _protocolType; + + /// + /// Initializes a new instance of the class. + /// + /// The to forward to the socket. + /// The to forward to the socket. + /// The to forward to the socket. + public SocketsConnectionFactory( + AddressFamily addressFamily, + SocketType socketType, + ProtocolType protocolType) + { + _addressFamily = addressFamily; + _socketType = socketType; + _protocolType = protocolType; + } + + /// + /// Initializes a new instance of the class + /// that will forward to the Socket constructor. + /// + /// The to forward to the socket. + /// The to forward to the socket. + /// The created socket will be an IPv6 socket with enabled. + public SocketsConnectionFactory(SocketType socketType, ProtocolType protocolType) + : this(AddressFamily.InterNetworkV6, socketType, protocolType) + { + } + + /// + /// When is . + public override async ValueTask ConnectAsync( + EndPoint? endPoint, + IConnectionProperties? options = null, + CancellationToken cancellationToken = default) + { + if (endPoint == null) throw new ArgumentNullException(nameof(endPoint)); + cancellationToken.ThrowIfCancellationRequested(); + + Socket socket = CreateSocket(_addressFamily, _socketType, _protocolType, endPoint, options); + + try + { + using var args = new TaskSocketAsyncEventArgs(); + args.RemoteEndPoint = endPoint; + + if (socket.ConnectAsync(args)) + { + using (cancellationToken.UnsafeRegister(static o => Socket.CancelConnectAsync((SocketAsyncEventArgs)o!), args)) + { + await args.Task.ConfigureAwait(false); + } + } + + if (args.SocketError != SocketError.Success) + { + if (args.SocketError == SocketError.OperationAborted) + { + cancellationToken.ThrowIfCancellationRequested(); + } + + throw NetworkErrorHelper.MapSocketException(new SocketException((int)args.SocketError)); + } + + return new SocketConnection(socket); + } + catch (SocketException socketException) + { + socket.Dispose(); + throw NetworkErrorHelper.MapSocketException(socketException); + } + catch + { + socket.Dispose(); + throw; + } + } + + /// + /// Creates the socket that shall be used with the connection. + /// + /// The to forward to the socket. + /// The to forward to the socket. + /// The to forward to the socket. + /// The this socket will be connected to. + /// Properties, if any, that might change how the socket is initialized. + /// A new unconnected . + /// + /// In case of TCP sockets, the default implementation of this method will create a socket with enabled. + /// In case of IPv6 sockets is also be enabled. + /// + protected virtual Socket CreateSocket( + AddressFamily addressFamily, + SocketType socketType, + ProtocolType protocolType, + EndPoint? endPoint, + IConnectionProperties? options) + { + Socket socket = new Socket(addressFamily, socketType, protocolType); + + if (protocolType == ProtocolType.Tcp) + { + socket.NoDelay = true; + } + + if (addressFamily == AddressFamily.InterNetworkV6) + { + socket.DualMode = true; + } + + return socket; + } + } +} diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Connections/TaskSocketAsyncEventArgs.cs b/src/libraries/System.Net.Connections/src/System/Net/Connections/Sockets/TaskSocketAsyncEventArgs.cs similarity index 94% rename from src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Connections/TaskSocketAsyncEventArgs.cs rename to src/libraries/System.Net.Connections/src/System/Net/Connections/Sockets/TaskSocketAsyncEventArgs.cs index 7bf8947215a40..a1bb69d3f0501 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Connections/TaskSocketAsyncEventArgs.cs +++ b/src/libraries/System.Net.Connections/src/System/Net/Connections/Sockets/TaskSocketAsyncEventArgs.cs @@ -17,7 +17,6 @@ internal sealed class TaskSocketAsyncEventArgs : SocketAsyncEventArgs, IValueTas public void GetResult(short token) => _valueTaskSource.GetResult(token); public ValueTaskSourceStatus GetStatus(short token) => _valueTaskSource.GetStatus(token); public void OnCompleted(Action continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags) => _valueTaskSource.OnCompleted(continuation, state, token, flags); - public void Complete() => _valueTaskSource.SetResult(0); public TaskSocketAsyncEventArgs() : base(unsafeSuppressExecutionContextFlow: true) diff --git a/src/libraries/System.Net.Connections/tests/System.Net.Connections.Tests/Sockets/SocketsConnectionFactoryTests.cs b/src/libraries/System.Net.Connections/tests/System.Net.Connections.Tests/Sockets/SocketsConnectionFactoryTests.cs new file mode 100644 index 0000000000000..ce45bfb3e9ed5 --- /dev/null +++ b/src/libraries/System.Net.Connections/tests/System.Net.Connections.Tests/Sockets/SocketsConnectionFactoryTests.cs @@ -0,0 +1,430 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using System.Collections.Concurrent; +using System.IO; +using System.IO.Pipelines; +using System.Net.Connections; +using System.Net.Sockets; +using System.Net.Sockets.Tests; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Net.Connections.Tests +{ + [ConditionalClass(typeof(PlatformDetection), nameof(PlatformDetection.IsNotMonoRuntime))] + public class SocketsConnectionFactoryTests + { + public static TheoryData GetConnectData() + { + var result = new TheoryData() + { + { new IPEndPoint(IPAddress.Loopback, 0), SocketType.Stream, ProtocolType.Tcp }, + { new IPEndPoint(IPAddress.IPv6Loopback, 0), SocketType.Stream, ProtocolType.Tcp }, + }; + + if (Socket.OSSupportsUnixDomainSockets) + { + result.Add(new UnixDomainSocketEndPoint("/replaced/in/test"), SocketType.Stream, ProtocolType.Unspecified); + } + + return result; + } + + // to avoid random names in TheoryData, we replace the path in test code: + private static EndPoint RecreateUdsEndpoint(EndPoint endPoint) + { + if (endPoint is UnixDomainSocketEndPoint) + { + endPoint = new UnixDomainSocketEndPoint($"{Path.GetTempPath()}/{Guid.NewGuid()}"); + } + return endPoint; + } + + private static Socket ValidateSocket(Connection connection, SocketType? socketType = null, ProtocolType? protocolType = null, AddressFamily? addressFamily = null) + { + Assert.True(connection.ConnectionProperties.TryGet(out Socket socket)); + Assert.True(socket.Connected); + if (addressFamily != null) Assert.Equal(addressFamily, socket.AddressFamily); + if (socketType != null) Assert.Equal(socketType, socket.SocketType); + if (protocolType != null) Assert.Equal(protocolType, socket.ProtocolType); + return socket; + } + + [Theory] + [MemberData(nameof(GetConnectData))] + public async Task Constructor3_ConnectAsync_Success_PropagatesConstructorArgumentsToSocket(EndPoint endPoint, SocketType socketType, ProtocolType protocolType) + { + endPoint = RecreateUdsEndpoint(endPoint); + using var server = SocketTestServer.SocketTestServerFactory(SocketImplementationType.Async, endPoint, protocolType); + using SocketsConnectionFactory factory = new SocketsConnectionFactory(endPoint.AddressFamily, socketType, protocolType); + using Connection connection = await factory.ConnectAsync(server.EndPoint); + + ValidateSocket(connection, socketType, protocolType, endPoint.AddressFamily); + } + + [Fact] + public async Task Constructor2_ConnectAsync_Success_CreatesIPv6DualModeSocket() + { + using var server = SocketTestServer.SocketTestServerFactory(SocketImplementationType.Async, IPAddress.IPv6Loopback); + using SocketsConnectionFactory factory = new SocketsConnectionFactory(SocketType.Stream, ProtocolType.Tcp); + using Connection connection = await factory.ConnectAsync(server.EndPoint); + + Socket socket = ValidateSocket(connection, SocketType.Stream, ProtocolType.Tcp, AddressFamily.InterNetworkV6); + Assert.True(socket.DualMode); + } + + [Fact] + public async Task ConnectAsync_Success_SocketNoDelayIsTrue() + { + using var server = SocketTestServer.SocketTestServerFactory(SocketImplementationType.Async, IPAddress.Loopback); + using SocketsConnectionFactory factory = new SocketsConnectionFactory(SocketType.Stream, ProtocolType.Tcp); + using Connection connection = await factory.ConnectAsync(server.EndPoint); + + connection.ConnectionProperties.TryGet(out Socket socket); + Assert.True(socket.NoDelay); + } + + [Fact] + public void ConnectAsync_NullEndpoint_ThrowsArgumentNullException() + { + using SocketsConnectionFactory factory = new SocketsConnectionFactory(SocketType.Stream, ProtocolType.Tcp); + Assert.ThrowsAsync(() => factory.ConnectAsync(null).AsTask()); + } + + // TODO: On OSX and Windows7 connection failures seem to fail with unexpected SocketErrors that are mapped to NetworkError.Unknown. This needs an investigation. + // Related: https://github.com/dotnet/runtime/pull/40565 + public static bool PlatformHasReliableConnectionFailures => !PlatformDetection.IsOSX && !PlatformDetection.IsWindows7 && !PlatformDetection.IsFreeBSD; + + [ConditionalFact(nameof(PlatformHasReliableConnectionFailures))] + public async Task ConnectAsync_WhenRefused_ThrowsNetworkException() + { + using Socket notListening = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + int port = notListening.BindToAnonymousPort(IPAddress.Loopback); + var endPoint = new IPEndPoint(IPAddress.Loopback, port); + + using SocketsConnectionFactory factory = new SocketsConnectionFactory(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + + NetworkException ex = await Assert.ThrowsAsync(() => factory.ConnectAsync(endPoint).AsTask()); + Assert.Equal(NetworkError.ConnectionRefused, ex.NetworkError); + } + + [OuterLoop] // TimedOut and HostNotFound is slow on Windows + [ConditionalFact(nameof(PlatformHasReliableConnectionFailures))] + public async Task ConnectAsync_WhenHostNotFound_ThrowsNetworkException() + { + using SocketsConnectionFactory factory = new SocketsConnectionFactory(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + + // Unassigned as per https://www.iana.org/assignments/service-names-port-numbers/service-names-port-numbers.txt + int unusedPort = 8; + DnsEndPoint endPoint = new DnsEndPoint(System.Net.Test.Common.Configuration.Sockets.InvalidHost, unusedPort); + + NetworkException ex = await Assert.ThrowsAsync(() => factory.ConnectAsync(endPoint).AsTask()); + Assert.Equal(NetworkError.HostNotFound, ex.NetworkError); + } + + [OuterLoop] // TimedOut and HostNotFound is slow on Windows + [ConditionalFact(nameof(PlatformHasReliableConnectionFailures))] + public async Task ConnectAsync_TimedOut_ThrowsNetworkException() + { + using SocketsConnectionFactory factory = new SocketsConnectionFactory(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + + IPEndPoint doesNotExist = new IPEndPoint(IPAddress.Parse("1.2.3.4"), 23); + + // SocketError.TimedOut currently maps to SocketError.Unknown, so no asserion + await Assert.ThrowsAsync(() => factory.ConnectAsync(doesNotExist).AsTask()); + } + + // On Windows, connection timeout takes 21 seconds. Abusing this behavior to test the cancellation logic + [Fact] + [PlatformSpecific(TestPlatforms.Windows)] + public async Task ConnectAsync_WhenCancelled_ThrowsTaskCancelledException() + { + using SocketsConnectionFactory factory = new SocketsConnectionFactory(SocketType.Stream, ProtocolType.Tcp); + IPEndPoint doesNotExist = new IPEndPoint(IPAddress.Parse("1.2.3.4"), 23); + + CancellationTokenSource cts = new CancellationTokenSource(); + cts.CancelAfter(100); + + OperationCanceledException ex = await Assert.ThrowsAsync(() => factory.ConnectAsync(doesNotExist, cancellationToken: cts.Token).AsTask()); + Assert.Equal(cts.Token, ex.CancellationToken); + } + + [Fact] + public async Task ConnectAsync_WhenCancelledBeforeInvocation_ThrowsTaskCancelledException() + { + using SocketsConnectionFactory factory = new SocketsConnectionFactory(SocketType.Stream, ProtocolType.Tcp); + IPEndPoint doesNotExist = new IPEndPoint(IPAddress.Parse("1.2.3.4"), 23); + + CancellationToken cancellationToken = new CancellationToken(true); + + OperationCanceledException ex = await Assert.ThrowsAsync(() => factory.ConnectAsync(doesNotExist, cancellationToken: cancellationToken).AsTask()); + Assert.Equal(cancellationToken, ex.CancellationToken); + } + + [Theory] + [MemberData(nameof(GetConnectData))] + public async Task Connection_Stream_ReadWrite_Success(EndPoint endPoint, SocketType socketType, ProtocolType protocolType) + { + endPoint = RecreateUdsEndpoint(endPoint); + using var server = SocketTestServer.SocketTestServerFactory(SocketImplementationType.Async, endPoint, protocolType); + using SocketsConnectionFactory factory = new SocketsConnectionFactory(endPoint.AddressFamily, socketType, protocolType); + + using Connection connection = await factory.ConnectAsync(server.EndPoint); + + Stream stream = connection.Stream; + + byte[] sendData = { 1, 2, 3 }; + byte[] receiveData = new byte[sendData.Length]; + + await stream.WriteAsync(sendData); + await stream.FlushAsync(); + await stream.ReadAsync(receiveData); + + // The test server should echo the data: + Assert.Equal(sendData, receiveData); + } + + [Theory] + [MemberData(nameof(GetConnectData))] + public async Task Connection_EndpointsAreCorrect(EndPoint endPoint, SocketType socketType, ProtocolType protocolType) + { + endPoint = RecreateUdsEndpoint(endPoint); + using var server = SocketTestServer.SocketTestServerFactory(SocketImplementationType.Async, endPoint, protocolType); + using SocketsConnectionFactory factory = new SocketsConnectionFactory(endPoint.AddressFamily, socketType, protocolType); + using Connection connection = await factory.ConnectAsync(server.EndPoint); + + // Checking for .ToString() result, because UnixDomainSocketEndPoint equality doesn't seem to be implemented + Assert.Equal(server.EndPoint.ToString(), connection.RemoteEndPoint.ToString()); + Assert.IsType(endPoint.GetType(), connection.LocalEndPoint); + } + + [Theory] + [MemberData(nameof(GetConnectData))] + public async Task Connection_Pipe_ReadWrite_Success(EndPoint endPoint, SocketType socketType, ProtocolType protocolType) + { + endPoint = RecreateUdsEndpoint(endPoint); + using var server = SocketTestServer.SocketTestServerFactory(SocketImplementationType.Async, endPoint, protocolType); + using SocketsConnectionFactory factory = new SocketsConnectionFactory(endPoint.AddressFamily, socketType, protocolType); + + + using Connection connection = await factory.ConnectAsync(server.EndPoint); + + IDuplexPipe pipe = connection.Pipe; + + byte[] sendData = { 1, 2, 3 }; + using MemoryStream receiveTempStream = new MemoryStream(); + + await pipe.Output.WriteAsync(sendData); + ReadResult rr = await pipe.Input.ReadAsync(); + + // The test server should echo the data: + Assert.True(rr.Buffer.FirstSpan.SequenceEqual(sendData)); + } + + [Fact] + public async Task Connection_Stream_FailingOperation_ThowsNetworkException() + { + using var server = SocketTestServer.SocketTestServerFactory(SocketImplementationType.Async, IPAddress.Loopback); + using SocketsConnectionFactory factory = new SocketsConnectionFactory(SocketType.Stream, ProtocolType.Tcp); + using Connection connection = await factory.ConnectAsync(server.EndPoint); + + connection.ConnectionProperties.TryGet(out Socket socket); + Stream stream = connection.Stream; + socket.Dispose(); + + Assert.Throws(() => stream.Read(new byte[1], 0, 1)); + Assert.Throws(() => stream.Write(new byte[1], 0, 1)); + } + + [Fact] + public async Task Connection_Pipe_FailingOperation_ThowsNetworkException() + { + using var server = SocketTestServer.SocketTestServerFactory(SocketImplementationType.Async, IPAddress.Loopback); + using SocketsConnectionFactory factory = new SocketsConnectionFactory(SocketType.Stream, ProtocolType.Tcp); + using Connection connection = await factory.ConnectAsync(server.EndPoint); + + connection.ConnectionProperties.TryGet(out Socket socket); + IDuplexPipe pipe = connection.Pipe; + socket.Dispose(); + + await Assert.ThrowsAsync(() => pipe.Input.ReadAsync().AsTask()); + await Assert.ThrowsAsync(() => pipe.Output.WriteAsync(new byte[1]).AsTask()); + } + + [Theory] + [InlineData(false, false)] + [InlineData(false, true)] + [InlineData(true, false)] + [InlineData(true, true)] + public async Task Connection_Dispose_ClosesSocket(bool disposeAsync, bool usePipe) + { + using var server = SocketTestServer.SocketTestServerFactory(SocketImplementationType.Async, IPAddress.Loopback); + using SocketsConnectionFactory factory = new SocketsConnectionFactory(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + Connection connection = await factory.ConnectAsync(server.EndPoint); + + Stream stream = usePipe ? null : connection.Stream; + if (usePipe) _ = connection.Pipe; + connection.ConnectionProperties.TryGet(out Socket socket); + + if (disposeAsync) await connection.DisposeAsync(); + else connection.Dispose(); + + Assert.False(socket.Connected); + + if (!usePipe) + { + // In this case we can also verify if the stream is disposed + Assert.Throws(() => stream.Write(new byte[1])); + } + } + + [Theory] + [InlineData(true, ConnectionCloseMethod.GracefulShutdown)] + [InlineData(true, ConnectionCloseMethod.Immediate)] + [InlineData(true, ConnectionCloseMethod.Abort)] + [InlineData(false, ConnectionCloseMethod.GracefulShutdown)] + [InlineData(false, ConnectionCloseMethod.Immediate)] + [InlineData(false, ConnectionCloseMethod.Abort)] + public async Task Connection_CloseAsync_ClosesSocket(bool usePipe, ConnectionCloseMethod method) + { + using var server = SocketTestServer.SocketTestServerFactory(SocketImplementationType.Async, IPAddress.Loopback); + using SocketsConnectionFactory factory = new SocketsConnectionFactory(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + Connection connection = await factory.ConnectAsync(server.EndPoint); + + Stream stream = null; + if (usePipe) + { + _ = connection.Pipe; + } + else + { + stream = connection.Stream; + } + + connection.ConnectionProperties.TryGet(out Socket socket); + + await connection.CloseAsync(method); + + Assert.Throws(() => socket.Send(new byte[1])); + + if (!usePipe) // No way to observe the stream if we work with the pipe + { + Assert.Throws(() => stream.Write(new byte[1])); + } + } + + // Test scenario based on: + // https://devblogs.microsoft.com/dotnet/system-io-pipelines-high-performance-io-in-net/ + [Theory(Timeout = 60000)] // Give 1 minute to fail, in case of a hang + [InlineData(30)] + [InlineData(500)] + [OuterLoop("Might run long")] + public async Task Connection_Pipe_ReadWrite_Integration(int totalLines) + { + using SocketsConnectionFactory factory = new SocketsConnectionFactory(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + using SocketTestServer echoServer = SocketTestServer.SocketTestServerFactory(SocketImplementationType.Async, IPAddress.Loopback); + + Socket serverSocket = null; + echoServer.Accepted += s => serverSocket = s; + + using Connection connection = await factory.ConnectAsync(echoServer.EndPoint); + + IDuplexPipe pipe = connection.Pipe; + + ConcurrentQueue linesSent = new ConcurrentQueue(); + Task writerTask = Task.Factory.StartNew(async () => + { + byte[] endl = Encoding.ASCII.GetBytes("\n"); + StringBuilder expectedLine = new StringBuilder(); + + for (int i = 0; i < totalLines; i++) + { + int words = i % 10 + 1; + for (int j = 0; j < words; j++) + { + string word = Guid.NewGuid() + " "; + Encoding.ASCII.GetBytes(word, pipe.Output); + expectedLine.Append(word); + } + linesSent.Enqueue(expectedLine.ToString()); + await pipe.Output.WriteAsync(endl); + expectedLine.Clear(); + } + + await pipe.Output.FlushAsync(); + + // This will also trigger completion on the reader. TODO: Fix + // await pipe.Output.CompleteAsync(); + }, TaskCreationOptions.LongRunning); + + // The test server should echo the data sent to it + + PipeReader reader = pipe.Input; + + int lineIndex = 0; + + void ProcessLine(ReadOnlySequence lineBuffer) + { + string line = Encoding.ASCII.GetString(lineBuffer); + Assert.True(linesSent.TryDequeue(out string expectedLine)); + Assert.Equal(expectedLine, line); + lineIndex++; + + // Received everything, shut down the server, so next read will complete: + if (lineIndex == totalLines) serverSocket.Shutdown(SocketShutdown.Both); + } + + while (true) + { + try + { + ReadResult result = await reader.ReadAsync(); + + ReadOnlySequence buffer = result.Buffer; + SequencePosition? position = null; + + // Stop reading if there's no more data coming + if (result.IsCompleted) + { + break; + } + + do + { + // Look for a EOL in the buffer + position = buffer.PositionOf((byte)'\n'); + + if (position != null) + { + // Process the line + ProcessLine(buffer.Slice(0, position.Value)); + + // Skip the line + the \n character (basically position) + buffer = buffer.Slice(buffer.GetPosition(1, position.Value)); + } + } + while (position != null); + + // Tell the PipeReader how much of the buffer we have consumed + reader.AdvanceTo(buffer.Start, buffer.End); + } + catch (SocketException) + { + // terminate + } + + } + + // Mark the PipeReader as complete + await reader.CompleteAsync(); + await writerTask; + + // TODO: If this is done in the end of writerTask the socket stops working + Assert.Equal(totalLines, lineIndex); + } + } +} diff --git a/src/libraries/System.Net.Connections/tests/System.Net.Connections.Tests/Sockets/SocketsConnectionFactoryTests_DerivedFactory.cs b/src/libraries/System.Net.Connections/tests/System.Net.Connections.Tests/Sockets/SocketsConnectionFactoryTests_DerivedFactory.cs new file mode 100644 index 0000000000000..766d21c988e4f --- /dev/null +++ b/src/libraries/System.Net.Connections/tests/System.Net.Connections.Tests/Sockets/SocketsConnectionFactoryTests_DerivedFactory.cs @@ -0,0 +1,81 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics.CodeAnalysis; +using System.IO; +using System.IO.Pipelines; +using System.Net.Connections; +using System.Net.Sockets; +using System.Net.Sockets.Tests; +using System.Threading.Tasks; +using Xunit; + +namespace System.Net.Connections.Tests +{ + [ConditionalClass(typeof(PlatformDetection), nameof(PlatformDetection.IsNotMonoRuntime))] + public class SocketsConnectionFactoryTests_DerivedFactory + { + private class CustomConnectionOptionsValues + { + public bool NoDelay { get; set; } + + public bool DualMode { get; set; } + } + + private class CustomConnectionOptions : IConnectionProperties + { + public CustomConnectionOptionsValues Values { get; } = new CustomConnectionOptionsValues(); + + public CustomConnectionOptions() + { + } + + public bool TryGet(Type propertyKey, [NotNullWhen(true)] out object property) + { + if (propertyKey == typeof(CustomConnectionOptionsValues)) + { + property = Values; + return true; + } + + property = null; + return false; + } + } + + private sealed class CustomFactory : SocketsConnectionFactory + { + public CustomFactory() : base(SocketType.Stream, ProtocolType.Tcp) + { + } + + protected override Socket CreateSocket(AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType, EndPoint endPoint, IConnectionProperties options) + { + Socket socket = new Socket(addressFamily, socketType, protocolType); + + if (options.TryGet(out CustomConnectionOptionsValues vals)) + { + socket.NoDelay = vals.NoDelay; + socket.DualMode = vals.DualMode; + } + + return socket; + } + } + + private readonly CustomFactory _factory = new CustomFactory(); + private readonly CustomConnectionOptions _options = new CustomConnectionOptions(); + private readonly IPEndPoint _endPoint = new IPEndPoint(IPAddress.IPv6Loopback, 0); + + [Fact] + public async Task DerivedFactory_CanShimSocket() + { + using var server = SocketTestServer.SocketTestServerFactory(SocketImplementationType.Async, _endPoint); + using Connection connection = await _factory.ConnectAsync(server.EndPoint, _options); + connection.ConnectionProperties.TryGet(out Socket socket); + + Assert.Equal(_options.Values.NoDelay, socket.NoDelay); + Assert.Equal(_options.Values.DualMode, socket.DualMode); + } + } +} diff --git a/src/libraries/System.Net.Connections/tests/System.Net.Connections.Tests/System.Net.Connections.Tests.csproj b/src/libraries/System.Net.Connections/tests/System.Net.Connections.Tests/System.Net.Connections.Tests.csproj index 9aa01126533b0..6426b8ea3149b 100644 --- a/src/libraries/System.Net.Connections/tests/System.Net.Connections.Tests/System.Net.Connections.Tests.csproj +++ b/src/libraries/System.Net.Connections/tests/System.Net.Connections.Tests/System.Net.Connections.Tests.csproj @@ -1,9 +1,10 @@ + $(NetCoreAppCurrent) + true - @@ -14,6 +15,20 @@ + + + + + + + + + + + + + + diff --git a/src/libraries/System.Net.Http/ref/System.Net.Http.cs b/src/libraries/System.Net.Http/ref/System.Net.Http.cs index 6a7cf68cc12ab..85997fa96cc5c 100644 --- a/src/libraries/System.Net.Http/ref/System.Net.Http.cs +++ b/src/libraries/System.Net.Http/ref/System.Net.Http.cs @@ -325,15 +325,6 @@ protected override void SerializeToStream(System.IO.Stream stream, System.Net.Tr protected override System.Threading.Tasks.Task SerializeToStreamAsync(System.IO.Stream stream, System.Net.TransportContext? context, System.Threading.CancellationToken cancellationToken) { throw null; } protected internal override bool TryComputeLength(out long length) { throw null; } } - public partial class SocketsHttpConnectionFactory : System.Net.Connections.ConnectionFactory - { - public SocketsHttpConnectionFactory() { } - public sealed override System.Threading.Tasks.ValueTask ConnectAsync(System.Net.EndPoint? endPoint, System.Net.Connections.IConnectionProperties? options = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } - public virtual System.Net.Sockets.Socket CreateSocket(System.Net.Http.HttpRequestMessage message, System.Net.EndPoint? endPoint, System.Net.Connections.IConnectionProperties options) { throw null; } - protected override void Dispose(bool disposing) { } - protected override System.Threading.Tasks.ValueTask DisposeAsyncCore() { throw null; } - public virtual System.Threading.Tasks.ValueTask EstablishConnectionAsync(System.Net.Http.HttpRequestMessage message, System.Net.EndPoint? endPoint, System.Net.Connections.IConnectionProperties options, System.Threading.CancellationToken cancellationToken) { throw null; } - } public sealed partial class SocketsHttpHandler : System.Net.Http.HttpMessageHandler { public SocketsHttpHandler() { } diff --git a/src/libraries/System.Net.Http/src/System.Net.Http.csproj b/src/libraries/System.Net.Http/src/System.Net.Http.csproj index ec91aeb408c9e..7794c3985bcf3 100644 --- a/src/libraries/System.Net.Http/src/System.Net.Http.csproj +++ b/src/libraries/System.Net.Http/src/System.Net.Http.csproj @@ -1,4 +1,4 @@ - + win true @@ -177,10 +177,7 @@ - - - - diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/BrowserHttpHandler/SocketsHttpConnectionFactory.cs b/src/libraries/System.Net.Http/src/System/Net/Http/BrowserHttpHandler/SocketsHttpConnectionFactory.cs deleted file mode 100644 index 664de02694ba6..0000000000000 --- a/src/libraries/System.Net.Http/src/System/Net/Http/BrowserHttpHandler/SocketsHttpConnectionFactory.cs +++ /dev/null @@ -1,19 +0,0 @@ -using System.Threading; -using System.Threading.Tasks; -using System.Net.Connections; -using System.Net.Sockets; - -namespace System.Net.Http -{ - public class SocketsHttpConnectionFactory : ConnectionFactory - { - public sealed override ValueTask ConnectAsync(EndPoint? endPoint, IConnectionProperties? options = null, CancellationToken cancellationToken = default) - => throw new NotImplementedException(); - - public virtual Socket CreateSocket(HttpRequestMessage message, EndPoint? endPoint, IConnectionProperties options) - => throw new NotImplementedException(); - - public virtual ValueTask EstablishConnectionAsync(HttpRequestMessage message, EndPoint? endPoint, IConnectionProperties options, CancellationToken cancellationToken) - => throw new NotImplementedException(); - } -} diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectHelper.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectHelper.cs index 97ec60c5c7cd2..587913e228933 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectHelper.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectHelper.cs @@ -3,10 +3,12 @@ using System.Diagnostics; using System.IO; +using System.IO.Pipelines; using System.Net.Connections; using System.Net.Quic; using System.Net.Security; using System.Net.Sockets; +using System.Runtime.ExceptionServices; using System.Security.Cryptography.X509Certificates; using System.Threading; using System.Threading.Tasks; @@ -76,7 +78,8 @@ public static Connection Connect(string host, int port, CancellationToken cancel throw CreateWrappedException(e, host, port, cancellationToken); } - return new SocketConnection(socket); + // Since we only do GracefulShutdown in SocketsHttpHandler code, Connection.FromStream() should match SocketConnection's behavior: + return Connection.FromStream(new NetworkStream(socket, ownsSocket: true), localEndPoint: socket.LocalEndPoint, remoteEndPoint: socket.RemoteEndPoint); } public static ValueTask EstablishSslConnectionAsync(SslClientAuthenticationOptions sslOptions, HttpRequestMessage request, bool async, Stream stream, CancellationToken cancellationToken) diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/DnsEndPointWithProperties.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/DnsEndPointWithProperties.cs index b99801c1fac15..ef4ee0b510297 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/DnsEndPointWithProperties.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/DnsEndPointWithProperties.cs @@ -9,18 +9,18 @@ namespace System.Net.Http // Passed to a connection factory, merges allocations for the DnsEndPoint and connection properties. internal sealed class DnsEndPointWithProperties : DnsEndPoint, IConnectionProperties { - public HttpRequestMessage InitialRequest { get; } + private readonly HttpRequestMessage _initialRequest; public DnsEndPointWithProperties(string host, int port, HttpRequestMessage initialRequest) : base(host, port) { - InitialRequest = initialRequest; + _initialRequest = initialRequest; } bool IConnectionProperties.TryGet(Type propertyKey, [NotNullWhen(true)] out object? property) { - if (propertyKey == typeof(DnsEndPointWithProperties)) + if (propertyKey == typeof(HttpRequestMessage)) { - property = this; + property = _initialRequest; return true; } diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionPool.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionPool.cs index 33a93b0755012..fedcbab4cba35 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionPool.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionPool.cs @@ -1265,11 +1265,13 @@ public ValueTask SendAsync(HttpRequestMessage request, bool } } + private static readonly SocketsConnectionFactory s_defaultConnectionFactory = new SocketsConnectionFactory(SocketType.Stream, ProtocolType.Tcp); + private ValueTask ConnectToTcpHostAsync(string host, int port, HttpRequestMessage initialRequest, bool async, CancellationToken cancellationToken) { if (async) { - ConnectionFactory connectionFactory = Settings._connectionFactory ?? SocketsHttpConnectionFactory.Default; + ConnectionFactory connectionFactory = Settings._connectionFactory ?? s_defaultConnectionFactory; var endPoint = new DnsEndPointWithProperties(host, port, initialRequest); return ConnectHelper.ConnectAsync(connectionFactory, endPoint, endPoint, cancellationToken); diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/SocketsHttpConnectionFactory.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/SocketsHttpConnectionFactory.cs deleted file mode 100644 index 34018a14170ac..0000000000000 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/SocketsHttpConnectionFactory.cs +++ /dev/null @@ -1,94 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Net.Connections; -using System.Net.Sockets; -using System.Runtime.ExceptionServices; -using System.Threading; -using System.Threading.Tasks; - -namespace System.Net.Http -{ - /// - /// The default connection factory used by , opening TCP connections. - /// - public class SocketsHttpConnectionFactory : ConnectionFactory - { - internal static SocketsHttpConnectionFactory Default { get; } = new SocketsHttpConnectionFactory(); - - /// - public sealed override ValueTask ConnectAsync(EndPoint? endPoint, IConnectionProperties? options = null, CancellationToken cancellationToken = default) - { - if (options == null || !options.TryGet(out DnsEndPointWithProperties? httpOptions)) - { - return ValueTask.FromException(ExceptionDispatchInfo.SetCurrentStackTrace(new HttpRequestException($"{nameof(SocketsHttpConnectionFactory)} requires a {nameof(DnsEndPointWithProperties)} property."))); - } - - return EstablishConnectionAsync(httpOptions!.InitialRequest, endPoint, options, cancellationToken); - } - - /// - /// Creates the socket to be used for a request. - /// - /// The request causing this socket to be opened. Once opened, it may be reused for many subsequent requests. - /// The EndPoint this socket will be connected to. - /// Properties, if any, that might change how the socket is initialized. - /// A new unconnected socket. - public virtual Socket CreateSocket(HttpRequestMessage message, EndPoint? endPoint, IConnectionProperties options) - { - return new Socket(SocketType.Stream, ProtocolType.Tcp); - } - - /// - /// Establishes a new connection for a request. - /// - /// The request causing this connection to be established. Once connected, it may be reused for many subsequent requests. - /// The EndPoint to connect to. - /// Properties, if any, that might change how the connection is made. - /// A cancellation token for the asynchronous operation. - /// A new open connection. - public virtual async ValueTask EstablishConnectionAsync(HttpRequestMessage message, EndPoint? endPoint, IConnectionProperties options, CancellationToken cancellationToken) - { - if (message == null) throw new ArgumentNullException(nameof(message)); - if (endPoint == null) throw new ArgumentNullException(nameof(endPoint)); - - Socket socket = CreateSocket(message, endPoint, options); - - try - { - using var args = new TaskSocketAsyncEventArgs(); - args.RemoteEndPoint = endPoint; - - if (socket.ConnectAsync(args)) - { - using (cancellationToken.UnsafeRegister(o => Socket.CancelConnectAsync((SocketAsyncEventArgs)o!), args)) - { - await args.Task.ConfigureAwait(false); - } - } - - if (args.SocketError != SocketError.Success) - { - Exception ex = args.SocketError == SocketError.OperationAborted && cancellationToken.IsCancellationRequested - ? (Exception)new OperationCanceledException(cancellationToken) - : new SocketException((int)args.SocketError); - - throw ex; - } - - socket.NoDelay = true; - return new SocketConnection(socket); - } - catch (SocketException socketException) - { - socket.Dispose(); - throw NetworkErrorHelper.MapSocketException(socketException); - } - catch - { - socket.Dispose(); - throw; - } - } - } -} diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/SocketsHttpHandler.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/SocketsHttpHandler.cs index fbe84c6a4a667..dbe8388b05c98 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/SocketsHttpHandler.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/SocketsHttpHandler.cs @@ -365,7 +365,7 @@ public bool EnableMultipleHttp2Connections /// /// When non-null, a custom factory used to open new TCP connections. - /// When null, a will be used. + /// When null, a will be used. /// public ConnectionFactory? ConnectionFactory { diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs index 49266bd125ff0..c2ab2dfb63720 100644 --- a/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs +++ b/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs @@ -152,7 +152,7 @@ public async Task CustomConnectionFactory_AsyncRequest_Success() [Fact] public async Task CustomConnectionFactory_SyncRequest_Fails() { - await using ConnectionFactory connectionFactory = new SocketsHttpConnectionFactory(); + await using ConnectionFactory connectionFactory = new SocketsConnectionFactory(SocketType.Stream, ProtocolType.Tcp); using SocketsHttpHandler handler = new SocketsHttpHandler { ConnectionFactory = connectionFactory @@ -161,7 +161,46 @@ public async Task CustomConnectionFactory_SyncRequest_Fails() using HttpClient client = CreateHttpClient(handler); HttpRequestException e = await Assert.ThrowsAnyAsync(() => client.GetStringAsync($"http://{Guid.NewGuid():N}.com/foo")); - NetworkException networkException = Assert.IsType(e.InnerException); + Assert.IsType(e.InnerException); + } + + class CustomConnectionFactory : SocketsConnectionFactory + { + public CustomConnectionFactory() : base(SocketType.Stream, ProtocolType.Tcp) { } + + public HttpRequestMessage LastHttpRequestMessage { get; private set; } + + public override ValueTask ConnectAsync(EndPoint endPoint, IConnectionProperties options = null, CancellationToken cancellationToken = default) + { + if (options.TryGet(out HttpRequestMessage message)) + { + LastHttpRequestMessage = message; + } + + return base.ConnectAsync(endPoint, options, cancellationToken); + } + } + + [Fact] + public Task CustomConnectionFactory_ConnectAsync_CanCaptureHttpRequestMessage() + { + return LoopbackServer.CreateClientAndServerAsync(async uri => + { + using var connectionFactory = new CustomConnectionFactory(); + using var handler = new SocketsHttpHandler() + { + ConnectionFactory = connectionFactory + }; + using HttpClient client = CreateHttpClient(handler); + + using var request = new HttpRequestMessage(HttpMethod.Get, uri); + + using HttpResponseMessage response = await client.SendAsync(request); + string content = await response.Content.ReadAsStringAsync(); + + Assert.Equal("OK", content); + Assert.Same(request, connectionFactory.LastHttpRequestMessage); + }, server => server.HandleRequestAsync(content: "OK")); } } diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/System.Net.Sockets.Tests.csproj b/src/libraries/System.Net.Sockets/tests/FunctionalTests/System.Net.Sockets.Tests.csproj index ef710218ea2ad..43662fa98f8df 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/System.Net.Sockets.Tests.csproj +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/System.Net.Sockets.Tests.csproj @@ -90,4 +90,4 @@ - + \ No newline at end of file