Skip to content

Commit

Permalink
feat: llama model tool calling (#965)
Browse files Browse the repository at this point in the history
  • Loading branch information
Wendong-Fan authored Sep 15, 2024
1 parent 6e94015 commit b8e1f5c
Show file tree
Hide file tree
Showing 8 changed files with 392 additions and 65 deletions.
2 changes: 1 addition & 1 deletion .github/ISSUE_TEMPLATE/bug_report.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ body:
attributes:
label: What version of camel are you using?
description: Run command `python3 -c 'print(__import__("camel").__version__)'` in your shell and paste the output here.
placeholder: E.g., 0.2.1
placeholder: E.g., 0.2.1a
validations:
required: true

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ conda create --name camel python=3.10
conda activate camel
# Clone github repo
git clone -b v0.2.1 https://github.com/camel-ai/camel.git
git clone -b v0.2.1a https://github.com/camel-ai/camel.git
# Change directory into project directory
cd camel
Expand Down
2 changes: 1 addition & 1 deletion camel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========

__version__ = '0.2.1'
__version__ = '0.2.1a'

__all__ = [
'__version__',
Expand Down
303 changes: 244 additions & 59 deletions camel/agents/chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

import json
import logging
import re
import uuid
from collections import defaultdict
from typing import (
TYPE_CHECKING,
Expand All @@ -28,6 +30,7 @@
)

from openai.types.chat import ChatCompletionMessageToolCall
from openai.types.chat.chat_completion_message_tool_call import Function
from pydantic import BaseModel

from camel.agents.base import BaseAgent
Expand Down Expand Up @@ -190,7 +193,7 @@ def __init__(
tool.get_openai_tool_schema() for tool in all_tools
]
self.model_backend.model_config_dict['tools'] = tool_schema_list

self.tool_schema_list = tool_schema_list
self.model_config_dict = self.model_backend.model_config_dict

self.model_token_limit = token_limit or self.model_backend.token_limit
Expand All @@ -206,6 +209,56 @@ def __init__(
self.response_terminators = response_terminators or []
self.init_messages()

# ruff: noqa: E501
def _generate_tool_prompt(self, tool_schema_list: List[Dict]) -> str:
tool_prompts = []

for tool in tool_schema_list:
tool_info = tool['function']
tool_name = tool_info['name']
tool_description = tool_info['description']
tool_json = json.dumps(tool_info, indent=4)

prompt = f"Use the function '{tool_name}' to '{tool_description}':\n{tool_json}\n"
tool_prompts.append(prompt)

tool_prompt_str = "\n".join(tool_prompts)

final_prompt = f'''
# Tool prompt
TOOL_PROMPT = f"""
You have access to the following functions:
{tool_prompt_str}
If you choose to call a function ONLY reply in the following format with no prefix or suffix:
<function=example_function_name>{{"example_name": "example_value"}}</function>
Reminder:
- Function calls MUST follow the specified format, start with <function= and end with </function>
- Required parameters MUST be specified
- Only call one function at a time
- Put the entire function call reply on one line
- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls
"""
'''
return final_prompt

def _parse_tool_response(self, response: str):
function_regex = r"<function=(\w+)>(.*?)</function>"
match = re.search(function_regex, response)

if match:
function_name, args_string = match.groups()
try:
args = json.loads(args_string)
return {"function": function_name, "arguments": args}
except json.JSONDecodeError as error:
print(f"Error parsing function arguments: {error}")
return None
return None

def reset(self):
r"""Resets the :obj:`ChatAgent` to its initial state and returns the
stored messages.
Expand Down Expand Up @@ -367,89 +420,221 @@ def step(
a boolean indicating whether the chat session has terminated,
and information about the chat session.
"""
self.update_memory(input_message, OpenAIBackendRole.USER)
if (
isinstance(self.model_type, ModelType)
and "lama" in self.model_type.value
or isinstance(self.model_type, str)
and "lama" in self.model_type
):
if self.model_backend.model_config_dict['tools']:
tool_prompt = self._generate_tool_prompt(self.tool_schema_list)

tool_sys_msg = BaseMessage.make_assistant_message(
role_name="Assistant",
content=tool_prompt,
)

tool_call_records: List[FunctionCallingRecord] = []
while True:
# Check if token has exceeded
try:
openai_messages, num_tokens = self.memory.get_context()
except RuntimeError as e:
return self._step_token_exceed(
e.args[1], tool_call_records, "max_tokens_exceeded"
self.update_memory(tool_sys_msg, OpenAIBackendRole.SYSTEM)

self.update_memory(input_message, OpenAIBackendRole.USER)

tool_call_records: List[FunctionCallingRecord] = []
while True:
# Check if token has exceeded
try:
openai_messages, num_tokens = self.memory.get_context()
except RuntimeError as e:
return self._step_token_exceed(
e.args[1], tool_call_records, "max_tokens_exceeded"
)

(
response,
output_messages,
finish_reasons,
usage_dict,
response_id,
) = self._step_model_response(openai_messages, num_tokens)
# If the model response is not a function call, meaning the
# model has generated a message response, break the loop
if (
not self.is_tools_added()
or not isinstance(response, ChatCompletion)
or "</function>" not in response.choices[0].message.content # type: ignore[operator]
):
break

parsed_content = self._parse_tool_response(
response.choices[0].message.content # type: ignore[arg-type]
)

(
response,
response.choices[0].message.tool_calls = [
ChatCompletionMessageToolCall(
id=str(uuid.uuid4()),
function=Function(
arguments=str(parsed_content["arguments"]).replace(
"'", '"'
),
name=str(parsed_content["function"]),
),
type="function",
)
]

# Check for external tool call
tool_call_request = response.choices[0].message.tool_calls[0]
if tool_call_request.function.name in self.external_tool_names:
# if model calls an external tool, directly return the
# request
info = self._step_get_info(
output_messages,
finish_reasons,
usage_dict,
response_id,
tool_call_records,
num_tokens,
tool_call_request,
)
return ChatAgentResponse(
msgs=output_messages,
terminated=self.terminated,
info=info,
)

# Normal function calling
tool_call_records.append(
self._step_tool_call_and_update(response)
)

if (
output_schema is not None
and self.model_type.supports_tool_calling
):
(
output_messages,
finish_reasons,
usage_dict,
response_id,
tool_call,
num_tokens,
) = self._structure_output_with_function(output_schema)
tool_call_records.append(tool_call)

info = self._step_get_info(
output_messages,
finish_reasons,
usage_dict,
response_id,
) = self._step_model_response(openai_messages, num_tokens)
tool_call_records,
num_tokens,
)

# If the model response is not a function call, meaning the model
# has generated a message response, break the loop
if (
not self.is_tools_added()
or not isinstance(response, ChatCompletion)
or response.choices[0].message.tool_calls is None
):
break
if len(output_messages) == 1:
# Auto record if the output result is a single message
self.record_message(output_messages[0])
else:
logger.warning(
"Multiple messages returned in `step()`, message won't be "
"recorded automatically. Please call `record_message()` "
"to record the selected message manually."
)

# Check for external tool call
tool_call_request = response.choices[0].message.tool_calls[0]
if tool_call_request.function.name in self.external_tool_names:
# if model calls an external tool, directly return the request
info = self._step_get_info(
return ChatAgentResponse(
msgs=output_messages, terminated=self.terminated, info=info
)

else:
self.update_memory(input_message, OpenAIBackendRole.USER)

tool_call_records: List[FunctionCallingRecord] = [] # type: ignore[no-redef]
while True:
# Check if token has exceeded
try:
openai_messages, num_tokens = self.memory.get_context()
except RuntimeError as e:
return self._step_token_exceed(
e.args[1], tool_call_records, "max_tokens_exceeded"
)

(
response,
output_messages,
finish_reasons,
usage_dict,
response_id,
tool_call_records,
num_tokens,
tool_call_request,
)
return ChatAgentResponse(
msgs=output_messages, terminated=self.terminated, info=info
) = self._step_model_response(openai_messages, num_tokens)
# If the model response is not a function call, meaning the
# model has generated a message response, break the loop
if (
not self.is_tools_added()
or not isinstance(response, ChatCompletion)
or response.choices[0].message.tool_calls is None
):
break

# Check for external tool call
tool_call_request = response.choices[0].message.tool_calls[0]

if tool_call_request.function.name in self.external_tool_names:
# if model calls an external tool, directly return the
# request
info = self._step_get_info(
output_messages,
finish_reasons,
usage_dict,
response_id,
tool_call_records,
num_tokens,
tool_call_request,
)
return ChatAgentResponse(
msgs=output_messages,
terminated=self.terminated,
info=info,
)

# Normal function calling
tool_call_records.append(
self._step_tool_call_and_update(response)
)

# Normal function calling
tool_call_records.append(self._step_tool_call_and_update(response))
if (
output_schema is not None
and self.model_type.supports_tool_calling
):
(
output_messages,
finish_reasons,
usage_dict,
response_id,
tool_call,
num_tokens,
) = self._structure_output_with_function(output_schema)
tool_call_records.append(tool_call)

if output_schema is not None and self.model_type.supports_tool_calling:
(
info = self._step_get_info(
output_messages,
finish_reasons,
usage_dict,
response_id,
tool_call,
tool_call_records,
num_tokens,
) = self._structure_output_with_function(output_schema)
tool_call_records.append(tool_call)
)

info = self._step_get_info(
output_messages,
finish_reasons,
usage_dict,
response_id,
tool_call_records,
num_tokens,
)
if len(output_messages) == 1:
# Auto record if the output result is a single message
self.record_message(output_messages[0])
else:
logger.warning(
"Multiple messages returned in `step()`, message won't be "
"recorded automatically. Please call `record_message()` "
"to record the selected message manually."
)

if len(output_messages) == 1:
# Auto record if the output result is a single message
self.record_message(output_messages[0])
else:
logger.warning(
"Multiple messages returned in `step()`, message won't be "
"recorded automatically. Please call `record_message()` to "
"record the selected message manually."
return ChatAgentResponse(
msgs=output_messages, terminated=self.terminated, info=info
)

return ChatAgentResponse(
msgs=output_messages, terminated=self.terminated, info=info
)

async def step_async(
self,
input_message: BaseMessage,
Expand Down
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
project = 'CAMEL'
copyright = '2024, CAMEL-AI.org'
author = 'CAMEL-AI.org'
release = '0.2.1'
release = '0.2.1a'

html_favicon = (
'https://github.com/camel-ai/camel/master/misc/favicon.png'
Expand Down
Loading

0 comments on commit b8e1f5c

Please sign in to comment.