diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index a75c5e82..a07fa3ab 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -253,6 +253,9 @@ async def execute_tools_and_side_effects( hooks=hooks, context_wrapper=context_wrapper, config=run_config, + original_input=original_input, + pre_step_items=pre_step_items, + new_step_items=new_step_items, ), cls.execute_computer_actions( agent=agent, @@ -539,12 +542,25 @@ async def execute_function_tool_calls( hooks: RunHooks[TContext], context_wrapper: RunContextWrapper[TContext], config: RunConfig, + original_input: str | list[TResponseInputItem], + pre_step_items: list[RunItem], + new_step_items: list[RunItem], ) -> list[FunctionToolResult]: async def run_single_tool( func_tool: FunctionTool, tool_call: ResponseFunctionToolCall ) -> Any: with function_span(func_tool.name) as span_fn: - tool_context = ToolContext.from_agent_context(context_wrapper, tool_call.call_id) + # Build conversation history from original input and all items generated so far + original_items: list[TResponseInputItem] = ItemHelpers.input_to_new_input_list( + original_input + ) + pre_items = [item.to_input_item() for item in pre_step_items] + new_items = [item.to_input_item() for item in new_step_items] + conversation_history = original_items + pre_items + new_items + + tool_context = ToolContext.from_agent_context( + context_wrapper, tool_call.call_id, conversation_history=conversation_history + ) if config.trace_include_sensitive_data: span_fn.span_data.input = tool_call.arguments try: diff --git a/src/agents/tool_context.py b/src/agents/tool_context.py index c4329b8a..495a7da5 100644 --- a/src/agents/tool_context.py +++ b/src/agents/tool_context.py @@ -1,6 +1,7 @@ from dataclasses import dataclass, field, fields from typing import Any +from .items import TResponseInputItem from .run_context import RunContextWrapper, TContext @@ -15,15 +16,34 @@ class ToolContext(RunContextWrapper[TContext]): tool_call_id: str = field(default_factory=_assert_must_pass_tool_call_id) """The ID of the tool call.""" + conversation_history: list[TResponseInputItem] = field(default_factory=list) + """The conversation history available at the time this tool was called. + + This includes the original input and all items generated during the agent run + up to the point when this tool was invoked. + """ + @classmethod def from_agent_context( - cls, context: RunContextWrapper[TContext], tool_call_id: str + cls, + context: RunContextWrapper[TContext], + tool_call_id: str, + conversation_history: list[TResponseInputItem] | None = None, ) -> "ToolContext": """ Create a ToolContext from a RunContextWrapper. + + Args: + context: The run context wrapper + tool_call_id: The ID of the tool call + conversation_history: The conversation history available at tool invocation time """ # Grab the names of the RunContextWrapper's init=True fields base_values: dict[str, Any] = { f.name: getattr(context, f.name) for f in fields(RunContextWrapper) if f.init } - return cls(tool_call_id=tool_call_id, **base_values) + return cls( + tool_call_id=tool_call_id, + conversation_history=list(conversation_history or []), + **base_values, + ) diff --git a/tests/test_tool_context_conversation_history.py b/tests/test_tool_context_conversation_history.py new file mode 100644 index 00000000..898bfea1 --- /dev/null +++ b/tests/test_tool_context_conversation_history.py @@ -0,0 +1,205 @@ +"""Tests for conversation_history functionality in ToolContext.""" + +from __future__ import annotations + +from typing import cast + +import pytest +from openai.types.responses import ResponseFunctionToolCall, ResponseOutputMessage +from openai.types.responses.response_input_item_param import FunctionCallOutput + +from agents import ( + Agent, + MessageOutputItem, + RunContextWrapper, + RunItem, + ToolCallItem, + ToolCallOutputItem, + Usage, +) +from agents.items import ItemHelpers +from agents.tool_context import ToolContext + +from .test_responses import ( + get_function_tool_call, + get_text_input_item, + get_text_message, +) + + +def test_tool_context_has_conversation_history_field(): + """Test that ToolContext has a conversation_history field.""" + context = ToolContext(context=None, tool_call_id="test-id") + assert hasattr(context, "conversation_history") + assert isinstance(context.conversation_history, list) + assert len(context.conversation_history) == 0 + + +def test_tool_context_from_agent_context_default_history(): + """Test ToolContext.from_agent_context with no conversation history.""" + run_context = RunContextWrapper(context=None, usage=Usage()) + tool_context = ToolContext.from_agent_context(run_context, "test-id") + + assert tool_context.tool_call_id == "test-id" + assert tool_context.conversation_history == [] + + +def test_tool_context_from_agent_context_with_history(): + """Test ToolContext.from_agent_context with conversation history.""" + run_context = RunContextWrapper(context=None, usage=Usage()) + history = [get_text_input_item("Hello"), get_text_input_item("How are you?")] + + tool_context = ToolContext.from_agent_context( + run_context, "test-id", conversation_history=history + ) + + assert tool_context.tool_call_id == "test-id" + assert tool_context.conversation_history == history + assert len(tool_context.conversation_history) == 2 + + +@pytest.mark.asyncio +async def test_conversation_history_in_tool_execution(): + """Test that conversation history is properly passed to tools during execution.""" + + # Create a dummy agent for the items + dummy_agent = Agent[None](name="dummy") + + # Test that we can build conversation history manually + original_input = "What's the weather like?" + pre_step_items: list[RunItem] = [ + MessageOutputItem( + agent=dummy_agent, + raw_item=cast( + ResponseOutputMessage, get_text_message("I'll check the weather for you.") + ), + ) + ] + new_step_items: list[RunItem] = [ + ToolCallItem( + agent=dummy_agent, + raw_item=cast(ResponseFunctionToolCall, get_function_tool_call("test_tool", "")), + ) + ] + + # Test that we can build conversation history manually + original_items = ItemHelpers.input_to_new_input_list(original_input) + pre_items = [item.to_input_item() for item in pre_step_items] + new_items = [item.to_input_item() for item in new_step_items] + expected_history = original_items + pre_items + new_items + + assert len(expected_history) >= 1 # Should have at least the original input + + +@pytest.mark.asyncio +async def test_conversation_history_empty_for_first_turn(): + """Test that conversation history works correctly for the first turn.""" + + # Create a dummy agent for the items + dummy_agent = Agent[None](name="dummy") + + # Simulate first turn - only original input, no pre_step_items + original_input = "Hello" + pre_step_items: list[RunItem] = [] + new_step_items: list[RunItem] = [ + ToolCallItem( + agent=dummy_agent, + raw_item=cast(ResponseFunctionToolCall, get_function_tool_call("first_turn_tool", "")), + ) + ] + + # Build conversation history as it would be built in the actual execution + original_items = ItemHelpers.input_to_new_input_list(original_input) + pre_items = [item.to_input_item() for item in pre_step_items] + new_items = [item.to_input_item() for item in new_step_items] + conversation_history = original_items + pre_items + new_items + + # Should have at least the original input + assert len(conversation_history) >= 1 + assert len(original_items) == 1 # Original input becomes one item + + +@pytest.mark.asyncio +async def test_conversation_history_multi_turn(): + """Test conversation history accumulates correctly across multiple turns.""" + + # Create a dummy agent for the items + dummy_agent = Agent[None](name="dummy") + + # Simulate multiple turns with accumulated history + original_input = "Start conversation" + pre_step_items: list[RunItem] = [ + MessageOutputItem( + agent=dummy_agent, + raw_item=cast(ResponseOutputMessage, get_text_message("Response to start")), + ), + ToolCallItem( + agent=dummy_agent, + raw_item=cast(ResponseFunctionToolCall, get_function_tool_call("multi_turn_tool", "")), + ), + ToolCallOutputItem( + agent=dummy_agent, + raw_item=cast( + FunctionCallOutput, + { + "type": "function_call_output", + "call_id": "call-1", + "output": "Previous tool output", + }, + ), + output="Previous tool output", + ), + MessageOutputItem( + agent=dummy_agent, + raw_item=cast(ResponseOutputMessage, get_text_message("Continuing conversation")), + ), + ] + new_step_items: list[RunItem] = [ + ToolCallItem( + agent=dummy_agent, + raw_item=cast(ResponseFunctionToolCall, get_function_tool_call("multi_turn_tool", "")), + ) + ] + + # Build conversation history + original_items = ItemHelpers.input_to_new_input_list(original_input) + pre_items = [item.to_input_item() for item in pre_step_items] + new_items = [item.to_input_item() for item in new_step_items] + conversation_history = original_items + pre_items + new_items + + # Should contain: original input + all previous messages and tool calls + current tool call + assert len(conversation_history) >= 5 # At least 5 items in this conversation + + +def test_conversation_history_immutable(): + """Test that conversation_history cannot be modified after creation.""" + run_context = RunContextWrapper(context=None, usage=Usage()) + history = [get_text_input_item("Original message")] + + tool_context = ToolContext.from_agent_context( + run_context, "test-id", conversation_history=history + ) + + # Modifying the original list should not affect the tool context + history.append(get_text_input_item("Should not appear")) + + assert len(tool_context.conversation_history) == 1 + + # The conversation_history should be a new list, not a reference + tool_context.conversation_history.append(get_text_input_item("Direct modification")) + + # Create a new tool context to verify it's not affected + new_tool_context = ToolContext.from_agent_context( + run_context, "test-id-2", conversation_history=[get_text_input_item("Original message")] + ) + assert len(new_tool_context.conversation_history) == 1 + + +def test_conversation_history_with_none(): + """Test that passing None for conversation_history results in empty list.""" + run_context = RunContextWrapper(context=None, usage=Usage()) + + tool_context = ToolContext.from_agent_context(run_context, "test-id", conversation_history=None) + + assert tool_context.conversation_history == [] + assert isinstance(tool_context.conversation_history, list)