From 45de736729032e5035b44316662f7ccb508fb32d Mon Sep 17 00:00:00 2001 From: Eden Reich Date: Mon, 26 May 2025 23:55:56 +0000 Subject: [PATCH] refactor: Enhance streaming chat completion with structured SSEvent handling and validation Simplify the code, Server Sent Events is the default for streaming by any provider that follows OpenAI API - it's compatible, i.e why I removed the use_sse. In previous versions of Ollama there was no OpenAI Compatible API and the response was a raw JSON instead of data: Signed-off-by: Eden Reich --- README.md | 44 +++++++++---- examples/chat/README.md | 28 ++++++-- examples/chat/main.py | 41 ++++++++---- inference_gateway/client.py | 111 +++++++++++--------------------- inference_gateway/models.py | 2 +- tests/test_client.py | 125 ++++++++++++++++++++++-------------- 6 files changed, 196 insertions(+), 155 deletions(-) diff --git a/README.md b/README.md index 7645e81..bf77435 100644 --- a/README.md +++ b/README.md @@ -124,25 +124,41 @@ print(response.choices[0].message.content) #### Streaming Completion ```python -# Using Server-Sent Events (SSE) +from inference_gateway.models import CreateChatCompletionStreamResponse +from pydantic import ValidationError +import json + +# Streaming returns SSEvent objects for chunk in client.create_chat_completion_stream( model="ollama/llama2", messages=[ Message(role="user", content="Tell me a story") - ], - use_sse=True -): - print(chunk.data, end="", flush=True) - -# Using JSON lines -for chunk in client.create_chat_completion_stream( - model="anthropic/claude-3", - messages=[ - Message(role="user", content="Explain AI safety") - ], - use_sse=False + ] ): - print(chunk["choices"][0]["delta"]["content"], end="", flush=True) + if chunk.data: + try: + # Parse the raw JSON data + data = json.loads(chunk.data) + + # Unmarshal to structured model for type safety + try: + structured_chunk = CreateChatCompletionStreamResponse.model_validate(data) + + # Use the structured model for better type safety and IDE support + if structured_chunk.choices and len(structured_chunk.choices) > 0: + choice = structured_chunk.choices[0] + if hasattr(choice.delta, 'content') and choice.delta.content: + print(choice.delta.content, end="", flush=True) + + except ValidationError: + # Fallback to manual parsing for non-standard chunks + if "choices" in data and len(data["choices"]) > 0: + delta = data["choices"][0].get("delta", {}) + if "content" in delta and delta["content"]: + print(delta["content"], end="", flush=True) + + except json.JSONDecodeError: + pass ``` ### Proxy Requests diff --git a/examples/chat/README.md b/examples/chat/README.md index fa0afcc..28bebc7 100644 --- a/examples/chat/README.md +++ b/examples/chat/README.md @@ -45,7 +45,8 @@ print(response.choices[0].message.content) ```python from inference_gateway import InferenceGatewayClient, Message -from inference_gateway.models import SSEvent +from inference_gateway.models import SSEvent, CreateChatCompletionStreamResponse +from pydantic import ValidationError import json client = InferenceGatewayClient("http://localhost:8080/v1") @@ -60,13 +61,28 @@ stream = client.create_chat_completion_stream( ) for chunk in stream: - if isinstance(chunk, SSEvent) and chunk.data: + if chunk.data: try: + # Parse the raw JSON data data = json.loads(chunk.data) - if "choices" in data and len(data["choices"]) > 0: - delta = data["choices"][0].get("delta", {}) - if "content" in delta and delta["content"]: - print(delta["content"], end="", flush=True) + + # Unmarshal to structured model for type safety + try: + structured_chunk = CreateChatCompletionStreamResponse.model_validate(data) + + # Use the structured model for better type safety and IDE support + if structured_chunk.choices and len(structured_chunk.choices) > 0: + choice = structured_chunk.choices[0] + if hasattr(choice.delta, 'content') and choice.delta.content: + print(choice.delta.content, end="", flush=True) + + except ValidationError: + # Fallback to manual parsing for non-standard chunks + if "choices" in data and len(data["choices"]) > 0: + delta = data["choices"][0].get("delta", {}) + if "content" in delta and delta["content"]: + print(delta["content"], end="", flush=True) + except json.JSONDecodeError: pass ``` diff --git a/examples/chat/main.py b/examples/chat/main.py index e6a33a7..604f404 100644 --- a/examples/chat/main.py +++ b/examples/chat/main.py @@ -1,9 +1,11 @@ import json import os +from pydantic import ValidationError + from inference_gateway import InferenceGatewayClient, Message from inference_gateway.client import InferenceGatewayAPIError, InferenceGatewayError -from inference_gateway.models import SSEvent +from inference_gateway.models import CreateChatCompletionStreamResponse, SSEvent def main() -> None: @@ -57,23 +59,36 @@ def main() -> None: ) for chunk in stream: - if isinstance(chunk, SSEvent): - # Handle Server-Sent Events format - if chunk.data: + # All chunks are now SSEvent objects + if chunk.data: + try: + # Parse the raw JSON data + data = json.loads(chunk.data) + + # Try to unmarshal to structured model for type safety try: - data = json.loads(chunk.data) + structured_chunk = CreateChatCompletionStreamResponse.model_validate(data) + + # Use the structured model for better type safety and IDE support + if structured_chunk.choices and len(structured_chunk.choices) > 0: + choice = structured_chunk.choices[0] + if hasattr(choice.delta, "content") and choice.delta.content: + print(choice.delta.content, end="", flush=True) + + # Optionally show other information + if choice.finish_reason and choice.finish_reason != "null": + print(f"\n[Finished: {choice.finish_reason}]") + + except ValidationError: + # Fallback to manual parsing for non-standard chunks if "choices" in data and len(data["choices"]) > 0: delta = data["choices"][0].get("delta", {}) if "content" in delta and delta["content"]: print(delta["content"], end="", flush=True) - except json.JSONDecodeError: - pass - elif isinstance(chunk, dict): - # Handle JSON format - if "choices" in chunk and len(chunk["choices"]) > 0: - delta = chunk["choices"][0].get("delta", {}) - if "content" in delta and delta["content"]: - print(delta["content"], end="", flush=True) + + except json.JSONDecodeError: + # Handle non-JSON SSE data + print(f"[Non-JSON chunk: {chunk.data}]", end="", flush=True) print("\n") diff --git a/inference_gateway/client.py b/inference_gateway/client.py index df3cf88..5170f6a 100644 --- a/inference_gateway/client.py +++ b/inference_gateway/client.py @@ -206,37 +206,6 @@ def list_tools(self) -> ListToolsResponse: except ValidationError as e: raise InferenceGatewayValidationError(f"Response validation failed: {e}") - def _parse_sse_chunk(self, chunk: bytes) -> SSEvent: - """Parse an SSE message chunk into structured event data. - - Args: - chunk: Raw SSE message chunk in bytes format - - Returns: - SSEvent: Parsed SSE message with event type and data fields - - Raises: - InferenceGatewayValidationError: If chunk format or content is invalid - """ - if not isinstance(chunk, bytes): - raise TypeError(f"Expected bytes, got {type(chunk)}") - - try: - decoded = chunk.decode("utf-8") - event_type = None - data = None - - for line in (l.strip() for l in decoded.split("\n") if l.strip()): - if line.startswith("event:"): - event_type = line.removeprefix("event:").strip() - elif line.startswith("data:"): - data = line.removeprefix("data:").strip() - - return SSEvent(event=event_type, data=data, retry=None) - - except UnicodeDecodeError as e: - raise InferenceGatewayValidationError(f"Invalid UTF-8 encoding in SSE chunk: {chunk!r}") - def _parse_json_line(self, line: bytes) -> Dict[str, Any]: """Parse a single JSON line into a dictionary. @@ -325,9 +294,8 @@ def create_chat_completion_stream( provider: Optional[Union[Provider, str]] = None, max_tokens: Optional[int] = None, tools: Optional[List[ChatCompletionTool]] = None, - use_sse: bool = True, **kwargs: Any, - ) -> Generator[Union[Dict[str, Any], SSEvent], None, None]: + ) -> Generator[SSEvent, None, None]: """Stream a chat completion. Args: @@ -336,11 +304,10 @@ def create_chat_completion_stream( provider: Optional provider specification max_tokens: Maximum number of tokens to generate tools: List of tools the model may call (using ChatCompletionTool models) - use_sse: Whether to use Server-Sent Events format **kwargs: Additional parameters to pass to the API Yields: - Union[Dict[str, Any], SSEvent]: Stream chunks + SSEvent: Stream chunks in SSEvent format Raises: InferenceGatewayAPIError: If the API request fails @@ -377,7 +344,7 @@ def create_chat_completion_stream( response.raise_for_status() except httpx.HTTPStatusError as e: raise InferenceGatewayAPIError(f"Request failed: {str(e)}") - yield from self._process_stream_response(response, use_sse) + yield from self._process_stream_response(response) else: requests_response = self.session.post( url, params=params, json=request.model_dump(exclude_none=True), stream=True @@ -386,49 +353,45 @@ def create_chat_completion_stream( requests_response.raise_for_status() except (requests.exceptions.HTTPError, Exception) as e: raise InferenceGatewayAPIError(f"Request failed: {str(e)}") - yield from self._process_stream_response(requests_response, use_sse) + yield from self._process_stream_response(requests_response) except ValidationError as e: raise InferenceGatewayValidationError(f"Request validation failed: {e}") def _process_stream_response( - self, response: Union[requests.Response, httpx.Response], use_sse: bool - ) -> Generator[Union[Dict[str, Any], SSEvent], None, None]: - """Process streaming response data.""" - if use_sse: - buffer: List[bytes] = [] - - for line in response.iter_lines(): - if not line: - if buffer: - chunk = b"\n".join(buffer) - yield self._parse_sse_chunk(chunk) - buffer = [] - continue - - if isinstance(line, str): - line_bytes = line.encode("utf-8") - else: - line_bytes = line - buffer.append(line_bytes) - else: - for line in response.iter_lines(): - if not line: - continue - - if isinstance(line, str): - line_bytes = line.encode("utf-8") - else: - line_bytes = line - - if line_bytes.strip() == b"data: [DONE]": - continue - if line_bytes.startswith(b"data: "): - json_str = line_bytes[6:].decode("utf-8") - data = json.loads(json_str) - yield data - else: - yield self._parse_json_line(line_bytes) + self, response: Union[requests.Response, httpx.Response] + ) -> Generator[SSEvent, None, None]: + """Process streaming response data in SSEvent format.""" + current_event = None + + for line in response.iter_lines(): + if not line: + continue + + if isinstance(line, str): + line_bytes = line.encode("utf-8") + else: + line_bytes = line + + if line_bytes.strip() == b"data: [DONE]": + continue + + if line_bytes.startswith(b"event: "): + current_event = line_bytes[7:].decode("utf-8").strip() + continue + elif line_bytes.startswith(b"data: "): + json_str = line_bytes[6:].decode("utf-8") + event_type = current_event if current_event else "content-delta" + yield SSEvent(event=event_type, data=json_str) + current_event = None + elif line_bytes.strip() == b"": + continue + else: + try: + parsed_data = self._parse_json_line(line_bytes) + yield SSEvent(event="content-delta", data=json.dumps(parsed_data)) + except Exception: + yield SSEvent(event="content-delta", data=line_bytes.decode("utf-8")) def proxy_request( self, diff --git a/inference_gateway/models.py b/inference_gateway/models.py index 11fa56a..089542e 100644 --- a/inference_gateway/models.py +++ b/inference_gateway/models.py @@ -1,6 +1,6 @@ # generated by datamodel-codegen: # filename: openapi.yaml -# timestamp: 2025-05-26T19:01:03+00:00 +# timestamp: 2025-05-26T23:49:03+00:00 from __future__ import annotations diff --git a/tests/test_client.py b/tests/test_client.py index ce3fce6..de8e137 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -243,44 +243,33 @@ def test_message_role_values(): MessageRole("invalid_role") -@pytest.mark.parametrize("use_sse,expected_format", [(True, "sse"), (False, "json")]) @patch("requests.Session.request") -def test_create_chat_completion_stream(mock_request, client, use_sse, expected_format): - """Test streaming chat completion with both raw JSON and SSE formats""" +def test_create_chat_completion_stream(mock_request, client): + """Test streaming chat completion in SSEvent format""" mock_response = Mock() mock_response.status_code = 200 mock_response.raise_for_status.return_value = None - if use_sse: - mock_response.iter_lines.return_value = [ - b"event: message-start", - b'data: {"role":"assistant"}', - b"", - b"event: content-delta", - b'data: {"content":"Hello"}', - b"", - b"event: content-delta", - b'data: {"content":" world!"}', - b"", - b"event: message-end", - b'data: {"content":""}', - b"", - ] - else: - mock_response.iter_lines.return_value = [ - b'data: {"choices":[{"delta":{"role":"assistant"}}],"model":"gpt-4"}', - b'data: {"choices":[{"delta":{"content":"Hello"}}],"model":"gpt-4"}', - b'data: {"choices":[{"delta":{"content":" world!"}}],"model":"gpt-4"}', - b"data: [DONE]", - ] + mock_response.iter_lines.return_value = [ + b"event: message-start", + b'data: {"role":"assistant"}', + b"", + b"event: content-delta", + b'data: {"content":"Hello"}', + b"", + b"event: content-delta", + b'data: {"content":" world!"}', + b"", + b"event: message-end", + b'data: {"content":""}', + b"", + ] mock_request.return_value = mock_response messages = [Message(role="user", content="What's up?")] chunks = list( - client.create_chat_completion_stream( - model="gpt-4", messages=messages, provider="openai", use_sse=use_sse - ) + client.create_chat_completion_stream(model="gpt-4", messages=messages, provider="openai") ) mock_request.assert_called_once_with( @@ -296,23 +285,67 @@ def test_create_chat_completion_stream(mock_request, client, use_sse, expected_f stream=True, ) - if expected_format == "sse": - assert len(chunks) == 4 - assert chunks[0].event == "message-start" - assert chunks[0].data == '{"role":"assistant"}' - assert chunks[1].event == "content-delta" - assert chunks[1].data == '{"content":"Hello"}' - assert chunks[2].event == "content-delta" - assert chunks[2].data == '{"content":" world!"}' - assert chunks[3].event == "message-end" - assert chunks[3].data == '{"content":""}' - else: - assert len(chunks) == 3 - assert "choices" in chunks[0] - assert "delta" in chunks[0]["choices"][0] - assert chunks[0]["choices"][0]["delta"]["role"] == "assistant" - assert chunks[1]["choices"][0]["delta"]["content"] == "Hello" - assert chunks[2]["choices"][0]["delta"]["content"] == " world!" + assert len(chunks) == 4 + assert chunks[0].event == "message-start" + assert chunks[0].data == '{"role":"assistant"}' + assert chunks[1].event == "content-delta" + assert chunks[1].data == '{"content":"Hello"}' + assert chunks[2].event == "content-delta" + assert chunks[2].data == '{"content":" world!"}' + assert chunks[3].event == "message-end" + assert chunks[3].data == '{"content":""}' + + +@patch("requests.Session.request") +def test_create_chat_completion_stream_openai_format(mock_request, client): + """Test streaming chat completion with OpenAI-compatible data format""" + mock_response = Mock() + mock_response.status_code = 200 + mock_response.raise_for_status.return_value = None + + mock_response.iter_lines.return_value = [ + b'data: {"choices":[{"delta":{"role":"assistant"}}],"model":"gpt-4"}', + b'data: {"choices":[{"delta":{"content":"Hello"}}],"model":"gpt-4"}', + b'data: {"choices":[{"delta":{"content":" world!"}}],"model":"gpt-4"}', + b"data: [DONE]", + ] + + mock_request.return_value = mock_response + + messages = [Message(role="user", content="What's up?")] + chunks = list( + client.create_chat_completion_stream(model="gpt-4", messages=messages, provider="openai") + ) + + mock_request.assert_called_once_with( + "POST", + "http://test-api/v1/chat/completions", + data=None, + json={ + "model": "gpt-4", + "messages": [{"role": "user", "content": "What's up?"}], + "stream": True, + }, + params={"provider": "openai"}, + stream=True, + ) + + assert len(chunks) == 3 + + for chunk in chunks: + assert isinstance(chunk, SSEvent) + assert chunk.event == "content-delta" + + import json + + chunk_0_data = json.loads(chunks[0].data) + assert chunk_0_data["choices"][0]["delta"]["role"] == "assistant" + + chunk_1_data = json.loads(chunks[1].data) + assert chunk_1_data["choices"][0]["delta"]["content"] == "Hello" + + chunk_2_data = json.loads(chunks[2].data) + assert chunk_2_data["choices"][0]["delta"]["content"] == " world!" @pytest.mark.parametrize( @@ -344,7 +377,6 @@ def test_create_chat_completion_stream_error(mock_request, client, test_params, mock_response.iter_lines.return_value = error_scenario["iter_lines"] mock_request.return_value = mock_response - use_sse = error_scenario.get("use_sse", False) with pytest.raises(InferenceGatewayError, match=error_scenario["expected_match"]): list( @@ -352,7 +384,6 @@ def test_create_chat_completion_stream_error(mock_request, client, test_params, model=test_params["model"], messages=[test_params["message"]], provider=test_params["provider"], - use_sse=use_sse, ) )