Skip to content

Commit

Permalink
.Net: Add DI registration helpers for collections and search (#9007)
Browse files Browse the repository at this point in the history
### Motivation and Context

If someone wants to consume just the IVectorizedSearch interface from DI
it's useful to have helpers to allow easy registration of collections
for these.

#8974

### Description

- Adding a DI helper for each collection type to register the collection
interface and all search interfaces on it.

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [x] I didn't break anyone 😄
  • Loading branch information
westey-m committed Sep 27, 2024
1 parent d08eef3 commit 21e37a0
Show file tree
Hide file tree
Showing 28 changed files with 1,800 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,69 @@ public void AddVectorStoreWithUriAndTokenCredsRegistersClass()
this.AssertVectorStoreCreated();
}

[Fact]
public void AddVectorStoreRecordCollectionRegistersClass()
{
// Arrange.
this._kernelBuilder.Services.AddSingleton<SearchIndexClient>(Mock.Of<SearchIndexClient>());

// Act.
this._kernelBuilder.AddAzureAISearchVectorStoreRecordCollection<TestRecord>("testcollection");

// Assert.
this.AssertVectorStoreRecordCollectionCreated();
}

[Fact]
public void AddVectorStoreRecordCollectionWithUriAndCredsRegistersClass()
{
// Act.
this._kernelBuilder.AddAzureAISearchVectorStoreRecordCollection<TestRecord>("testcollection", new Uri("https://localhost"), new AzureKeyCredential("fakeKey"));

// Assert.
this.AssertVectorStoreRecordCollectionCreated();
}

[Fact]
public void AddVectorStoreRecordCollectionWithUriAndTokenCredsRegistersClass()
{
// Act.
this._kernelBuilder.AddAzureAISearchVectorStoreRecordCollection<TestRecord>("testcollection", new Uri("https://localhost"), Mock.Of<TokenCredential>());

// Assert.
this.AssertVectorStoreRecordCollectionCreated();
}

private void AssertVectorStoreCreated()
{
var kernel = this._kernelBuilder.Build();
var vectorStore = kernel.Services.GetRequiredService<IVectorStore>();
Assert.NotNull(vectorStore);
Assert.IsType<AzureAISearchVectorStore>(vectorStore);
}

private void AssertVectorStoreRecordCollectionCreated()
{
var kernel = this._kernelBuilder.Build();

var collection = kernel.Services.GetRequiredService<IVectorStoreRecordCollection<string, TestRecord>>();
Assert.NotNull(collection);
Assert.IsType<AzureAISearchVectorStoreRecordCollection<TestRecord>>(collection);

var vectorizedSearch = kernel.Services.GetRequiredService<IVectorizedSearch<TestRecord>>();
Assert.NotNull(vectorizedSearch);
Assert.IsType<AzureAISearchVectorStoreRecordCollection<TestRecord>>(vectorizedSearch);

var vectorizableSearch = kernel.Services.GetRequiredService<IVectorizableTextSearch<TestRecord>>();
Assert.NotNull(vectorizableSearch);
Assert.IsType<AzureAISearchVectorStoreRecordCollection<TestRecord>>(vectorizableSearch);
}

#pragma warning disable CA1812 // Avoid uninstantiated internal classes
private sealed class TestRecord
#pragma warning restore CA1812 // Avoid uninstantiated internal classes
{
[VectorStoreRecordKey]
public string Id { get; set; } = string.Empty;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,69 @@ public void AddVectorStoreWithUriAndTokenCredsRegistersClass()
this.AssertVectorStoreCreated();
}

[Fact]
public void AddVectorStoreRecordCollectionRegistersClass()
{
// Arrange.
this._serviceCollection.AddSingleton<SearchIndexClient>(Mock.Of<SearchIndexClient>());

// Act.
this._serviceCollection.AddAzureAISearchVectorStoreRecordCollection<TestRecord>("testcollection");

// Assert.
this.AssertVectorStoreRecordCollectionCreated();
}

[Fact]
public void AddVectorStoreRecordCollectionWithUriAndCredsRegistersClass()
{
// Act.
this._serviceCollection.AddAzureAISearchVectorStoreRecordCollection<TestRecord>("testcollection", new Uri("https://localhost"), new AzureKeyCredential("fakeKey"));

// Assert.
this.AssertVectorStoreRecordCollectionCreated();
}

[Fact]
public void AddVectorStoreRecordCollectionWithUriAndTokenCredsRegistersClass()
{
// Act.
this._serviceCollection.AddAzureAISearchVectorStoreRecordCollection<TestRecord>("testcollection", new Uri("https://localhost"), Mock.Of<TokenCredential>());

// Assert.
this.AssertVectorStoreRecordCollectionCreated();
}

private void AssertVectorStoreCreated()
{
var serviceProvider = this._serviceCollection.BuildServiceProvider();
var vectorStore = serviceProvider.GetRequiredService<IVectorStore>();
Assert.NotNull(vectorStore);
Assert.IsType<AzureAISearchVectorStore>(vectorStore);
}

