Skip to content

Commit

Permalink
add RegisterColumnEncryptionKeyStoreProvidersOnConnection (#1045)
Browse files Browse the repository at this point in the history
  • Loading branch information
johnnypham committed Apr 28, 2021
1 parent 37f7fbd commit 2f19bc4
Show file tree
Hide file tree
Showing 7 changed files with 166 additions and 1 deletion.
18 changes: 18 additions & 0 deletions doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml
Original file line number Diff line number Diff line change
Expand Up @@ -1052,6 +1052,24 @@ GO
This function was called more than once.
</exception>
</RegisterColumnEncryptionKeyStoreProviders>
<RegisterColumnEncryptionKeyStoreProvidersOnConnection>
<param name="customProviders">Dictionary of custom column encryption key providers</param>
<summary>Registers the encryption key store providers on the <see cref="T:Microsoft.Data.SqlClient.SqlConnection" /> instance. If this function has been called, any providers registered using the static <see cref="T:Microsoft.Data.SqlClient.SqlConnection.RegisterColumnEncryptionKeyStoreProviders" /> methods will be ignored. This function can be called more than once. This does shallow copying of the dictionary so that the app cannot alter the custom provider list once it has been set.</summary>
<exception cref="T:System.ArgumentNullException">
A null dictionary was provided.

-or-

A string key in the dictionary was null or empty.

-or-

An EncryptionKeyStoreProvider value in the dictionary was null.
</exception>
<exception cref="T:System.ArgumentException">
A string key in the dictionary started with "MSSQL_". This prefix is reserved for system providers.
</exception>
</RegisterColumnEncryptionKeyStoreProvidersOnConnection>
<RetryLogicProvider>
<summary> Gets or sets a value that specifies the
<see cref="T:Microsoft.Data.SqlClient.SqlRetryLogicBaseProvider" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -689,7 +689,9 @@ public SqlConnection(string connectionString, Microsoft.Data.SqlClient.SqlCreden
public static System.Collections.Generic.IDictionary<string, System.Collections.Generic.IList<string>> ColumnEncryptionTrustedMasterKeyPaths { get { throw null; } }
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/RegisterColumnEncryptionKeyStoreProviders/*'/>
public static void RegisterColumnEncryptionKeyStoreProviders(System.Collections.Generic.IDictionary<string, Microsoft.Data.SqlClient.SqlColumnEncryptionKeyStoreProvider> customProviders) { }
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/AccessToken/*'/>
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/RegisterColumnEncryptionKeyStoreProvidersOnConnection/*' />
public void RegisterColumnEncryptionKeyStoreProvidersOnConnection(System.Collections.Generic.IDictionary<string, Microsoft.Data.SqlClient.SqlColumnEncryptionKeyStoreProvider> customProviders) { }
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/AccessToken/*'/>
[System.ComponentModel.BrowsableAttribute(false)]
[System.ComponentModel.DesignerSerializationVisibilityAttribute(0)]
public string AccessToken { get { throw null; } set { } }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ private static readonly Dictionary<string, SqlColumnEncryptionKeyStoreProvider>
/// </summary>
private static IReadOnlyDictionary<string, SqlColumnEncryptionKeyStoreProvider> s_globalCustomColumnEncryptionKeyStoreProviders;

/// <summary>
/// Per-connection custom providers. It can be provided by the user and can be set more than once.
/// </summary>
private IReadOnlyDictionary<string, SqlColumnEncryptionKeyStoreProvider> _customColumnEncryptionKeyStoreProviders;

/// <summary>
/// Dictionary object holding trusted key paths for various SQL Servers.
/// Key to the dictionary is a SQL Server Name
Expand Down Expand Up @@ -234,6 +239,13 @@ internal static bool TryGetColumnEncryptionKeyStoreProvider(string providerName,
return true;
}

// instance-level custom provider cache takes precedence over global cache
if (connection._customColumnEncryptionKeyStoreProviders != null &&
connection._customColumnEncryptionKeyStoreProviders.Count > 0)
{
return connection._customColumnEncryptionKeyStoreProviders.TryGetValue(providerName, out columnKeyStoreProvider);
}

lock (s_globalCustomColumnEncryptionKeyProvidersLock)
{
// If custom provider is not set, then return false
Expand Down Expand Up @@ -264,6 +276,11 @@ internal static List<string> GetColumnEncryptionSystemKeyStoreProviders()
/// <returns>Combined list of provider names</returns>
internal static List<string> GetColumnEncryptionCustomKeyStoreProviders(SqlConnection connection)
{
if (connection._customColumnEncryptionKeyStoreProviders != null &&
connection._customColumnEncryptionKeyStoreProviders.Count > 0)
{
return connection._customColumnEncryptionKeyStoreProviders.Keys.ToList();
}
if (s_globalCustomColumnEncryptionKeyStoreProviders != null)
{
return s_globalCustomColumnEncryptionKeyStoreProviders.Keys.ToList();
Expand Down Expand Up @@ -306,6 +323,24 @@ public static void RegisterColumnEncryptionKeyStoreProviders(IDictionary<string,
s_globalCustomColumnEncryptionKeyStoreProviders = customColumnEncryptionKeyStoreProviders;
}
}

/// <include file='../../../../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/RegisterColumnEncryptionKeyStoreProvidersOnConnection/*' />
public void RegisterColumnEncryptionKeyStoreProvidersOnConnection(IDictionary<string, SqlColumnEncryptionKeyStoreProvider> customProviders)
{
ValidateCustomProviders(customProviders);

// Create a temporary dictionary and then add items from the provided dictionary.
// Dictionary constructor does shallow copying by simply copying the provider name and provider reference pairs
// in the provided customerProviders dictionary.
Dictionary<string, SqlColumnEncryptionKeyStoreProvider> customColumnEncryptionKeyStoreProviders =
new Dictionary<string, SqlColumnEncryptionKeyStoreProvider>(customProviders, StringComparer.OrdinalIgnoreCase);

// Set the dictionary to the ReadOnly dictionary.
// This method can be called more than once. Re-registering a new collection will replace the
// old collection of providers.
_customColumnEncryptionKeyStoreProviders = customColumnEncryptionKeyStoreProviders;
}

private static void ValidateCustomProviders(IDictionary<string, SqlColumnEncryptionKeyStoreProvider> customProviders)
{
// Throw when the provided dictionary is null.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,8 @@ public void Open(SqlConnectionOverrides overrides) { }
public override System.Threading.Tasks.Task OpenAsync(System.Threading.CancellationToken cancellationToken) { throw null; }
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/RegisterColumnEncryptionKeyStoreProviders/*'/>
public static void RegisterColumnEncryptionKeyStoreProviders(System.Collections.Generic.IDictionary<string, Microsoft.Data.SqlClient.SqlColumnEncryptionKeyStoreProvider> customProviders) { }
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/RegisterColumnEncryptionKeyStoreProvidersOnConnection/*' />
public void RegisterColumnEncryptionKeyStoreProvidersOnConnection(System.Collections.Generic.IDictionary<string, Microsoft.Data.SqlClient.SqlColumnEncryptionKeyStoreProvider> customProviders) { }
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/ResetStatistics/*'/>
public void ResetStatistics() { }
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/RetrieveStatistics/*'/>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ static private readonly Dictionary<string, SqlColumnEncryptionKeyStoreProvider>
/// </summary>
private static IReadOnlyDictionary<string, SqlColumnEncryptionKeyStoreProvider> s_globalCustomColumnEncryptionKeyStoreProviders;

/// Instance-level list of custom key store providers. It can be set more than once by the user.
private IReadOnlyDictionary<string, SqlColumnEncryptionKeyStoreProvider> _customColumnEncryptionKeyStoreProviders;

// Lock to control setting of s_globalCustomColumnEncryptionKeyStoreProviders
private static readonly object s_globalCustomColumnEncryptionKeyProvidersLock = new object();

Expand Down Expand Up @@ -164,6 +167,23 @@ static public void RegisterColumnEncryptionKeyStoreProviders(IDictionary<string,
}
}

/// <include file='../../../../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/RegisterColumnEncryptionKeyStoreProvidersOnConnection/*' />
public void RegisterColumnEncryptionKeyStoreProvidersOnConnection(IDictionary<string, SqlColumnEncryptionKeyStoreProvider> customProviders)
{
ValidateCustomProviders(customProviders);

// Create a temporary dictionary and then add items from the provided dictionary.
// Dictionary constructor does shallow copying by simply copying the provider name and provider reference pairs
// in the provided customerProviders dictionary.
Dictionary<string, SqlColumnEncryptionKeyStoreProvider> customColumnEncryptionKeyStoreProviders =
new Dictionary<string, SqlColumnEncryptionKeyStoreProvider>(customProviders, StringComparer.OrdinalIgnoreCase);

// Set the dictionary to the ReadOnly dictionary.
// This method can be called more than once. Re-registering a new collection will replace the
// old collection of providers.
_customColumnEncryptionKeyStoreProviders = customColumnEncryptionKeyStoreProviders;
}

private static void ValidateCustomProviders(IDictionary<string, SqlColumnEncryptionKeyStoreProvider> customProviders)
{
// Throw when the provided dictionary is null.
Expand Down Expand Up @@ -216,6 +236,13 @@ static internal bool TryGetColumnEncryptionKeyStoreProvider(string providerName,
return true;
}

// instance-level custom provider cache takes precedence over global cache
if (connection._customColumnEncryptionKeyStoreProviders != null &&
connection._customColumnEncryptionKeyStoreProviders.Count > 0)
{
return connection._customColumnEncryptionKeyStoreProviders.TryGetValue(providerName, out columnKeyStoreProvider);
}

lock (s_globalCustomColumnEncryptionKeyProvidersLock)
{
// If custom provider is not set, then return false
Expand Down Expand Up @@ -246,6 +273,11 @@ internal static List<string> GetColumnEncryptionSystemKeyStoreProviders()
/// <returns>Combined list of provider names</returns>
internal static List<string> GetColumnEncryptionCustomKeyStoreProviders(SqlConnection connection)
{
if (connection._customColumnEncryptionKeyStoreProviders != null &&
connection._customColumnEncryptionKeyStoreProviders.Count > 0)
{
return connection._customColumnEncryptionKeyStoreProviders.Keys.ToList();
}
if (s_globalCustomColumnEncryptionKeyStoreProviders != null)
{
return s_globalCustomColumnEncryptionKeyStoreProviders.Keys.ToList();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ public void TestNullDictionary()

ArgumentNullException e = Assert.Throws<ArgumentNullException>(() => SqlConnection.RegisterColumnEncryptionKeyStoreProviders(customProviders));
Assert.Contains(expectedMessage, e.Message);

e = Assert.Throws<ArgumentNullException>(() => connection.RegisterColumnEncryptionKeyStoreProvidersOnConnection(customProviders));
Assert.Contains(expectedMessage, e.Message);
}

[Fact]
Expand All @@ -35,6 +38,9 @@ public void TestInvalidProviderName()

ArgumentException e = Assert.Throws<ArgumentException>(() => SqlConnection.RegisterColumnEncryptionKeyStoreProviders(customProviders));
Assert.Contains(expectedMessage, e.Message);

e = Assert.Throws<ArgumentException>(() => connection.RegisterColumnEncryptionKeyStoreProvidersOnConnection(customProviders));
Assert.Contains(expectedMessage, e.Message);
}

[Fact]
Expand All @@ -48,6 +54,9 @@ public void TestNullProviderValue()

ArgumentNullException e = Assert.Throws<ArgumentNullException>(() => SqlConnection.RegisterColumnEncryptionKeyStoreProviders(customProviders));
Assert.Contains(expectedMessage, e.Message);

e = Assert.Throws<ArgumentNullException>(() => connection.RegisterColumnEncryptionKeyStoreProvidersOnConnection(customProviders));
Assert.Contains(expectedMessage, e.Message);
}

[Fact]
Expand All @@ -60,6 +69,9 @@ public void TestEmptyProviderName()

ArgumentNullException e = Assert.Throws<ArgumentNullException>(() => SqlConnection.RegisterColumnEncryptionKeyStoreProviders(customProviders));
Assert.Contains(expectedMessage, e.Message);

e = Assert.Throws<ArgumentNullException>(() => connection.RegisterColumnEncryptionKeyStoreProvidersOnConnection(customProviders));
Assert.Contains(expectedMessage, e.Message);
}

[Fact]
Expand All @@ -81,5 +93,47 @@ public void TestCanSetGlobalProvidersOnlyOnce()

Utility.ClearSqlConnectionGlobalProviders();
}

[Fact]
public void TestCanSetInstanceProvidersMoreThanOnce()
{
const string dummyProviderName1 = "DummyProvider1";
const string dummyProviderName2 = "DummyProvider2";
const string dummyProviderName3 = "DummyProvider3";
IDictionary<string, SqlColumnEncryptionKeyStoreProvider> singleKeyStoreProvider =
new Dictionary<string, SqlColumnEncryptionKeyStoreProvider>()
{
{dummyProviderName1, new DummyKeyStoreProvider() }
};

IDictionary<string, SqlColumnEncryptionKeyStoreProvider> multipleKeyStoreProviders =
new Dictionary<string, SqlColumnEncryptionKeyStoreProvider>()
{
{ dummyProviderName2, new DummyKeyStoreProvider() },
{ dummyProviderName3, new DummyKeyStoreProvider() }
};

using (SqlConnection connection = new SqlConnection())
{
connection.RegisterColumnEncryptionKeyStoreProvidersOnConnection(singleKeyStoreProvider);
IReadOnlyDictionary<string, SqlColumnEncryptionKeyStoreProvider> instanceCache =
GetInstanceCacheFromConnection(connection);
Assert.Single(instanceCache);
Assert.True(instanceCache.ContainsKey(dummyProviderName1));

connection.RegisterColumnEncryptionKeyStoreProvidersOnConnection(multipleKeyStoreProviders);
instanceCache = GetInstanceCacheFromConnection(connection);
Assert.Equal(2, instanceCache.Count);
Assert.True(instanceCache.ContainsKey(dummyProviderName2));
Assert.True(instanceCache.ContainsKey(dummyProviderName3));
}

IReadOnlyDictionary<string, SqlColumnEncryptionKeyStoreProvider> GetInstanceCacheFromConnection(SqlConnection conn)
{
FieldInfo instanceCacheField = conn.GetType().GetField(
"_customColumnEncryptionKeyStoreProviders", BindingFlags.NonPublic | BindingFlags.Instance);
return instanceCacheField.GetValue(conn) as IReadOnlyDictionary<string, SqlColumnEncryptionKeyStoreProvider>;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2160,6 +2160,28 @@ public void TestCustomKeyStoreProviderDuringAeQuery(string connectionString)
() => ExecuteQueryThatRequiresCustomKeyStoreProvider(connection));
Assert.Contains(failedToDecryptMessage, ex.Message);
Assert.True(ex.InnerException is NotImplementedException);

// not required provider in instance cache
// it should not fall back to the global cache so the right provider will not be found
connection.RegisterColumnEncryptionKeyStoreProvidersOnConnection(notRequiredProvider);
ex = Assert.Throws<ArgumentException>(
() => ExecuteQueryThatRequiresCustomKeyStoreProvider(connection));
Assert.Equal(providerNotFoundMessage, ex.Message);

// required provider in instance cache
// if the instance cache is not empty, it is always checked for the provider.
// => if the provider is found, it must have been retrieved from the instance cache and not the global cache
connection.RegisterColumnEncryptionKeyStoreProvidersOnConnection(requiredProvider);
ex = Assert.Throws<SqlException>(
() => ExecuteQueryThatRequiresCustomKeyStoreProvider(connection));
Assert.Contains(failedToDecryptMessage, ex.Message);
Assert.True(ex.InnerException is NotImplementedException);

// not required provider will replace the previous entry so required provider will not be found
connection.RegisterColumnEncryptionKeyStoreProvidersOnConnection(notRequiredProvider);
ex = Assert.Throws<ArgumentException>(
() => ExecuteQueryThatRequiresCustomKeyStoreProvider(connection));
Assert.Equal(providerNotFoundMessage, ex.Message);
}

void ExecuteQueryThatRequiresCustomKeyStoreProvider(SqlConnection connection)
Expand Down

0 comments on commit 2f19bc4

Please sign in to comment.