Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Claude 3 model #2432

Merged
merged 4 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ install_requires=

# Basic Scenarios
datasets~=2.15
pyarrow>=11.0.0, # Pinned transitive dependency for datasets; workaround for #1026
pyarrow>=11.0.0 # Pinned transitive dependency for datasets; workaround for #1026
pyarrow-hotfix~=0.6 # Hotfix for CVE-2023-47248

# Basic metrics
Expand Down Expand Up @@ -118,7 +118,7 @@ amazon =
botocore~=1.31.57

anthropic =
anthropic~=0.2.5
anthropic~=0.17
websocket-client~=1.3.2 # For legacy stanford-online-all-v4-s3

mistral =
Expand Down
124 changes: 117 additions & 7 deletions src/helm/clients/anthropic_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, TypedDict, cast
import json
import requests
import time
Expand All @@ -20,16 +20,28 @@
TokenizationRequest,
TokenizationRequestResult,
)
from helm.proxy.retry import NonRetriableException
from helm.tokenizers.tokenizer import Tokenizer
from .client import CachingClient, truncate_sequence
from helm.clients.client import CachingClient, truncate_sequence, truncate_and_tokenize_response_text

try:
import anthropic
from anthropic import Anthropic
from anthropic.types import MessageParam
import websocket
except ModuleNotFoundError as e:
handle_module_not_found_error(e, ["anthropic"])


class AnthropicCompletionRequest(TypedDict):
prompt: str
stop_sequences: List[str]
model: str
max_tokens_to_sample: int
temperature: float
top_p: float
top_k: int


