Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG]Type errors in embading function #1169 #1517

Merged
48 changes: 25 additions & 23 deletions chromadb/utils/embedding_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class SentenceTransformerEmbeddingFunction(EmbeddingFunction[Documents]):
models: Dict[str, Any] = {}

# If you have a beefier machine, try "gtr-t5-large".
# for a full list of options: https://huggingface.co/sentence-transformers, https://www.sbert.net/docs/pretrained_models.html
# for a full list of options: https://huggingface.co/sentence-transformers, https://www.sbert.net/docs/pretrained_models.html
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we added an extra space here, can we remove?

def __init__(
self,
model_name: str = "all-MiniLM-L6-v2",
Expand All @@ -55,11 +55,11 @@ def __init__(
self._normalize_embeddings = normalize_embeddings

def __call__(self, input: Documents) -> Embeddings:
return self._model.encode( # type: ignore
return cast(Embeddings, self._model.encode(
list(input),
convert_to_numpy=True,
normalize_embeddings=self._normalize_embeddings,
).tolist()
).tolist())


class Text2VecEmbeddingFunction(EmbeddingFunction[Documents]):
Expand All @@ -73,7 +73,7 @@ def __init__(self, model_name: str = "shibing624/text2vec-base-chinese"):
self._model = SentenceModel(model_name_or_path=model_name)

def __call__(self, input: Documents) -> Embeddings:
return self._model.encode(list(input), convert_to_numpy=True).tolist() # type: ignore # noqa E501
return cast(Embeddings, self._model.encode(list(input), convert_to_numpy=True).tolist()) # noqa E501


class OpenAIEmbeddingFunction(EmbeddingFunction[Documents]):
Expand Down Expand Up @@ -168,10 +168,11 @@ def __call__(self, input: Documents) -> Embeddings:
# Sort resulting embeddings by index
sorted_embeddings = sorted(
embeddings, key=lambda e: e.index
) # type: ignore
)

# Return just the embeddings
return [result.embedding for result in sorted_embeddings]
emb = [result.embedding for result in sorted_embeddings]
return cast(Embeddings, emb)
DevMadhav13 marked this conversation as resolved.
Show resolved Hide resolved
else:
if self._api_type == "azure":
embeddings = self._client.create(
Expand All @@ -185,10 +186,11 @@ def __call__(self, input: Documents) -> Embeddings:
# Sort resulting embeddings by index
sorted_embeddings = sorted(
embeddings, key=lambda e: e["index"]
) # type: ignore
)

# Return just the embeddings
return [result["embedding"] for result in sorted_embeddings]
emb1 = [result["embedding"] for result in sorted_embeddings]
return cast(Embeddings, emb1)
DevMadhav13 marked this conversation as resolved.
Show resolved Hide resolved


class CohereEmbeddingFunction(EmbeddingFunction[Documents]):
Expand Down Expand Up @@ -249,9 +251,9 @@ def __call__(self, input: Documents) -> Embeddings:
>>> embeddings = hugging_face(texts)
"""
# Call HuggingFace Embedding API for each document
return self._session.post( # type: ignore
return cast(Embeddings, self._session.post(
self._api_url, json={"inputs": input, "options": {"wait_for_model": True}}
).json()
).json())


class JinaEmbeddingFunction(EmbeddingFunction[Documents]):
Expand Down Expand Up @@ -291,7 +293,7 @@ def __call__(self, input: Documents) -> Embeddings:
>>> embeddings = jina_ai_fn(input)
"""
# Call Jina AI Embedding API
resp = self._session.post( # type: ignore
resp = self._session.post(
self._api_url, json={"input": input, "model": self._model_name}
).json()
if "data" not in resp:
Expand All @@ -300,10 +302,11 @@ def __call__(self, input: Documents) -> Embeddings:
embeddings = resp["data"]

# Sort resulting embeddings by index
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"])

# Return just the embeddings
return [result["embedding"] for result in sorted_embeddings]
embading = [result["embedding"] for result in sorted_embeddings]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: embedding, not embading

return cast(Embeddings, embading)


class InstructorEmbeddingFunction(EmbeddingFunction[Documents]):
Expand All @@ -326,11 +329,11 @@ def __init__(

def __call__(self, input: Documents) -> Embeddings:
if self._instruction is None:
return self._model.encode(input).tolist() # type: ignore
return cast(Embeddings, self._model.encode(input).tolist())

texts_with_instructions = [[self._instruction, text] for text in input]
# type: ignore
return self._model.encode(texts_with_instructions).tolist()

return cast(Embeddings, self._model.encode(texts_with_instructions).tolist())


# In order to remove dependencies on sentence-transformers, which in turn depends on
Expand Down Expand Up @@ -405,16 +408,15 @@ def _download(self, url: str, fname: str, chunk_size: int = 1024) -> None:

# Use pytorches default epsilon for division by zero
# https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
def _normalize(self, v: npt.NDArray) -> npt.NDArray: # type: ignore
def _normalize(self, v: npt.NDArray) -> npt.NDArray:
norm = np.linalg.norm(v, axis=1)
norm[norm == 0] = 1e-12
return v / norm[:, np.newaxis] # type: ignore
return cast(npt.NDArray, v / norm[:, np.newaxis])

# type: ignore
def _forward(self, documents: List[str], batch_size: int = 32) -> npt.NDArray:
# We need to cast to the correct type because the type checker doesn't know that init_model_and_tokenizer will set the values
self.tokenizer = cast(self.Tokenizer, self.tokenizer) # type: ignore
self.model = cast(self.ort.InferenceSession, self.model) # type: ignore
self.tokenizer = cast(self.Tokenizer, self.tokenizer)
self.model = cast(self.ort.InferenceSession, self.model)
all_embeddings = []
for i in range(0, len(documents), batch_size):
batch = documents[i : i + batch_size]
Expand Down Expand Up @@ -719,9 +721,9 @@ def __call__(self, input: Documents) -> Embeddings:
>>> embeddings = hugging_face(texts)
"""
# Call HuggingFace Embedding Server API for each document
return self._session.post( # type: ignore
return cast (Embeddings,self._session.post(
self._api_url, json={"inputs": input}
).json()
).json())


# List of all classes in this module
Expand Down