diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj index cffb127ec7..bd8a798f88 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj @@ -505,6 +505,9 @@ Resources\StringsHelper.cs + + Common\System\Diagnostics\CodeAnalysis.cs + diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs index 4c3ad107b4..e69e8c3fab 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs @@ -2,18 +2,23 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +#nullable enable + using System; using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; +using System.Threading; using System.Threading.Tasks; using Microsoft.Data.Common; namespace Microsoft.Data.SqlClient.SNI { - internal class TdsParserStateObjectManaged : TdsParserStateObject + internal sealed class TdsParserStateObjectManaged : TdsParserStateObject { - private SNIMarsConnection _marsConnection; - private SNIHandle _sessionHandle; - private SspiClientContextStatus _sspiClientContextStatus; + private SNIMarsConnection? _marsConnection; + private SNIHandle? _sessionHandle; + private SspiClientContextStatus? _sspiClientContextStatus; public TdsParserStateObjectManaged(TdsParser parser) : base(parser) { } @@ -21,8 +26,6 @@ internal TdsParserStateObjectManaged(TdsParser parser, TdsParserStateObject phys base(parser, physicalConnection, async) { } - internal SNIHandle Handle => _sessionHandle; - internal override uint Status => _sessionHandle != null ? _sessionHandle.Status : TdsEnums.SNI_UNINITIALIZED; internal override SessionHandle SessionHandle => SessionHandle.FromManagedSession(_sessionHandle); @@ -36,14 +39,24 @@ protected override bool CheckPacket(PacketHandle packet, TaskCompletionSource _sessionHandle.Status != TdsEnums.SNI_SUCCESS; - - internal override PacketHandle ReadSyncOverAsync(int timeoutRemaining, out uint error) + internal override bool IsFailedHandle() { - SNIHandle handle = Handle; - if (handle == null) + SNIHandle? sessionHandle = _sessionHandle; + if (sessionHandle is not null) { - throw ADP.ClosedConnectionError(); + return sessionHandle.Status != TdsEnums.SNI_SUCCESS; } + return true; + } + - error = handle.Receive(out SNIPacket packet, timeoutRemaining); + internal override PacketHandle ReadSyncOverAsync(int timeoutRemaining, out uint error) + { + SNIHandle sessionHandle = GetSessionSNIHandleHandleOrThrow(); + + error = sessionHandle.Receive(out SNIPacket packet, timeoutRemaining); - SqlClientEventSource.Log.TryTraceEvent("TdsParserStateObjectManaged.ReadSyncOverAsync | Info | State Object Id {0}, Session Id {1}", _objectID, _sessionHandle?.ConnectionId); + SqlClientEventSource.Log.TryTraceEvent("TdsParserStateObjectManaged.ReadSyncOverAsync | Info | State Object Id {0}, Session Id {1}", _objectID, sessionHandle.ConnectionId); #if DEBUG - SqlClientEventSource.Log.TryAdvancedTraceEvent("TdsParserStateObjectManaged.ReadSyncOverAsync | TRC | State Object Id {0}, Session Id {1}, Packet {2} received, Packet owner Id {3}, Packet dataLeft {4}", _objectID, _sessionHandle?.ConnectionId, packet?._id, packet?._owner.ConnectionId, packet?.DataLeft); + SqlClientEventSource.Log.TryAdvancedTraceEvent("TdsParserStateObjectManaged.ReadSyncOverAsync | TRC | State Object Id {0}, Session Id {1}, Packet {2} received, Packet owner Id {3}, Packet dataLeft {4}", _objectID, sessionHandle.ConnectionId, packet?._id, packet?._owner.ConnectionId, packet?.DataLeft); #endif return PacketHandle.FromManagedPacket(packet); } @@ -195,22 +219,31 @@ internal override void ReleasePacket(PacketHandle syncReadPacket) #if DEBUG SqlClientEventSource.Log.TryAdvancedTraceEvent("TdsParserStateObjectManaged.ReleasePacket | TRC | State Object Id {0}, Session Id {1}, Packet {2} will be released, Packet Owner Id {3}, Packet dataLeft {4}", _objectID, _sessionHandle?.ConnectionId, packet?._id, packet?._owner.ConnectionId, packet?.DataLeft); #endif - if (packet != null) + if (packet is not null) { - SNIHandle handle = Handle; - handle.ReturnPacket(packet); + SNIHandle? sessionHandle = _sessionHandle; + if (sessionHandle is not null) + { + sessionHandle.ReturnPacket(packet); + } + else + { + // clear the packet and drop it to GC because we no longer know how to return it to the correct owner + // this can only happen if a packet is in-flight when the _sessionHandle is cleared + packet.Release(); + } } } internal override uint CheckConnection() { - SNIHandle handle = Handle; - return handle == null ? TdsEnums.SNI_SUCCESS : handle.CheckConnection(); + SNIHandle? handle = GetSessionSNIHandleHandleOrThrow(); + return handle is null ? TdsEnums.SNI_SUCCESS : handle.CheckConnection(); } internal override PacketHandle ReadAsync(SessionHandle handle, out uint error) { - SNIPacket packet = null; + SNIPacket? packet = null; error = handle.ManagedHandle.ReceiveAsync(ref packet); SqlClientEventSource.Log.TryTraceEvent("TdsParserStateObjectManaged.ReadAsync | Info | State Object Id {0}, Session Id {1}, Packet DataLeft {2}", _objectID, _sessionHandle?.ConnectionId, packet?.DataLeft); @@ -232,20 +265,21 @@ internal override PacketHandle CreateAndSetAttentionPacket() internal override uint WritePacket(PacketHandle packetHandle, bool sync) { - uint result; - SNIHandle handle = Handle; - SNIPacket packet = packetHandle.ManagedPacket; + uint result = TdsEnums.SNI_UNINITIALIZED; + SNIHandle sessionHandle = GetSessionSNIHandleHandleOrThrow(); + SNIPacket? packet = packetHandle.ManagedPacket; + if (sync) { - result = handle.Send(packet); - handle.ReturnPacket(packet); + result = sessionHandle.Send(packet); + sessionHandle.ReturnPacket(packet); } else { - result = handle.SendAsync(packet); + result = sessionHandle.SendAsync(packet); } - - SqlClientEventSource.Log.TryTraceEvent("TdsParserStateObjectManaged.WritePacket | Info | Session Id {0}, SendAsync Result {1}", handle?.ConnectionId, result); + + SqlClientEventSource.Log.TryTraceEvent("TdsParserStateObjectManaged.WritePacket | Info | Session Id {0}, SendAsync Result {1}", sessionHandle.ConnectionId, result); return result; } @@ -264,12 +298,12 @@ internal override bool IsValidPacket(PacketHandle packet) internal override PacketHandle GetResetWritePacket(int dataSize) { - SNIHandle handle = Handle; - SNIPacket packet = handle.RentPacket(headerSize: handle.ReserveHeaderSize, dataSize: dataSize); + SNIHandle sessionHandle = GetSessionSNIHandleHandleOrThrow(); + SNIPacket packet = sessionHandle.RentPacket(headerSize: sessionHandle.ReserveHeaderSize, dataSize: dataSize); #if DEBUG Debug.Assert(packet.IsActive, "packet is not active, a serious pooling error may have occurred"); #endif - Debug.Assert(packet.ReservedHeaderSize == handle.ReserveHeaderSize, "failed to reserve header"); + Debug.Assert(packet.ReservedHeaderSize == sessionHandle.ReserveHeaderSize, "failed to reserve header"); return PacketHandle.FromManagedPacket(packet); } @@ -285,23 +319,24 @@ internal override void SetPacketData(PacketHandle packet, byte[] buffer, int byt internal override uint SniGetConnectionId(ref Guid clientConnectionId) { - clientConnectionId = Handle.ConnectionId; + clientConnectionId = GetSessionSNIHandleHandleOrThrow().ConnectionId; SqlClientEventSource.Log.TryTraceEvent("TdsParserStateObjectManaged.GetConnectionId | Info | Session Id {0}", clientConnectionId); return TdsEnums.SNI_SUCCESS; } internal override uint DisableSsl() { - SNIHandle handle = Handle; - SqlClientEventSource.Log.TryTraceEvent("TdsParserStateObjectManaged.DisableSsl | Info | Session Id {0}", handle?.ConnectionId); - handle.DisableSsl(); + SNIHandle sessionHandle = GetSessionSNIHandleHandleOrThrow(); + SqlClientEventSource.Log.TryTraceEvent("TdsParserStateObjectManaged.DisableSsl | Info | Session Id {0}", sessionHandle.ConnectionId); + sessionHandle.DisableSsl(); return TdsEnums.SNI_SUCCESS; } internal override uint EnableMars(ref uint info) { - _marsConnection = new SNIMarsConnection(Handle); - SqlClientEventSource.Log.TryTraceEvent("TdsParserStateObjectManaged.EnableMars | Info | State Object Id {0}, Session Id {1}", _objectID, _sessionHandle?.ConnectionId); + SNIHandle sessionHandle = GetSessionSNIHandleHandleOrThrow(); + _marsConnection = new SNIMarsConnection(sessionHandle); + SqlClientEventSource.Log.TryTraceEvent("TdsParserStateObjectManaged.EnableMars | Info | State Object Id {0}, Session Id {1}", _objectID, sessionHandle.ConnectionId); if (_marsConnection.StartReceive() == TdsEnums.SNI_SUCCESS_IO_PENDING) { @@ -313,28 +348,28 @@ internal override uint EnableMars(ref uint info) internal override uint EnableSsl(ref uint info) { - SNIHandle handle = Handle; + SNIHandle sessionHandle = GetSessionSNIHandleHandleOrThrow(); try { - SqlClientEventSource.Log.TryTraceEvent("TdsParserStateObjectManaged.EnableSsl | Info | Session Id {0}", handle?.ConnectionId); - return handle.EnableSsl(info); + SqlClientEventSource.Log.TryTraceEvent("TdsParserStateObjectManaged.EnableSsl | Info | Session Id {0}", sessionHandle.ConnectionId); + return sessionHandle.EnableSsl(info); } catch (Exception e) { - SqlClientEventSource.Log.TryTraceEvent("TdsParserStateObjectManaged.EnableSsl | Err | Session Id {0}, SNI Handshake failed with exception: {1}", handle?.ConnectionId, e?.Message); + SqlClientEventSource.Log.TryTraceEvent("TdsParserStateObjectManaged.EnableSsl | Err | Session Id {0}, SNI Handshake failed with exception: {1}", sessionHandle.ConnectionId, e.Message); return SNICommon.ReportSNIError(SNIProviders.SSL_PROV, SNICommon.HandshakeFailureError, e); } } internal override uint SetConnectionBufferSize(ref uint unsignedPacketSize) { - Handle.SetBufferSize((int)unsignedPacketSize); + GetSessionSNIHandleHandleOrThrow().SetBufferSize((int)unsignedPacketSize); return TdsEnums.SNI_SUCCESS; } internal override uint GenerateSspiClientContext(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength, byte[][] _sniSpnBuffer) { - if (_sspiClientContextStatus == null) + if (_sspiClientContextStatus is null) { _sspiClientContextStatus = new SspiClientContextStatus(); } @@ -347,8 +382,22 @@ internal override uint GenerateSspiClientContext(byte[] receivedBuff, uint recei internal override uint WaitForSSLHandShakeToComplete(out int protocolVersion) { - protocolVersion = Handle.ProtocolVersion; + protocolVersion = GetSessionSNIHandleHandleOrThrow().ProtocolVersion; return 0; } + + private SNIHandle GetSessionSNIHandleHandleOrThrow() + { + SNIHandle? sessionHandle = _sessionHandle; + if (sessionHandle is null) + { + ThrowClosedConnection(); + } + return sessionHandle; + } + + [DoesNotReturn] + [MethodImpl(MethodImplOptions.NoInlining)] // this forces the exception throwing code not to be inlined for performance + private void ThrowClosedConnection() => throw ADP.ClosedConnectionError(); } } diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Resources/Strings.Designer.cs b/src/Microsoft.Data.SqlClient/netcore/src/Resources/Strings.Designer.cs index 9f3d2fec91..7a7ff4a367 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Resources/Strings.Designer.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Resources/Strings.Designer.cs @@ -1914,6 +1914,15 @@ internal static string SNI_ERROR_9 { } } + /// + /// Looks up a localized string similar to Incorrect physicalConnection type. + /// + internal static string SNI_IncorrectPhysicalConnectionType { + get { + return ResourceManager.GetString("SNI_IncorrectPhysicalConnectionType", resourceCulture); + } + } + /// /// Looks up a localized string similar to HTTP Provider. /// diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Resources/Strings.resx b/src/Microsoft.Data.SqlClient/netcore/src/Resources/Strings.resx index 433c6a7684..a77dd0aef1 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Resources/Strings.resx +++ b/src/Microsoft.Data.SqlClient/netcore/src/Resources/Strings.resx @@ -1935,4 +1935,7 @@ Connection timed out while retrieving an access token using '{0}' authentication method. Last error: {1}: {2} + + Incorrect physicalConnection type. + diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/AdapterUtil.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/AdapterUtil.cs index 9f59d4a08a..e6b23a5c68 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/AdapterUtil.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/AdapterUtil.cs @@ -1565,6 +1565,9 @@ internal static ArgumentOutOfRangeException InvalidIsolationLevel(IsolationLevel return InvalidEnumerationValue(typeof(IsolationLevel), (int)value); } + // ConnectionUtil + internal static Exception IncorrectPhysicalConnectionType() => new ArgumentException(StringsHelper.GetString(StringsHelper.SNI_IncorrectPhysicalConnectionType)); + // IDataParameter.Direction internal static ArgumentOutOfRangeException InvalidParameterDirection(ParameterDirection value) { diff --git a/src/Microsoft.Data.SqlClient/src/System/Diagnostics/CodeAnalysis.cs b/src/Microsoft.Data.SqlClient/src/System/Diagnostics/CodeAnalysis.cs new file mode 100644 index 0000000000..256f7cd1e0 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/System/Diagnostics/CodeAnalysis.cs @@ -0,0 +1,96 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace System.Diagnostics.CodeAnalysis +{ +#if NETSTANDARD2_0 + [AttributeUsage(AttributeTargets.Field | AttributeTargets.Parameter | AttributeTargets.Property)] + internal sealed class AllowNullAttribute : Attribute + { + } + + [AttributeUsage(AttributeTargets.Field | AttributeTargets.Parameter | AttributeTargets.Property)] + internal sealed class DisallowNullAttribute : Attribute + { + } + + [AttributeUsage(AttributeTargets.Method)] + internal sealed class DoesNotReturnAttribute : Attribute + { + } + + [AttributeUsage(AttributeTargets.Parameter)] + internal sealed class DoesNotReturnIfAttribute : Attribute + { + public DoesNotReturnIfAttribute(bool parameterValue) => ParameterValue = parameterValue; + public bool ParameterValue { get; } + } + + [AttributeUsage(AttributeTargets.Field | AttributeTargets.Parameter | AttributeTargets.Property | AttributeTargets.ReturnValue)] + internal sealed class MaybeNullAttribute : Attribute + { + } + + [AttributeUsage(AttributeTargets.Parameter)] + internal sealed class MaybeNullWhenAttribute : Attribute + { + public MaybeNullWhenAttribute(bool returnValue) => ReturnValue = returnValue; + public bool ReturnValue { get; } + } + + [AttributeUsage(AttributeTargets.Field | AttributeTargets.Parameter | AttributeTargets.Property | AttributeTargets.ReturnValue)] + internal sealed class NotNullAttribute : Attribute + { + } + + [AttributeUsage(AttributeTargets.Parameter | AttributeTargets.Property | AttributeTargets.ReturnValue, AllowMultiple = true)] + internal sealed class NotNullIfNotNullAttribute : Attribute + { + public NotNullIfNotNullAttribute(string parameterName) => ParameterName = parameterName; + public string ParameterName { get; } + } + + [AttributeUsage(AttributeTargets.Parameter)] + internal sealed class NotNullWhenAttribute : Attribute + { + public NotNullWhenAttribute(bool returnValue) => ReturnValue = returnValue; + public bool ReturnValue { get; } + } +#endif + +#if !NET5_0_OR_GREATER + [AttributeUsage(AttributeTargets.Method | AttributeTargets.Property, AllowMultiple = true, Inherited = false)] + internal sealed class MemberNotNullAttribute : Attribute + { + public MemberNotNullAttribute(string member) => Members = new string[] + { + member + }; + + public MemberNotNullAttribute(params string[] members) => Members = members; + + public string[] Members { get; } + } + + [AttributeUsage(AttributeTargets.Method | AttributeTargets.Property, AllowMultiple = true, Inherited = false)] + internal sealed class MemberNotNullWhenAttribute : Attribute + { + public MemberNotNullWhenAttribute(bool returnValue, string member) + { + ReturnValue = returnValue; + Members = new string[1] { member }; + } + + public MemberNotNullWhenAttribute(bool returnValue, params string[] members) + { + ReturnValue = returnValue; + Members = members; + } + + public bool ReturnValue { get; } + + public string[] Members { get; } + } +#endif +}