private void AssertVectorStoreRecordCollectionCreated()
{
var serviceProvider = this._serviceCollection.BuildServiceProvider();

var collection = serviceProvider.GetRequiredService<IVectorStoreRecordCollection<string, TestRecord>>();
Assert.NotNull(collection);
Assert.IsType<AzureAISearchVectorStoreRecordCollection<TestRecord>>(collection);

var vectorizedSearch = serviceProvider.GetRequiredService<IVectorizedSearch<TestRecord>>();
Assert.NotNull(vectorizedSearch);
Assert.IsType<AzureAISearchVectorStoreRecordCollection<TestRecord>>(vectorizedSearch);

var vectorizableSearch = serviceProvider.GetRequiredService<IVectorizableTextSearch<TestRecord>>();
Assert.NotNull(vectorizableSearch);
Assert.IsType<AzureAISearchVectorStoreRecordCollection<TestRecord>>(vectorizableSearch);
}

#pragma warning disable CA1812 // Avoid uninstantiated internal classes
private sealed class TestRecord
#pragma warning restore CA1812 // Avoid uninstantiated internal classes
{
[VectorStoreRecordKey]
public string Id { get; set; } = string.Empty;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,48 @@ public void AddVectorStoreWithConnectionStringRegistersClass()
var database = (IMongoDatabase)vectorStore.GetType().GetField("_mongoDatabase", BindingFlags.NonPublic | BindingFlags.Instance)!.GetValue(vectorStore)!;
Assert.Equal(HttpHeaderConstant.Values.UserAgent, database.Client.Settings.ApplicationName);
}

[Fact]
public void AddVectorStoreRecordCollectionRegistersClass()
{
// Arrange
this._kernelBuilder.Services.AddSingleton<IMongoDatabase>(Mock.Of<IMongoDatabase>());

// Act
this._kernelBuilder.AddAzureCosmosDBMongoDBVectorStoreRecordCollection<TestRecord>("testcollection");

// Assert
this.AssertVectorStoreRecordCollectionCreated();
}

[Fact]
public void AddVectorStoreRecordCollectionWithConnectionStringRegistersClass()
{
// Act
this._kernelBuilder.AddAzureCosmosDBMongoDBVectorStoreRecordCollection<TestRecord>("testcollection", "mongodb://localhost:27017", "mydb");

// Assert
this.AssertVectorStoreRecordCollectionCreated();
}

private void AssertVectorStoreRecordCollectionCreated()
{
var kernel = this._kernelBuilder.Build();

var collection = kernel.Services.GetRequiredService<IVectorStoreRecordCollection<string, TestRecord>>();
Assert.NotNull(collection);
Assert.IsType<AzureCosmosDBMongoDBVectorStoreRecordCollection<TestRecord>>(collection);

var vectorizedSearch = kernel.Services.GetRequiredService<IVectorizedSearch<TestRecord>>();
Assert.NotNull(vectorizedSearch);
Assert.IsType<AzureCosmosDBMongoDBVectorStoreRecordCollection<TestRecord>>(vectorizedSearch);
}

#pragma warning disable CA1812 // Avoid uninstantiated internal classes
private sealed class TestRecord
#pragma warning restore CA1812 // Avoid uninstantiated internal classes
{
[VectorStoreRecordKey]
public string Id { get; set; } = string.Empty;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,48 @@ public void AddVectorStoreWithConnectionStringRegistersClass()
var database = (IMongoDatabase)vectorStore.GetType().GetField("_mongoDatabase", BindingFlags.NonPublic | BindingFlags.Instance)!.GetValue(vectorStore)!;
Assert.Equal(HttpHeaderConstant.Values.UserAgent, database.Client.Settings.ApplicationName);
}

[Fact]
public void AddVectorStoreRecordCollectionRegistersClass()
{
// Arrange
this._serviceCollection.AddSingleton<IMongoDatabase>(Mock.Of<IMongoDatabase>());

// Act
this._serviceCollection.AddAzureCosmosDBMongoDBVectorStoreRecordCollection<TestRecord>("testcollection");

// Assert
this.AssertVectorStoreRecordCollectionCreated();
}

[Fact]
public void AddVectorStoreRecordCollectionWithConnectionStringRegistersClass()
{
// Act
this._serviceCollection.AddAzureCosmosDBMongoDBVectorStoreRecordCollection<TestRecord>("testcollection", "mongodb://localhost:27017", "mydb");

// Assert
this.AssertVectorStoreRecordCollectionCreated();
}

