Skip to content

Commit

Permalink
net fx draft
Browse files Browse the repository at this point in the history
  • Loading branch information
DavoudEshtehari committed May 11, 2022
1 parent 865ac03 commit 1beb887
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1877,7 +1877,7 @@ private void LoginNoFailover(ServerInfo serverInfo, string newPassword, SecureSt
throw SQL.ROR_TimeoutAfterRoutingInfo(this);
}

serverInfo = new ServerInfo(ConnectionOptions, _routingInfo, serverInfo.ResolvedServerName);
serverInfo = new ServerInfo(ConnectionOptions, _routingInfo, serverInfo.ResolvedServerName, serverInfo.ServerSPN);
timeoutErrorInternal.SetInternalSourceType(SqlConnectionInternalSourceType.RoutingDestination);
_originalClientConnectionId = _clientConnectionId;
_routingDestination = serverInfo.UserServerName;
Expand Down Expand Up @@ -2047,7 +2047,7 @@ TimeoutTimer timeout
long timeoutUnitInterval;

string protocol = ConnectionOptions.NetworkLibrary;
ServerInfo failoverServerInfo = new ServerInfo(connectionOptions, failoverHost);
ServerInfo failoverServerInfo = new ServerInfo(connectionOptions, failoverHost, connectionOptions.FailoverPartnerSPN);

ResolveExtendedServerName(primaryServerInfo, !redirectedUserInstance, connectionOptions);
if (null == ServerProvidedFailOverPartner)
Expand Down Expand Up @@ -2150,7 +2150,7 @@ TimeoutTimer timeout
_parser = new TdsParser(ConnectionOptions.MARS, ConnectionOptions.Asynchronous);
Debug.Assert(SniContext.Undefined == Parser._physicalStateObj.SniContext, $"SniContext should be Undefined; actual Value: {Parser._physicalStateObj.SniContext}");

currentServerInfo = new ServerInfo(ConnectionOptions, _routingInfo, currentServerInfo.ResolvedServerName);
currentServerInfo = new ServerInfo(ConnectionOptions, _routingInfo, currentServerInfo.ResolvedServerName, currentServerInfo.ServerSPN);
timeoutErrorInternal.SetInternalSourceType(SqlConnectionInternalSourceType.RoutingDestination);
_originalClientConnectionId = _clientConnectionId;
_routingDestination = currentServerInfo.UserServerName;
Expand Down Expand Up @@ -2296,13 +2296,9 @@ private void AttemptOneLogin(ServerInfo serverInfo, string newPassword, SecureSt
this,
ignoreSniOpenTimeout,
timeout.LegacyTimerExpire,
ConnectionOptions.Encrypt,
ConnectionOptions.TrustServerCertificate,
ConnectionOptions.IntegratedSecurity,
ConnectionOptions,
withFailover,
isFirstTransparentAttempt,
ConnectionOptions.Authentication,
ConnectionOptions.Certificate,
_serverCallback,
_clientCallback,
_originalNetworkAddressInfo != null,
Expand Down Expand Up @@ -3244,6 +3240,7 @@ internal sealed class ServerInfo
internal string ResolvedServerName { get; private set; } // the resolved servername only
internal string ResolvedDatabaseName { get; private set; } // name of target database after resolution
internal string UserProtocol { get; private set; } // the user specified protocol
internal string ServerSPN { get; private set; } // the server SPN

// The original user-supplied server name from the connection string.
// If connection string has no Data Source, the value is set to string.Empty.
Expand All @@ -3264,10 +3261,16 @@ private set
internal readonly string PreRoutingServerName;

// Initialize server info from connection options,
internal ServerInfo(SqlConnectionString userOptions) : this(userOptions, userOptions.DataSource) { }
internal ServerInfo(SqlConnectionString userOptions) : this(userOptions, userOptions.DataSource, userOptions.ServerSPN) { }

// Initialize server info from connection options, but override DataSource and ServerSPN with given server name and server SPN
internal ServerInfo(SqlConnectionString userOptions, string serverName, string serverSPN) : this(userOptions, serverName)
{
ServerSPN = serverSPN;
}

// Initialize server info from connection options, but override DataSource with given server name
internal ServerInfo(SqlConnectionString userOptions, string serverName)
private ServerInfo(SqlConnectionString userOptions, string serverName)
{
//-----------------
// Preconditions
Expand All @@ -3286,7 +3289,7 @@ internal ServerInfo(SqlConnectionString userOptions, string serverName)


// Initialize server info from connection options, but override DataSource with given server name
internal ServerInfo(SqlConnectionString userOptions, RoutingInfo routing, string preRoutingServerName)
internal ServerInfo(SqlConnectionString userOptions, RoutingInfo routing, string preRoutingServerName, string serverSPN)
{
//-----------------
// Preconditions
Expand All @@ -3307,6 +3310,7 @@ internal ServerInfo(SqlConnectionString userOptions, RoutingInfo routing, string
UserProtocol = TdsEnums.TCP;
SetDerivedNames(UserProtocol, UserServerName);
ResolvedDatabaseName = userOptions.InitialCatalog;
ServerSPN = serverSPN;
}

internal void SetDerivedNames(string protocol, string serverName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -493,18 +493,20 @@ internal void Connect(ServerInfo serverInfo,
SqlInternalConnectionTds connHandler,
bool ignoreSniOpenTimeout,
long timerExpire,
bool encrypt,
bool trustServerCert,
bool integratedSecurity,
SqlConnectionString connectionOptions,
bool withFailover,
bool isFirstTransparentAttempt,
SqlAuthenticationMethod authType,
string certificate,
ServerCertificateValidationCallback serverCallback,
ClientCertificateRetrievalCallback clientCallback,
bool useOriginalAddressInfo,
bool disableTnir)
{
bool encrypt = connectionOptions.Encrypt;
bool trustServerCert = connectionOptions.TrustServerCertificate;
bool integratedSecurity = connectionOptions.IntegratedSecurity;
SqlAuthenticationMethod authType = connectionOptions.Authentication;
string certificate = connectionOptions.Certificate;

if (_state != TdsParserState.Closed)
{
Debug.Fail("TdsParser.Connect called on non-closed connection!");
Expand Down Expand Up @@ -542,8 +544,18 @@ internal void Connect(ServerInfo serverInfo,
if (integratedSecurity || authType == SqlAuthenticationMethod.ActiveDirectoryIntegrated)
{
LoadSSPILibrary();
// now allocate proper length of buffer
_sniSpnBuffer = new byte[SNINativeMethodWrapper.SniMaxComposedSpnLength];
if (!string.IsNullOrEmpty(serverInfo.ServerSPN))
{
byte[] srvSPN = Encoding.Unicode.GetBytes(serverInfo.ServerSPN);
Trace.Assert(srvSPN.Length <= SNINativeMethodWrapper.SniMaxComposedSpnLength, "The provider SPN length exceeded the buffer size.");
_sniSpnBuffer = srvSPN;
SqlClientEventSource.Log.TryTraceEvent("<sc.TdsParser.Connect|SEC> Server SPN `{0}` from the connection string is used.", serverInfo.ServerSPN);
}
else
{
// now allocate proper length of buffer
_sniSpnBuffer = new byte[SNINativeMethodWrapper.SniMaxComposedSpnLength];
}
SqlClientEventSource.Log.TryTraceEvent("<sc.TdsParser.Connect|SEC> SSPI or Active Directory Authentication Library for SQL Server based integrated authentication");
}
else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,8 @@ internal static class DbConnectionStringDefaults
internal const SqlConnectionAttestationProtocol AttestationProtocol = SqlConnectionAttestationProtocol.NotSpecified;
internal const SqlConnectionIPAddressPreference IPAddressPreference = SqlConnectionIPAddressPreference.IPv4First;
internal const PoolBlockingPeriod PoolBlockingPeriod = SqlClient.PoolBlockingPeriod.Auto;
internal const string ServerSPN = "";
internal const string FailoverPartnerSPN = "";
}

internal static class DbConnectionStringKeywords
Expand Down Expand Up @@ -1029,6 +1031,8 @@ internal static class DbConnectionStringKeywords
internal const string EnclaveAttestationUrl = "Enclave Attestation Url";
internal const string AttestationProtocol = "Attestation Protocol";
internal const string IPAddressPreference = "IP Address Preference";
internal const string ServerSPN = "Server SPN";
internal const string FailoverPartnerSPN = "Failover Partner SPN";

// common keywords (OleDb, OracleClient, SqlClient)
internal const string DataSource = "Data Source";
Expand Down Expand Up @@ -1122,5 +1126,9 @@ internal static class DbConnectionStringSynonyms

//internal const string WorkstationID = WSID;
internal const string WSID = "wsid";

//internal const string server SPNs
internal const string ServerSPN = "ServerSPN";
internal const string FailoverPartnerSPN = "FailoverPartnerSPN";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ internal static class DEFAULT
internal static readonly SqlAuthenticationMethod Authentication = DbConnectionStringDefaults.Authentication;
internal static readonly SqlConnectionAttestationProtocol AttestationProtocol = DbConnectionStringDefaults.AttestationProtocol;
internal static readonly SqlConnectionIPAddressPreference IpAddressPreference = DbConnectionStringDefaults.IPAddressPreference;
internal const string ServerSPN = DbConnectionStringDefaults.ServerSPN;
internal const string FailoverPartnerSPN = DbConnectionStringDefaults.FailoverPartnerSPN;
#if NETFRAMEWORK
internal static readonly bool TransparentNetworkIPResolution = DbConnectionStringDefaults.TransparentNetworkIPResolution;
internal const bool Connection_Reset = DbConnectionStringDefaults.ConnectionReset;
Expand Down Expand Up @@ -113,6 +115,8 @@ internal static class KEY
internal const string Connect_Retry_Count = DbConnectionStringKeywords.ConnectRetryCount;
internal const string Connect_Retry_Interval = DbConnectionStringKeywords.ConnectRetryInterval;
internal const string Authentication = DbConnectionStringKeywords.Authentication;
internal const string Server_SPN = DbConnectionStringKeywords.ServerSPN;
internal const string Failover_Partner_SPN = DbConnectionStringKeywords.FailoverPartnerSPN;
#if NETFRAMEWORK
internal const string TransparentNetworkIPResolution = DbConnectionStringKeywords.TransparentNetworkIPResolution;
#if ADONET_CERT_AUTH
Expand Down Expand Up @@ -173,6 +177,9 @@ private static class SYNONYM
internal const string User = DbConnectionStringSynonyms.User;
// workstation id
internal const string WSID = DbConnectionStringSynonyms.WSID;
// server SPNs
internal const string ServerSPN = DbConnectionStringSynonyms.ServerSPN;
internal const string FailoverPartnerSPN = DbConnectionStringSynonyms.FailoverPartnerSPN;

#if NETFRAMEWORK
internal const string TRANSPARENTNETWORKIPRESOLUTION = DbConnectionStringSynonyms.TRANSPARENTNETWORKIPRESOLUTION;
Expand Down Expand Up @@ -212,9 +219,9 @@ internal static class TRANSACTIONBINDING
}

#if NETFRAMEWORK
internal const int SynonymCount = 29;
internal const int SynonymCount = 31;
#else
internal const int SynonymCount = 26;
internal const int SynonymCount = 28;
internal const int DeprecatedSynonymCount = 2;
#endif // NETFRAMEWORK

Expand Down Expand Up @@ -257,6 +264,8 @@ internal static class TRANSACTIONBINDING
private readonly string _initialCatalog;
private readonly string _password;
private readonly string _userID;
private readonly string _serverSPN;
private readonly string _failoverPartnerSPN;

private readonly string _workstationId;

Expand Down Expand Up @@ -322,6 +331,8 @@ internal SqlConnectionString(string connectionString) : base(connectionString, G
_enclaveAttestationUrl = ConvertValueToString(KEY.EnclaveAttestationUrl, DEFAULT.EnclaveAttestationUrl);
_attestationProtocol = ConvertValueToAttestationProtocol();
_ipAddressPreference = ConvertValueToIPAddressPreference();
_serverSPN = ConvertValueToString(KEY.Server_SPN, DEFAULT.ServerSPN);
_failoverPartnerSPN = ConvertValueToString(KEY.Failover_Partner_SPN, DEFAULT.FailoverPartnerSPN);

// Temporary string - this value is stored internally as an enum.
string typeSystemVersionString = ConvertValueToString(KEY.Type_System_Version, null);
Expand Down Expand Up @@ -675,6 +686,8 @@ internal SqlConnectionString(SqlConnectionString connectionOptions, string dataS
_columnEncryptionSetting = connectionOptions._columnEncryptionSetting;
_enclaveAttestationUrl = connectionOptions._enclaveAttestationUrl;
_attestationProtocol = connectionOptions._attestationProtocol;
_serverSPN = connectionOptions._serverSPN;
_failoverPartnerSPN = connectionOptions._failoverPartnerSPN;
#if NETFRAMEWORK
_connectionReset = connectionOptions._connectionReset;
_contextConnection = connectionOptions._contextConnection;
Expand Down Expand Up @@ -732,7 +745,8 @@ internal SqlConnectionString(SqlConnectionString connectionOptions, string dataS
internal string UserID => _userID;
internal string WorkstationId => _workstationId;
internal PoolBlockingPeriod PoolBlockingPeriod => _poolBlockingPeriod;

internal string ServerSPN => _serverSPN;
internal string FailoverPartnerSPN => _failoverPartnerSPN;

internal TypeSystem TypeSystemVersion => _typeSystemVersion;
internal Version TypeSystemAssemblyVersion => _typeSystemAssemblyVersion;
Expand Down Expand Up @@ -843,6 +857,8 @@ internal static Dictionary<string, string> GetParseSynonyms()
{ KEY.Connect_Retry_Interval, KEY.Connect_Retry_Interval },
{ KEY.Authentication, KEY.Authentication },
{ KEY.IPAddressPreference, KEY.IPAddressPreference },
{ KEY.Server_SPN, KEY.Server_SPN },
{ KEY.Failover_Partner_SPN, KEY.Failover_Partner_SPN },

{ SYNONYM.APP, KEY.Application_Name },
{ SYNONYM.APPLICATIONINTENT, KEY.ApplicationIntent },
Expand Down Expand Up @@ -871,6 +887,8 @@ internal static Dictionary<string, string> GetParseSynonyms()
{ SYNONYM.UID, KEY.User_ID },
{ SYNONYM.User, KEY.User_ID },
{ SYNONYM.WSID, KEY.Workstation_Id },
{ SYNONYM.ServerSPN, KEY.Server_SPN },
{ SYNONYM.FailoverPartnerSPN, KEY.Failover_Partner_SPN },
#if NETFRAMEWORK
#if ADONET_CERT_AUTH
{ KEY.Certificate, KEY.Certificate },
Expand Down
Loading

0 comments on commit 1beb887

Please sign in to comment.