Skip to content

Commit

Permalink
Implement async support in Qdrant local mode (#8001)
Browse files Browse the repository at this point in the history
I've extended the support of async API to local Qdrant mode. It is faked
but allows prototyping without spinning a container. The tests are
improved to test the in-memory case as well.

@baskaryan @rlancemartin @eyurtsev @agola11
  • Loading branch information
kacperlukawski committed Jul 21, 2023
1 parent 7717c24 commit ed6a553
Show file tree
Hide file tree
Showing 10 changed files with 196 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@
" \"rating\": 9.9,\n",
" \"director\": \"Andrei Tarkovsky\",\n",
" \"genre\": \"science fiction\",\n",
" \"rating\": 9.9,\n",
" },\n",
" ),\n",
"]\n",
Expand Down
120 changes: 92 additions & 28 deletions langchain/vectorstores/qdrant.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Wrapper around Qdrant vector database."""
from __future__ import annotations

import asyncio
import functools
import uuid
import warnings
from itertools import islice
Expand Down Expand Up @@ -40,6 +42,30 @@ class QdrantException(Exception):
"""Base class for all the Qdrant related exceptions"""


def sync_call_fallback(method: Callable) -> Callable:
"""
Decorator to call the synchronous method of the class if the async method is not
implemented. This decorator might be only used for the methods that are defined
as async in the class.
"""

@functools.wraps(method)
async def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
try:
return await method(self, *args, **kwargs)
except NotImplementedError:
# If the async method is not implemented, call the synchronous method
# by removing the first letter from the method name. For example,
# if the async method is called ``aaad_texts``, the synchronous method
# will be called ``aad_texts``.
sync_method = functools.partial(
getattr(self, method.__name__[1:]), *args, **kwargs
)
return await asyncio.get_event_loop().run_in_executor(None, sync_method)

return wrapper


class Qdrant(VectorStore):
"""Wrapper around Qdrant vector database.
Expand Down Expand Up @@ -155,6 +181,7 @@ def add_texts(

return added_ids

@sync_call_fallback
async def aadd_texts(
self,
texts: Iterable[str],
Expand Down Expand Up @@ -250,6 +277,7 @@ def similarity_search(
)
return list(map(itemgetter(0), results))

@sync_call_fallback
async def asimilarity_search(
self,
query: str,
Expand Down Expand Up @@ -322,6 +350,7 @@ def similarity_search_with_score(
**kwargs,
)

@sync_call_fallback
async def asimilarity_search_with_score(
self,
query: str,
Expand Down Expand Up @@ -431,6 +460,7 @@ def similarity_search_by_vector(
)
return list(map(itemgetter(0), results))

@sync_call_fallback
async def asimilarity_search_by_vector(
self,
embedding: List[float],
Expand Down Expand Up @@ -567,6 +597,7 @@ def similarity_search_with_score_by_vector(
for result in results
]

@sync_call_fallback
async def asimilarity_search_with_score_by_vector(
self,
embedding: List[float],
Expand Down Expand Up @@ -685,6 +716,7 @@ def max_marginal_relevance_search(
query_embedding, k, fetch_k, lambda_mult, **kwargs
)

@sync_call_fallback
async def amax_marginal_relevance_search(
self,
query: str,
Expand Down Expand Up @@ -739,33 +771,12 @@ def max_marginal_relevance_search_by_vector(
Returns:
List of Documents selected by maximal marginal relevance.
"""
query_vector = embedding
if self.vector_name is not None:
query_vector = (self.vector_name, query_vector) # type: ignore[assignment]

results = self.client.search(
collection_name=self.collection_name,
query_vector=query_vector,
with_payload=True,
with_vectors=True,
limit=fetch_k,
)
embeddings = [
result.vector.get(self.vector_name) # type: ignore[index, union-attr]
if self.vector_name is not None
else result.vector
for result in results
]
mmr_selected = maximal_marginal_relevance(
np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult
results = self.max_marginal_relevance_search_with_score_by_vector(
embedding=embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, **kwargs
)
return [
self._document_from_scored_point(
results[i], self.content_payload_key, self.metadata_payload_key
)
for i in mmr_selected
]
return list(map(itemgetter(0), results))

@sync_call_fallback
async def amax_marginal_relevance_search_by_vector(
self,
embedding: List[float],
Expand Down Expand Up @@ -795,6 +806,61 @@ async def amax_marginal_relevance_search_by_vector(
)
return list(map(itemgetter(0), results))

def max_marginal_relevance_search_with_score_by_vector(
self,
embedding: List[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
Defaults to 20.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
Returns:
List of Documents selected by maximal marginal relevance and distance for
each.
"""
query_vector = embedding
if self.vector_name is not None:
query_vector = (self.vector_name, query_vector) # type: ignore[assignment]

results = self.client.search(
collection_name=self.collection_name,
query_vector=query_vector,
with_payload=True,
with_vectors=True,
limit=fetch_k,
)
embeddings = [
result.vector.get(self.vector_name) # type: ignore[index, union-attr]
if self.vector_name is not None
else result.vector
for result in results
]
mmr_selected = maximal_marginal_relevance(
np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult
)
return [
(
self._document_from_scored_point(
results[i], self.content_payload_key, self.metadata_payload_key
),
results[i].score,
)
for i in mmr_selected
]

@sync_call_fallback
async def amax_marginal_relevance_search_with_score_by_vector(
self,
embedding: List[float],
Expand Down Expand Up @@ -1038,7 +1104,6 @@ def from_texts(
content_payload_key,
metadata_payload_key,
vector_name,
batch_size,
shard_number,
replication_factor,
write_consistency_factor,
Expand All @@ -1055,6 +1120,7 @@ def from_texts(
return qdrant

@classmethod
@sync_call_fallback
async def afrom_texts(
cls: Type[Qdrant],
texts: List[str],
Expand Down Expand Up @@ -1214,7 +1280,6 @@ async def afrom_texts(
content_payload_key,
metadata_payload_key,
vector_name,
batch_size,
shard_number,
replication_factor,
write_consistency_factor,
Expand Down Expand Up @@ -1253,7 +1318,6 @@ def _construct_instance(
content_payload_key: str = CONTENT_KEY,
metadata_payload_key: str = METADATA_KEY,
vector_name: Optional[str] = VECTOR_NAME,
batch_size: int = 64,
shard_number: Optional[int] = None,
replication_factor: Optional[int] = None,
write_consistency_factor: Optional[int] = None,
Expand Down
13 changes: 13 additions & 0 deletions tests/integration_tests/vectorstores/qdrant/async_api/fixtures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import logging
from typing import List

from tests.integration_tests.vectorstores.qdrant.common import qdrant_is_not_running

logger = logging.getLogger(__name__)


def qdrant_locations() -> List[str]:
if qdrant_is_not_running():
logger.warning("Running Qdrant async tests in memory mode only.")
return [":memory:"]
return ["http://localhost:6333", ":memory:"]
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,23 @@
from tests.integration_tests.vectorstores.fake_embeddings import (
ConsistentFakeEmbeddings,
)

from .common import qdrant_is_not_running

# Skipping all the tests in the module if Qdrant is not running on localhost.
pytestmark = pytest.mark.skipif(
qdrant_is_not_running(), reason="Qdrant server is not running"
from tests.integration_tests.vectorstores.qdrant.async_api.fixtures import ( # noqa
qdrant_locations,
)


@pytest.mark.asyncio
@pytest.mark.parametrize("batch_size", [1, 64])
async def test_qdrant_aadd_texts_returns_all_ids(batch_size: int) -> None:
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_aadd_texts_returns_all_ids(
batch_size: int, qdrant_location: str
) -> None:
"""Test end to end Qdrant.aadd_texts returns unique ids."""
docsearch: Qdrant = Qdrant.from_texts(
["foobar"],
ConsistentFakeEmbeddings(),
batch_size=batch_size,
location=qdrant_location,
)

ids = await docsearch.aadd_texts(["foo", "bar", "baz"])
Expand All @@ -33,14 +33,15 @@ async def test_qdrant_aadd_texts_returns_all_ids(batch_size: int) -> None:

@pytest.mark.asyncio
@pytest.mark.parametrize("vector_name", [None, "my-vector"])
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_aadd_texts_stores_duplicated_texts(
vector_name: Optional[str],
vector_name: Optional[str], qdrant_location: str
) -> None:
"""Test end to end Qdrant.aadd_texts stores duplicated texts separately."""
from qdrant_client import QdrantClient
from qdrant_client.http import models as rest

client = QdrantClient()
client = QdrantClient(location=qdrant_location)
collection_name = "test"
vectors_config = rest.VectorParams(size=10, distance=rest.Distance.COSINE)
if vector_name is not None:
Expand All @@ -61,7 +62,10 @@ async def test_qdrant_aadd_texts_stores_duplicated_texts(

@pytest.mark.asyncio
@pytest.mark.parametrize("batch_size", [1, 64])
async def test_qdrant_aadd_texts_stores_ids(batch_size: int) -> None:
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_aadd_texts_stores_ids(
batch_size: int, qdrant_location: str
) -> None:
"""Test end to end Qdrant.aadd_texts stores provided ids."""
from qdrant_client import QdrantClient

Expand All @@ -70,7 +74,7 @@ async def test_qdrant_aadd_texts_stores_ids(batch_size: int) -> None:
"cdc1aa36-d6ab-4fb2-8a94-56674fd27484",
]

client = QdrantClient()
client = QdrantClient(location=qdrant_location)
collection_name = "test"
client.recreate_collection(
collection_name,
Expand All @@ -90,15 +94,16 @@ async def test_qdrant_aadd_texts_stores_ids(batch_size: int) -> None:

@pytest.mark.asyncio
@pytest.mark.parametrize("vector_name", ["custom-vector"])
@pytest.mark.parametrize("qdrant_location", qdrant_locations())
async def test_qdrant_aadd_texts_stores_embeddings_as_named_vectors(
vector_name: str,
vector_name: str, qdrant_location: str
) -> None:
"""Test end to end Qdrant.aadd_texts stores named vectors if name is provided."""
from qdrant_client import QdrantClient

collection_name = "test"

client = QdrantClient()
client = QdrantClient(location=qdrant_location)
client.recreate_collection(
collection_name,
vectors_config={
Expand Down
Loading

0 comments on commit ed6a553

Please sign in to comment.