class AnthropicClient(CachingClient):
"""
Client for the Anthropic models (https://arxiv.org/abs/2204.05862).
Expand Down Expand Up @@ -63,12 +75,12 @@ def __init__(
self.tokenizer = tokenizer
self.tokenizer_name = tokenizer_name
self.api_key: Optional[str] = api_key
self._client = anthropic.Client(api_key) if api_key else None
self.client = Anthropic(api_key=api_key)

def _send_request(self, raw_request: Dict[str, Any]) -> Dict[str, Any]:
def _send_request(self, raw_request: AnthropicCompletionRequest) -> Dict[str, Any]:
if self.api_key is None:
raise Exception("API key is not set. Please set it in the HELM config file.")
result = self._client.completion(**raw_request)
result = self.client.completions.create(**raw_request).model_dump()
assert "error" not in result, f"Request failed with error: {result['error']}"
return result

Expand Down Expand Up @@ -103,7 +115,7 @@ def make_request(self, request: Request) -> RequestResult:
if request.max_tokens == 0 and not request.echo_prompt:
raise ValueError("echo_prompt must be True when max_tokens=0.")

raw_request = {
raw_request: AnthropicCompletionRequest = {
"prompt": request.prompt,
"stop_sequences": request.stop_sequences,
"model": request.model_engine,
Expand Down Expand Up @@ -190,6 +202,104 @@ def do_it():
)


class AnthropicMessagesRequest(TypedDict, total=False):
messages: List[MessageParam]
model: str
stop_sequences: List[str]
system: str
max_tokens: int
temperature: float
top_k: int
top_p: float


class AnthropicMessagesRequestError(NonRetriableException):
pass


class AnthropicMessagesResponseError(Exception):
pass


class AnthropicMessagesClient(CachingClient):
# Source: https://docs.anthropic.com/claude/docs/models-overview
MAX_OUTPUT_TOKENS = 4096

def __init__(
self, tokenizer: Tokenizer, tokenizer_name: str, cache_config: CacheConfig, api_key: Optional[str] = None
):
super().__init__(cache_config=cache_config)
self.tokenizer = tokenizer
self.tokenizer_name = tokenizer_name
self.client = Anthropic(api_key=api_key)
self.api_key: Optional[str] = api_key

def make_request(self, request: Request) -> RequestResult:

if request.max_tokens > AnthropicMessagesClient.MAX_OUTPUT_TOKENS:
raise AnthropicMessagesRequestError(
f"Request.max_tokens must be <= {AnthropicMessagesClient.MAX_OUTPUT_TOKENS}"
)

messages: List[MessageParam] = []
system_message: Optional[MessageParam] = None
if request.messages and request.prompt:
raise AnthropicMessagesRequestError("Exactly one of Request.messages and Request.prompt should be set")
if request.messages:
messages = cast(List[MessageParam], request.messages)
if messages[0]["role"] == "system":
system_message = messages.pop(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we clone to avoid mutating request?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch... changed to messages = messages[1:] instead which does not mutate.

else:
messages = [{"role": "user", "content": request.prompt}]

raw_request: AnthropicMessagesRequest = {
"messages": messages,
"model": request.model_engine,
"stop_sequences": request.stop_sequences,
"max_tokens": request.max_tokens,
"temperature": request.temperature,
"top_p": request.top_p,
"top_k": request.top_k_per_token,
}
if system_message is not None:
raw_request["system"] = cast(str, system_message["content"])
completions: List[Sequence] = []

# `num_completions` is not supported, so instead make `num_completions` separate requests.
for completion_index in range(request.num_completions):

def do_it() -> Dict[str, Any]:
result = self.client.messages.create(**raw_request).model_dump()
if "content" not in result or not result["content"]:
raise AnthropicMessagesResponseError(f"Anthropic response has empty content: {result}")
elif "text" not in result["content"][0]:
raise AnthropicMessagesResponseError(f"Anthropic response has non-text content: {result}")
return result

cache_key = CachingClient.make_cache_key(
{
"completion_index": completion_index,
**raw_request,
},
request,
)

response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
completion = truncate_and_tokenize_response_text(
response["content"][0]["text"], request, self.tokenizer, self.tokenizer_name, original_finish_reason=""
)
completions.append(completion)

return RequestResult(
success=True,
cached=cached,
request_time=response["request_time"],
request_datetime=response["request_datetime"],
completions=completions,
embedding=[],
)


class AnthropicRequestError(Exception):
pass

Expand Down
14 changes: 14 additions & 0 deletions src/helm/config/model_deployments.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,20 @@ model_deployments:
client_spec:
class_name: "helm.clients.anthropic_client.AnthropicClient"

- name: anthropic/claude-3-sonnet-20240229
model_name: anthropic/claude-3-sonnet-20240229
tokenizer_name: anthropic/claude
max_sequence_length: 200000
client_spec:
class_name: "helm.clients.anthropic_client.AnthropicMessagesClient"

- name: anthropic/claude-3-opus-20240229
model_name: anthropic/claude-3-opus-20240229
tokenizer_name: anthropic/claude
max_sequence_length: 200000
client_spec:
class_name: "helm.clients.anthropic_client.AnthropicMessagesClient"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about haiku?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No Haiku because "Haiku will be available soon." (source)


- name: anthropic/stanford-online-all-v4-s3
deprecated: true # Closed model, not accessible via API
model_name: anthropic/stanford-online-all-v4-s3
Expand Down
16 changes: 16 additions & 0 deletions src/helm/config/model_metadata.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,22 @@ models:
release_date: 2023-11-21
tags: [ANTHROPIC_CLAUDE_2_MODEL_TAG, TEXT_MODEL_TAG, LIMITED_FUNCTIONALITY_TEXT_MODEL_TAG, ABLATION_MODEL_TAG, INSTRUCTION_FOLLOWING_MODEL_TAG]

- name: anthropic/claude-3-sonnet-20240229
display_name: Claude 3 Sonnet (20240229)
description: TBD
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have an issue to get these filled? We need them before we actually deploy...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Filed #2435

creator_organization_name: Anthropic
access: limited
release_date: 2024-03-04
tags: [TEXT_MODEL_TAG, LIMITED_FUNCTIONALITY_TEXT_MODEL_TAG, INSTRUCTION_FOLLOWING_MODEL_TAG]

- name: anthropic/claude-3-opus-20240229
display_name: Claude 3 Opus (20240229)
description: TBD
creator_organization_name: Anthropic
access: limited
release_date: 2024-03-04
tags: [TEXT_MODEL_TAG, LIMITED_FUNCTIONALITY_TEXT_MODEL_TAG, INSTRUCTION_FOLLOWING_MODEL_TAG]

# DEPRECATED: Please do not use.
- name: anthropic/stanford-online-all-v4-s3
display_name: Anthropic-LM v4-s3 (52B)
Expand Down
2 changes: 1 addition & 1 deletion src/helm/tokenizers/anthropic_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self, cache_config: CacheConfig) -> None:
super().__init__(cache_config)
with AnthropicTokenizer.LOCK:
self._tokenizer: PreTrainedTokenizerBase = PreTrainedTokenizerFast(
tokenizer_object=anthropic.get_tokenizer()
tokenizer_object=anthropic.Anthropic().get_tokenizer()
)

def _tokenize_do_it(self, request: Dict[str, Any]) -> Dict[str, Any]:
Expand Down