diff --git a/chromadb/utils/embedding_functions.py b/chromadb/utils/embedding_functions.py index ec5fc05e3ee..118e5b78e1f 100644 --- a/chromadb/utils/embedding_functions.py +++ b/chromadb/utils/embedding_functions.py @@ -214,7 +214,11 @@ def __call__(self, input: Documents) -> Embeddings: class CohereEmbeddingFunction(EmbeddingFunction[Documents]): - def __init__(self, api_key: str, model_name: str = "large"): + def __init__( + self, + api_key: str, + model_name: str = "large", + input_type: str = "search_document"): try: import cohere except ImportError: @@ -224,13 +228,14 @@ def __init__(self, api_key: str, model_name: str = "large"): self._client = cohere.Client(api_key) self._model_name = model_name + self._input_type = input_type def __call__(self, input: Documents) -> Embeddings: # Call Cohere Embedding API for each document. return [ embeddings for embeddings in self._client.embed( - texts=input, model=self._model_name, input_type="search_document" + texts=input, model=self._model_name, input_type=self._input_type ) ] @@ -724,6 +729,7 @@ def __init__( self, session: "boto3.Session", # noqa: F821 # Quote for forward reference model_name: str = "amazon.titan-embed-text-v1", + model_params: dict = {}, **kwargs: Any, ): """Initialize AmazonBedrockEmbeddingFunction. @@ -742,27 +748,62 @@ def __init__( """ self._model_name = model_name + self._model_provider = self._model_name.split('.')[0] + + if self._model_provider == "cohere": + self._input_type = model_params.get('input_type', "search_document") + self._truncate = model_params.get('truncate', "NONE") + + self._bedrock_client = session.client( + service_name="bedrock", + **kwargs + ) + + self._model_details = self._bedrock_client.get_foundation_model(modelIdentifier=self._model_name)['modelDetails'] + if "EMBEDDING" not in self._model_details['outputModalities']: + raise ValueError(f"{self._model_name} doesn't have embedding modality output!") - self._client = session.client( + self._bedrock_runtime_client = session.client( service_name="bedrock-runtime", **kwargs, ) + + def call_model(self, body) -> dict: + body = json.dumps(body) + response = self._bedrock_runtime_client.invoke_model( + body=body, + modelId=self._model_name, + accept="application/json", + contentType="application/json", + ) + return response def __call__(self, input: Documents) -> Embeddings: - accept = "application/json" - content_type = "application/json" - embeddings = [] - for text in input: - input_body = {"inputText": text} - body = json.dumps(input_body) - response = self._client.invoke_model( - body=body, - modelId=self._model_name, - accept=accept, - contentType=content_type, - ) - embedding = json.load(response.get("body")).get("embedding") - embeddings.append(embedding) + if self._model_provider == "amazon": + embeddings = [] + for text in input: + input_body = { + "inputText": text + } + response = self.call_model(input_body) + embedding = json.load(response.get("body")).get("embedding") + embeddings.append(embedding) + elif self._model_provider == "cohere": + # See Amazon Bedrock User Guide > Cohere Embed models for more information + # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-embed.html + if len(input) > 128: + raise ValueError(f"Input texts exceeds max size (Got: {len(input)}, Expected: <=128)") + if not all(len(text) <= 2048 for text in input): + raise ValueError(f"Input contains texts exceeding max length (2048)") + input_body = { + "texts": input, + "input_type": self._input_type, + "truncate": self._truncate + } + response = self.call_model(input_body) + embeddings = json.load(response.get("body")).get("embeddings") + else: + raise NotImplementedError(f"Model {self._model_name} is not supported!") return embeddings