-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG]Type errors in embading function #1169 #1517
Merged
beggers
merged 9 commits into
chroma-core:main
from
DevMadhav13:type_errors_in_embading_function_1169
Jan 11, 2024
Merged
Changes from 4 commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
a9e2ead
Fixed type errors in chromadb/utils/embedding_functions.py #1169
DevMadhav13 0aa707f
Merge branch 'type_errors_in_embading_function_1169' of https://githu…
DevMadhav13 8b74c54
Merge branch 'chroma-core:main' into type_errors_in_embading_function…
DevMadhav13 fe14101
Code formating - unindenting
DevMadhav13 e5163f2
Spell correction
DevMadhav13 4f427e1
Update chromadb/utils/embedding_functions.py
DevMadhav13 b403179
Update chromadb/utils/embedding_functions.py
DevMadhav13 09a450b
Update chromadb/utils/embedding_functions.py
DevMadhav13 7938dba
Merge branch 'main' into type_errors_in_embading_function_1169
beggers File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,7 +36,7 @@ class SentenceTransformerEmbeddingFunction(EmbeddingFunction[Documents]): | |
models: Dict[str, Any] = {} | ||
|
||
# If you have a beefier machine, try "gtr-t5-large". | ||
# for a full list of options: https://huggingface.co/sentence-transformers, https://www.sbert.net/docs/pretrained_models.html | ||
# for a full list of options: https://huggingface.co/sentence-transformers, https://www.sbert.net/docs/pretrained_models.html | ||
def __init__( | ||
self, | ||
model_name: str = "all-MiniLM-L6-v2", | ||
|
@@ -55,11 +55,11 @@ def __init__( | |
self._normalize_embeddings = normalize_embeddings | ||
|
||
def __call__(self, input: Documents) -> Embeddings: | ||
return self._model.encode( # type: ignore | ||
return cast(Embeddings, self._model.encode( | ||
list(input), | ||
convert_to_numpy=True, | ||
normalize_embeddings=self._normalize_embeddings, | ||
).tolist() | ||
).tolist()) | ||
|
||
|
||
class Text2VecEmbeddingFunction(EmbeddingFunction[Documents]): | ||
|
@@ -73,7 +73,7 @@ def __init__(self, model_name: str = "shibing624/text2vec-base-chinese"): | |
self._model = SentenceModel(model_name_or_path=model_name) | ||
|
||
def __call__(self, input: Documents) -> Embeddings: | ||
return self._model.encode(list(input), convert_to_numpy=True).tolist() # type: ignore # noqa E501 | ||
return cast(Embeddings, self._model.encode(list(input), convert_to_numpy=True).tolist()) # noqa E501 | ||
|
||
|
||
class OpenAIEmbeddingFunction(EmbeddingFunction[Documents]): | ||
|
@@ -168,10 +168,11 @@ def __call__(self, input: Documents) -> Embeddings: | |
# Sort resulting embeddings by index | ||
sorted_embeddings = sorted( | ||
embeddings, key=lambda e: e.index | ||
) # type: ignore | ||
) | ||
|
||
# Return just the embeddings | ||
return [result.embedding for result in sorted_embeddings] | ||
emb = [result.embedding for result in sorted_embeddings] | ||
return cast(Embeddings, emb) | ||
DevMadhav13 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
if self._api_type == "azure": | ||
embeddings = self._client.create( | ||
|
@@ -185,10 +186,11 @@ def __call__(self, input: Documents) -> Embeddings: | |
# Sort resulting embeddings by index | ||
sorted_embeddings = sorted( | ||
embeddings, key=lambda e: e["index"] | ||
) # type: ignore | ||
) | ||
|
||
# Return just the embeddings | ||
return [result["embedding"] for result in sorted_embeddings] | ||
emb1 = [result["embedding"] for result in sorted_embeddings] | ||
return cast(Embeddings, emb1) | ||
DevMadhav13 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
class CohereEmbeddingFunction(EmbeddingFunction[Documents]): | ||
|
@@ -249,9 +251,9 @@ def __call__(self, input: Documents) -> Embeddings: | |
>>> embeddings = hugging_face(texts) | ||
""" | ||
# Call HuggingFace Embedding API for each document | ||
return self._session.post( # type: ignore | ||
return cast(Embeddings, self._session.post( | ||
self._api_url, json={"inputs": input, "options": {"wait_for_model": True}} | ||
).json() | ||
).json()) | ||
|
||
|
||
class JinaEmbeddingFunction(EmbeddingFunction[Documents]): | ||
|
@@ -291,7 +293,7 @@ def __call__(self, input: Documents) -> Embeddings: | |
>>> embeddings = jina_ai_fn(input) | ||
""" | ||
# Call Jina AI Embedding API | ||
resp = self._session.post( # type: ignore | ||
resp = self._session.post( | ||
self._api_url, json={"input": input, "model": self._model_name} | ||
).json() | ||
if "data" not in resp: | ||
|
@@ -300,10 +302,11 @@ def __call__(self, input: Documents) -> Embeddings: | |
embeddings = resp["data"] | ||
|
||
# Sort resulting embeddings by index | ||
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore | ||
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) | ||
|
||
# Return just the embeddings | ||
return [result["embedding"] for result in sorted_embeddings] | ||
embading = [result["embedding"] for result in sorted_embeddings] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: embedding, not embading |
||
return cast(Embeddings, embading) | ||
|
||
|
||
class InstructorEmbeddingFunction(EmbeddingFunction[Documents]): | ||
|
@@ -326,11 +329,11 @@ def __init__( | |
|
||
def __call__(self, input: Documents) -> Embeddings: | ||
if self._instruction is None: | ||
return self._model.encode(input).tolist() # type: ignore | ||
return cast(Embeddings, self._model.encode(input).tolist()) | ||
|
||
texts_with_instructions = [[self._instruction, text] for text in input] | ||
# type: ignore | ||
return self._model.encode(texts_with_instructions).tolist() | ||
|
||
return cast(Embeddings, self._model.encode(texts_with_instructions).tolist()) | ||
|
||
|
||
# In order to remove dependencies on sentence-transformers, which in turn depends on | ||
|
@@ -405,16 +408,15 @@ def _download(self, url: str, fname: str, chunk_size: int = 1024) -> None: | |
|
||
# Use pytorches default epsilon for division by zero | ||
# https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html | ||
def _normalize(self, v: npt.NDArray) -> npt.NDArray: # type: ignore | ||
def _normalize(self, v: npt.NDArray) -> npt.NDArray: | ||
norm = np.linalg.norm(v, axis=1) | ||
norm[norm == 0] = 1e-12 | ||
return v / norm[:, np.newaxis] # type: ignore | ||
return cast(npt.NDArray, v / norm[:, np.newaxis]) | ||
|
||
# type: ignore | ||
def _forward(self, documents: List[str], batch_size: int = 32) -> npt.NDArray: | ||
# We need to cast to the correct type because the type checker doesn't know that init_model_and_tokenizer will set the values | ||
self.tokenizer = cast(self.Tokenizer, self.tokenizer) # type: ignore | ||
self.model = cast(self.ort.InferenceSession, self.model) # type: ignore | ||
self.tokenizer = cast(self.Tokenizer, self.tokenizer) | ||
self.model = cast(self.ort.InferenceSession, self.model) | ||
all_embeddings = [] | ||
for i in range(0, len(documents), batch_size): | ||
batch = documents[i : i + batch_size] | ||
|
@@ -719,9 +721,9 @@ def __call__(self, input: Documents) -> Embeddings: | |
>>> embeddings = hugging_face(texts) | ||
""" | ||
# Call HuggingFace Embedding Server API for each document | ||
return self._session.post( # type: ignore | ||
return cast (Embeddings,self._session.post( | ||
self._api_url, json={"inputs": input} | ||
).json() | ||
).json()) | ||
|
||
|
||
# List of all classes in this module | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we added an extra space here, can we remove?