Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: supported Amazon Titan Text Embeddings V2 #112

Merged
merged 3 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ DIAL_URL=<dial core url>

# Misc env vars for the server
LOG_LEVEL=INFO # Default in prod is INFO. Use DEBUG for dev.
WEB_CONCURRENCY=1 # Number of unicorn workers
WEB_CONCURRENCY=1 # Number of uvicorn workers
TEST_SERVER_URL=http://0.0.0.0:5001
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ The models that support `/truncate_prompt` do also support `max_prompt_tokens` r

Certain model do not support precise tokenization, because the tokenization algorithm is not known. Instead an approximate tokenization algorithm is used. It conservatively counts every byte in UTF-8 encoding of a string as a single token.

The following models support `SERVER_URL/openai/deployments/DEPLOYMENT_NAME/embeddings` endpoint:

|Model|Deployment name|Modality|
|---|---|---|
|Amazon Titan Text Embeddings V2|amazon.titan-embed-text-v2:0|text-to-embedding|

## Developer environment

This project uses [Python>=3.11](https://www.python.org/downloads/) and [Poetry>=1.6.1](https://python-poetry.org/) as a dependency manager.
Expand Down
50 changes: 46 additions & 4 deletions aidial_adapter_bedrock/app.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,28 @@
import json
from typing import Optional

from aidial_sdk import DIALApp
from aidial_sdk.telemetry.types import TelemetryConfig
from fastapi import Body, Header, Path

from aidial_adapter_bedrock.chat_completion import BedrockChatCompletion
from aidial_adapter_bedrock.deployments import BedrockDeployment
from aidial_adapter_bedrock.dial_api.response import ModelObject, ModelsResponse
from aidial_adapter_bedrock.deployments import (
ChatCompletionDeployment,
EmbeddingsDeployment,
)
from aidial_adapter_bedrock.dial_api.request import (
EmbeddingsRequest,
EmbeddingsType,
)
from aidial_adapter_bedrock.dial_api.response import (
ModelObject,
ModelsResponse,
make_embeddings_response,
)
from aidial_adapter_bedrock.llm.model.adapter import get_embeddings_model
from aidial_adapter_bedrock.server.exceptions import dial_exception_decorator
from aidial_adapter_bedrock.utils.env import get_aws_default_region
from aidial_adapter_bedrock.utils.log_config import bedrock_logger as log
from aidial_adapter_bedrock.utils.log_config import configure_loggers

AWS_DEFAULT_REGION = get_aws_default_region()
Expand All @@ -28,13 +45,38 @@ async def models():
return ModelsResponse(
data=[
ModelObject(id=deployment.deployment_id)
for deployment in BedrockDeployment
for deployment in ChatCompletionDeployment
]
)


for deployment in BedrockDeployment:
for deployment in ChatCompletionDeployment:
app.add_chat_completion(
deployment.deployment_id,
BedrockChatCompletion(region=AWS_DEFAULT_REGION),
)


@app.post("/openai/deployments/{deployment}/embeddings")
@dial_exception_decorator
async def embeddings(
embeddings_type: EmbeddingsType = Header(
alias="X-DIAL-Type", default=EmbeddingsType.SYMMETRIC
),
embeddings_instruction: Optional[str] = Header(
alias="X-DIAL-Instruction", default=None
),
deployment: EmbeddingsDeployment = Path(...),
query: dict = Body(..., example=EmbeddingsRequest.example()),
):
log.debug(f"query: {json.dumps(query)}")

model = await get_embeddings_model(
deployment=deployment, region=AWS_DEFAULT_REGION
)

response = await model.embeddings(
query, embeddings_instruction, embeddings_type
)

return make_embeddings_response(deployment, response)
4 changes: 2 additions & 2 deletions aidial_adapter_bedrock/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ async def ainvoke_non_streaming(self, model: str, args: dict) -> dict:
body: StreamingBody = response["body"]
body_dict = json.loads(await make_async(lambda: body.read()))

log.debug(f"response['body']: {body_dict}")
log.debug(f"response['body']: {json.dumps(body_dict)}")

return body_dict

Expand All @@ -65,7 +65,7 @@ async def ainvoke_streaming(
chunk = event.get("chunk")
if chunk:
chunk_dict = json.loads(chunk.get("bytes").decode())
log.debug(f"chunk: {chunk_dict}")
log.debug(f"chunk: {json.dumps(chunk_dict)}")
yield chunk_dict


Expand Down
6 changes: 4 additions & 2 deletions aidial_adapter_bedrock/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
)
from typing_extensions import override

from aidial_adapter_bedrock.deployments import BedrockDeployment
from aidial_adapter_bedrock.deployments import ChatCompletionDeployment
from aidial_adapter_bedrock.dial_api.request import ModelParameters
from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage
from aidial_adapter_bedrock.llm.chat_model import ChatCompletionAdapter
Expand All @@ -44,7 +44,9 @@ def __init__(self, region: str):
async def _get_model(
self, request: FromRequestDeploymentMixin
) -> ChatCompletionAdapter:
deployment = BedrockDeployment.from_deployment_id(request.deployment_id)
deployment = ChatCompletionDeployment.from_deployment_id(
request.deployment_id
)
return await get_bedrock_adapter(
region=self.region,
deployment=deployment,
Expand Down
26 changes: 22 additions & 4 deletions aidial_adapter_bedrock/deployments.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from enum import Enum


class BedrockDeployment(str, Enum):
class ChatCompletionDeployment(str, Enum):
AMAZON_TITAN_TG1_LARGE = "amazon.titan-tg1-large"
AI21_J2_GRANDE_INSTRUCT = "ai21.j2-grande-instruct"
AI21_J2_JUMBO_INSTRUCT = "ai21.j2-jumbo-instruct"
Expand Down Expand Up @@ -30,11 +30,29 @@ def model_id(self) -> str:
"""Id of the model in the Bedrock service."""

# Redirect Stability model without version to the earliest non-deprecated version (V1)
if self == BedrockDeployment.STABILITY_STABLE_DIFFUSION_XL:
return BedrockDeployment.STABILITY_STABLE_DIFFUSION_XL_V1.model_id
if self == ChatCompletionDeployment.STABILITY_STABLE_DIFFUSION_XL:
return (
ChatCompletionDeployment.STABILITY_STABLE_DIFFUSION_XL_V1.model_id
)

return self.value

@classmethod
def from_deployment_id(cls, deployment_id: str) -> "BedrockDeployment":
def from_deployment_id(
cls, deployment_id: str
) -> "ChatCompletionDeployment":
return cls(deployment_id)


class EmbeddingsDeployment(str, Enum):
AMAZON_TITAN_EMBED_TEXT_2 = "amazon.titan-embed-text-v2:0"

@property
def deployment_id(self) -> str:
"""Deployment id under which the model is served by the adapter."""
return self.value

@property
def model_id(self) -> str:
"""Id of the model in the Bedrock service."""
return self.value
21 changes: 20 additions & 1 deletion aidial_adapter_bedrock/dial_api/request.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List, Optional
from enum import Enum
from typing import List, Literal, Optional

from aidial_sdk.chat_completion.request import ChatCompletionRequest
from pydantic import BaseModel
Expand All @@ -7,6 +8,7 @@
ToolsConfig,
validate_messages,
)
from aidial_adapter_bedrock.utils.pydantic import ExtraAllowModel


class ModelParameters(BaseModel):
Expand Down Expand Up @@ -44,3 +46,20 @@ def create(cls, request: ChatCompletionRequest) -> "ModelParameters":

def add_stop_sequences(self, stop: List[str]) -> "ModelParameters":
return self.copy(update={"stop": [*self.stop, *stop]})


class EmbeddingsType(str, Enum):
SYMMETRIC = "symmetric"
DOCUMENT = "document"
QUERY = "query"


class EmbeddingsRequest(ExtraAllowModel):
roman-romanov-o marked this conversation as resolved.
Show resolved Hide resolved
input: str | List[str]
user: Optional[str] = None
encoding_format: Literal["float", "base64"] = "float"
dimensions: Optional[int] = None

@staticmethod
def example() -> "EmbeddingsRequest":
return EmbeddingsRequest(input=["fish", "ball"])
43 changes: 42 additions & 1 deletion aidial_adapter_bedrock/dial_api/response.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import List, Literal
from typing import List, Literal, Tuple, TypedDict

from pydantic import BaseModel

from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage


class ModelObject(BaseModel):
object: Literal["model"] = "model"
Expand All @@ -11,3 +13,42 @@ class ModelObject(BaseModel):
class ModelsResponse(BaseModel):
object: Literal["list"] = "list"
data: List[ModelObject]


class EmbeddingsDict(TypedDict):
index: int
object: Literal["embedding"]
embedding: List[float]


class EmbeddingsTokenUsageDict(TypedDict):
prompt_tokens: int
total_tokens: int


class EmbeddingsResponseDict(TypedDict):
object: Literal["list"]
model: str
data: List[EmbeddingsDict]
usage: EmbeddingsTokenUsageDict


def make_embeddings_response(
model_id: str, resp: Tuple[List[List[float]], TokenUsage]
) -> EmbeddingsResponseDict:
vectors, usage = resp

data: List[EmbeddingsDict] = [
{"index": idx, "object": "embedding", "embedding": vec}
for idx, vec in enumerate(vectors)
]

return {
"object": "list",
"model": model_id,
"data": data,
"usage": {
"prompt_tokens": usage.prompt_tokens,
"total_tokens": usage.total_tokens,
},
}
103 changes: 103 additions & 0 deletions aidial_adapter_bedrock/embeddings/amazon_titan_text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from typing import Iterable, List, Literal, Optional, Self, Tuple

from pydantic import BaseModel

from aidial_adapter_bedrock.bedrock import Bedrock
from aidial_adapter_bedrock.dial_api.request import (
EmbeddingsRequest,
EmbeddingsType,
)
from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage
from aidial_adapter_bedrock.embeddings.embeddings_adapter import (
EmbeddingsAdapter,
)
from aidial_adapter_bedrock.llm.errors import ValidationError
from aidial_adapter_bedrock.utils.json import remove_nones
from aidial_adapter_bedrock.utils.log_config import bedrock_logger as log


def validate_parameters(
encoding_format: Literal["float", "base64"],
embedding_type: EmbeddingsType,
embedding_instruction: Optional[str],
supported_embedding_types: List[EmbeddingsType],
) -> None:
if encoding_format == "base64":
raise ValidationError("Base64 encoding format is not supported")

if embedding_instruction is not None:
raise ValidationError("Instruction prompt is not supported")

assert (
len(supported_embedding_types) != 0
), "The embedding model doesn't support any embedding types"

if embedding_type not in supported_embedding_types:
allowed = ", ".join([e.value for e in supported_embedding_types])
raise ValidationError(
f"Embedding types other than {allowed} are not supported"
)


def create_requests(request: EmbeddingsRequest) -> Iterable[dict]:
inputs: List[str] = (
[request.input] if isinstance(request.input, str) else request.input
)

# This includes all Titan-specific request parameters missing
# from the OpenAI Embeddings request, e.g. "normalize" boolean flag.
extra_body = request.get_extra_fields()

# NOTE: Amazon Titan doesn't support batched inputs
for input in inputs:
yield remove_nones(
{
"inputText": input,
"dimensions": request.dimensions,
**extra_body,
}
)


class AmazonResponse(BaseModel):
inputTextTokenCount: int
embedding: List[float]


class AmazonTitanTextEmbeddings(EmbeddingsAdapter):
model: str
client: Bedrock

@classmethod
def create(cls, client: Bedrock, model: str) -> Self:
return cls(client=client, model=model)

async def embeddings(
self,
request_body: dict,
embedding_instruction: Optional[str],
embedding_type: EmbeddingsType,
) -> Tuple[List[List[float]], TokenUsage]:
request = EmbeddingsRequest.parse_obj(request_body)

validate_parameters(
request.encoding_format,
embedding_type,
embedding_instruction,
[EmbeddingsType.SYMMETRIC],
)

embeddings: List[List[float]] = []
usage = TokenUsage()

for request in create_requests(request):
log.debug(f"request: {request}")

response_dict = await self.client.ainvoke_non_streaming(
self.model, request
)
response = AmazonResponse.parse_obj(response_dict)
embeddings.append(response.embedding)
usage.prompt_tokens += response.inputTextTokenCount

return embeddings, usage
21 changes: 21 additions & 0 deletions aidial_adapter_bedrock/embeddings/embeddings_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple

from pydantic import BaseModel

from aidial_adapter_bedrock.dial_api.request import EmbeddingsType
from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage


class EmbeddingsAdapter(ABC, BaseModel):
class Config:
arbitrary_types_allowed = True

@abstractmethod
async def embeddings(
self,
request_body: dict,
embedding_instruction: Optional[str],
embedding_type: EmbeddingsType,
) -> Tuple[List[List[float]], TokenUsage]:
pass
Loading