diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIHandle.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIHandle.cs index 15466f6418..019ecf2b23 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIHandle.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIHandle.cs @@ -86,6 +86,7 @@ internal abstract class SNIHandle /// public abstract Guid ConnectionId { get; } + public virtual int ReserveHeaderSize => 0; #if DEBUG /// /// Test handle for killing underlying connection diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsConnection.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsConnection.cs index f0629c8eac..481f3ce10b 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsConnection.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsConnection.cs @@ -106,6 +106,12 @@ public uint SendAsync(SNIPacket packet, SNIAsyncCallback callback) /// SNI error code public uint ReceiveAsync(ref SNIPacket packet) { + if (packet != null) + { + packet.Release(); + packet = null; + } + lock (this) { return _lowerHandle.ReceiveAsync(ref packet); @@ -137,7 +143,7 @@ public void HandleReceiveError(SNIPacket packet) handle.HandleReceiveError(packet); } } - packet?.Dispose(); + packet?.Release(); } /// @@ -187,8 +193,6 @@ public void HandleReceiveComplete(SNIPacket packet, uint sniErrorCode) if (bytesTaken == 0) { - packet.Dispose(); - packet = null; sniErrorCode = ReceiveAsync(ref packet); if (sniErrorCode == TdsEnums.SNI_SUCCESS_IO_PENDING) @@ -204,7 +208,7 @@ public void HandleReceiveComplete(SNIPacket packet, uint sniErrorCode) _currentHeader.Read(_headerBytes); _dataBytesLeft = (int)_currentHeader.length; - _currentPacket = new SNIPacket((int)_currentHeader.length); + _currentPacket = new SNIPacket(headerSize: 0, dataSize: (int)_currentHeader.length); } currentHeader = _currentHeader; @@ -219,8 +223,6 @@ public void HandleReceiveComplete(SNIPacket packet, uint sniErrorCode) if (_dataBytesLeft > 0) { - packet.Dispose(); - packet = null; sniErrorCode = ReceiveAsync(ref packet); if (sniErrorCode == TdsEnums.SNI_SUCCESS_IO_PENDING) @@ -276,8 +278,6 @@ public void HandleReceiveComplete(SNIPacket packet, uint sniErrorCode) { if (packet.DataLeft == 0) { - packet.Dispose(); - packet = null; sniErrorCode = ReceiveAsync(ref packet); if (sniErrorCode == TdsEnums.SNI_SUCCESS_IO_PENDING) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsHandle.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsHandle.cs index b9cfce6250..00af84d656 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsHandle.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsHandle.cs @@ -12,7 +12,7 @@ namespace Microsoft.Data.SqlClient.SNI /// /// MARS handle /// - internal class SNIMarsHandle : SNIHandle + internal sealed class SNIMarsHandle : SNIHandle { private const uint ACK_THRESHOLD = 2; @@ -37,24 +37,14 @@ internal class SNIMarsHandle : SNIHandle /// /// Connection ID /// - public override Guid ConnectionId - { - get - { - return _connectionId; - } - } + public override Guid ConnectionId => _connectionId; /// /// Handle status /// - public override uint Status - { - get - { - return _status; - } - } + public override uint Status => _status; + + public override int ReserveHeaderSize => SNISMUXHeader.HEADER_LENGTH; /// /// Dispose object @@ -93,21 +83,22 @@ public SNIMarsHandle(SNIMarsConnection connection, ushort sessionId, object call /// SMUX header flags private void SendControlPacket(SNISMUXFlags flags) { - Span headerBytes = stackalloc byte[SNISMUXHeader.HEADER_LENGTH]; + SNIPacket packet = new SNIPacket(headerSize: SNISMUXHeader.HEADER_LENGTH, dataSize: 0); lock (this) { - GetSMUXHeaderBytes(0, flags, headerBytes); + SetupSMUXHeader(0, flags); + _currentHeader.Write(packet.GetHeaderBuffer(SNISMUXHeader.HEADER_LENGTH)); + packet.SetHeaderActive(); } - SNIPacket packet = new SNIPacket(SNISMUXHeader.HEADER_LENGTH); - packet.AppendData(headerBytes); - _connection.Send(packet); } - private void GetSMUXHeaderBytes(int length, SNISMUXFlags flags, Span bytes) + private void SetupSMUXHeader(int length, SNISMUXFlags flags) { + Debug.Assert(Monitor.IsEntered(this), "must take lock on self before updating mux header"); + _currentHeader.SMID = 83; _currentHeader.flags = (byte)flags; _currentHeader.sessionId = _sessionId; @@ -115,27 +106,22 @@ private void GetSMUXHeaderBytes(int length, SNISMUXFlags flags, Span bytes _currentHeader.sequenceNumber = ((flags == SNISMUXFlags.SMUX_FIN) || (flags == SNISMUXFlags.SMUX_ACK)) ? _sequenceNumber - 1 : _sequenceNumber++; _currentHeader.highwater = _receiveHighwater; _receiveHighwaterLastAck = _currentHeader.highwater; - - _currentHeader.Write(bytes); } /// /// Generate a packet with SMUX header /// /// SNI packet - /// Encapsulated SNI packet - private SNIPacket GetSMUXEncapsulatedPacket(SNIPacket packet) + /// The packet with the SMUx header set. + private SNIPacket SetPacketSMUXHeader(SNIPacket packet) { - uint xSequenceNumber = _sequenceNumber; - Span header = stackalloc byte[SNISMUXHeader.HEADER_LENGTH]; - GetSMUXHeaderBytes(packet.Length, SNISMUXFlags.SMUX_DATA, header); + Debug.Assert(packet.ReservedHeaderSize == SNISMUXHeader.HEADER_LENGTH, "mars handle attempting to mux packet without mux reservation"); - SNIPacket smuxPacket = new SNIPacket(SNISMUXHeader.HEADER_LENGTH + packet.Length); - smuxPacket.AppendData(header); - smuxPacket.AppendPacket(packet); - packet.Dispose(); + SetupSMUXHeader(packet.Length, SNISMUXFlags.SMUX_DATA); + _currentHeader.Write(packet.GetHeaderBuffer(SNISMUXHeader.HEADER_LENGTH)); + packet.SetHeaderActive(); - return smuxPacket; + return packet; } /// @@ -145,6 +131,8 @@ private SNIPacket GetSMUXEncapsulatedPacket(SNIPacket packet) /// SNI error code public override uint Send(SNIPacket packet) { + Debug.Assert(packet.ReservedHeaderSize == SNISMUXHeader.HEADER_LENGTH, "mars handle attempting to send muxed packet without mux reservation in Send"); + while (true) { lock (this) @@ -163,9 +151,12 @@ public override uint Send(SNIPacket packet) } } - SNIPacket encapsulatedPacket = GetSMUXEncapsulatedPacket(packet); - - return _connection.Send(encapsulatedPacket); + SNIPacket muxedPacket = null; + lock (this) + { + muxedPacket = SetPacketSMUXHeader(packet); + } + return _connection.Send(muxedPacket); } /// @@ -176,6 +167,8 @@ public override uint Send(SNIPacket packet) /// SNI error code private uint InternalSendAsync(SNIPacket packet, SNIAsyncCallback callback) { + Debug.Assert(packet.ReservedHeaderSize == SNISMUXHeader.HEADER_LENGTH, "mars handle attempting to send muxed packet without mux reservation in InternalSendAsync"); + lock (this) { if (_sequenceNumber >= _sendHighwater) @@ -183,18 +176,9 @@ private uint InternalSendAsync(SNIPacket packet, SNIAsyncCallback callback) return TdsEnums.SNI_QUEUE_FULL; } - SNIPacket encapsulatedPacket = GetSMUXEncapsulatedPacket(packet); - - if (callback != null) - { - encapsulatedPacket.SetCompletionCallback(callback); - } - else - { - encapsulatedPacket.SetCompletionCallback(HandleSendComplete); - } - - return _connection.SendAsync(encapsulatedPacket, callback); + SNIPacket muxedPacket = SetPacketSMUXHeader(packet); + muxedPacket.SetCompletionCallback(callback ?? HandleSendComplete); + return _connection.SendAsync(muxedPacket, callback); } } diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNINpHandle.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNINpHandle.cs index 2716bde270..2957881405 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNINpHandle.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNINpHandle.cs @@ -15,7 +15,7 @@ namespace Microsoft.Data.SqlClient.SNI /// /// Named Pipe connection handle /// - internal class SNINpHandle : SNIHandle + internal sealed class SNINpHandle : SNIHandle { internal const string DefaultPipePath = @"sql\query"; // e.g. \\HOSTNAME\pipe\sql\query private const int MAX_PIPE_INSTANCES = 255; @@ -26,6 +26,7 @@ internal class SNINpHandle : SNIHandle private Stream _stream; private NamedPipeClientStream _pipeStream; private SslOverTdsStream _sslOverTdsStream; + private SslStream _sslStream; private SNIAsyncCallback _receiveCallback; private SNIAsyncCallback _sendCallback; @@ -150,7 +151,7 @@ public override uint Receive(out SNIPacket packet, int timeout) packet = null; try { - packet = new SNIPacket(_bufferSize); + packet = new SNIPacket(headerSize: 0, dataSize: _bufferSize); packet.ReadFromStream(_stream); if (packet.Length == 0) @@ -174,8 +175,8 @@ public override uint Receive(out SNIPacket packet, int timeout) public override uint ReceiveAsync(ref SNIPacket packet) { - packet = new SNIPacket(_bufferSize); - + packet = new SNIPacket(headerSize: 0, dataSize: _bufferSize); + try { packet.ReadFromStreamAsync(_stream, _receiveCallback); diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIPacket.NetCoreApp.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIPacket.NetCoreApp.cs index e6a35caeda..f6fa20907d 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIPacket.NetCoreApp.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIPacket.NetCoreApp.cs @@ -24,8 +24,8 @@ async Task ReadFromStreamAsync(SNIPacket packet, SNIAsyncCallback cb, ValueTask< bool error = false; try { - packet._length = await valueTask.ConfigureAwait(false); - if (packet._length == 0) + packet._dataLength = await valueTask.ConfigureAwait(false); + if (packet._dataLength == 0) { SNILoadHandle.SingletonInstance.LastError = new SNIError(SNIProviders.TCP_PROV, 0, SNICommon.ConnTerminatedError, string.Empty); error = true; @@ -45,13 +45,13 @@ async Task ReadFromStreamAsync(SNIPacket packet, SNIAsyncCallback cb, ValueTask< cb(packet, error ? TdsEnums.SNI_ERROR : TdsEnums.SNI_SUCCESS); } - ValueTask vt = stream.ReadAsync(new Memory(_data, 0, _capacity), CancellationToken.None); + ValueTask vt = stream.ReadAsync(new Memory(_data, _headerLength, _dataCapacity), CancellationToken.None); if (vt.IsCompletedSuccessfully) { - _length = vt.Result; + _dataLength = vt.Result; // Zero length to go via async local function as is error condition - if (_length > 0) + if (_dataLength > 0) { callback(this, TdsEnums.SNI_SUCCESS); @@ -61,7 +61,7 @@ async Task ReadFromStreamAsync(SNIPacket packet, SNIAsyncCallback cb, ValueTask< else { // Avoid consuming the same instance twice. - vt = new ValueTask(_length); + vt = new ValueTask(_dataLength); } } @@ -96,11 +96,11 @@ async Task WriteToStreamAsync(SNIPacket packet, SNIAsyncCallback cb, SNIProvider if (disposeAfter) { - packet.Dispose(); + packet.Release(); } } - ValueTask vt = stream.WriteAsync(new Memory(_data, 0, _length), CancellationToken.None); + ValueTask vt = stream.WriteAsync(new Memory(_data, _headerLength, _dataLength), CancellationToken.None); if (vt.IsCompletedSuccessfully) { @@ -111,7 +111,7 @@ async Task WriteToStreamAsync(SNIPacket packet, SNIAsyncCallback cb, SNIProvider if (disposeAfterWriteAsync) { - Dispose(); + Release(); } // Completed diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIPacket.NetStandard.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIPacket.NetStandard.cs index 2a3cf12670..68b1ba7313 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIPacket.NetStandard.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIPacket.NetStandard.cs @@ -24,8 +24,8 @@ async Task ReadFromStreamAsync(SNIPacket packet, SNIAsyncCallback cb, Task bool error = false; try { - packet._length = await task.ConfigureAwait(false); - if (packet._length == 0) + packet._dataLength = await task.ConfigureAwait(false); + if (packet._dataLength == 0) { SNILoadHandle.SingletonInstance.LastError = new SNIError(SNIProviders.TCP_PROV, 0, SNICommon.ConnTerminatedError, string.Empty); error = true; @@ -45,13 +45,13 @@ async Task ReadFromStreamAsync(SNIPacket packet, SNIAsyncCallback cb, Task cb(packet, error ? TdsEnums.SNI_ERROR : TdsEnums.SNI_SUCCESS); } - Task t = stream.ReadAsync(_data, 0, _capacity, CancellationToken.None); + Task t = stream.ReadAsync(_data, _headerLength, _dataCapacity, CancellationToken.None); if ((t.Status & TaskStatus.RanToCompletion) != 0) { - _length = t.Result; + _dataLength = t.Result; // Zero length to go via async local function as is error condition - if (_length > 0) + if (_dataLength > 0) { callback(this, TdsEnums.SNI_SUCCESS); @@ -91,11 +91,11 @@ async Task WriteToStreamAsync(SNIPacket packet, SNIAsyncCallback cb, SNIProvider if (disposeAfter) { - packet.Dispose(); + packet.Release(); } } - Task t = stream.WriteAsync(_data, 0, _length, CancellationToken.None); + Task t = stream.WriteAsync(_data, _headerLength, _dataLength, CancellationToken.None); if ((t.Status & TaskStatus.RanToCompletion) != 0) { @@ -106,7 +106,7 @@ async Task WriteToStreamAsync(SNIPacket packet, SNIAsyncCallback cb, SNIProvider if (disposeAfterWriteAsync) { - Dispose(); + Release(); } // Completed diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIPacket.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIPacket.cs index 17437880a2..0ff733d26a 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIPacket.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIPacket.cs @@ -4,6 +4,7 @@ using System; using System.Buffers; +using System.Diagnostics; using System.IO; namespace Microsoft.Data.SqlClient.SNI @@ -11,38 +12,19 @@ namespace Microsoft.Data.SqlClient.SNI /// /// SNI Packet /// - internal partial class SNIPacket : IDisposable, IEquatable + internal sealed partial class SNIPacket { + private int _dataLength; // the length of the data in the data segment, advanced by Append-ing data, does not include smux header length + private int _dataCapacity; // the total capacity requested, if the array is rented this may be less than the _data.Length, does not include smux header length + private int _dataOffset; // the start point of the data in the data segment, advanced by Take-ing data + private int _headerLength; // the amount of space at the start of the array reserved for the smux header, this is zeroed in SetHeader + // _headerOffset is not needed because it is always 0 private byte[] _data; - private int _length; - private int _capacity; - private int _offset; - private string _description; private SNIAsyncCallback _completionCallback; - private bool _isBufferFromArrayPool; - - public SNIPacket() { } - - public SNIPacket(int capacity) - { - Allocate(capacity); - } - - /// - /// Packet description (used for debugging) - /// - public string Description + public SNIPacket(int headerSize, int dataSize) { - get - { - return _description; - } - - set - { - _description = value; - } + Allocate(headerSize, dataSize); } public bool HasCompletionCallback => !(_completionCallback is null); @@ -54,17 +36,19 @@ public string Description /// /// Length of data left to process /// - public int DataLeft => (_length - _offset); + public int DataLeft => (_dataLength - _dataOffset); /// /// Length of data /// - public int Length => _length; + public int Length => _dataLength; /// /// Packet validity /// - public bool IsInvalid => (_data == null); + public bool IsInvalid => _data is null; + + public int ReservedHeaderSize => _headerLength; /// /// Set async completion callback @@ -87,67 +71,26 @@ public void InvokeCompletionCallback(uint sniErrorCode) /// /// Allocate space for data /// - /// Length of byte array to be allocated - public void Allocate(int capacity) + /// Length of packet header + /// Length of byte array to be allocated + private void Allocate(int headerLength, int dataLength) { - if (_data != null && _data.Length < capacity) - { - if (_isBufferFromArrayPool) - { - ArrayPool.Shared.Return(_data); - } - _data = null; - } - - if (_data == null) - { - _data = ArrayPool.Shared.Rent(capacity); - _isBufferFromArrayPool = true; - } - - _capacity = capacity; - _length = 0; - _offset = 0; + _data = ArrayPool.Shared.Rent(headerLength + dataLength); + _dataCapacity = dataLength; + _dataLength = 0; + _dataOffset = 0; + _headerLength = headerLength; } /// - /// Clone packet - /// - /// Cloned packet - public SNIPacket Clone() - { - SNIPacket packet = new SNIPacket(_capacity); - Buffer.BlockCopy(_data, 0, packet._data, 0, _capacity); - packet._length = _length; - packet._description = _description; - packet._completionCallback = _completionCallback; - - return packet; - } - - /// - /// Get packet data + /// Read packet data into a buffer without removing it from the packet /// /// Buffer - /// Data in packet + /// Number of bytes read from the packet into the buffer public void GetData(byte[] buffer, ref int dataSize) { - Buffer.BlockCopy(_data, 0, buffer, 0, _length); - dataSize = _length; - } - - /// - /// Set packet data - /// - /// Data - /// Length - public void SetData(byte[] data, int length) - { - _data = data; - _length = length; - _capacity = data.Length; - _offset = 0; - _isBufferFromArrayPool = false; + Buffer.BlockCopy(_data, _headerLength, buffer, 0, _dataLength); + dataSize = _dataLength; } /// @@ -158,8 +101,8 @@ public void SetData(byte[] data, int length) /// Amount of data taken public int TakeData(SNIPacket packet, int size) { - int dataSize = TakeData(packet._data, packet._length, size); - packet._length += dataSize; + int dataSize = TakeData(packet._data, packet._headerLength + packet._dataLength, size); + packet._dataLength += dataSize; return dataSize; } @@ -170,50 +113,50 @@ public int TakeData(SNIPacket packet, int size) /// Size public void AppendData(byte[] data, int size) { - Buffer.BlockCopy(data, 0, _data, _length, size); - _length += size; - } - - public void AppendData(ReadOnlySpan data) - { - data.CopyTo(_data.AsSpan(_length)); - _length += data.Length; - } - - /// - /// Append another packet - /// - /// Packet - public void AppendPacket(SNIPacket packet) - { - Buffer.BlockCopy(packet._data, 0, _data, _length, packet._length); - _length += packet._length; + Buffer.BlockCopy(data, 0, _data, _headerLength + _dataLength, size); + _dataLength += size; } /// - /// Take data from packet and advance offset + /// Read data from the packet into the buffer at dataOffset for zize and then remove that data from the packet /// /// Buffer - /// Data offset - /// Size + /// Data offset to write data at + /// Number of bytes to read from the packet into the buffer /// public int TakeData(byte[] buffer, int dataOffset, int size) { - if (_offset >= _length) + if (_dataOffset >= _dataLength) { return 0; } - if (_offset + size > _length) + if (_dataOffset + size > _dataLength) { - size = _length - _offset; + size = _dataLength - _dataOffset; } - Buffer.BlockCopy(_data, _offset, buffer, dataOffset, size); - _offset += size; + Buffer.BlockCopy(_data, _headerLength + _dataOffset, buffer, dataOffset, size); + _dataOffset += size; return size; } + public Span GetHeaderBuffer(int headerSize) + { + Debug.Assert(_dataOffset == 0, "requested packet header buffer from partially consumed packet"); + Debug.Assert(headerSize > 0, "requested packet header buffer of 0 length"); + Debug.Assert(_headerLength == headerSize, "requested packet header of headerSize which is not equal to the _headerSize reservation"); + return _data.AsSpan(0, headerSize); + } + + public void SetHeaderActive() + { + Debug.Assert(_headerLength > 0, "requested to set header active when it is not reserved or is already active"); + _dataCapacity += _headerLength; + _dataLength += _headerLength; + _headerLength = 0; + } + /// /// Release packet /// @@ -221,24 +164,15 @@ public void Release() { if (_data != null) { - if (_isBufferFromArrayPool) - { - ArrayPool.Shared.Return(_data); - } + Array.Clear(_data, 0, _headerLength + _dataLength); + ArrayPool.Shared.Return(_data, clearArray: false); + _data = null; - _capacity = 0; + _dataCapacity = 0; } - Reset(); - } - - /// - /// Reset packet - /// - public void Reset() - { - _length = 0; - _offset = 0; - _description = null; + _dataLength = 0; + _dataOffset = 0; + _headerLength = 0; _completionCallback = null; } @@ -248,7 +182,7 @@ public void Reset() /// Stream to read from public void ReadFromStream(Stream stream) { - _length = stream.Read(_data, 0, _capacity); + _dataLength = stream.Read(_data, _headerLength, _dataCapacity); } /// @@ -257,7 +191,7 @@ public void ReadFromStream(Stream stream) /// Stream to write to public void WriteToStream(Stream stream) { - stream.Write(_data, 0, _length); + stream.Write(_data, _headerLength, _dataLength); } /// diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs index 78762c195b..1d8b13a6b2 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs @@ -236,16 +236,15 @@ public uint GetConnectionId(SNIHandle handle, ref Guid clientConnectionId) /// SNI error status public uint WritePacket(SNIHandle handle, SNIPacket packet, bool sync) { - SNIPacket clonedPacket = packet.Clone(); uint result; if (sync) { - result = handle.Send(clonedPacket); - clonedPacket.Dispose(); + result = handle.Send(packet); + packet.Release(); } else { - result = handle.SendAsync(clonedPacket, true); + result = handle.SendAsync(packet, true); } return result; @@ -456,7 +455,7 @@ public uint ReadAsync(SNIHandle handle, out SNIPacket packet) /// Length public void PacketSetData(SNIPacket packet, byte[] data, int length) { - packet.SetData(data, length); + packet.AppendData(data, length); } /// diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs index e877f27ca2..91ebec426a 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs @@ -20,7 +20,7 @@ namespace Microsoft.Data.SqlClient.SNI /// /// TCP connection handle /// - internal class SNITCPHandle : SNIHandle + internal sealed class SNITCPHandle : SNIHandle { private readonly string _targetServer; private readonly object _callbackObject; @@ -482,7 +482,7 @@ public override uint Receive(out SNIPacket packet, int timeoutInMilliseconds) return TdsEnums.SNI_WAIT_TIMEOUT; } - packet = new SNIPacket(_bufferSize); + packet = new SNIPacket(headerSize: 0, dataSize: _bufferSize); packet.ReadFromStream(_stream); if (packet.Length == 0) @@ -553,7 +553,7 @@ public override uint SendAsync(SNIPacket packet, bool disposePacketAfterSendAsyn /// SNI error code public override uint ReceiveAsync(ref SNIPacket packet) { - packet = new SNIPacket(_bufferSize); + packet = new SNIPacket(headerSize: 0, dataSize: _bufferSize); try { diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.cs index 6fc4ec0268..e6f0a30c94 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using System.Buffers; using System.IO; using System.IO.Pipes; using System.Threading; @@ -90,10 +91,14 @@ public override Task ReadAsync(byte[] buffer, int offset, int count, Cancel private async Task ReadInternal(byte[] buffer, int offset, int count, CancellationToken token, bool async) { int readBytes = 0; - byte[] packetData = new byte[count < TdsEnums.HEADER_LEN ? TdsEnums.HEADER_LEN : count]; - + byte[] packetData = null; + byte[] readTarget = buffer; + int readOffset = offset; if (_encapsulate) { + packetData = ArrayPool.Shared.Rent(count < TdsEnums.HEADER_LEN ? TdsEnums.HEADER_LEN : count); + readTarget = packetData; + readOffset = 0; if (_packetBytes == 0) { // Account for split packets @@ -115,15 +120,18 @@ await _stream.ReadAsync(packetData, readBytes, TdsEnums.HEADER_LEN - readBytes, } readBytes = async ? - await _stream.ReadAsync(packetData, 0, count, token).ConfigureAwait(false) : - _stream.Read(packetData, 0, count); + await _stream.ReadAsync(readTarget, readOffset, count, token).ConfigureAwait(false) : + _stream.Read(readTarget, readOffset, count); if (_encapsulate) { _packetBytes -= readBytes; } - - Buffer.BlockCopy(packetData, 0, buffer, offset, readBytes); + if (packetData != null) + { + Buffer.BlockCopy(packetData, 0, buffer, offset, readBytes); + ArrayPool.Shared.Return(packetData, clearArray: true); + } return readBytes; } @@ -154,11 +162,12 @@ private async Task WriteInternal(byte[] buffer, int offset, int count, Cancellat count -= currentCount; // Prepend buffer data with TDS prelogin header - byte[] combinedBuffer = new byte[TdsEnums.HEADER_LEN + currentCount]; + int combinedLength = TdsEnums.HEADER_LEN + currentCount; + byte[] combinedBuffer = ArrayPool.Shared.Rent(combinedLength); // We can only send 4088 bytes in one packet. Header[1] is set to 1 if this is a // partial packet (whether or not count != 0). - // + combinedBuffer[7] = 0; // touch this first for the jit bounds check combinedBuffer[0] = PRELOGIN_PACKET_TYPE; combinedBuffer[1] = (byte)(count > 0 ? 0 : 1); combinedBuffer[2] = (byte)((currentCount + TdsEnums.HEADER_LEN) / 0x100); @@ -166,21 +175,20 @@ private async Task WriteInternal(byte[] buffer, int offset, int count, Cancellat combinedBuffer[4] = 0; combinedBuffer[5] = 0; combinedBuffer[6] = 0; - combinedBuffer[7] = 0; - for (int i = TdsEnums.HEADER_LEN; i < combinedBuffer.Length; i++) - { - combinedBuffer[i] = buffer[currentOffset + (i - TdsEnums.HEADER_LEN)]; - } + Array.Copy(buffer, currentOffset, combinedBuffer, TdsEnums.HEADER_LEN, (combinedLength - TdsEnums.HEADER_LEN)); if (async) { - await _stream.WriteAsync(combinedBuffer, 0, combinedBuffer.Length, token).ConfigureAwait(false); + await _stream.WriteAsync(combinedBuffer, 0, combinedLength, token).ConfigureAwait(false); } else { - _stream.Write(combinedBuffer, 0, combinedBuffer.Length); + _stream.Write(combinedBuffer, 0, combinedLength); } + + Array.Clear(combinedBuffer, 0, combinedLength); + ArrayPool.Shared.Return(combinedBuffer); } else { diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs index 931bb53f7e..0a6d80e7c7 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs @@ -788,7 +788,7 @@ private void ResetCancelAndProcessAttention() protected abstract uint SNIPacketGetData(PacketHandle packet, byte[] _inBuff, ref uint dataSize); - internal abstract PacketHandle GetResetWritePacket(); + internal abstract PacketHandle GetResetWritePacket(int dataSize); internal abstract void ClearAllWritePackets(); @@ -3515,7 +3515,7 @@ internal void SendAttention(bool mustTakeWriteLock = false) private Task WriteSni(bool canAccumulate) { // Prepare packet, and write to packet. - PacketHandle packet = GetResetWritePacket(); + PacketHandle packet = GetResetWritePacket(_outBytesUsed); SetBufferSecureStrings(); SetPacketData(packet, _outBuff, _outBytesUsed); diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs index 51c81340a6..23fde757ca 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs @@ -12,18 +12,12 @@ namespace Microsoft.Data.SqlClient.SNI { internal class TdsParserStateObjectManaged : TdsParserStateObject { - private SNIMarsConnection _marsConnection = null; - private SNIHandle _sessionHandle = null; // the SNI handle we're to work on - private SNIPacket _sniPacket = null; // Will have to re-vamp this for MARS - internal SNIPacket _sniAsyncAttnPacket = null; // Packet to use to send Attn - private readonly Dictionary _pendingWritePackets = new Dictionary(); // Stores write packets that have been sent to SNI, but have not yet finished writing (i.e. we are waiting for SNI's callback) - - private readonly WritePacketCache _writePacketCache = new WritePacketCache(); // Store write packets that are ready to be re-used + private SNIMarsConnection _marsConnection; + private SNIHandle _sessionHandle; + private SspiClientContextStatus _sspiClientContextStatus; public TdsParserStateObjectManaged(TdsParser parser) : base(parser) { } - internal SspiClientContextStatus sspiClientContextStatus = new SspiClientContextStatus(); - internal TdsParserStateObjectManaged(TdsParser parser, TdsParserStateObject physicalConnection, bool async) : base(parser, physicalConnection, async) { } @@ -83,39 +77,23 @@ protected override void RemovePacketFromPendingList(PacketHandle packet) internal override void Dispose() { - SNIPacket packetHandle = _sniPacket; SNIHandle sessionHandle = _sessionHandle; - SNIPacket asyncAttnPacket = _sniAsyncAttnPacket; - _sniPacket = null; _sessionHandle = null; - _sniAsyncAttnPacket = null; _marsConnection = null; DisposeCounters(); - if (null != sessionHandle || null != packetHandle) + if (null != sessionHandle) { - packetHandle?.Dispose(); - asyncAttnPacket?.Dispose(); - - if (sessionHandle != null) - { - sessionHandle.Dispose(); - DecrementPendingCallbacks(true); // Will dispose of GC handle. - } + sessionHandle.Dispose(); + DecrementPendingCallbacks(true); // Will dispose of GC handle. } - - DisposePacketCache(); } internal override void DisposePacketCache() { - lock (_writePacketLockObject) - { - _writePacketCache.Dispose(); - // Do not set _writePacketCache to null, just in case a WriteAsyncCallback completes after this point - } + // No - op } protected override void FreeGcHandle(int remaining, bool release) @@ -141,7 +119,7 @@ internal override PacketHandle ReadSyncOverAsync(int timeoutRemaining, out uint internal override bool IsPacketEmpty(PacketHandle packet) => packet.ManagedPacket == null; - internal override void ReleasePacket(PacketHandle syncReadPacket) => syncReadPacket.ManagedPacket?.Dispose(); + internal override void ReleasePacket(PacketHandle syncReadPacket) => syncReadPacket.ManagedPacket?.Release(); internal override uint CheckConnection() { @@ -157,19 +135,14 @@ internal override PacketHandle ReadAsync(SessionHandle handle, out uint error) internal override PacketHandle CreateAndSetAttentionPacket() { - if (_sniAsyncAttnPacket == null) - { - SNIPacket attnPacket = new SNIPacket(); - SetPacketData(PacketHandle.FromManagedPacket(attnPacket), SQL.AttentionHeader, TdsEnums.HEADER_LEN); - _sniAsyncAttnPacket = attnPacket; - } - return PacketHandle.FromManagedPacket(_sniAsyncAttnPacket); + PacketHandle packetHandle = GetResetWritePacket(TdsEnums.HEADER_LEN); + SetPacketData(packetHandle, SQL.AttentionHeader, TdsEnums.HEADER_LEN); + return packetHandle; } internal override uint WritePacket(PacketHandle packet, bool sync) => SNIProxy.Singleton.WritePacket(Handle, packet.ManagedPacket, sync); - // No- Op in managed SNI internal override PacketHandle AddPacketToPendingList(PacketHandle packet) => packet; @@ -183,34 +156,16 @@ internal override bool IsValidPacket(PacketHandle packet) ); } - internal override PacketHandle GetResetWritePacket() + internal override PacketHandle GetResetWritePacket(int dataSize) { - if (_sniPacket != null) - { - _sniPacket.Reset(); - } - else - { - lock (_writePacketLockObject) - { - _sniPacket = _writePacketCache.Take(Handle); - } - } - return PacketHandle.FromManagedPacket(_sniPacket); + var packet = new SNIPacket(headerSize: _sessionHandle.ReserveHeaderSize, dataSize: dataSize); + Debug.Assert(packet.ReservedHeaderSize == _sessionHandle.ReserveHeaderSize, "failed to reserve header"); + return PacketHandle.FromManagedPacket(packet); } internal override void ClearAllWritePackets() { - if (_sniPacket != null) - { - _sniPacket.Dispose(); - _sniPacket = null; - } - lock (_writePacketLockObject) - { - Debug.Assert(_pendingWritePackets.Count == 0 && _asyncWriteCount == 0, "Should not clear all write packets if there are packets pending"); - _writePacketCache.Clear(); - } + Debug.Assert(_asyncWriteCount == 0, "Should not clear all write packets if there are packets pending"); } internal override void SetPacketData(PacketHandle packet, byte[] buffer, int bytesUsed) => SNIProxy.Singleton.PacketSetData(packet.ManagedPacket, buffer, bytesUsed); @@ -236,70 +191,15 @@ internal override uint EnableMars(ref uint info) internal override uint GenerateSspiClientContext(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength, byte[] _sniSpnBuffer) { - SNIProxy.Singleton.GenSspiClientContext(sspiClientContextStatus, receivedBuff, ref sendBuff, _sniSpnBuffer); + if (_sspiClientContextStatus == null) + { + _sspiClientContextStatus = new SspiClientContextStatus(); + } + SNIProxy.Singleton.GenSspiClientContext(_sspiClientContextStatus, receivedBuff, ref sendBuff, _sniSpnBuffer); sendLength = (uint)(sendBuff != null ? sendBuff.Length : 0); return 0; } internal override uint WaitForSSLHandShakeToComplete() => 0; - - internal sealed class WritePacketCache : IDisposable - { - private bool _disposed; - private Stack _packets; - - public WritePacketCache() - { - _disposed = false; - _packets = new Stack(); - } - - public SNIPacket Take(SNIHandle sniHandle) - { - SNIPacket packet; - if (_packets.Count > 0) - { - // Success - reset the packet - packet = _packets.Pop(); - packet.Reset(); - } - else - { - // Failed to take a packet - create a new one - packet = new SNIPacket(); - } - return packet; - } - - public void Add(SNIPacket packet) - { - if (!_disposed) - { - _packets.Push(packet); - } - else - { - // If we're disposed, then get rid of any packets added to us - packet.Dispose(); - } - } - - public void Clear() - { - while (_packets.Count > 0) - { - _packets.Pop().Dispose(); - } - } - - public void Dispose() - { - if (!_disposed) - { - _disposed = true; - Clear(); - } - } - } } } diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs index 1e9740cf75..cbf2a78d83 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs @@ -252,7 +252,7 @@ internal override bool IsValidPacket(PacketHandle packetPointer) ); } - internal override PacketHandle GetResetWritePacket() + internal override PacketHandle GetResetWritePacket(int dataSize) { if (_sniPacket != null) {