Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
jbdlb committed Nov 7, 2023
1 parent 8e05f88 commit 0b0b8fe
Showing 1 changed file with 32 additions and 24 deletions.
56 changes: 32 additions & 24 deletions libs/langchain/langchain/vectorstores/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,10 @@ def __init__(
f"got {type(client)}"
)

if type(content_payload_key) == str: # Ensuring Backward compatibility
if isinstance(content_payload_key, str): # Ensuring Backward compatibility
content_payload_key = [content_payload_key]

if type(metadata_payload_key) == str: # Ensuring Backward compatibility
if isinstance(metadata_payload_key, str): # Ensuring Backward compatibility
metadata_payload_key = [metadata_payload_key]

if embeddings is None and embedding_function is None:
Expand All @@ -133,8 +133,14 @@ def __init__(
self._embeddings_function = embedding_function
self.client: qdrant_client.QdrantClient = client
self.collection_name = collection_name
self.content_payload_key = content_payload_key or self.CONTENT_KEY
self.metadata_payload_key = metadata_payload_key or self.METADATA_KEY
self.content_payload_key = (
content_payload_key if content_payload_key is not None else self.CONTENT_KEY
)
self.metadata_payload_key = (
metadata_payload_key
if metadata_payload_key is not None
else self.METADATA_KEY
)
self.vector_name = vector_name or self.VECTOR_NAME

if embedding_function is not None:
Expand Down Expand Up @@ -1184,8 +1190,8 @@ def from_texts(
path: Optional[str] = None,
collection_name: Optional[str] = None,
distance_func: str = "Cosine",
content_payload_key: str = CONTENT_KEY,
metadata_payload_key: str = METADATA_KEY,
content_payload_key: List[str] = CONTENT_KEY,
metadata_payload_key: List[str] = METADATA_KEY,
vector_name: Optional[str] = VECTOR_NAME,
batch_size: int = 64,
shard_number: Optional[int] = None,
Expand Down Expand Up @@ -1360,8 +1366,8 @@ async def afrom_texts(
path: Optional[str] = None,
collection_name: Optional[str] = None,
distance_func: str = "Cosine",
content_payload_key: str = CONTENT_KEY,
metadata_payload_key: str = METADATA_KEY,
content_payload_key: List[str] = CONTENT_KEY,
metadata_payload_key: List[str] = METADATA_KEY,
vector_name: Optional[str] = VECTOR_NAME,
batch_size: int = 64,
shard_number: Optional[int] = None,
Expand Down Expand Up @@ -1533,8 +1539,8 @@ def construct_instance(
path: Optional[str] = None,
collection_name: Optional[str] = None,
distance_func: str = "Cosine",
content_payload_key: str = CONTENT_KEY,
metadata_payload_key: str = METADATA_KEY,
content_payload_key: List[str] = CONTENT_KEY,
metadata_payload_key: List[str] = METADATA_KEY,
vector_name: Optional[str] = VECTOR_NAME,
shard_number: Optional[int] = None,
replication_factor: Optional[int] = None,
Expand Down Expand Up @@ -1697,8 +1703,8 @@ async def aconstruct_instance(
path: Optional[str] = None,
collection_name: Optional[str] = None,
distance_func: str = "Cosine",
content_payload_key: str = CONTENT_KEY,
metadata_payload_key: str = METADATA_KEY,
content_payload_key: List[str] = CONTENT_KEY,
metadata_payload_key: List[str] = METADATA_KEY,
vector_name: Optional[str] = VECTOR_NAME,
shard_number: Optional[int] = None,
replication_factor: Optional[int] = None,
Expand Down Expand Up @@ -1894,11 +1900,11 @@ def _similarity_search_with_relevance_scores(

@classmethod
def _build_payloads(
cls,
cls: Type[Qdrant],
texts: Iterable[str],
metadatas: Optional[List[dict]],
content_payload_key: str,
metadata_payload_key: str,
content_payload_key: list[str],
metadata_payload_key: list[str],
) -> List[dict]:
payloads = []
for i, text in enumerate(texts):
Expand All @@ -1919,38 +1925,40 @@ def _build_payloads(

@classmethod
def _document_from_scored_point(
cls,
cls: Type[Qdrant],
scored_point: Any,
content_payload_key: list,
metadata_payload_key: list,
content_payload_key: list[str],
metadata_payload_key: list[str],
) -> Document:
payload = scored_point.payload
return Qdrant._document_from_payload(
cls=cls,
payload=payload,
content_payload_key=content_payload_key,
metadata_payload_key=metadata_payload_key,
)

@classmethod
def _document_from_scored_point_grpc(
cls,
cls: Type[Qdrant],
scored_point: Any,
content_payload_key: list,
metadata_payload_key: list,
content_payload_key: list[str],
metadata_payload_key: list[str],
) -> Document:
from qdrant_client.conversions.conversion import grpc_to_payload

payload = grpc_to_payload(scored_point.payload)
return Qdrant._document_from_payload(
cls=cls,
payload=payload,
content_payload_key=content_payload_key,
metadata_payload_key=metadata_payload_key,
)

@classmethod
def _document_from_payload(
cls, payload: Any, content_payload_key: list, metadata_payload_key: list
cls: Type[Qdrant],
payload: Any,
content_payload_key: list[str],
metadata_payload_key: list[str],
) -> Document:
if len(content_payload_key) == 1:
content = payload.get(
Expand Down

0 comments on commit 0b0b8fe

Please sign in to comment.