private void AssertVectorStoreRecordCollectionCreated()
{
var serviceProvider = this._serviceCollection.BuildServiceProvider();

var collection = serviceProvider.GetRequiredService<IVectorStoreRecordCollection<string, TestRecord>>();
Assert.NotNull(collection);
Assert.IsType<AzureCosmosDBMongoDBVectorStoreRecordCollection<TestRecord>>(collection);

var vectorizedSearch = serviceProvider.GetRequiredService<IVectorizedSearch<TestRecord>>();
Assert.NotNull(vectorizedSearch);
Assert.IsType<AzureCosmosDBMongoDBVectorStoreRecordCollection<TestRecord>>(vectorizedSearch);
}

#pragma warning disable CA1812 // Avoid uninstantiated internal classes
private sealed class TestRecord
#pragma warning restore CA1812 // Avoid uninstantiated internal classes
{
[VectorStoreRecordKey]
public string Id { get; set; } = string.Empty;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,47 @@ public void AddVectorStoreWithConnectionStringRegistersClass()
var database = (Database)vectorStore.GetType().GetField("_database", BindingFlags.NonPublic | BindingFlags.Instance)!.GetValue(vectorStore)!;
Assert.Equal(HttpHeaderConstant.Values.UserAgent, database.Client.ClientOptions.ApplicationName);
}
[Fact]
public void AddVectorStoreRecordCollectionRegistersClass()
{
// Arrange
this._kernelBuilder.Services.AddSingleton<Database>(Mock.Of<Database>());

// Act
this._kernelBuilder.AddAzureCosmosDBNoSQLVectorStoreRecordCollection<TestRecord>("testcollection");

// Assert
this.AssertVectorStoreRecordCollectionCreated();
}

[Fact]
public void AddVectorStoreRecordCollectionWithConnectionStringRegistersClass()
{
// Act
this._kernelBuilder.AddAzureCosmosDBNoSQLVectorStoreRecordCollection<TestRecord>("testcollection", "AccountEndpoint=https://test.documents.azure.com:443/;AccountKey=mock;", "mydb");

// Assert
this.AssertVectorStoreRecordCollectionCreated();
}

private void AssertVectorStoreRecordCollectionCreated()
{
var kernel = this._kernelBuilder.Build();

var collection = kernel.Services.GetRequiredService<IVectorStoreRecordCollection<string, TestRecord>>();
Assert.NotNull(collection);
Assert.IsType<AzureCosmosDBNoSQLVectorStoreRecordCollection<TestRecord>>(collection);

var vectorizedSearch = kernel.Services.GetRequiredService<IVectorizedSearch<TestRecord>>();
Assert.NotNull(vectorizedSearch);
Assert.IsType<AzureCosmosDBNoSQLVectorStoreRecordCollection<TestRecord>>(vectorizedSearch);
}

#pragma warning disable CA1812 // Avoid uninstantiated internal classes
private sealed class TestRecord
#pragma warning restore CA1812 // Avoid uninstantiated internal classes
{
[VectorStoreRecordKey]
public string Id { get; set; } = string.Empty;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,48 @@ public void AddVectorStoreWithConnectionStringRegistersClass()
var database = (Database)vectorStore.GetType().GetField("_database", BindingFlags.NonPublic | BindingFlags.Instance)!.GetValue(vectorStore)!;
Assert.Equal(HttpHeaderConstant.Values.UserAgent, database.Client.ClientOptions.ApplicationName);
}

[Fact]
public void AddVectorStoreRecordCollectionRegistersClass()
{
// Arrange
this._serviceCollection.AddSingleton<Database>(Mock.Of<Database>());

// Act
this._serviceCollection.AddAzureCosmosDBNoSQLVectorStoreRecordCollection<TestRecord>("testcollection");

// Assert
this.AssertVectorStoreRecordCollectionCreated();
}

[Fact]
public void AddVectorStoreRecordCollectionWithConnectionStringRegistersClass()
{
// Act
this._serviceCollection.AddAzureCosmosDBNoSQLVectorStoreRecordCollection<TestRecord>("testcollection", "AccountEndpoint=https://test.documents.azure.com:443/;AccountKey=mock;", "mydb");

// Assert
this.AssertVectorStoreRecordCollectionCreated();
}

private void AssertVectorStoreRecordCollectionCreated()
{
var serviceProvider = this._serviceCollection.BuildServiceProvider();

var collection = serviceProvider.GetRequiredService<IVectorStoreRecordCollection<string, TestRecord>>();
Assert.NotNull(collection);
Assert.IsType<AzureCosmosDBNoSQLVectorStoreRecordCollection<TestRecord>>(collection);

var vectorizedSearch = serviceProvider.GetRequiredService<IVectorizedSearch<TestRecord>>();
Assert.NotNull(vectorizedSearch);
Assert.IsType<AzureCosmosDBNoSQLVectorStoreRecordCollection<TestRecord>>(vectorizedSearch);
}

#pragma warning disable CA1812 // Avoid uninstantiated internal classes
private sealed class TestRecord
#pragma warning restore CA1812 // Avoid uninstantiated internal classes
{
[VectorStoreRecordKey]
public string Id { get; set; } = string.Empty;
}
}
Loading

0 comments on commit 21e37a0

Please sign in to comment.