Skip to content

Commit

Permalink
anthropic[patch]: fix tool call and tool res image_url handling (#26587)
Browse files Browse the repository at this point in the history
Co-authored-by: ccurme <chester.curme@gmail.com>
  • Loading branch information
baskaryan and ccurme committed Sep 17, 2024
1 parent c6bdd6f commit 5ced41b
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 36 deletions.
72 changes: 44 additions & 28 deletions libs/partners/anthropic/langchain_anthropic/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,65 +194,81 @@ def _format_messages(

# populate content
content = []
for item in message.content:
if isinstance(item, str):
content.append({"type": "text", "text": item})
elif isinstance(item, dict):
if "type" not in item:
raise ValueError("Dict content item must have a type key")
elif item["type"] == "image_url":
for block in message.content:
if isinstance(block, str):
content.append({"type": "text", "text": block})
elif isinstance(block, dict):
if "type" not in block:
raise ValueError("Dict content block must have a type key")
elif block["type"] == "image_url":
# convert format
source = _format_image(item["image_url"]["url"])
source = _format_image(block["image_url"]["url"])
content.append({"type": "image", "source": source})
elif item["type"] == "tool_use":
elif block["type"] == "tool_use":
# If a tool_call with the same id as a tool_use content block
# exists, the tool_call is preferred.
if isinstance(message, AIMessage) and item["id"] in [
if isinstance(message, AIMessage) and block["id"] in [
tc["id"] for tc in message.tool_calls
]:
overlapping = [
tc
for tc in message.tool_calls
if tc["id"] == item["id"]
if tc["id"] == block["id"]
]
content.extend(
_lc_tool_calls_to_anthropic_tool_use_blocks(overlapping)
)
else:
item.pop("text", None)
content.append(item)
elif item["type"] == "text":
text = item.get("text", "")
block.pop("text", None)
content.append(block)
elif block["type"] == "text":
text = block.get("text", "")
# Only add non-empty strings for now as empty ones are not
# accepted.
# https://github.com/anthropics/anthropic-sdk-python/issues/461
if text.strip():
content.append(
{
k: v
for k, v in item.items()
for k, v in block.items()
if k in ("type", "text", "cache_control")
}
)
elif block["type"] == "tool_result":
tool_content = _format_messages(
[HumanMessage(block["content"])]
)[1][0]["content"]
content.append({**block, **{"content": tool_content}})
else:
content.append(item)
content.append(block)
else:
raise ValueError(
f"Content items must be str or dict, instead was: {type(item)}"
f"Content blocks must be str or dict, instead was: "
f"{type(block)}"
)
elif isinstance(message, AIMessage) and message.tool_calls:
content = (
[]
if not message.content
else [{"type": "text", "text": message.content}]
)
# Note: Anthropic can't have invalid tool calls as presently defined,
# since the model already returns dicts args not JSON strings, and invalid
# tool calls are those with invalid JSON for args.
content += _lc_tool_calls_to_anthropic_tool_use_blocks(message.tool_calls)
else:
content = message.content

# Ensure all tool_calls have a tool_use content block
if isinstance(message, AIMessage) and message.tool_calls:
content = content or []
content = (
[{"type": "text", "text": message.content}]
if isinstance(content, str) and content
else content
)
tool_use_ids = [
cast(dict, block)["id"]
for block in content
if cast(dict, block)["type"] == "tool_use"
]
missing_tool_calls = [
tc for tc in message.tool_calls if tc["id"] not in tool_use_ids
]
cast(list, content).extend(
_lc_tool_calls_to_anthropic_tool_use_blocks(missing_tool_calls)
)

formatted_messages.append({"role": role, "content": content})
return system, formatted_messages

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ def chat_model_params(self) -> dict:
def supports_image_inputs(self) -> bool:
return True

@property
def supports_image_tool_message(self) -> bool:
return True

@property
def supports_anthropic_inputs(self) -> bool:
return True
87 changes: 80 additions & 7 deletions libs/partners/anthropic/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,15 +366,36 @@ def test_convert_to_anthropic_tool(
def test__format_messages_with_tool_calls() -> None:
system = SystemMessage("fuzz") # type: ignore[misc]
human = HumanMessage("foo") # type: ignore[misc]
ai = AIMessage( # type: ignore[misc]
"",
ai = AIMessage(
"", # with empty string
tool_calls=[{"name": "bar", "id": "1", "args": {"baz": "buzz"}}],
)
tool = ToolMessage( # type: ignore[misc]
ai2 = AIMessage(
[], # with empty list
tool_calls=[{"name": "bar", "id": "2", "args": {"baz": "buzz"}}],
)
tool = ToolMessage(
"blurb",
tool_call_id="1",
)
messages = [system, human, ai, tool]
tool_image_url = ToolMessage(
[{"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,...."}}],
tool_call_id="2",
)
tool_image = ToolMessage(
[
{
"type": "image",
"source": {
"data": "....",
"type": "base64",
"media_type": "image/jpeg",
},
}
],
tool_call_id="3",
)
messages = [system, human, ai, tool, ai2, tool_image_url, tool_image]
expected = (
"fuzz",
[
Expand All @@ -401,6 +422,52 @@ def test__format_messages_with_tool_calls() -> None:
}
],
},
{
"role": "assistant",
"content": [
{
"type": "tool_use",
"name": "bar",
"id": "2",
"input": {"baz": "buzz"},
}
],
},
{
"role": "user",
"content": [
{
"type": "tool_result",
"content": [
{
"type": "image",
"source": {
"data": "....",
"type": "base64",
"media_type": "image/jpeg",
},
}
],
"tool_use_id": "2",
"is_error": False,
},
{
"type": "tool_result",
"content": [
{
"type": "image",
"source": {
"data": "....",
"type": "base64",
"media_type": "image/jpeg",
},
}
],
"tool_use_id": "3",
"is_error": False,
},
],
},
],
)
actual = _format_messages(messages)
Expand Down Expand Up @@ -454,8 +521,6 @@ def test__format_messages_with_str_content_and_tool_calls() -> None:
def test__format_messages_with_list_content_and_tool_calls() -> None:
system = SystemMessage("fuzz") # type: ignore[misc]
human = HumanMessage("foo") # type: ignore[misc]
# If content and tool_calls are specified and content is a list, then content is
# preferred.
ai = AIMessage( # type: ignore[misc]
[{"type": "text", "text": "thought"}],
tool_calls=[{"name": "bar", "id": "1", "args": {"baz": "buzz"}}],
Expand All @@ -471,7 +536,15 @@ def test__format_messages_with_list_content_and_tool_calls() -> None:
{"role": "user", "content": "foo"},
{
"role": "assistant",
"content": [{"type": "text", "text": "thought"}],
"content": [
{"type": "text", "text": "thought"},
{
"type": "tool_use",
"name": "bar",
"id": "1",
"input": {"baz": "buzz"},
},
],
},
{
"role": "user",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def chat_model_class(self) -> Type[BaseChatModel]:

@property
def chat_model_params(self) -> dict:
return {"model": "gpt-4o", "stream_usage": True}
return {"model": "gpt-4o-mini", "stream_usage": True}

@property
def supports_image_inputs(self) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,37 @@ def test_image_inputs(self, model: BaseChatModel) -> None:
)
model.invoke([message])

def test_image_tool_message(self, model: BaseChatModel) -> None:
if not self.supports_image_tool_message:
return
image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
image_data = base64.b64encode(httpx.get(image_url).content).decode("utf-8")
messages = [
HumanMessage("get a random image using the tool and describe the weather"),
AIMessage(
[],
tool_calls=[
{"type": "tool_call", "id": "1", "name": "random_image", "args": {}}
],
),
ToolMessage(
content=[
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{image_data}"},
},
],
tool_call_id="1",
name="random_image",
),
]

def random_image() -> str:
"""Return a random image."""
return ""

model.bind_tools([random_image]).invoke(messages)

def test_anthropic_inputs(self, model: BaseChatModel) -> None:
if not self.supports_anthropic_inputs:
return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,10 @@ def returns_usage_metadata(self) -> bool:
def supports_anthropic_inputs(self) -> bool:
return False

@property
def supports_image_tool_message(self) -> bool:
return False


class ChatModelUnitTests(ChatModelTests):
@property
Expand Down

0 comments on commit 5ced41b

Please sign in to comment.