From 12b531ef25a9c8d0255f1cf7de8381fd85e2543f Mon Sep 17 00:00:00 2001 From: Aniket Gupta <70100745+aniketg-21@users.noreply.github.com> Date: Sat, 14 Jun 2025 21:24:38 +0530 Subject: [PATCH] GH-3540: Allow user-provided embeddings in VectorStore Signed-off-by: Aniket Gupta <70100745+aniketg-21@users.noreply.github.com> --- .../ai/vectorstore/SimpleVectorStore.java | 5 +- .../ai/vectorstore/VectorStore.java | 13 +++ .../AbstractObservationVectorStore.java | 80 ++++++++++++++++++- .../vectorstore/SimpleVectorStoreTests.java | 70 ++++++++++++++-- .../SimpleVectorStoreWithFilterTests.java | 3 +- .../cosmosdb/CosmosDBVectorStore.java | 7 +- .../vectorstore/azure/AzureVectorStore.java | 6 +- .../cassandra/CassandraVectorStore.java | 6 +- .../chroma/vectorstore/ChromaVectorStore.java | 6 +- .../coherence/CoherenceVectorStore.java | 4 +- .../CouchbaseSearchVectorStore.java | 5 +- .../ElasticsearchVectorStore.java | 6 +- .../gemfire/GemFireVectorStore.java | 5 +- .../hanadb/HanaCloudVectorStore.java | 8 +- .../mariadb/MariaDBVectorStore.java | 7 +- .../mariadb/MariaDBStoreTests.java | 9 +-- .../vectorstore/milvus/MilvusVectorStore.java | 5 +- .../atlas/MongoDBAtlasVectorStore.java | 5 +- .../vectorstore/neo4j/Neo4jVectorStore.java | 6 +- .../opensearch/OpenSearchVectorStore.java | 7 +- .../vectorstore/oracle/OracleVectorStore.java | 5 +- .../vectorstore/pgvector/PgVectorStore.java | 6 +- .../pgvector/PgVectorStoreTests.java | 9 +-- .../pinecone/PineconeVectorStore.java | 20 +---- .../vectorstore/qdrant/QdrantVectorStore.java | 8 +- .../vectorstore/redis/RedisVectorStore.java | 6 +- .../typesense/TypesenseVectorStore.java | 6 +- .../weaviate/WeaviateVectorStore.java | 6 +- 28 files changed, 191 insertions(+), 138 deletions(-) diff --git a/spring-ai-vector-store/src/main/java/org/springframework/ai/vectorstore/SimpleVectorStore.java b/spring-ai-vector-store/src/main/java/org/springframework/ai/vectorstore/SimpleVectorStore.java index 8c7d3fe9687..3f9dff10416 100644 --- a/spring-ai-vector-store/src/main/java/org/springframework/ai/vectorstore/SimpleVectorStore.java +++ b/spring-ai-vector-store/src/main/java/org/springframework/ai/vectorstore/SimpleVectorStore.java @@ -102,7 +102,7 @@ public static SimpleVectorStoreBuilder builder(EmbeddingModel embeddingModel) { } @Override - public void doAdd(List documents) { + public void doAdd(List documents, List embeddings) { Objects.requireNonNull(documents, "Documents list cannot be null"); if (documents.isEmpty()) { throw new IllegalArgumentException("Documents list cannot be empty"); @@ -110,9 +110,8 @@ public void doAdd(List documents) { for (Document document : documents) { logger.info("Calling EmbeddingModel for document id = {}", document.getId()); - float[] embedding = this.embeddingModel.embed(document); SimpleVectorStoreContent storeContent = new SimpleVectorStoreContent(document.getId(), document.getText(), - document.getMetadata(), embedding); + document.getMetadata(), embeddings.get(documents.indexOf(document))); this.store.put(document.getId(), storeContent); } } diff --git a/spring-ai-vector-store/src/main/java/org/springframework/ai/vectorstore/VectorStore.java b/spring-ai-vector-store/src/main/java/org/springframework/ai/vectorstore/VectorStore.java index df1a11f614d..9dca4c62408 100644 --- a/spring-ai-vector-store/src/main/java/org/springframework/ai/vectorstore/VectorStore.java +++ b/spring-ai-vector-store/src/main/java/org/springframework/ai/vectorstore/VectorStore.java @@ -51,6 +51,19 @@ default String getName() { */ void add(List documents); + /** + * Adds list of {@link Document}s with their corresponding embeddings to the vector store. + * @param documents the list of documents to store. Throws an exception if the + * underlying provider checks for duplicate IDs. + * @param embeddings the list of float[] embeddings corresponding to each document. + * @throws IllegalArgumentException if there is: + *
    + *
  • A mismatch between documents and embeddings + *
  • Dimensional inconsistency between embeddings + *
  • Embeddings contain {@code NaN}, {@code Infinity}, or null/empty vectors. + */ + void add(List documents, List embeddings); + @Override default void accept(List documents) { add(documents); diff --git a/spring-ai-vector-store/src/main/java/org/springframework/ai/vectorstore/observation/AbstractObservationVectorStore.java b/spring-ai-vector-store/src/main/java/org/springframework/ai/vectorstore/observation/AbstractObservationVectorStore.java index ca3c3ae9185..6e8732ca73b 100644 --- a/spring-ai-vector-store/src/main/java/org/springframework/ai/vectorstore/observation/AbstractObservationVectorStore.java +++ b/spring-ai-vector-store/src/main/java/org/springframework/ai/vectorstore/observation/AbstractObservationVectorStore.java @@ -23,11 +23,13 @@ import org.springframework.ai.document.Document; import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.EmbeddingOptionsBuilder; import org.springframework.ai.vectorstore.AbstractVectorStoreBuilder; import org.springframework.ai.vectorstore.SearchRequest; import org.springframework.ai.vectorstore.VectorStore; import org.springframework.ai.vectorstore.filter.Filter; import org.springframework.lang.Nullable; +import org.springframework.util.Assert; /** * Abstract base class for {@link VectorStore} implementations that provides observation @@ -82,7 +84,29 @@ public void add(List documents) { VectorStoreObservationDocumentation.AI_VECTOR_STORE .observation(this.customObservationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry) - .observe(() -> this.doAdd(documents)); + .observe(() -> this.doAdd(documents, this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), + this.batchingStrategy))); + } + + /** + * Create a new {@link AbstractObservationVectorStore} instance. + * @param documents the documents to add + * @param embeddings the embeddings corresponding to each document + */ + @Override + public void add(List documents, List embeddings) { + + VectorStoreObservationContext observationContext = this + .createObservationContextBuilder(VectorStoreObservationContext.Operation.ADD.value()) + .build(); + + VectorStoreObservationDocumentation.AI_VECTOR_STORE + .observation(this.customObservationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, + this.observationRegistry) + .observe(() -> { + this.validateEmbeddings(documents, embeddings); + this.doAdd(documents, embeddings); + }); } @Override @@ -132,8 +156,9 @@ public List similaritySearch(SearchRequest request) { /** * Perform the actual add operation. * @param documents the documents to add + * @param embeddings the embeddings corresponding to each document */ - public abstract void doAdd(List documents); + public abstract void doAdd(List documents, List embeddings); /** * Perform the actual delete operation. @@ -167,4 +192,55 @@ protected void doDelete(Filter.Expression filterExpression) { */ public abstract VectorStoreObservationContext.Builder createObservationContextBuilder(String operationName); + /** + * Validates a list of documents and their corresponding embeddings. + * + * @param documents The list of documents. Must not be null. + * @param embeddings The list of float[] embeddings corresponding to each document. + * @throws IllegalArgumentException if validation fails for: + *
      + *
    • A mismatch between documents and embeddings + *
    • Dimensional inconsistency between embeddings + *
    • Embeddings contain {@code NaN}, {@code Infinity}, or null/empty vectors. + */ + protected void validateEmbeddings(List documents, List embeddings) { + Assert.notNull(documents, "Documents list cannot be null."); + Assert.notNull(embeddings, "Embeddings list cannot be null."); + + int docSize = documents.size(); + int embSize = embeddings.size(); + + if (docSize != embSize) { + throw new IllegalArgumentException( + String.format("Mismatch between documents (%d) and embeddings (%d).", docSize, embSize)); + } + if (embSize == 0) return; + + float[] first = embeddings.get(0); + if (first == null || first.length == 0) { + throw new IllegalArgumentException("First embedding is null or empty."); + } + + final int expectedDim = first.length; + + for (int i = 0; i < embSize; i++) { + float[] emb = embeddings.get(i); + + if (emb == null) { + throw new IllegalArgumentException("Embedding at index " + i + " is null."); + } + if (emb.length != expectedDim) { + throw new IllegalArgumentException(String.format( + "Embedding at index %d has dimension %d, expected %d.", i, emb.length, expectedDim)); + } + + for (float val : emb) { + if (Float.isNaN(val) || Float.isInfinite(val)) { + throw new IllegalArgumentException(String.format( + "Embedding at index %d contains NaN or Infinite value.", i)); + } + } + } + } + } diff --git a/spring-ai-vector-store/src/test/java/org/springframework/ai/vectorstore/SimpleVectorStoreTests.java b/spring-ai-vector-store/src/test/java/org/springframework/ai/vectorstore/SimpleVectorStoreTests.java index 85fe0b384c6..11ca374a3d4 100644 --- a/spring-ai-vector-store/src/test/java/org/springframework/ai/vectorstore/SimpleVectorStoreTests.java +++ b/spring-ai-vector-store/src/test/java/org/springframework/ai/vectorstore/SimpleVectorStoreTests.java @@ -19,12 +19,7 @@ import java.io.File; import java.io.IOException; import java.nio.file.Path; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Set; +import java.util.*; import java.util.stream.Collectors; import org.junit.jupiter.api.BeforeEach; @@ -57,7 +52,8 @@ void setUp() { this.mockEmbeddingModel = mock(EmbeddingModel.class); when(this.mockEmbeddingModel.dimensions()).thenReturn(3); when(this.mockEmbeddingModel.embed(any(String.class))).thenReturn(new float[] { 0.1f, 0.2f, 0.3f }); - when(this.mockEmbeddingModel.embed(any(Document.class))).thenReturn(new float[] { 0.1f, 0.2f, 0.3f }); + when(this.mockEmbeddingModel.embed(any(), any(), any())) + .thenReturn(List.of(new float[] { 0.1f, 0.2f, 0.3f }, new float[] { 0.1f, 0.2f, 0.3f })); this.vectorStore = new SimpleVectorStore(SimpleVectorStore.builder(this.mockEmbeddingModel)); } @@ -86,6 +82,66 @@ void shouldAddMultipleDocuments() { assertThat(results).hasSize(2).extracting(Document::getId).containsExactlyInAnyOrder("1", "2"); } + @Test + void shouldAddMultipleDocsWithProvidedEmbeddings() { + List docs = Arrays.asList(Document.builder().id("1").text("first").build(), + Document.builder().id("2").text("second").build()); + List embeddings = List.of(new float[] {0.1f, 0.2f, 0.3f}, new float[] {0.4f, 0.5f, 0.6f}); + + this.vectorStore.add(docs, embeddings); + + List results = this.vectorStore.similaritySearch("first"); + assertThat(results).hasSize(2).extracting(Document::getId).containsExactlyInAnyOrder("1", "2"); + } + + @Test + void shouldHandleNullEmbeddingsList() { + assertThatThrownBy(() -> this.vectorStore.add(Collections.emptyList(), null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Embeddings list cannot be null."); + } + + @Test + void shouldHandleMismatchDocsAndEmbeddingsList() { + List embeddings = List.of(new float[] {0.1f, 0.2f, 0.3f}); + + assertThatThrownBy(() -> this.vectorStore.add(Collections.emptyList(), embeddings)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Mismatch between documents (0) and embeddings (1)."); + } + + @Test + void shouldHandleInvalidEmbeddings() { + List docs = List.of(Document.builder().id("1").text("first").build()); + + assertThatThrownBy(() -> this.vectorStore.add(docs, List.of(new float[] {Float.NaN}))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Embedding at index 0 contains NaN or Infinite value."); + + List nullEmbeddings = new ArrayList<>(); + nullEmbeddings.add(null); + + assertThatThrownBy(() -> this.vectorStore.add(docs, nullEmbeddings)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("First embedding is null or empty."); + + List newDocs = Arrays.asList(Document.builder().id("1").text("first").build(), + Document.builder().id("2").text("second").build()); + List invalidEmbeddings = new ArrayList<>(); + invalidEmbeddings.add(new float[] {0.1f, 0.2f, 0.3f}); + invalidEmbeddings.add(null); + + assertThatThrownBy(() -> this.vectorStore.add(newDocs, invalidEmbeddings)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Embedding at index 1 is null."); + + List invalidEmbeddingsDimensions = List.of(new float[] {0.1f, 0.2f, 0.3f}, new float[] {0.1f, 0.2f}); + + assertThatThrownBy(() -> this.vectorStore.add(newDocs, invalidEmbeddingsDimensions)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Embedding at index 1 has dimension 2, expected 3."); + } + @Test void shouldHandleEmptyDocumentList() { assertThatThrownBy(() -> this.vectorStore.add(Collections.emptyList())) diff --git a/spring-ai-vector-store/src/test/java/org/springframework/ai/vectorstore/SimpleVectorStoreWithFilterTests.java b/spring-ai-vector-store/src/test/java/org/springframework/ai/vectorstore/SimpleVectorStoreWithFilterTests.java index 27e7ac6079c..c3346125ce1 100644 --- a/spring-ai-vector-store/src/test/java/org/springframework/ai/vectorstore/SimpleVectorStoreWithFilterTests.java +++ b/spring-ai-vector-store/src/test/java/org/springframework/ai/vectorstore/SimpleVectorStoreWithFilterTests.java @@ -57,7 +57,8 @@ void setUp() { this.mockEmbeddingModel = mock(EmbeddingModel.class); when(this.mockEmbeddingModel.dimensions()).thenReturn(3); when(this.mockEmbeddingModel.embed(any(String.class))).thenReturn(new float[] { 0.1f, 0.2f, 0.3f }); - when(this.mockEmbeddingModel.embed(any(Document.class))).thenReturn(new float[] { 0.1f, 0.2f, 0.3f }); + when(this.mockEmbeddingModel.embed(any(), any(), any())) + .thenReturn(List.of(new float[] { 0.1f, 0.2f, 0.3f }, new float[] { 0.1f, 0.2f, 0.3f })); this.vectorStore = SimpleVectorStore.builder(this.mockEmbeddingModel).build(); } diff --git a/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/cosmosdb/CosmosDBVectorStore.java b/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/cosmosdb/CosmosDBVectorStore.java index ca2a817001a..701c26d8436 100644 --- a/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/cosmosdb/CosmosDBVectorStore.java +++ b/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/cosmosdb/CosmosDBVectorStore.java @@ -60,7 +60,6 @@ import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; -import org.springframework.ai.embedding.EmbeddingOptionsBuilder; import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.vectorstore.AbstractVectorStoreBuilder; import org.springframework.ai.vectorstore.SearchRequest; @@ -226,11 +225,7 @@ private JsonNode mapCosmosDocument(Document document, float[] queryEmbedding) { } @Override - public void doAdd(List documents) { - - // Batch the documents based on the batching strategy - List embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), - this.batchingStrategy); + public void doAdd(List documents, List embeddings) { // Create a list to hold both the CosmosItemOperation and the corresponding // document ID diff --git a/vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureVectorStore.java b/vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureVectorStore.java index 0f86bd10c9f..a11da756949 100644 --- a/vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureVectorStore.java +++ b/vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureVectorStore.java @@ -48,7 +48,6 @@ import org.springframework.ai.document.Document; import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.EmbeddingModel; -import org.springframework.ai.embedding.EmbeddingOptionsBuilder; import org.springframework.ai.model.EmbeddingUtils; import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric; @@ -151,16 +150,13 @@ public static Builder builder(SearchIndexClient searchIndexClient, EmbeddingMode } @Override - public void doAdd(List documents) { + public void doAdd(List documents, List embeddings) { Assert.notNull(documents, "The document list should not be null."); if (CollectionUtils.isEmpty(documents)) { return; // nothing to do; } - List embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), - this.batchingStrategy); - final var searchDocuments = documents.stream().map(document -> { SearchDocument searchDocument = new SearchDocument(); searchDocument.put(ID_FIELD_NAME, document.getId()); diff --git a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/cassandra/CassandraVectorStore.java b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/cassandra/CassandraVectorStore.java index 46bb76d0330..f42003a90d0 100644 --- a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/cassandra/CassandraVectorStore.java +++ b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/cassandra/CassandraVectorStore.java @@ -68,7 +68,6 @@ import org.springframework.ai.document.Document; import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.EmbeddingModel; -import org.springframework.ai.embedding.EmbeddingOptionsBuilder; import org.springframework.ai.model.EmbeddingUtils; import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric; @@ -267,12 +266,9 @@ private static Float[] toFloatArray(float[] embedding) { } @Override - public void doAdd(List documents) { + public void doAdd(List documents, List embeddings) { var futures = new CompletableFuture[documents.size()]; - List embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), - this.batchingStrategy); - int i = 0; for (Document d : documents) { futures[i++] = CompletableFuture.runAsync(() -> { diff --git a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/vectorstore/ChromaVectorStore.java b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/vectorstore/ChromaVectorStore.java index b56b99673a5..4abb4311b59 100644 --- a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/vectorstore/ChromaVectorStore.java +++ b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/vectorstore/ChromaVectorStore.java @@ -34,7 +34,6 @@ import org.springframework.ai.document.Document; import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.EmbeddingModel; -import org.springframework.ai.embedding.EmbeddingOptionsBuilder; import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.util.JacksonUtils; import org.springframework.ai.vectorstore.AbstractVectorStoreBuilder; @@ -147,7 +146,7 @@ public void afterPropertiesSet() throws Exception { } @Override - public void doAdd(@NonNull List documents) { + public void doAdd(@NonNull List documents, List documentEmbeddings) { Assert.notNull(documents, "Documents must not be null"); if (CollectionUtils.isEmpty(documents)) { return; @@ -158,9 +157,6 @@ public void doAdd(@NonNull List documents) { List contents = new ArrayList<>(); List embeddings = new ArrayList<>(); - List documentEmbeddings = this.embeddingModel.embed(documents, - EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); - for (Document document : documents) { ids.add(document.getId()); metadatas.add(document.getMetadata()); diff --git a/vector-stores/spring-ai-coherence-store/src/main/java/org/springframework/ai/vectorstore/coherence/CoherenceVectorStore.java b/vector-stores/spring-ai-coherence-store/src/main/java/org/springframework/ai/vectorstore/coherence/CoherenceVectorStore.java index ef73947bbb0..dbcf1055460 100644 --- a/vector-stores/spring-ai-coherence-store/src/main/java/org/springframework/ai/vectorstore/coherence/CoherenceVectorStore.java +++ b/vector-stores/spring-ai-coherence-store/src/main/java/org/springframework/ai/vectorstore/coherence/CoherenceVectorStore.java @@ -166,12 +166,12 @@ public static Builder builder(Session session, EmbeddingModel embeddingModel) { } @Override - public void doAdd(final List documents) { + public void doAdd(final List documents, List embeddings) { Map chunks = new HashMap<>((int) Math.ceil(documents.size() / 0.75f)); for (Document doc : documents) { var id = toChunkId(doc.getId()); var chunk = new DocumentChunk(doc.getText(), doc.getMetadata(), - toFloat32Vector(this.embeddingModel.embed(doc))); + toFloat32Vector(embeddings.get(documents.indexOf(doc)))); chunks.put(id, chunk); } this.documentChunks.putAll(chunks); diff --git a/vector-stores/spring-ai-couchbase-store/src/main/java/org/springframework/ai/vectorstore/CouchbaseSearchVectorStore.java b/vector-stores/spring-ai-couchbase-store/src/main/java/org/springframework/ai/vectorstore/CouchbaseSearchVectorStore.java index ad0ebc95a1d..6d44c44a217 100644 --- a/vector-stores/spring-ai-couchbase-store/src/main/java/org/springframework/ai/vectorstore/CouchbaseSearchVectorStore.java +++ b/vector-stores/spring-ai-couchbase-store/src/main/java/org/springframework/ai/vectorstore/CouchbaseSearchVectorStore.java @@ -42,7 +42,6 @@ import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; -import org.springframework.ai.embedding.EmbeddingOptionsBuilder; import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.vectorstore.filter.Filter; import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore; @@ -133,12 +132,10 @@ public void afterPropertiesSet() { } @Override - public void doAdd(List documents) { + public void doAdd(List documents, List embeddings) { logger.info("Trying Add"); logger.info(this.bucketName); logger.info(this.scopeName); - List embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), - this.batchingStrategy); for (Document document : documents) { CouchbaseDocument cbDoc = new CouchbaseDocument(document.getId(), document.getText(), document.getMetadata(), embeddings.get(documents.indexOf(document))); diff --git a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/elasticsearch/ElasticsearchVectorStore.java b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/elasticsearch/ElasticsearchVectorStore.java index 06f0ba38a2d..a55e349213a 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/elasticsearch/ElasticsearchVectorStore.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/elasticsearch/ElasticsearchVectorStore.java @@ -42,7 +42,6 @@ import org.springframework.ai.document.Document; import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.EmbeddingModel; -import org.springframework.ai.embedding.EmbeddingOptionsBuilder; import org.springframework.ai.model.EmbeddingUtils; import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric; @@ -177,7 +176,7 @@ protected ElasticsearchVectorStore(Builder builder) { } @Override - public void doAdd(List documents) { + public void doAdd(List documents, List embeddings) { // For the index to be present, either it must be pre-created or set the // initializeSchema to true. if (!indexExists()) { @@ -185,9 +184,6 @@ public void doAdd(List documents) { } BulkRequest.Builder bulkRequestBuilder = new BulkRequest.Builder(); - List embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), - this.batchingStrategy); - for (int i = 0; i < embeddings.size(); i++) { Document document = documents.get(i); float[] embedding = embeddings.get(i); diff --git a/vector-stores/spring-ai-gemfire-store/src/main/java/org/springframework/ai/vectorstore/gemfire/GemFireVectorStore.java b/vector-stores/spring-ai-gemfire-store/src/main/java/org/springframework/ai/vectorstore/gemfire/GemFireVectorStore.java index 6d90043fde4..43cc96536d8 100644 --- a/vector-stores/spring-ai-gemfire-store/src/main/java/org/springframework/ai/vectorstore/gemfire/GemFireVectorStore.java +++ b/vector-stores/spring-ai-gemfire-store/src/main/java/org/springframework/ai/vectorstore/gemfire/GemFireVectorStore.java @@ -32,7 +32,6 @@ import org.springframework.ai.document.Document; import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.EmbeddingModel; -import org.springframework.ai.embedding.EmbeddingOptionsBuilder; import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.util.JacksonUtils; import org.springframework.ai.vectorstore.AbstractVectorStoreBuilder; @@ -200,9 +199,7 @@ public String getIndex() { } @Override - public void doAdd(List documents) { - List embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), - this.batchingStrategy); + public void doAdd(List documents, List embeddings) { UploadRequest upload = new UploadRequest(documents.stream() .map(document -> new UploadRequest.Embedding(document.getId(), embeddings.get(documents.indexOf(document)), DOCUMENT_FIELD, document.getText(), document.getMetadata())) diff --git a/vector-stores/spring-ai-hanadb-store/src/main/java/org/springframework/ai/vectorstore/hanadb/HanaCloudVectorStore.java b/vector-stores/spring-ai-hanadb-store/src/main/java/org/springframework/ai/vectorstore/hanadb/HanaCloudVectorStore.java index 3efdd517ff4..3c2d4edac7a 100644 --- a/vector-stores/spring-ai-hanadb-store/src/main/java/org/springframework/ai/vectorstore/hanadb/HanaCloudVectorStore.java +++ b/vector-stores/spring-ai-hanadb-store/src/main/java/org/springframework/ai/vectorstore/hanadb/HanaCloudVectorStore.java @@ -112,13 +112,13 @@ public static Builder builder(HanaVectorRepository r } @Override - public void doAdd(List documents) { + public void doAdd(List documents, List embeddings) { int count = 1; for (Document document : documents) { logger.info("[{}/{}] Calling EmbeddingModel for document id = {}", count++, documents.size(), document.getId()); String content = document.getText().replaceAll("\\s+", " "); - String embedding = getEmbedding(document); + String embedding = getEmbedding(embeddings.get(documents.indexOf(document))); this.repository.save(this.tableName, document.getId(), embedding, content); } logger.info("Embeddings saved in HanaCloudVectorStore for {} documents", count - 1); @@ -171,8 +171,8 @@ private String getEmbedding(SearchRequest searchRequest) { .collect(Collectors.joining(", ")) + "]"; } - private String getEmbedding(Document document) { - return "[" + EmbeddingUtils.toList(this.embeddingModel.embed(document)) + private String getEmbedding(float[] embedding) { + return "[" + EmbeddingUtils.toList(embedding) .stream() .map(String::valueOf) .collect(Collectors.joining(", ")) + "]"; diff --git a/vector-stores/spring-ai-mariadb-store/src/main/java/org/springframework/ai/vectorstore/mariadb/MariaDBVectorStore.java b/vector-stores/spring-ai-mariadb-store/src/main/java/org/springframework/ai/vectorstore/mariadb/MariaDBVectorStore.java index 9223bec60d6..664b1781fc3 100644 --- a/vector-stores/spring-ai-mariadb-store/src/main/java/org/springframework/ai/vectorstore/mariadb/MariaDBVectorStore.java +++ b/vector-stores/spring-ai-mariadb-store/src/main/java/org/springframework/ai/vectorstore/mariadb/MariaDBVectorStore.java @@ -33,7 +33,6 @@ import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; -import org.springframework.ai.embedding.EmbeddingOptionsBuilder; import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric; import org.springframework.ai.util.JacksonUtils; @@ -249,11 +248,7 @@ public MariaDBDistanceType getDistanceType() { } @Override - public void doAdd(List documents) { - // Batch the documents based on the batching strategy - List embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), - this.batchingStrategy); - + public void doAdd(List documents, List embeddings) { List> batchedDocuments = batchDocuments(documents, embeddings); batchedDocuments.forEach(this::insertOrUpdateBatch); } diff --git a/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/mariadb/MariaDBStoreTests.java b/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/mariadb/MariaDBStoreTests.java index 952c6c98a32..3e25d7a887f 100644 --- a/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/mariadb/MariaDBStoreTests.java +++ b/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/mariadb/MariaDBStoreTests.java @@ -30,11 +30,8 @@ import org.springframework.jdbc.core.JdbcTemplate; import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.only; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -74,12 +71,10 @@ void shouldAddDocumentsInBatchesAndEmbedOnce() { // Testing with 9989 documents var documents = Collections.nCopies(9989, new Document("foo")); + var embeddings = Collections.nCopies(9989, new float[] { 0.1f, 0.2f, 0.3f }); // When - mariadbVectorStore.doAdd(documents); - - // Then - verify(embeddingModel, only()).embed(eq(documents), any(), any()); + mariadbVectorStore.doAdd(documents, embeddings); var batchUpdateCaptor = ArgumentCaptor.forClass(BatchPreparedStatementSetter.class); verify(jdbcTemplate, times(10)).batchUpdate(anyString(), batchUpdateCaptor.capture()); diff --git a/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/milvus/MilvusVectorStore.java b/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/milvus/MilvusVectorStore.java index 0b8a938f430..badf713573a 100644 --- a/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/milvus/MilvusVectorStore.java +++ b/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/milvus/MilvusVectorStore.java @@ -58,7 +58,6 @@ import org.springframework.ai.document.Document; import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.EmbeddingModel; -import org.springframework.ai.embedding.EmbeddingOptionsBuilder; import org.springframework.ai.model.EmbeddingUtils; import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric; @@ -231,7 +230,7 @@ public static Builder builder(MilvusServiceClient milvusServiceClient, Embedding } @Override - public void doAdd(List documents) { + public void doAdd(List documents, List embeddings) { Assert.notNull(documents, "Documents must not be null"); @@ -241,8 +240,6 @@ public void doAdd(List documents) { List> embeddingArray = new ArrayList<>(); // TODO: Need to customize how we pass the embedding options - List embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), - this.batchingStrategy); for (Document document : documents) { docIdArray.add(document.getId()); diff --git a/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/mongodb/atlas/MongoDBAtlasVectorStore.java b/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/mongodb/atlas/MongoDBAtlasVectorStore.java index c6b515aa60f..0b3d820debc 100644 --- a/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/mongodb/atlas/MongoDBAtlasVectorStore.java +++ b/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/mongodb/atlas/MongoDBAtlasVectorStore.java @@ -30,7 +30,6 @@ import org.springframework.ai.document.Document; import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.EmbeddingModel; -import org.springframework.ai.embedding.EmbeddingOptionsBuilder; import org.springframework.ai.model.EmbeddingUtils; import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.vectorstore.AbstractVectorStoreBuilder; @@ -257,9 +256,7 @@ private Document mapMongoDocument(org.bson.Document mongoDocument, float[] query } @Override - public void doAdd(List documents) { - List embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), - this.batchingStrategy); + public void doAdd(List documents, List embeddings) { for (Document document : documents) { MongoDBDocument mdbDocument = new MongoDBDocument(document.getId(), document.getText(), document.getMetadata(), embeddings.get(documents.indexOf(document))); diff --git a/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/vectorstore/neo4j/Neo4jVectorStore.java b/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/vectorstore/neo4j/Neo4jVectorStore.java index 7d44aa2a9ad..96dfb6f0922 100644 --- a/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/vectorstore/neo4j/Neo4jVectorStore.java +++ b/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/vectorstore/neo4j/Neo4jVectorStore.java @@ -31,7 +31,6 @@ import org.springframework.ai.document.Document; import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.EmbeddingModel; -import org.springframework.ai.embedding.EmbeddingOptionsBuilder; import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric; import org.springframework.ai.vectorstore.AbstractVectorStoreBuilder; @@ -203,10 +202,7 @@ protected Neo4jVectorStore(Builder builder) { } @Override - public void doAdd(List documents) { - - List embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), - this.batchingStrategy); + public void doAdd(List documents, List embeddings) { var rows = documents.stream() .map(document -> documentToRecord(document, embeddings.get(documents.indexOf(document)))) diff --git a/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/opensearch/OpenSearchVectorStore.java b/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/opensearch/OpenSearchVectorStore.java index 13a3f59cf6c..b93cb98cc52 100644 --- a/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/opensearch/OpenSearchVectorStore.java +++ b/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/opensearch/OpenSearchVectorStore.java @@ -44,7 +44,6 @@ import org.springframework.ai.document.Document; import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.EmbeddingModel; -import org.springframework.ai.embedding.EmbeddingOptionsBuilder; import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric; import org.springframework.ai.vectorstore.AbstractVectorStoreBuilder; @@ -203,13 +202,11 @@ public OpenSearchVectorStore withSimilarityFunction(String similarityFunction) { } @Override - public void doAdd(List documents) { - List embedding = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), - this.batchingStrategy); + public void doAdd(List documents, List embeddings) { BulkRequest.Builder bulkRequestBuilder = new BulkRequest.Builder(); for (Document document : documents) { OpenSearchDocument openSearchDocument = new OpenSearchDocument(document.getId(), document.getText(), - document.getMetadata(), embedding.get(documents.indexOf(document))); + document.getMetadata(), embeddings.get(documents.indexOf(document))); bulkRequestBuilder.operations(op -> op .index(idx -> idx.index(this.index).id(openSearchDocument.id()).document(openSearchDocument))); } diff --git a/vector-stores/spring-ai-oracle-store/src/main/java/org/springframework/ai/vectorstore/oracle/OracleVectorStore.java b/vector-stores/spring-ai-oracle-store/src/main/java/org/springframework/ai/vectorstore/oracle/OracleVectorStore.java index 99332f311f0..e6a4e19e7f3 100644 --- a/vector-stores/spring-ai-oracle-store/src/main/java/org/springframework/ai/vectorstore/oracle/OracleVectorStore.java +++ b/vector-stores/spring-ai-oracle-store/src/main/java/org/springframework/ai/vectorstore/oracle/OracleVectorStore.java @@ -40,7 +40,6 @@ import org.springframework.ai.document.Document; import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.EmbeddingModel; -import org.springframework.ai.embedding.EmbeddingOptionsBuilder; import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric; import org.springframework.ai.vectorstore.AbstractVectorStoreBuilder; @@ -167,9 +166,7 @@ public static Builder builder(JdbcTemplate jdbcTemplate, EmbeddingModel embeddin } @Override - public void doAdd(final List documents) { - List embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), - this.batchingStrategy); + public void doAdd(final List documents, List embeddings) { this.jdbcTemplate.batchUpdate(getIngestStatement(), new BatchPreparedStatementSetter() { @Override diff --git a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/pgvector/PgVectorStore.java b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/pgvector/PgVectorStore.java index cc69b06ab45..94e36f5ea0e 100644 --- a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/pgvector/PgVectorStore.java +++ b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/pgvector/PgVectorStore.java @@ -36,7 +36,6 @@ import org.springframework.ai.document.Document; import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.EmbeddingModel; -import org.springframework.ai.embedding.EmbeddingOptionsBuilder; import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric; import org.springframework.ai.util.JacksonUtils; @@ -251,10 +250,7 @@ public static PgVectorStoreBuilder builder(JdbcTemplate jdbcTemplate, EmbeddingM } @Override - public void doAdd(List documents) { - List embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), - this.batchingStrategy); - + public void doAdd(List documents, List embeddings) { List> batchedDocuments = batchDocuments(documents); batchedDocuments.forEach(batchDocument -> insertOrUpdateBatch(batchDocument, documents, embeddings)); } diff --git a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/pgvector/PgVectorStoreTests.java b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/pgvector/PgVectorStoreTests.java index cfe63b81e5a..51ed33d4f5c 100644 --- a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/pgvector/PgVectorStoreTests.java +++ b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/pgvector/PgVectorStoreTests.java @@ -29,11 +29,8 @@ import org.springframework.jdbc.core.JdbcTemplate; import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.only; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -83,12 +80,10 @@ void shouldAddDocumentsInBatchesAndEmbedOnce() { // Testing with 9989 documents var documents = Collections.nCopies(9989, new Document("foo")); + var embeddings = Collections.nCopies(9989, new float[] { 0.1f, 0.2f, 0.3f }); // When - pgVectorStore.doAdd(documents); - - // Then - verify(embeddingModel, only()).embed(eq(documents), any(), any()); + pgVectorStore.doAdd(documents, embeddings); var batchUpdateCaptor = ArgumentCaptor.forClass(BatchPreparedStatementSetter.class); verify(jdbcTemplate, times(10)).batchUpdate(anyString(), batchUpdateCaptor.capture()); diff --git a/vector-stores/spring-ai-pinecone-store/src/main/java/org/springframework/ai/vectorstore/pinecone/PineconeVectorStore.java b/vector-stores/spring-ai-pinecone-store/src/main/java/org/springframework/ai/vectorstore/pinecone/PineconeVectorStore.java index 0f21192686e..1572e0c3a5a 100644 --- a/vector-stores/spring-ai-pinecone-store/src/main/java/org/springframework/ai/vectorstore/pinecone/PineconeVectorStore.java +++ b/vector-stores/spring-ai-pinecone-store/src/main/java/org/springframework/ai/vectorstore/pinecone/PineconeVectorStore.java @@ -37,7 +37,6 @@ import org.springframework.ai.document.Document; import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.EmbeddingModel; -import org.springframework.ai.embedding.EmbeddingOptionsBuilder; import org.springframework.ai.model.EmbeddingUtils; import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.vectorstore.AbstractVectorStoreBuilder; @@ -132,29 +131,18 @@ public static Builder.BuilderWithApiKey builder(EmbeddingModel embeddingModel) { } /** - * Adds a list of documents to the vector store based on the namespace. + * Adds a list of documents to the vector store. * @param documents The list of documents to be added. - * @param namespace The namespace to add the documents to */ - public void add(List documents, String namespace) { - List embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), - this.batchingStrategy); + @Override + public void doAdd(List documents, List embeddings) { List upsertVectors = new ArrayList<>(); for (Document document : documents) { upsertVectors.add(io.pinecone.commons.IndexInterface.buildUpsertVectorWithUnsignedIndices(document.getId(), EmbeddingUtils.toList(embeddings.get(documents.indexOf(document))), null, null, metadataToStruct(document))); } - this.pinecone.getIndexConnection(this.pineconeIndexName).upsert(upsertVectors, namespace); - } - - /** - * Adds a list of documents to the vector store. - * @param documents The list of documents to be added. - */ - @Override - public void doAdd(List documents) { - add(documents, this.pineconeNamespace); + this.pinecone.getIndexConnection(this.pineconeIndexName).upsert(upsertVectors, this.pineconeNamespace); } /** diff --git a/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStore.java b/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStore.java index e06da8a4f9d..af3f09667c4 100644 --- a/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStore.java +++ b/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStore.java @@ -37,7 +37,6 @@ import org.springframework.ai.document.Document; import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.EmbeddingModel; -import org.springframework.ai.embedding.EmbeddingOptionsBuilder; import org.springframework.ai.model.EmbeddingUtils; import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.vectorstore.AbstractVectorStoreBuilder; @@ -172,13 +171,8 @@ public static Builder builder(QdrantClient qdrantClient, EmbeddingModel embeddin * @param documents The list of documents to be added. */ @Override - public void doAdd(List documents) { + public void doAdd(List documents, List embeddings) { try { - - // Compute and assign an embedding to the document. - List embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), - this.batchingStrategy); - List points = documents.stream() .map(document -> PointStruct.newBuilder() .setId(io.qdrant.client.PointIdFactory.id(UUID.fromString(document.getId()))) diff --git a/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/redis/RedisVectorStore.java b/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/redis/RedisVectorStore.java index 67d033fb2cf..79cd150c944 100644 --- a/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/redis/RedisVectorStore.java +++ b/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/redis/RedisVectorStore.java @@ -48,7 +48,6 @@ import org.springframework.ai.document.Document; import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.EmbeddingModel; -import org.springframework.ai.embedding.EmbeddingOptionsBuilder; import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric; import org.springframework.ai.vectorstore.AbstractVectorStoreBuilder; @@ -250,12 +249,9 @@ public JedisPooled getJedis() { } @Override - public void doAdd(List documents) { + public void doAdd(List documents, List embeddings) { try (Pipeline pipeline = this.jedis.pipelined()) { - List embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), - this.batchingStrategy); - for (Document document : documents) { var fields = new HashMap(); fields.put(this.embeddingFieldName, embeddings.get(documents.indexOf(document))); diff --git a/vector-stores/spring-ai-typesense-store/src/main/java/org/springframework/ai/vectorstore/typesense/TypesenseVectorStore.java b/vector-stores/spring-ai-typesense-store/src/main/java/org/springframework/ai/vectorstore/typesense/TypesenseVectorStore.java index 21cad303c11..0ff623e0cb3 100644 --- a/vector-stores/spring-ai-typesense-store/src/main/java/org/springframework/ai/vectorstore/typesense/TypesenseVectorStore.java +++ b/vector-stores/spring-ai-typesense-store/src/main/java/org/springframework/ai/vectorstore/typesense/TypesenseVectorStore.java @@ -40,7 +40,6 @@ import org.springframework.ai.document.Document; import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.EmbeddingModel; -import org.springframework.ai.embedding.EmbeddingOptionsBuilder; import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric; import org.springframework.ai.vectorstore.AbstractVectorStoreBuilder; @@ -138,12 +137,9 @@ public static Builder builder(Client client, EmbeddingModel embeddingModel) { } @Override - public void doAdd(List documents) { + public void doAdd(List documents, List embeddings) { Assert.notNull(documents, "Documents must not be null"); - List embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), - this.batchingStrategy); - List> documentList = documents.stream().map(document -> { HashMap typesenseDoc = new HashMap<>(); typesenseDoc.put(DOC_ID_FIELD_NAME, document.getId()); diff --git a/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/weaviate/WeaviateVectorStore.java b/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/weaviate/WeaviateVectorStore.java index 6628d2eff52..e36d9676545 100644 --- a/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/weaviate/WeaviateVectorStore.java +++ b/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/weaviate/WeaviateVectorStore.java @@ -49,7 +49,6 @@ import org.springframework.ai.document.Document; import org.springframework.ai.document.DocumentMetadata; import org.springframework.ai.embedding.EmbeddingModel; -import org.springframework.ai.embedding.EmbeddingOptionsBuilder; import org.springframework.ai.model.EmbeddingUtils; import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.vectorstore.AbstractVectorStoreBuilder; @@ -196,15 +195,12 @@ private Field[] buildWeaviateSimilaritySearchFields() { } @Override - public void doAdd(List documents) { + public void doAdd(List documents, List embeddings) { if (CollectionUtils.isEmpty(documents)) { return; } - List embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), - this.batchingStrategy); - List weaviateObjects = documents.stream() .map(document -> toWeaviateObject(document, documents, embeddings)) .toList();