Skip to content

refactor: Enhance streaming chat completion with structured SSEvent handling and validation #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 30 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,25 +124,41 @@ print(response.choices[0].message.content)
#### Streaming Completion

```python
# Using Server-Sent Events (SSE)
from inference_gateway.models import CreateChatCompletionStreamResponse
from pydantic import ValidationError
import json

# Streaming returns SSEvent objects
for chunk in client.create_chat_completion_stream(
model="ollama/llama2",
messages=[
Message(role="user", content="Tell me a story")
],
use_sse=True
):
print(chunk.data, end="", flush=True)

# Using JSON lines
for chunk in client.create_chat_completion_stream(
model="anthropic/claude-3",
messages=[
Message(role="user", content="Explain AI safety")
],
use_sse=False
]
):
print(chunk["choices"][0]["delta"]["content"], end="", flush=True)
if chunk.data:
try:
# Parse the raw JSON data
data = json.loads(chunk.data)

# Unmarshal to structured model for type safety
try:
structured_chunk = CreateChatCompletionStreamResponse.model_validate(data)

# Use the structured model for better type safety and IDE support
if structured_chunk.choices and len(structured_chunk.choices) > 0:
choice = structured_chunk.choices[0]
if hasattr(choice.delta, 'content') and choice.delta.content:
print(choice.delta.content, end="", flush=True)

except ValidationError:
# Fallback to manual parsing for non-standard chunks
if "choices" in data and len(data["choices"]) > 0:
delta = data["choices"][0].get("delta", {})
if "content" in delta and delta["content"]:
print(delta["content"], end="", flush=True)

except json.JSONDecodeError:
pass
```

### Proxy Requests
Expand Down
28 changes: 22 additions & 6 deletions examples/chat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ print(response.choices[0].message.content)

```python
from inference_gateway import InferenceGatewayClient, Message
from inference_gateway.models import SSEvent
from inference_gateway.models import SSEvent, CreateChatCompletionStreamResponse
from pydantic import ValidationError
import json

client = InferenceGatewayClient("http://localhost:8080/v1")
Expand All @@ -60,13 +61,28 @@ stream = client.create_chat_completion_stream(
)

for chunk in stream:
if isinstance(chunk, SSEvent) and chunk.data:
if chunk.data:
try:
# Parse the raw JSON data
data = json.loads(chunk.data)
if "choices" in data and len(data["choices"]) > 0:
delta = data["choices"][0].get("delta", {})
if "content" in delta and delta["content"]:
print(delta["content"], end="", flush=True)

# Unmarshal to structured model for type safety
try:
structured_chunk = CreateChatCompletionStreamResponse.model_validate(data)

# Use the structured model for better type safety and IDE support
if structured_chunk.choices and len(structured_chunk.choices) > 0:
choice = structured_chunk.choices[0]
if hasattr(choice.delta, 'content') and choice.delta.content:
print(choice.delta.content, end="", flush=True)

except ValidationError:
# Fallback to manual parsing for non-standard chunks
if "choices" in data and len(data["choices"]) > 0:
delta = data["choices"][0].get("delta", {})
if "content" in delta and delta["content"]:
print(delta["content"], end="", flush=True)

except json.JSONDecodeError:
pass
```
Expand Down
41 changes: 28 additions & 13 deletions examples/chat/main.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import json
import os

from pydantic import ValidationError

from inference_gateway import InferenceGatewayClient, Message
from inference_gateway.client import InferenceGatewayAPIError, InferenceGatewayError
from inference_gateway.models import SSEvent
from inference_gateway.models import CreateChatCompletionStreamResponse, SSEvent


def main() -> None:
Expand Down Expand Up @@ -57,23 +59,36 @@ def main() -> None:
)

for chunk in stream:
if isinstance(chunk, SSEvent):
# Handle Server-Sent Events format
if chunk.data:
# All chunks are now SSEvent objects
if chunk.data:
try:
# Parse the raw JSON data
data = json.loads(chunk.data)

# Try to unmarshal to structured model for type safety
try:
data = json.loads(chunk.data)
structured_chunk = CreateChatCompletionStreamResponse.model_validate(data)

# Use the structured model for better type safety and IDE support
if structured_chunk.choices and len(structured_chunk.choices) > 0:
choice = structured_chunk.choices[0]
if hasattr(choice.delta, "content") and choice.delta.content:
print(choice.delta.content, end="", flush=True)

# Optionally show other information
if choice.finish_reason and choice.finish_reason != "null":
print(f"\n[Finished: {choice.finish_reason}]")

except ValidationError:
# Fallback to manual parsing for non-standard chunks
if "choices" in data and len(data["choices"]) > 0:
delta = data["choices"][0].get("delta", {})
if "content" in delta and delta["content"]:
print(delta["content"], end="", flush=True)
except json.JSONDecodeError:
pass
elif isinstance(chunk, dict):
# Handle JSON format
if "choices" in chunk and len(chunk["choices"]) > 0:
delta = chunk["choices"][0].get("delta", {})
if "content" in delta and delta["content"]:
print(delta["content"], end="", flush=True)

except json.JSONDecodeError:
# Handle non-JSON SSE data
print(f"[Non-JSON chunk: {chunk.data}]", end="", flush=True)

print("\n")

