Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix | Linux SPN port number using named instance and Kerberos authentication does not return port# #2240

Merged
merged 17 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ private static byte[][] GetSqlServerSPNs(DataSource dataSource, string serverSPN
}
else if (!string.IsNullOrWhiteSpace(dataSource.InstanceName))
{
postfix = dataSource.InstanceName;
postfix = dataSource._connectionProtocol == DataSource.Protocol.TCP ? dataSource.ResolvedPort.ToString() : dataSource.InstanceName;
}

SqlClientEventSource.Log.TryTraceEvent("SNIProxy.GetSqlServerSPN | Info | ServerName {0}, InstanceName {1}, Port {2}, postfix {3}", dataSource?.ServerName, dataSource?.InstanceName, dataSource?.Port, postfix);
Expand Down Expand Up @@ -317,7 +317,7 @@ private static SNITCPHandle CreateTcpHandle(
{
try
{
port = isAdminConnection ?
details.ResolvedPort = port = isAdminConnection ?
SSRP.GetDacPortByInstanceName(hostName, details.InstanceName, timeout, parallel, ipPreference) :
SSRP.GetPortByInstanceName(hostName, details.InstanceName, timeout, parallel, ipPreference);
}
Expand Down Expand Up @@ -436,6 +436,11 @@ internal enum Protocol { TCP, NP, None, Admin };
/// </summary>
internal int Port { get; private set; } = -1;

/// <summary>
/// The port resolved by SSRP when InstanceName is specified
/// </summary>
internal int ResolvedPort { get; set; } = -1;

/// <summary>
/// Provides the inferred Instance Name from Server Data Source
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1073,6 +1073,21 @@ public static string GetMachineFQDN(string hostname)
return fqdn.ToString();
}

public static bool IsNotLocalhost()
{
// get the tcp connection string
SqlConnectionStringBuilder builder = new(DataTestUtility.TCPConnectionString);

string hostname = "";

// parse the datasource
ParseDataSource(builder.DataSource, out hostname, out _, out _);

// hostname must not be localhost, ., 127.0.0.1 nor ::1
return !(new string[] { "localhost", ".", "127.0.0.1", "::1" }).Contains(hostname.ToLowerInvariant());

}

private static bool RunningAsUWPApp()
{
if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System;
using System.Net;
using System.Net.Sockets;
using System.Reflection;
using System.Text;
using System.Threading.Tasks;
using Xunit;
Expand Down Expand Up @@ -83,6 +84,138 @@ public static void ConnectManagedWithInstanceNameTest(bool useMultiSubnetFailove
}
}

// Note: This Unit test was tested in a domain-joined VM connecting to a remote
// SQL Server using Kerberos in the same domain.
[ActiveIssue("27824")] // When specifying instance name and port number, this method call always returns false
[ConditionalFact(nameof(IsKerberos))]
public static void PortNumberInSPNTest()
{
string connStr = DataTestUtility.TCPConnectionString;
// If config.json.SupportsIntegratedSecurity = true, replace all keys defined below with Integrated Security=true
if (DataTestUtility.IsIntegratedSecuritySetup())
{
string[] removeKeys = { "Authentication", "User ID", "Password", "UID", "PWD", "Trusted_Connection" };
connStr = DataTestUtility.RemoveKeysInConnStr(DataTestUtility.TCPConnectionString, removeKeys) + $"Integrated Security=true";
}

SqlConnectionStringBuilder builder = new(connStr);

Assert.True(DataTestUtility.ParseDataSource(builder.DataSource, out string hostname, out _, out string instanceName), "Data source to be parsed must contain a host name and instance name");

bool condition = IsBrowserAlive(hostname) && IsValidInstance(hostname, instanceName);
Assert.True(condition, "Browser service is not running or instance name is invalid");

if (condition)
{
using SqlConnection connection = new(builder.ConnectionString);
connection.Open();
using SqlCommand command = new("SELECT auth_scheme, local_tcp_port from sys.dm_exec_connections where session_id = @@spid", connection);
using SqlDataReader reader = command.ExecuteReader();
Assert.True(reader.Read(), "Expected to receive one row data");
Assert.Equal("KERBEROS", reader.GetString(0));
int localTcpPort = reader.GetInt32(1);

int spnPort = -1;
string spnInfo = GetSPNInfo(builder.DataSource, out spnPort);

// sample output to validate = MSSQLSvc/machine.domain.tld:spnPort"
Assert.Contains($"MSSQLSvc/{hostname}", spnInfo);
// the local_tcp_port should be the same as the inferred SPN port from instance name
Assert.Equal(localTcpPort, spnPort);
}
}

private static string GetSPNInfo(string datasource, out int out_port)
{
Assembly sqlConnectionAssembly = Assembly.GetAssembly(typeof(SqlConnection));

// Get all required types using reflection
Type sniProxyType = sqlConnectionAssembly.GetType("Microsoft.Data.SqlClient.SNI.SNIProxy");
Type ssrpType = sqlConnectionAssembly.GetType("Microsoft.Data.SqlClient.SNI.SSRP");
Type dataSourceType = sqlConnectionAssembly.GetType("Microsoft.Data.SqlClient.SNI.DataSource");
Type timeoutTimerType = sqlConnectionAssembly.GetType("Microsoft.Data.ProviderBase.TimeoutTimer");

// Used in Datasource constructor param type array
Type[] dataSourceConstructorTypesArray = new Type[] { typeof(string) };

// Used in GetSqlServerSPNs function param types array
Type[] getSqlServerSPNsTypesArray = new Type[] { dataSourceType, typeof(string) };

// GetPortByInstanceName parameters array
Type[] getPortByInstanceNameTypesArray = new Type[] { typeof(string), typeof(string), timeoutTimerType, typeof(bool), typeof(Microsoft.Data.SqlClient.SqlConnectionIPAddressPreference) };

// TimeoutTimer.StartSecondsTimeout params
Type[] startSecondsTimeoutTypesArray = new Type[] { typeof(int) };

// Get all types constructors
ConstructorInfo sniProxyCtor = sniProxyType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null);
ConstructorInfo SSRPCtor = ssrpType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null);
ConstructorInfo dataSourceCtor = dataSourceType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, dataSourceConstructorTypesArray, null);
ConstructorInfo timeoutTimerCtor = timeoutTimerType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null);

