diff --git a/src/Databases/Mongo/src/MongoVectorCollection.cs b/src/Databases/Mongo/src/MongoVectorCollection.cs new file mode 100644 index 00000000..d9c38a77 --- /dev/null +++ b/src/Databases/Mongo/src/MongoVectorCollection.cs @@ -0,0 +1,73 @@ +using LangChain.Databases.Mongo.Client; +using MongoDB.Bson.Serialization; +using MongoDB.Driver; + +namespace LangChain.Databases.Mongo; + +public class MongoVectorCollection( + IMongoContext mongoContext, + string indexName, + string name = VectorCollection.DefaultName, + string? id = null) +: VectorCollection(name, id), IVectorCollection +{ + private IMongoCollection _mongoCollection = mongoContext.GetCollection(name); + + public async Task> AddAsync(IReadOnlyCollection items, CancellationToken cancellationToken = default) + { + await _mongoCollection.InsertManyAsync(items, cancellationToken: cancellationToken).ConfigureAwait(false); + return items.Select(i => i.Id).ToList(); + } + + public async Task DeleteAsync(IEnumerable ids, CancellationToken cancellationToken = default) + { + var filter = Builders.Filter.In(i => i.Id, ids); + var result = await _mongoCollection.DeleteManyAsync(filter, cancellationToken).ConfigureAwait(false); + return result.IsAcknowledged; + } + + public async Task GetAsync(string id, CancellationToken cancellationToken = default) + { + var filter = Builders.Filter.Eq(i => i.Id, id); + var result = await _mongoCollection.FindAsync(filter, cancellationToken: cancellationToken).ConfigureAwait(false); + return result.FirstOrDefault(cancellationToken: cancellationToken); + } + + public async Task IsEmptyAsync(CancellationToken cancellationToken = default) + { + return await _mongoCollection.EstimatedDocumentCountAsync(cancellationToken: cancellationToken).ConfigureAwait(false) == 0; + } + + public async Task SearchAsync(VectorSearchRequest request, VectorSearchSettings? settings = null, CancellationToken cancellationToken = default) + { + request = request ?? throw new ArgumentNullException(nameof(request)); + settings ??= new VectorSearchSettings(); + + var options = new VectorSearchOptions() + { + IndexName = indexName, + NumberOfCandidates = settings.NumberOfResults * 10 + }; + var projectionDefinition = Builders.Projection + .Exclude(a => a.Distance) + .Meta("score", "vectorSearchScore"); + + var results = await _mongoCollection.Aggregate() + .VectorSearch(nameof(Vector.Embedding), request.Embeddings.First(), settings.NumberOfResults, options) + .Project(projectionDefinition) + .ToListAsync(cancellationToken) + .ConfigureAwait(false); + + + return new VectorSearchResponse + { + Items = results.Select(result => + { + var output = BsonSerializer.Deserialize(result); + output.Distance = (float)result["score"].ToDouble(); + return output; + }) + .ToArray(), + }; + } +}