From 5afe05e194b1f5be00a418fd0c98b7090d7ca72b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Galego?= Date: Wed, 24 Jan 2024 12:08:05 +0000 Subject: [PATCH 1/5] Refactored AmazonBedrockEmbeddingFunction to add support for Cohere Embed models --- chromadb/utils/embedding_functions.py | 31 ++++++++++++++++++++++----- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/chromadb/utils/embedding_functions.py b/chromadb/utils/embedding_functions.py index ec5fc05e3ee..8bb3895be44 100644 --- a/chromadb/utils/embedding_functions.py +++ b/chromadb/utils/embedding_functions.py @@ -751,9 +751,29 @@ def __init__( def __call__(self, input: Documents) -> Embeddings: accept = "application/json" content_type = "application/json" - embeddings = [] - for text in input: - input_body = {"inputText": text} + provider = self._model_name.split('.')[0] + if provider == "amazon": + embeddings = [] + for text in input: + input_body = { + "inputText": text + } + body = json.dumps(input_body) + response = self._client.invoke_model( + body=body, + modelId=self._model_name, + accept=accept, + contentType=content_type, + ) + embedding = json.load(response.get("body")).get("embedding") + embeddings.append(embedding) + elif provider == "cohere": + # See Amazon Bedrock User Guide > Cohere Embed models for more information + # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-embed.html + input_body = { + "texts": input, + "input_type": "search_document" + } body = json.dumps(input_body) response = self._client.invoke_model( body=body, @@ -761,8 +781,9 @@ def __call__(self, input: Documents) -> Embeddings: accept=accept, contentType=content_type, ) - embedding = json.load(response.get("body")).get("embedding") - embeddings.append(embedding) + embeddings = json.load(response.get("body")).get("embeddings") + else: + raise NotImplemented return embeddings From 826550d10a2bd06e90007b477f4982e98353158e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Galego?= Date: Thu, 22 Feb 2024 17:12:52 +0000 Subject: [PATCH 2/5] Added fast-fail based on model provider --- chromadb/utils/embedding_functions.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/chromadb/utils/embedding_functions.py b/chromadb/utils/embedding_functions.py index 8bb3895be44..b939b789c61 100644 --- a/chromadb/utils/embedding_functions.py +++ b/chromadb/utils/embedding_functions.py @@ -742,6 +742,10 @@ def __init__( """ self._model_name = model_name + self._model_provider = self._model_name.split('.')[0] + + if self._model_provider not in ["amazon", "cohere"]: + raise ValueError(f"Model {self._model_name} is not supported. You can find the full list of supported foundation models in Amazon Bedrock at https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html") self._client = session.client( service_name="bedrock-runtime", @@ -751,8 +755,7 @@ def __init__( def __call__(self, input: Documents) -> Embeddings: accept = "application/json" content_type = "application/json" - provider = self._model_name.split('.')[0] - if provider == "amazon": + if self._model_provider == "amazon": embeddings = [] for text in input: input_body = { @@ -767,7 +770,7 @@ def __call__(self, input: Documents) -> Embeddings: ) embedding = json.load(response.get("body")).get("embedding") embeddings.append(embedding) - elif provider == "cohere": + elif self._model_provider == "cohere": # See Amazon Bedrock User Guide > Cohere Embed models for more information # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-embed.html input_body = { @@ -782,8 +785,6 @@ def __call__(self, input: Documents) -> Embeddings: contentType=content_type, ) embeddings = json.load(response.get("body")).get("embeddings") - else: - raise NotImplemented return embeddings From f743449a568f1b5823a2fd57f84248854034b5e7 Mon Sep 17 00:00:00 2001 From: JGalego Date: Thu, 18 Apr 2024 18:34:31 +0100 Subject: [PATCH 3/5] Added missing checks; Refactored to avoid code duplication --- chromadb/utils/embedding_functions.py | 43 +++++++++++++++------------ 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/chromadb/utils/embedding_functions.py b/chromadb/utils/embedding_functions.py index b939b789c61..4fe5e3278ae 100644 --- a/chromadb/utils/embedding_functions.py +++ b/chromadb/utils/embedding_functions.py @@ -744,47 +744,52 @@ def __init__( self._model_name = model_name self._model_provider = self._model_name.split('.')[0] - if self._model_provider not in ["amazon", "cohere"]: - raise ValueError(f"Model {self._model_name} is not supported. You can find the full list of supported foundation models in Amazon Bedrock at https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html") + self._bedrock_client = session.client( + service_name="bedrock", + **kwargs + ) + + self._model_details = self._bedrock_client.get_foundation_model(modelIdentifier=self._model_name)['modelDetails'] + assert "EMBEDDING" in self._model_details['outputModalities'], f"{self._model_name} doesn't have embedding modality output!" - self._client = session.client( + self._bedrock_runtime_client = session.client( service_name="bedrock-runtime", **kwargs, ) + + def call_model(self, body) -> dict: + body = json.dumps(body) + response = self._bedrock_runtime_client.invoke_model( + body=body, + modelId=self._model_name, + accept="application/json", + contentType="application/json", + ) + return response def __call__(self, input: Documents) -> Embeddings: - accept = "application/json" - content_type = "application/json" if self._model_provider == "amazon": embeddings = [] for text in input: input_body = { "inputText": text } - body = json.dumps(input_body) - response = self._client.invoke_model( - body=body, - modelId=self._model_name, - accept=accept, - contentType=content_type, - ) + response = self.call_model(input_body) embedding = json.load(response.get("body")).get("embedding") embeddings.append(embedding) elif self._model_provider == "cohere": # See Amazon Bedrock User Guide > Cohere Embed models for more information # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-embed.html + assert len(input) <= 128, f"Input texts exceeds max size (Got: {len(input)}, Expected: <=128)" + assert all(len(text) <= 2048 for text in input), f"Input contains texts exceeding max length (2048)" input_body = { "texts": input, "input_type": "search_document" } - body = json.dumps(input_body) - response = self._client.invoke_model( - body=body, - modelId=self._model_name, - accept=accept, - contentType=content_type, - ) + response = self.call_model(input_body) embeddings = json.load(response.get("body")).get("embeddings") + else: + raise NotImplementedError(f"Model {self._model_name} is not supported!") return embeddings From 26c2e161d26de98bec010829e38f2ab81a35c69d Mon Sep 17 00:00:00 2001 From: JGalego Date: Thu, 18 Apr 2024 22:04:49 +0100 Subject: [PATCH 4/5] Replaced asserts with if-statements --- chromadb/utils/embedding_functions.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/chromadb/utils/embedding_functions.py b/chromadb/utils/embedding_functions.py index 4fe5e3278ae..77b2f17ebd6 100644 --- a/chromadb/utils/embedding_functions.py +++ b/chromadb/utils/embedding_functions.py @@ -750,7 +750,8 @@ def __init__( ) self._model_details = self._bedrock_client.get_foundation_model(modelIdentifier=self._model_name)['modelDetails'] - assert "EMBEDDING" in self._model_details['outputModalities'], f"{self._model_name} doesn't have embedding modality output!" + if "EMBEDDING" not in self._model_details['outputModalities']: + raise ValueError(f"{self._model_name} doesn't have embedding modality output!") self._bedrock_runtime_client = session.client( service_name="bedrock-runtime", @@ -780,8 +781,10 @@ def __call__(self, input: Documents) -> Embeddings: elif self._model_provider == "cohere": # See Amazon Bedrock User Guide > Cohere Embed models for more information # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-embed.html - assert len(input) <= 128, f"Input texts exceeds max size (Got: {len(input)}, Expected: <=128)" - assert all(len(text) <= 2048 for text in input), f"Input contains texts exceeding max length (2048)" + if len(input) > 128: + raise ValueError(f"Input texts exceeds max size (Got: {len(input)}, Expected: <=128)") + if not all(len(text) <= 2048 for text in input): + raise ValueError(f"Input contains texts exceeding max length (2048)") input_body = { "texts": input, "input_type": "search_document" From 50dfe53df67fdd6a10ba45d567bc3e6502b0c7f1 Mon Sep 17 00:00:00 2001 From: JGalego Date: Thu, 18 Apr 2024 22:33:32 +0100 Subject: [PATCH 5/5] Added model_params arg to BedrockEF and input_type+truncate for cohere models; Added input_type arg to CohereEF --- chromadb/utils/embedding_functions.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/chromadb/utils/embedding_functions.py b/chromadb/utils/embedding_functions.py index 77b2f17ebd6..118e5b78e1f 100644 --- a/chromadb/utils/embedding_functions.py +++ b/chromadb/utils/embedding_functions.py @@ -214,7 +214,11 @@ def __call__(self, input: Documents) -> Embeddings: class CohereEmbeddingFunction(EmbeddingFunction[Documents]): - def __init__(self, api_key: str, model_name: str = "large"): + def __init__( + self, + api_key: str, + model_name: str = "large", + input_type: str = "search_document"): try: import cohere except ImportError: @@ -224,13 +228,14 @@ def __init__(self, api_key: str, model_name: str = "large"): self._client = cohere.Client(api_key) self._model_name = model_name + self._input_type = input_type def __call__(self, input: Documents) -> Embeddings: # Call Cohere Embedding API for each document. return [ embeddings for embeddings in self._client.embed( - texts=input, model=self._model_name, input_type="search_document" + texts=input, model=self._model_name, input_type=self._input_type ) ] @@ -724,6 +729,7 @@ def __init__( self, session: "boto3.Session", # noqa: F821 # Quote for forward reference model_name: str = "amazon.titan-embed-text-v1", + model_params: dict = {}, **kwargs: Any, ): """Initialize AmazonBedrockEmbeddingFunction. @@ -744,6 +750,10 @@ def __init__( self._model_name = model_name self._model_provider = self._model_name.split('.')[0] + if self._model_provider == "cohere": + self._input_type = model_params.get('input_type', "search_document") + self._truncate = model_params.get('truncate', "NONE") + self._bedrock_client = session.client( service_name="bedrock", **kwargs @@ -787,7 +797,8 @@ def __call__(self, input: Documents) -> Embeddings: raise ValueError(f"Input contains texts exceeding max length (2048)") input_body = { "texts": input, - "input_type": "search_document" + "input_type": self._input_type, + "truncate": self._truncate } response = self.call_model(input_body) embeddings = json.load(response.get("body")).get("embeddings")