Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG]Type errors in embading function #1169 #1517

Merged
43 changes: 23 additions & 20 deletions chromadb/utils/embedding_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@ def __init__(
self._normalize_embeddings = normalize_embeddings

def __call__(self, input: Documents) -> Embeddings:
return self._model.encode( # type: ignore
return cast(Embeddings, self._model.encode(
list(input),
convert_to_numpy=True,
normalize_embeddings=self._normalize_embeddings,
).tolist()
).tolist())


class Text2VecEmbeddingFunction(EmbeddingFunction[Documents]):
Expand All @@ -91,7 +91,7 @@ def __init__(self, model_name: str = "shibing624/text2vec-base-chinese"):
self._model = SentenceModel(model_name_or_path=model_name)

def __call__(self, input: Documents) -> Embeddings:
return self._model.encode(list(input), convert_to_numpy=True).tolist() # type: ignore # noqa E501
return cast(Embeddings, self._model.encode(list(input), convert_to_numpy=True).tolist()) # noqa E501


class OpenAIEmbeddingFunction(EmbeddingFunction[Documents]):
Expand Down Expand Up @@ -186,10 +186,10 @@ def __call__(self, input: Documents) -> Embeddings:
# Sort resulting embeddings by index
sorted_embeddings = sorted(
embeddings, key=lambda e: e.index
) # type: ignore
)

# Return just the embeddings
return [result.embedding for result in sorted_embeddings]
return cast(Embeddings, [result.embedding for result in sorted_embeddings])
else:
if self._api_type == "azure":
embeddings = self._client.create(
Expand All @@ -203,10 +203,12 @@ def __call__(self, input: Documents) -> Embeddings:
# Sort resulting embeddings by index
sorted_embeddings = sorted(
embeddings, key=lambda e: e["index"]
) # type: ignore
)

# Return just the embeddings
return [result["embedding"] for result in sorted_embeddings]
return cast(
Embeddings, [result["embedding"] for result in sorted_embeddings]
)


class CohereEmbeddingFunction(EmbeddingFunction[Documents]):
Expand Down Expand Up @@ -267,9 +269,9 @@ def __call__(self, input: Documents) -> Embeddings:
>>> embeddings = hugging_face(texts)
"""
# Call HuggingFace Embedding API for each document
return self._session.post( # type: ignore
return cast(Embeddings, self._session.post(
self._api_url, json={"inputs": input, "options": {"wait_for_model": True}}
).json()
).json())


class JinaEmbeddingFunction(EmbeddingFunction[Documents]):
Expand Down Expand Up @@ -309,7 +311,7 @@ def __call__(self, input: Documents) -> Embeddings:
>>> embeddings = jina_ai_fn(input)
"""
# Call Jina AI Embedding API
resp = self._session.post( # type: ignore
resp = self._session.post(
self._api_url, json={"input": input, "model": self._model_name}
).json()
if "data" not in resp:
Expand All @@ -318,10 +320,10 @@ def __call__(self, input: Documents) -> Embeddings:
embeddings = resp["data"]

# Sort resulting embeddings by index
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"])

# Return just the embeddings
return [result["embedding"] for result in sorted_embeddings]
return cast(Embeddings, [result["embedding"] for result in sorted_embeddings])


class InstructorEmbeddingFunction(EmbeddingFunction[Documents]):
Expand All @@ -344,11 +346,11 @@ def __init__(

def __call__(self, input: Documents) -> Embeddings:
if self._instruction is None:
return self._model.encode(input).tolist() # type: ignore
return cast(Embeddings, self._model.encode(input).tolist())

texts_with_instructions = [[self._instruction, text] for text in input]
# type: ignore
return self._model.encode(texts_with_instructions).tolist()

return cast(Embeddings, self._model.encode(texts_with_instructions).tolist())


# In order to remove dependencies on sentence-transformers, which in turn depends on
Expand Down Expand Up @@ -434,14 +436,15 @@ def _download(self, url: str, fname: str, chunk_size: int = 1024) -> None:

# Use pytorches default epsilon for division by zero
# https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
def _normalize(self, v: npt.NDArray) -> npt.NDArray: # type: ignore
def _normalize(self, v: npt.NDArray) -> npt.NDArray:
norm = np.linalg.norm(v, axis=1)
norm[norm == 0] = 1e-12
return v / norm[:, np.newaxis] # type: ignore
return cast(npt.NDArray, v / norm[:, np.newaxis])

# type: ignore
def _forward(self, documents: List[str], batch_size: int = 32) -> npt.NDArray:
# We need to cast to the correct type because the type checker doesn't know that init_model_and_tokenizer will set the values
self.tokenizer = cast(self.Tokenizer, self.tokenizer)
self.model = cast(self.ort.InferenceSession, self.model)
all_embeddings = []
for i in range(0, len(documents), batch_size):
batch = documents[i : i + batch_size]
Expand Down Expand Up @@ -795,9 +798,9 @@ def __call__(self, input: Documents) -> Embeddings:
>>> embeddings = hugging_face(texts)
"""
# Call HuggingFace Embedding Server API for each document
return self._session.post( # type: ignore
return cast (Embeddings,self._session.post(
self._api_url, json={"inputs": input}
).json()
).json())


# List of all classes in this module
Expand Down