diff --git a/.gitattributes b/.gitattributes index 839c730..5316de4 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1 +1,3 @@ .devcontainer/** linguist-vendored=true + +inference_gateway/models.py linguist-generated=true diff --git a/.gitignore b/.gitignore index c481c18..cf4f4b3 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ dist .coverage node_modules/ .mypy_cache/ +**/.env diff --git a/README.md b/README.md index 7189a81..7645e81 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,9 @@ - [Error Handling](#error-handling) - [Advanced Usage](#advanced-usage) - [Using Tools](#using-tools) + - [Listing Available MCP Tools](#listing-available-mcp-tools) - [Custom HTTP Configuration](#custom-http-configuration) + - [Examples](#examples) - [License](#license) A modern Python SDK for interacting with the [Inference Gateway](https://github.com/edenreich/inference-gateway), providing a unified interface to multiple AI providers. @@ -41,17 +43,17 @@ pip install inference-gateway ### Basic Usage ```python -from inference_gateway import InferenceGatewayClient, Message, MessageRole +from inference_gateway import InferenceGatewayClient, Message # Initialize client -client = InferenceGatewayClient("http://localhost:8080") +client = InferenceGatewayClient("http://localhost:8080/v1") # Simple chat completion response = client.create_chat_completion( model="openai/gpt-4", messages=[ - Message(role=MessageRole.SYSTEM, content="You are a helpful assistant"), - Message(role=MessageRole.USER, content="Hello!") + Message(role="system", content="You are a helpful assistant"), + Message(role="user", content="Hello!") ] ) @@ -70,18 +72,18 @@ print(response.choices[0].message.content) from inference_gateway import InferenceGatewayClient # Basic configuration -client = InferenceGatewayClient("http://localhost:8080") +client = InferenceGatewayClient("http://localhost:8080/v1") # With authentication client = InferenceGatewayClient( - "http://localhost:8080", + "http://localhost:8080/v1", token="your-api-token", timeout=60.0 # Custom timeout ) # Using httpx instead of requests client = InferenceGatewayClient( - "http://localhost:8080", + "http://localhost:8080/v1", use_httpx=True ) ``` @@ -105,13 +107,13 @@ print("OpenAI models:", openai_models) #### Standard Completion ```python -from inference_gateway import Message, MessageRole +from inference_gateway import Message response = client.create_chat_completion( model="openai/gpt-4", messages=[ - Message(role=MessageRole.SYSTEM, content="You are a helpful assistant"), - Message(role=MessageRole.USER, content="Explain quantum computing") + Message(role="system", content="You are a helpful assistant"), + Message(role="user", content="Explain quantum computing") ], max_tokens=500 ) @@ -126,7 +128,7 @@ print(response.choices[0].message.content) for chunk in client.create_chat_completion_stream( model="ollama/llama2", messages=[ - Message(role=MessageRole.USER, content="Tell me a story") + Message(role="user", content="Tell me a story") ], use_sse=True ): @@ -136,7 +138,7 @@ for chunk in client.create_chat_completion_stream( for chunk in client.create_chat_completion_stream( model="anthropic/claude-3", messages=[ - Message(role=MessageRole.USER, content="Explain AI safety") + Message(role="user", content="Explain AI safety") ], use_sse=False ): @@ -186,43 +188,96 @@ except InferenceGatewayError as e: ### Using Tools ```python -# List available MCP tools works when MCP_ENABLE and MCP_EXPOSE are set on the gateway -tools = client.list_tools() -print("Available tools:", tools) +# Define a weather tool using type-safe Pydantic models +from inference_gateway.models import ChatCompletionTool, FunctionObject, FunctionParameters + +weather_tool = ChatCompletionTool( + type="function", + function=FunctionObject( + name="get_current_weather", + description="Get the current weather in a given location", + parameters=FunctionParameters( + type="object", + properties={ + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use" + } + }, + required=["location"] + ) + ) +) -# Use tools in chat completion works when MCP_ENABLE and MCP_EXPOSE are set to false on the gateway +# Using tools in a chat completion response = client.create_chat_completion( model="openai/gpt-4", - messages=[...], - tools=[ - { - "type": "function", - "function": { - "name": "get_current_weather", - "description": "Get the current weather", - "parameters": {...} - } - } - ] + messages=[ + Message(role="system", content="You are a helpful assistant with access to weather information"), + Message(role="user", content="What is the weather like in New York?") + ], + tools=[weather_tool] # Pass the tool definition ) + +print(response.choices[0].message.content) + +# Check if the model made a tool call +if response.choices[0].message.tool_calls: + for tool_call in response.choices[0].message.tool_calls: + print(f"Tool called: {tool_call.function.name}") + print(f"Arguments: {tool_call.function.arguments}") ``` +### Listing Available MCP Tools + +```python +# List available MCP tools (requires MCP_ENABLE and MCP_EXPOSE to be set on the gateway) +tools = client.list_tools() +print("Available tools:", tools) +``` + +**Server-Side Tool Management** + +The SDK currently supports listing available MCP tools, which is particularly useful for UI applications that need to display connected tools to users. The key advantage is that tools are managed server-side: + +- **Automatic Tool Injection**: Tools are automatically inferred and injected into requests by the Inference Gateway server +- **Simplified Client Code**: No need to manually manage or configure tools in your client application +- **Transparent Tool Calls**: During streaming chat completions with configured MCP servers, tool calls appear in the response stream - no special handling required except optionally displaying them to users + +This architecture allows you to focus on LLM interactions while the gateway handles all tool management complexities behind the scenes. + ### Custom HTTP Configuration ```python # With custom headers client = InferenceGatewayClient( - "http://localhost:8080", + "http://localhost:8080/v1", headers={"X-Custom-Header": "value"} ) # With proxy settings client = InferenceGatewayClient( - "http://localhost:8080", + "http://localhost:8080/v1", proxies={"http": "http://proxy.example.com"} ) ``` +## Examples + +For comprehensive examples demonstrating various use cases, see the [examples](examples/) directory: + +- [List LLMs](examples/list/) - How to list available models +- [Chat](examples/chat/) - Basic and advanced chat completion examples +- [Tools](examples/tools/) - Working with function tools +- [MCP](examples/mcp/) - Model Context Protocol integration examples + +Each example includes a detailed README with setup instructions and explanations. + ## License This SDK is distributed under the MIT License, see [LICENSE](LICENSE) for more information. diff --git a/Taskfile.yml b/Taskfile.yml index 07438e3..e1865ea 100644 --- a/Taskfile.yml +++ b/Taskfile.yml @@ -43,7 +43,7 @@ tasks: --output inference_gateway/models.py --output-model-type pydantic_v2.BaseModel --enum-field-as-literal all - --target-python-version {{.PYTHON_VERSION}} + --target-python-version 3.12 --use-schema-description --use-generic-container-types --use-standard-collections @@ -58,7 +58,8 @@ tasks: --strict-nullable --allow-population-by-field-name --snake-case-field - --strip-default-none + --use-default + --use-default-kwarg --use-title-as-name - echo "✅ Models generated successfully" - task: format @@ -67,17 +68,17 @@ tasks: desc: Format code with black and isort cmds: - echo "Formatting code..." - - black inference_gateway/ tests/ - - isort inference_gateway/ tests/ + - black inference_gateway/ tests/ examples/ + - isort inference_gateway/ tests/ examples/ - echo "✅ Code formatted" lint: desc: Run all linting checks cmds: - echo "Running linting checks..." - - black --check inference_gateway/ tests/ - - isort --check-only inference_gateway/ tests/ - - mypy inference_gateway/ + - black --check inference_gateway/ tests/ examples/ + - isort --check-only inference_gateway/ tests/ examples/ + - mypy inference_gateway/ examples/ - echo "✅ All linting checks passed" test: @@ -122,6 +123,29 @@ tasks: - python -m build - echo "✅ Package built successfully" + install-global: + desc: Build and install the package globally for testing + deps: + - build + cmds: + - echo "Installing package globally..." + - pip uninstall -y inference-gateway || true + - pip install dist/*.whl --force-reinstall + - echo "✅ Package installed globally successfully" + + install-global-dev: + desc: Build and install the package globally for testing (skip tests) + deps: + - clean + - format + cmds: + - echo "Building package (skipping tests)..." + - python -m build + - echo "Installing package globally..." + - pip uninstall -y inference-gateway || true + - pip install dist/*.whl --force-reinstall + - echo "✅ Package installed globally successfully" + docs:serve: desc: Serve documentation locally (placeholder for future docs) cmds: diff --git a/examples/.env.example b/examples/.env.example new file mode 100644 index 0000000..c70fb63 --- /dev/null +++ b/examples/.env.example @@ -0,0 +1,48 @@ + +# General settings +ENVIRONMENT=production +ENABLE_TELEMETRY=false +ENABLE_AUTH=false +# Model Context Protocol (MCP) +MCP_ENABLE=false +MCP_EXPOSE=false +MCP_SERVERS= +MCP_CLIENT_TIMEOUT=5s +MCP_DIAL_TIMEOUT=3s +MCP_TLS_HANDSHAKE_TIMEOUT=3s +MCP_RESPONSE_HEADER_TIMEOUT=3s +MCP_EXPECT_CONTINUE_TIMEOUT=1s +MCP_REQUEST_TIMEOUT=5s +# OpenID Connect +OIDC_ISSUER_URL=http://keycloak:8080/realms/inference-gateway-realm +OIDC_CLIENT_ID=inference-gateway-client +OIDC_CLIENT_SECRET= +# Server settings +SERVER_HOST=0.0.0.0 +SERVER_PORT=8080 +SERVER_READ_TIMEOUT=30s +SERVER_WRITE_TIMEOUT=30s +SERVER_IDLE_TIMEOUT=120s +SERVER_TLS_CERT_PATH= +SERVER_TLS_KEY_PATH= +# Client settings +CLIENT_TIMEOUT=30s +CLIENT_MAX_IDLE_CONNS=20 +CLIENT_MAX_IDLE_CONNS_PER_HOST=20 +CLIENT_IDLE_CONN_TIMEOUT=30s +CLIENT_TLS_MIN_VERSION=TLS12 +# Providers +ANTHROPIC_API_URL=https://api.anthropic.com/v1 +ANTHROPIC_API_KEY= +CLOUDFLARE_API_URL=https://api.cloudflare.com/client/v4/accounts/{ACCOUNT_ID}/ai +CLOUDFLARE_API_KEY= +COHERE_API_URL=https://api.cohere.ai +COHERE_API_KEY= +GROQ_API_URL=https://api.groq.com/openai/v1 +GROQ_API_KEY= +OLLAMA_API_URL=http://ollama:8080/v1 +OLLAMA_API_KEY= +OPENAI_API_URL=https://api.openai.com/v1 +OPENAI_API_KEY= +DEEPSEEK_API_URL=https://api.deepseek.com +DEEPSEEK_API_KEY= diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 0000000..5b22f35 --- /dev/null +++ b/examples/README.md @@ -0,0 +1,26 @@ +# Examples + +Before starting with the examples, ensure you have the inference-gateway up and running: + +1. Copy the `.env.example` file to `.env` and set your provider key. + +2. Set your preferred Large Language Model (LLM) provider for the examples: + +```sh +export LLM_NAME=groq/meta-llama/llama-4-scout-17b-16e-instruct +``` + +3. Run the Docker container: + +``` +docker run --rm -it -p 8080:8080 --env-file .env -e $LLM_NAME ghcr.io/inference-gateway/inference-gateway:0.7.1 +``` + +Recommended is to set the environment variable `ENVIRONMENT=development` in your `.env` file to enable debug mode. + +The following examples demonstrate how to use the Inference Gateway SDK for various tasks: + +- [List LLMs](list/README.md) +- [Chat](chat/README.md) +- [Tools](tools/README.md) +- [MCP](mcp/README.md) diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 0000000..a5c80fb --- /dev/null +++ b/examples/__init__.py @@ -0,0 +1 @@ +# Examples package diff --git a/examples/chat/README.md b/examples/chat/README.md new file mode 100644 index 0000000..fa0afcc --- /dev/null +++ b/examples/chat/README.md @@ -0,0 +1,101 @@ +# Chat Completions Example + +This example demonstrates how to use the Inference Gateway Python SDK for chat completions with both standard HTTP requests and streaming responses. + +## Features + +- **Standard Chat Completion**: Traditional request-response pattern with complete messages +- **Streaming Chat Completion**: Real-time streaming responses for better user experience + +## Usage + +1. **Set up environment**: + + ```bash + export LLM_NAME="groq/meta-llama/llama-4-scout-17b-16e-instruct" + ``` + +2. **Run the example**: + ```bash + python main.py + ``` + +## Code Examples + +### Standard Chat Completion + +```python +from inference_gateway import InferenceGatewayClient, Message + +client = InferenceGatewayClient("http://localhost:8080/v1") + +response = client.create_chat_completion( + model="groq/meta-llama/llama-4-scout-17b-16e-instruct", + messages=[ + Message(role="system", content="You are a helpful assistant"), + Message(role="user", content="Hello! Please introduce yourself briefly."), + ], + max_tokens=100, +) + +print(response.choices[0].message.content) +``` + +### Streaming Chat Completion + +```python +from inference_gateway import InferenceGatewayClient, Message +from inference_gateway.models import SSEvent +import json + +client = InferenceGatewayClient("http://localhost:8080/v1") + +stream = client.create_chat_completion_stream( + model="groq/meta-llama/llama-4-scout-17b-16e-instruct", + messages=[ + Message(role="system", content="You are a helpful assistant"), + Message(role="user", content="Tell me a short story."), + ], + max_tokens=200, +) + +for chunk in stream: + if isinstance(chunk, SSEvent) and chunk.data: + try: + data = json.loads(chunk.data) + if "choices" in data and len(data["choices"]) > 0: + delta = data["choices"][0].get("delta", {}) + if "content" in delta and delta["content"]: + print(delta["content"], end="", flush=True) + except json.JSONDecodeError: + pass +``` + +## Error Handling + +The SDK provides specific exception types: + +- `InferenceGatewayAPIError`: API-related errors (4xx, 5xx responses) +- `InferenceGatewayError`: SDK-related errors (network, parsing, etc.) + +```python +from inference_gateway.client import InferenceGatewayAPIError, InferenceGatewayError + +try: + response = client.create_chat_completion(...) +except (InferenceGatewayAPIError, InferenceGatewayError) as e: + print(f"Error: {e}") +``` + +## Dependencies + +- `inference_gateway`: The Python SDK for Inference Gateway +- Standard library: `json`, `os` + +## Configuration + +The example uses the `LLM_NAME` environment variable to specify the model. Supported models include: + +- OpenAI models: `openai/gpt-4`, `openai/gpt-3.5-turbo` +- Groq models: `groq/meta-llama/llama-4-scout-17b-16e-instruct` +- Other providers as configured in your Inference Gateway instance diff --git a/examples/chat/__init__.py b/examples/chat/__init__.py new file mode 100644 index 0000000..78b8091 --- /dev/null +++ b/examples/chat/__init__.py @@ -0,0 +1 @@ +# Chat example package diff --git a/examples/chat/main.py b/examples/chat/main.py new file mode 100644 index 0000000..e6a33a7 --- /dev/null +++ b/examples/chat/main.py @@ -0,0 +1,87 @@ +import json +import os + +from inference_gateway import InferenceGatewayClient, Message +from inference_gateway.client import InferenceGatewayAPIError, InferenceGatewayError +from inference_gateway.models import SSEvent + + +def main() -> None: + """ + Simple demo of standard and streaming chat completions using the Inference Gateway Python SDK. + """ + # Initialize client + client = InferenceGatewayClient("http://localhost:8080/v1") + + # Use environment variable with default model + LLM_NAME = os.getenv("LLM_NAME", "openai/gpt-4") + print(f"Using model: {LLM_NAME}") + print("=" * 50) + + # Example 1: Standard Chat Completion + print("\n1. Standard Chat Completion:") + print("-" * 30) + + try: + response = client.create_chat_completion( + model=LLM_NAME, + messages=[ + Message(role="system", content="You are a helpful assistant"), + Message(role="user", content="Hello! Please introduce yourself briefly."), + ], + max_tokens=100, + ) + + print(f"Response: {response.choices[0].message.content}") + if response.usage: + print(f"Usage: {response.usage.total_tokens} tokens") + + except (InferenceGatewayAPIError, InferenceGatewayError) as e: + print(f"Error: {e}") + return + + # Example 2: Streaming Chat Completion + print("\n\n2. Streaming Chat Completion:") + print("-" * 30) + + try: + print("Assistant: ", end="", flush=True) + + stream = client.create_chat_completion_stream( + model=LLM_NAME, + messages=[ + Message(role="system", content="You are a helpful assistant"), + Message(role="user", content="Tell me a short story about a robot."), + ], + max_tokens=200, + ) + + for chunk in stream: + if isinstance(chunk, SSEvent): + # Handle Server-Sent Events format + if chunk.data: + try: + data = json.loads(chunk.data) + if "choices" in data and len(data["choices"]) > 0: + delta = data["choices"][0].get("delta", {}) + if "content" in delta and delta["content"]: + print(delta["content"], end="", flush=True) + except json.JSONDecodeError: + pass + elif isinstance(chunk, dict): + # Handle JSON format + if "choices" in chunk and len(chunk["choices"]) > 0: + delta = chunk["choices"][0].get("delta", {}) + if "content" in delta and delta["content"]: + print(delta["content"], end="", flush=True) + + print("\n") + + except (InferenceGatewayAPIError, InferenceGatewayError) as e: + print(f"\nStreaming Error: {e}") + + print("\n" + "=" * 50) + + +if __name__ == "__main__": + main() diff --git a/examples/list/README.md b/examples/list/README.md new file mode 100644 index 0000000..c733c06 --- /dev/null +++ b/examples/list/README.md @@ -0,0 +1,19 @@ +# List LLMs Example + +This example demonstrates how to use the Inference Gateway Python SDK to list available language models. It shows how to: + +- Connect to the Inference Gateway +- List all available models +- Filter models by provider (e.g., OpenAI, Groq, etc.) + +Run the example with: + +```bash +python main.py +``` + +You can specify a different model provider by setting the `LLM_NAME` environment variable: + +```bash +LLM_NAME=openai/gpt-4 python main.py +``` diff --git a/examples/list/__init__.py b/examples/list/__init__.py new file mode 100644 index 0000000..99c5600 --- /dev/null +++ b/examples/list/__init__.py @@ -0,0 +1 @@ +# List example package diff --git a/examples/list/main.py b/examples/list/main.py new file mode 100644 index 0000000..9cd8305 --- /dev/null +++ b/examples/list/main.py @@ -0,0 +1,21 @@ +import os + +from inference_gateway import InferenceGatewayClient + +# Initialize client +client = InferenceGatewayClient("http://localhost:8080/v1") + +# Use environment variable with default model +LLM_NAME = os.getenv("LLM_NAME", "groq/llama3-8b-8192") + +PROVIDER = LLM_NAME.split("/")[0] + +print(f"Using provider: {PROVIDER}") + +# List all available models +models = client.list_models() +print("All models:", models) + +# Filter by provider +openai_models = client.list_models(provider=PROVIDER) +print(f"Provider {PROVIDER} models:", openai_models) diff --git a/examples/mcp/README.md b/examples/mcp/README.md new file mode 100644 index 0000000..b466946 --- /dev/null +++ b/examples/mcp/README.md @@ -0,0 +1,16 @@ +# Model Context Protocol (MCP) Example + +This example demonstrates how to use the Inference Gateway Python SDK to interact with Model Context Protocol (MCP) tools. It shows how to: + +- Connect to the Inference Gateway with MCP enabled +- List available MCP tools exposed by the gateway + +**Prerequisites:** The Inference Gateway must be configured with `MCP_ENABLE=true` and `MCP_EXPOSE=true` environment variables. + +Run the example with: + +```bash +python main.py +``` + +Learn more about MCP at [modelcontextprotocol.io](https://modelcontextprotocol.io) diff --git a/examples/mcp/__init__.py b/examples/mcp/__init__.py new file mode 100644 index 0000000..838c11a --- /dev/null +++ b/examples/mcp/__init__.py @@ -0,0 +1 @@ +# MCP example package diff --git a/examples/mcp/main.py b/examples/mcp/main.py new file mode 100644 index 0000000..63336a7 --- /dev/null +++ b/examples/mcp/main.py @@ -0,0 +1,8 @@ +from inference_gateway import InferenceGatewayClient + +# Initialize client +client = InferenceGatewayClient("http://localhost:8080/v1") + +# List available MCP tools works when MCP_ENABLE and MCP_EXPOSE are set on the gateway +tools = client.list_tools() +print("Available tools:", tools) diff --git a/examples/tools/README.md b/examples/tools/README.md new file mode 100644 index 0000000..ae08e35 --- /dev/null +++ b/examples/tools/README.md @@ -0,0 +1,91 @@ +# Tools Example + +This example demonstrates how to use function calling (tools) with the Inference Gateway Python SDK. + +## Overview + +The example shows: + +- How to define tools using the standard OpenAI function calling format +- Making chat completions with tools enabled +- Handling tool calls from the model +- Simulating tool execution +- Continuing conversations with tool results + +## Tools Defined + +1. **Weather Tool** - Get current weather for a location +2. **Calculator Tool** - Perform basic mathematical calculations + +## Running the Example + +```bash +# Set the model (optional, defaults to openai/gpt-4) +export LLM_NAME="openai/gpt-4" + +# Run the example +python examples/tools/main.py +``` + +## Tool Format + +Tools are defined using type-safe Pydantic models from the SDK: + +```python +from inference_gateway.models import ChatCompletionTool, FunctionObject, FunctionParameters + +weather_tool = ChatCompletionTool( + type="function", + function=FunctionObject( + name="get_current_weather", + description="Get the current weather in a given location", + parameters=FunctionParameters( + type="object", + properties={ + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use" + } + }, + required=["location"] + ) + ) +) +``` + +This approach provides: + +- **Type Safety**: Full type checking at development time +- **IDE Support**: Better autocomplete and error detection +- **Validation**: Automatic validation of tool definitions +- **Documentation**: Self-documenting code with proper types + +## Test Cases + +The example runs through several test cases: + +1. Weather query: "What's the weather like in San Francisco?" +2. Math calculation: "Calculate 15 \* 8 + 32" +3. Multiple tools: "What's the weather in London in Fahrenheit and also calculate 100 / 4?" +4. No tools needed: "Hello! How are you doing today?" + +## Key Concepts + +- **Type-Safe Tool Definition**: Use Pydantic models for tool definitions to ensure type safety +- **Function Parameters**: Define proper JSON schema for function parameters with validation +- **Tool Calls**: Handle when the model decides to call a function +- **Tool Results**: Provide function results back to continue the conversation +- **Multiple Tools**: Models can call multiple tools in a single response +- **Conversation Flow**: Maintain proper message history with tool calls and results +- **Type Safety**: Full type checking for tool calls and responses + +## Notes + +- This example simulates tool execution - in production you'd integrate with real APIs +- The calculator uses `eval()` for simplicity - use a proper math parser in production +- Tool calls are optional - the model will only use them when appropriate diff --git a/examples/tools/__init__.py b/examples/tools/__init__.py new file mode 100644 index 0000000..d56d8cb --- /dev/null +++ b/examples/tools/__init__.py @@ -0,0 +1 @@ +# Tools example package diff --git a/examples/tools/main.py b/examples/tools/main.py new file mode 100644 index 0000000..b005128 --- /dev/null +++ b/examples/tools/main.py @@ -0,0 +1,200 @@ +#!/usr/bin/env python3 +""" +Example demonstrating tool use with the Inference Gateway Python SDK. + +This example shows how to: +1. Define tools/functions using type-safe Pydantic models +2. Make chat completions with tool calling enabled +3. Handle tool calls from the model +4. Simulate tool execution and continue the conversation +""" + +import json +import os +from typing import Any, Dict + +from inference_gateway import InferenceGatewayClient, Message +from inference_gateway.models import ( + ChatCompletionMessageToolCall, + ChatCompletionTool, + FunctionObject, + FunctionParameters, +) + +# Initialize client +client = InferenceGatewayClient("http://localhost:8080/v1") + +# Use environment variable with default model +LLM_NAME = os.getenv("LLM_NAME", "openai/gpt-4") + +print(f"Using model: {LLM_NAME}") + +# Define weather tool using type-safe Pydantic models +weather_tool = ChatCompletionTool( + type="function", + function=FunctionObject( + name="get_current_weather", + description="Get the current weather in a given location", + parameters=FunctionParameters( + type="object", + properties={ + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use", + }, + }, + required=["location"], + ), + ), +) + +# Define calculator tool using type-safe Pydantic models +calculator_tool = ChatCompletionTool( + type="function", + function=FunctionObject( + name="calculate", + description="Perform basic mathematical calculations", + parameters=FunctionParameters( + type="object", + properties={ + "expression": { + "type": "string", + "description": "Mathematical expression to evaluate, e.g. '2 + 3 * 4'", + } + }, + required=["expression"], + ), + ), +) + + +def simulate_weather_api(location: str, unit: str = "celsius") -> Dict[str, Any]: + """Simulate a weather API call.""" + # This would normally call a real weather API + weather_data = { + "location": location, + "temperature": 22 if unit == "celsius" else 72, + "unit": unit, + "condition": "partly cloudy", + "humidity": 65, + "wind_speed": 10, + } + return weather_data + + +def simulate_calculator(expression: str) -> Dict[str, Any]: + """Simulate a calculator function.""" + try: + # Safe evaluation of mathematical expressions + # In production, use a proper math parser instead of eval + result = eval(expression) + return {"expression": expression, "result": result} + except Exception as e: + return {"expression": expression, "error": str(e)} + + +def execute_tool_call(tool_call: ChatCompletionMessageToolCall) -> str: + """Execute a tool call and return the result as a JSON string.""" + function_name = tool_call.function.name + arguments = json.loads(tool_call.function.arguments) + + if function_name == "get_current_weather": + result = simulate_weather_api( + location=arguments["location"], unit=arguments.get("unit", "celsius") + ) + return json.dumps(result) + elif function_name == "calculate": + result = simulate_calculator(arguments["expression"]) + return json.dumps(result) + else: + return json.dumps({"error": f"Unknown function: {function_name}"}) + + +def main() -> None: + """Main example demonstrating tool use.""" + + # Available tools + tools = [weather_tool, calculator_tool] + + # Test cases for different scenarios + test_cases = [ + "What's the weather like in San Francisco?", + "Calculate 15 * 8 + 32", + "What's the weather in London in Fahrenheit and also calculate 100 / 4?", + "Hello! How are you doing today?", # No tool use expected + ] + + for i, user_message in enumerate(test_cases, 1): + print(f"\n{'='*60}") + print(f"Test Case {i}: {user_message}") + print("=" * 60) + + # Start conversation + messages = [ + Message( + role="system", + content="You are a helpful assistant with access to weather information and a calculator. Use the tools when appropriate to help answer user questions.", + ), + Message(role="user", content=user_message), + ] + + # Make initial request with tools + response = client.create_chat_completion(model=LLM_NAME, messages=messages, tools=tools) + + assistant_message = response.choices[0].message + print(f"Assistant: {assistant_message.content or '(Making tool calls...)'}") + + # Check if the model made tool calls + if assistant_message.tool_calls: + print(f"\nTool calls made: {len(assistant_message.tool_calls)}") + + # Add assistant message to conversation + messages.append( + Message( + role="assistant", + content=assistant_message.content, + tool_calls=[ + { + "id": tool_call.id, + "type": tool_call.type, + "function": { + "name": tool_call.function.name, + "arguments": tool_call.function.arguments, + }, + } + for tool_call in assistant_message.tool_calls + ], + ) + ) + + # Execute each tool call and add results + for tool_call in assistant_message.tool_calls: + print(f"\nExecuting: {tool_call.function.name}") + print(f"Arguments: {tool_call.function.arguments}") + + # Execute the tool call + tool_result = execute_tool_call(tool_call) + print(f"Result: {tool_result}") + + # Add tool result to conversation + messages.append( + Message(role="tool", tool_call_id=tool_call.id, content=tool_result) + ) + + # Get final response after tool execution + final_response = client.create_chat_completion( + model=LLM_NAME, messages=messages, tools=tools + ) + + print(f"\nFinal Assistant Response: {final_response.choices[0].message.content}") + else: + print("No tool calls were made.") + + +if __name__ == "__main__": + main() diff --git a/inference_gateway/client.py b/inference_gateway/client.py index d2c7101..df3cf88 100644 --- a/inference_gateway/client.py +++ b/inference_gateway/client.py @@ -12,6 +12,7 @@ from pydantic import ValidationError from inference_gateway.models import ( + ChatCompletionTool, CreateChatCompletionRequest, CreateChatCompletionResponse, ListModelsResponse, @@ -57,11 +58,11 @@ class InferenceGatewayClient: Example: ```python # Basic usage - client = InferenceGatewayClient("https://api.example.com") + client = InferenceGatewayClient("https://api.example.com/v1") # With authentication client = InferenceGatewayClient( - "https://api.example.com", + "https://api.example.com/v1", token="your-api-token" ) @@ -87,7 +88,7 @@ def __init__( """Initialize the client with base URL and optional auth token. Args: - base_url: The base URL of the Inference Gateway API + base_url: The base URL of the Inference Gateway API (should include /v1) token: Optional authentication token timeout: Request timeout in seconds (default: 30.0) use_httpx: Whether to use httpx instead of requests (default: False) @@ -174,7 +175,7 @@ def list_models(self, provider: Optional[Union[Provider, str]] = None) -> ListMo InferenceGatewayAPIError: If the API request fails InferenceGatewayValidationError: If response validation fails """ - url = f"{self.base_url}/v1/models" + url = f"{self.base_url}/models" params = {} if provider: @@ -264,7 +265,7 @@ def create_chat_completion( provider: Optional[Union[Provider, str]] = None, max_tokens: Optional[int] = None, stream: bool = False, - tools: Optional[List[Dict[str, Any]]] = None, + tools: Optional[List[ChatCompletionTool]] = None, **kwargs: Any, ) -> CreateChatCompletionResponse: """Generate a chat completion. @@ -275,7 +276,7 @@ def create_chat_completion( provider: Optional provider specification max_tokens: Maximum number of tokens to generate stream: Whether to stream the response - tools: List of tools the model may call + tools: List of tools the model may call (using ChatCompletionTool models) **kwargs: Additional parameters to pass to the API Returns: @@ -285,7 +286,7 @@ def create_chat_completion( InferenceGatewayAPIError: If the API request fails InferenceGatewayValidationError: If request/response validation fails """ - url = f"{self.base_url}/v1/chat/completions" + url = f"{self.base_url}/chat/completions" params = {} if provider: @@ -302,7 +303,7 @@ def create_chat_completion( if max_tokens is not None: request_data["max_tokens"] = max_tokens if tools: - request_data["tools"] = tools + request_data["tools"] = [tool.model_dump(exclude_none=True) for tool in tools] request_data.update(kwargs) @@ -323,7 +324,7 @@ def create_chat_completion_stream( messages: List[Message], provider: Optional[Union[Provider, str]] = None, max_tokens: Optional[int] = None, - tools: Optional[List[Dict[str, Any]]] = None, + tools: Optional[List[ChatCompletionTool]] = None, use_sse: bool = True, **kwargs: Any, ) -> Generator[Union[Dict[str, Any], SSEvent], None, None]: @@ -334,7 +335,7 @@ def create_chat_completion_stream( messages: List of messages for the conversation provider: Optional provider specification max_tokens: Maximum number of tokens to generate - tools: List of tools the model may call + tools: List of tools the model may call (using ChatCompletionTool models) use_sse: Whether to use Server-Sent Events format **kwargs: Additional parameters to pass to the API @@ -345,7 +346,7 @@ def create_chat_completion_stream( InferenceGatewayAPIError: If the API request fails InferenceGatewayValidationError: If request validation fails """ - url = f"{self.base_url}/v1/chat/completions" + url = f"{self.base_url}/chat/completions" params = {} if provider: @@ -362,7 +363,7 @@ def create_chat_completion_stream( if max_tokens is not None: request_data["max_tokens"] = max_tokens if tools: - request_data["tools"] = tools + request_data["tools"] = [tool.model_dump(exclude_none=True) for tool in tools] request_data.update(kwargs) diff --git a/inference_gateway/models.py b/inference_gateway/models.py index 932edeb..11fa56a 100644 --- a/inference_gateway/models.py +++ b/inference_gateway/models.py @@ -1,6 +1,6 @@ # generated by datamodel-codegen: # filename: openapi.yaml -# timestamp: 2025-05-26T15:46:03+00:00 +# timestamp: 2025-05-26T19:01:03+00:00 from __future__ import annotations @@ -18,12 +18,6 @@ class Provider( ) root: Literal["ollama", "groq", "openai", "cloudflare", "cohere", "anthropic", "deepseek"] - def __eq__(self, other: Any) -> bool: - """Allow comparison with strings.""" - if isinstance(other, str): - return self.root == other - return super().__eq__(other) - class ProviderSpecificResponse(BaseModel): """ @@ -94,9 +88,9 @@ class SSEvent(BaseModel): "message-end", "stream-end", ] - ] - data: Optional[str] - retry: Optional[int] + ] = None + data: Optional[str] = None + retry: Optional[int] = None class Endpoints(BaseModel): @@ -111,7 +105,7 @@ class Error(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - error: Optional[str] + error: Optional[str] = None class MessageRole(RootModel[Literal["system", "user", "assistant", "tool"]]): @@ -123,12 +117,6 @@ class MessageRole(RootModel[Literal["system", "user", "assistant", "tool"]]): Role of the message sender """ - def __eq__(self, other: Any) -> bool: - """Allow comparison with strings.""" - if isinstance(other, str): - return self.root == other - return super().__eq__(other) - class Model(BaseModel): """ @@ -153,9 +141,9 @@ class ListModelsResponse(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - provider: Optional[Provider] + provider: Optional[Provider] = None object: str - data: Sequence[Model] + data: Sequence[Model] = [] class MCPTool(BaseModel): @@ -191,7 +179,7 @@ class MCPTool(BaseModel): } ] ), - ] + ] = None """ JSON schema for the tool's input parameters """ @@ -227,15 +215,15 @@ class CompletionUsage(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - completion_tokens: int + completion_tokens: int = 0 """ Number of tokens in the generated completion. """ - prompt_tokens: int + prompt_tokens: int = 0 """ Number of tokens in the prompt. """ - total_tokens: int + total_tokens: int = 0 """ Total number of tokens used in the request (prompt + completion). """ @@ -291,11 +279,11 @@ class Function(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - name: Optional[str] + name: Optional[str] = None """ The name of the function to call. """ - arguments: Optional[str] + arguments: Optional[str] = None """ The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function. """ @@ -306,15 +294,15 @@ class ChatCompletionMessageToolCallChunk(BaseModel): populate_by_name=True, ) index: int - id: Optional[str] + id: Optional[str] = None """ The ID of the tool call. """ - type: Optional[str] + type: Optional[str] = None """ The type of the tool. Currently, only `function` is supported. """ - function: Optional[Function] + function: Optional[Function] = None class TopLogprob(BaseModel): @@ -414,7 +402,7 @@ class ListToolsResponse(BaseModel): """ Always "list" """ - data: Sequence[MCPTool] + data: Sequence[MCPTool] = [] """ Array of available MCP tools """ @@ -424,7 +412,7 @@ class FunctionObject(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - description: Optional[str] + description: Optional[str] = None """ A description of what the function does, used by the model to choose when and how to call the function. """ @@ -432,7 +420,7 @@ class FunctionObject(BaseModel): """ The name of the function to be called. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64. """ - parameters: Optional[FunctionParameters] + parameters: Optional[FunctionParameters] = None strict: bool = False """ Whether to enable strict schema adherence when generating the function call. If set to true, the model will follow the exact schema defined in the `parameters` field. Only a subset of JSON Schema is supported when `strict` is `true`. Learn more about Structured Outputs in the [function calling guide](docs/guides/function-calling). @@ -549,7 +537,7 @@ class CreateChatCompletionResponse(BaseModel): """ The object type, which is always `chat.completion`. """ - usage: Optional[CompletionUsage] + usage: Optional[CompletionUsage] = None class ChatCompletionStreamResponseDelta(BaseModel): @@ -564,17 +552,17 @@ class ChatCompletionStreamResponseDelta(BaseModel): """ The contents of the chunk message. """ - reasoning_content: Optional[str] + reasoning_content: Optional[str] = None """ The reasoning content of the chunk message. """ - reasoning: Optional[str] + reasoning: Optional[str] = None """ The reasoning of the chunk message. Same as reasoning_content. """ - tool_calls: Optional[Sequence[ChatCompletionMessageToolCallChunk]] + tool_calls: Optional[Sequence[ChatCompletionMessageToolCallChunk]] = None role: MessageRole - refusal: Optional[str] + refusal: Optional[str] = None """ The refusal message generated by the model. """ @@ -585,7 +573,7 @@ class ChatCompletionStreamChoice(BaseModel): populate_by_name=True, ) delta: ChatCompletionStreamResponseDelta - logprobs: Optional[Logprobs] + logprobs: Optional[Logprobs] = None """ Log probability information for the choice. """ @@ -624,7 +612,7 @@ class CreateChatCompletionStreamResponse(BaseModel): """ The model to generate the completion. """ - system_fingerprint: Optional[str] + system_fingerprint: Optional[str] = None """ This fingerprint represents the backend configuration that the model runs with. Can be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism. @@ -634,8 +622,8 @@ class CreateChatCompletionStreamResponse(BaseModel): """ The object type, which is always `chat.completion.chunk`. """ - usage: Optional[CompletionUsage] - reasoning_format: Optional[str] + usage: Optional[CompletionUsage] = None + reasoning_format: Optional[str] = None """ The format of the reasoning content. Can be `raw` or `parsed`. When specified as raw some reasoning models will output tags. When specified as parsed the model will output the reasoning under reasoning_content. diff --git a/tests/test_client.py b/tests/test_client.py index 706bc8e..ce3fce6 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -25,7 +25,7 @@ @pytest.fixture def client(): """Create a test client instance""" - return InferenceGatewayClient("http://test-api") + return InferenceGatewayClient("http://test-api/v1") @pytest.fixture @@ -55,7 +55,7 @@ def mock_response(): def test_params(): """Fixture providing test parameters""" return { - "api_url": "http://test-api", + "api_url": "http://test-api/v1", "provider": "openai", "model": "gpt-4", "message": Message(role="user", content="Hello"), @@ -64,11 +64,11 @@ def test_params(): def test_client_initialization(): """Test client initialization with and without token""" - client = InferenceGatewayClient("http://test-api") - assert client.base_url == "http://test-api" + client = InferenceGatewayClient("http://test-api/v1") + assert client.base_url == "http://test-api/v1" assert "Authorization" not in client.session.headers - client_with_token = InferenceGatewayClient("http://test-api", token="test-token") + client_with_token = InferenceGatewayClient("http://test-api/v1", token="test-token") assert "Authorization" in client_with_token.session.headers assert client_with_token.session.headers["Authorization"] == "Bearer test-token" @@ -83,7 +83,7 @@ def test_list_models(mock_request, client, mock_response): "GET", "http://test-api/v1/models", params={}, timeout=30.0 ) assert isinstance(response, ListModelsResponse) - assert response.provider == "openai" + assert response.provider.root == "openai" assert response.object == "list" assert len(response.data) == 1 assert response.data[0].id == "gpt-4" @@ -123,7 +123,7 @@ def test_list_models_with_provider(mock_request, client): "GET", "http://test-api/v1/models", params={"provider": "openai"}, timeout=30.0 ) assert isinstance(response, ListModelsResponse) - assert response.provider == "openai" + assert response.provider.root == "openai" assert response.object == "list" assert len(response.data) == 2 assert response.data[0].id == "gpt-4" @@ -191,25 +191,27 @@ def test_create_chat_completion(mock_request, client): @patch("requests.Session.request") -def test_health_check(mock_request, client): +def test_health_check(mock_request): """Test health check endpoint""" + health_client = InferenceGatewayClient("http://test-api") + mock_response = Mock() mock_response.status_code = 200 mock_response.raise_for_status.return_value = None mock_request.return_value = mock_response - assert client.health_check() is True + assert health_client.health_check() is True mock_request.assert_called_once_with("GET", "http://test-api/health", timeout=30.0) mock_response.status_code = 500 mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError("Server error") - assert client.health_check() is False + assert health_client.health_check() is False def test_message_model(): """Test Message model creation and serialization""" message = Message(role="user", content="Hello!") - assert message.role == "user" + assert message.role.root == "user" assert message.content == "Hello!" message_dict = message.model_dump() @@ -369,8 +371,9 @@ def test_create_chat_completion_stream_error(mock_request, client, test_params, @patch("requests.Session.request") -def test_proxy_request(mock_request, client): +def test_proxy_request(mock_request): """Test proxy request to provider""" + proxy_client = InferenceGatewayClient("http://test-api") mock_resp = Mock() mock_resp.status_code = 200 @@ -378,12 +381,15 @@ def test_proxy_request(mock_request, client): mock_resp.raise_for_status.return_value = None mock_request.return_value = mock_resp - response = client.proxy_request( + response = proxy_client.proxy_request( provider="openai", path="completions", method="POST", json_data={"prompt": "Hello"} ) mock_request.assert_called_once_with( - "POST", "http://test-api/proxy/openai/completions", json={"prompt": "Hello"}, timeout=30.0 + "POST", + "http://test-api/proxy/openai/completions", + json={"prompt": "Hello"}, + timeout=30.0, ) assert response == {"response": "test"} @@ -408,8 +414,8 @@ def test_exception_hierarchy(): def test_context_manager(): """Test client as context manager""" - with InferenceGatewayClient("http://test-api") as client: - assert client.base_url == "http://test-api" + with InferenceGatewayClient("http://test-api/v1") as client: + assert client.base_url == "http://test-api/v1" assert client.session is not None @@ -434,7 +440,7 @@ def test_client_with_custom_timeout(mock_request): } mock_request.return_value = mock_response - client = InferenceGatewayClient("http://test-api", timeout=30) + client = InferenceGatewayClient("http://test-api/v1", timeout=30) client.list_models() mock_request.assert_called_once_with("GET", "http://test-api/v1/models", params={}, timeout=30)