// Instantiate SNIProxy
object sniProxy = sniProxyCtor.Invoke(new object[] { });

// Instantiate datasource
object dataSourceObj = dataSourceCtor.Invoke(new object[] { datasource });

// Instantiate SSRP
object ssrp = SSRPCtor.Invoke(new object[] { });

// Instantiate TimeoutTimer
object timeoutTimer = timeoutTimerCtor.Invoke(new object[] { });

// Get TimeoutTimer.StartSecondsTimeout Method
MethodInfo startSecondsTimeout = timeoutTimer.GetType().GetMethod("StartSecondsTimeout", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, startSecondsTimeoutTypesArray, null);
// Create a timeoutTimer that expires in 30 seconds
timeoutTimer = startSecondsTimeout.Invoke(dataSourceObj, new object[] { 30 });

// Parse the datasource to separate the server name and instance name
MethodInfo ParseServerName = dataSourceObj.GetType().GetMethod("ParseServerName", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, dataSourceConstructorTypesArray, null);
object dataSrcInfo = ParseServerName.Invoke(dataSourceObj, new object[] { datasource });

// Get the GetPortByInstanceName method of SSRP
MethodInfo getPortByInstanceName = ssrp.GetType().GetMethod("GetPortByInstanceName", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, getPortByInstanceNameTypesArray, null);

// Get the server name
PropertyInfo serverInfo = dataSrcInfo.GetType().GetProperty("ServerName", BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic);
string serverName = serverInfo.GetValue(dataSrcInfo, null).ToString();

// Get the instance name
PropertyInfo instanceNameInfo = dataSrcInfo.GetType().GetProperty("InstanceName", BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic);
string instanceName = instanceNameInfo.GetValue(dataSrcInfo, null).ToString();

// Get the port number using the GetPortByInstanceName method of SSRP
object port = getPortByInstanceName.Invoke(ssrp, parameters: new object[] { serverName, instanceName, timeoutTimer, false, 0 });

// Set the resolved port property of datasource
PropertyInfo resolvedPortInfo = dataSrcInfo.GetType().GetProperty("ResolvedPort", BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic);
resolvedPortInfo.SetValue(dataSrcInfo, (int)port, null);

// Prepare the GetSqlServerSPNs method
string serverSPN = "";
MethodInfo getSqlServerSPNs = sniProxy.GetType().GetMethod("GetSqlServerSPNs", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, getSqlServerSPNsTypesArray, null);

// Finally call GetSqlServerSPNs
byte[][] result = (byte[][])getSqlServerSPNs.Invoke(sniProxy, new object[] { dataSrcInfo, serverSPN });

// Example result: MSSQLSvc/machine.domain.tld:port"
string spnInfo = Encoding.Unicode.GetString(result[0]);

out_port = (int)port;

return spnInfo;
}

private static bool IsKerberos()
{
return (DataTestUtility.AreConnStringsSetup()
&& DataTestUtility.IsNotLocalhost()
&& DataTestUtility.IsKerberosTest
&& DataTestUtility.IsNotAzureServer()
&& DataTestUtility.IsNotAzureSynapse());
}

private static bool IsBrowserAlive(string browserHostname)
{
const byte ClntUcastEx = 0x03;
Expand Down
Loading