Expand Down
111 changes: 37 additions & 74 deletions inference_gateway/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,37 +206,6 @@ def list_tools(self) -> ListToolsResponse:
except ValidationError as e:
raise InferenceGatewayValidationError(f"Response validation failed: {e}")

def _parse_sse_chunk(self, chunk: bytes) -> SSEvent:
"""Parse an SSE message chunk into structured event data.

Args:
chunk: Raw SSE message chunk in bytes format

Returns:
SSEvent: Parsed SSE message with event type and data fields

Raises:
InferenceGatewayValidationError: If chunk format or content is invalid
"""
if not isinstance(chunk, bytes):
raise TypeError(f"Expected bytes, got {type(chunk)}")

try:
decoded = chunk.decode("utf-8")
event_type = None
data = None

for line in (l.strip() for l in decoded.split("\n") if l.strip()):
if line.startswith("event:"):
event_type = line.removeprefix("event:").strip()
elif line.startswith("data:"):
data = line.removeprefix("data:").strip()

return SSEvent(event=event_type, data=data, retry=None)

except UnicodeDecodeError as e:
raise InferenceGatewayValidationError(f"Invalid UTF-8 encoding in SSE chunk: {chunk!r}")

def _parse_json_line(self, line: bytes) -> Dict[str, Any]:
"""Parse a single JSON line into a dictionary.

Expand Down Expand Up @@ -325,9 +294,8 @@ def create_chat_completion_stream(
provider: Optional[Union[Provider, str]] = None,
max_tokens: Optional[int] = None,
tools: Optional[List[ChatCompletionTool]] = None,
use_sse: bool = True,
**kwargs: Any,
) -> Generator[Union[Dict[str, Any], SSEvent], None, None]:
) -> Generator[SSEvent, None, None]:
"""Stream a chat completion.

Args:
Expand All @@ -336,11 +304,10 @@ def create_chat_completion_stream(
provider: Optional provider specification
max_tokens: Maximum number of tokens to generate
tools: List of tools the model may call (using ChatCompletionTool models)
use_sse: Whether to use Server-Sent Events format
**kwargs: Additional parameters to pass to the API

Yields:
Union[Dict[str, Any], SSEvent]: Stream chunks
SSEvent: Stream chunks in SSEvent format

Raises:
InferenceGatewayAPIError: If the API request fails
Expand Down Expand Up @@ -377,7 +344,7 @@ def create_chat_completion_stream(
response.raise_for_status()
except httpx.HTTPStatusError as e:
raise InferenceGatewayAPIError(f"Request failed: {str(e)}")
yield from self._process_stream_response(response, use_sse)
yield from self._process_stream_response(response)
else:
requests_response = self.session.post(
url, params=params, json=request.model_dump(exclude_none=True), stream=True
Expand All @@ -386,49 +353,45 @@ def create_chat_completion_stream(
requests_response.raise_for_status()
except (requests.exceptions.HTTPError, Exception) as e:
raise InferenceGatewayAPIError(f"Request failed: {str(e)}")
yield from self._process_stream_response(requests_response, use_sse)
yield from self._process_stream_response(requests_response)

except ValidationError as e:
raise InferenceGatewayValidationError(f"Request validation failed: {e}")

def _process_stream_response(
self, response: Union[requests.Response, httpx.Response], use_sse: bool
) -> Generator[Union[Dict[str, Any], SSEvent], None, None]:
"""Process streaming response data."""
if use_sse:
buffer: List[bytes] = []

for line in response.iter_lines():
if not line:
if buffer:
chunk = b"\n".join(buffer)
yield self._parse_sse_chunk(chunk)
buffer = []
continue

if isinstance(line, str):
line_bytes = line.encode("utf-8")
else:
line_bytes = line
buffer.append(line_bytes)
else:
for line in response.iter_lines():
if not line:
continue

if isinstance(line, str):
line_bytes = line.encode("utf-8")
else:
line_bytes = line

if line_bytes.strip() == b"data: [DONE]":
continue
if line_bytes.startswith(b"data: "):
json_str = line_bytes[6:].decode("utf-8")
data = json.loads(json_str)
yield data
else:
yield self._parse_json_line(line_bytes)
self, response: Union[requests.Response, httpx.Response]
) -> Generator[SSEvent, None, None]:
"""Process streaming response data in SSEvent format."""
current_event = None

for line in response.iter_lines():
if not line:
continue

if isinstance(line, str):
line_bytes = line.encode("utf-8")
else:
line_bytes = line

if line_bytes.strip() == b"data: [DONE]":
continue

if line_bytes.startswith(b"event: "):
current_event = line_bytes[7:].decode("utf-8").strip()
continue
elif line_bytes.startswith(b"data: "):
json_str = line_bytes[6:].decode("utf-8")
event_type = current_event if current_event else "content-delta"
yield SSEvent(event=event_type, data=json_str)
current_event = None
elif line_bytes.strip() == b"":
continue
else:
try:
parsed_data = self._parse_json_line(line_bytes)
yield SSEvent(event="content-delta", data=json.dumps(parsed_data))
except Exception:
yield SSEvent(event="content-delta", data=line_bytes.decode("utf-8"))

def proxy_request(
self,
Expand Down
2 changes: 1 addition & 1 deletion inference_gateway/models.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading