Skip to content

Commit cfcf147

Browse files
authored
refactor: Enhance streaming chat completion with structured SSEvent handling and validation (#3)
Simplify the code, Server Sent Events is the default for streaming by any provider that follows OpenAI API - it's compatible, i.e why I removed the use_sse. In previous versions of Ollama there was no OpenAI Compatible API and the response was a raw JSON instead of data: <JSON> Signed-off-by: Eden Reich <eden.reich@gmail.com>
1 parent 3248626 commit cfcf147

File tree

6 files changed

+196
-155
lines changed

6 files changed

+196
-155
lines changed

README.md

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -124,25 +124,41 @@ print(response.choices[0].message.content)
124124
#### Streaming Completion
125125

126126
```python
127-
# Using Server-Sent Events (SSE)
127+
from inference_gateway.models import CreateChatCompletionStreamResponse
128+
from pydantic import ValidationError
129+
import json
130+
131+
# Streaming returns SSEvent objects
128132
for chunk in client.create_chat_completion_stream(
129133
model="ollama/llama2",
130134
messages=[
131135
Message(role="user", content="Tell me a story")
132-
],
133-
use_sse=True
134-
):
135-
print(chunk.data, end="", flush=True)
136-
137-
# Using JSON lines
138-
for chunk in client.create_chat_completion_stream(
139-
model="anthropic/claude-3",
140-
messages=[
141-
Message(role="user", content="Explain AI safety")
142-
],
143-
use_sse=False
136+
]
144137
):
145-
print(chunk["choices"][0]["delta"]["content"], end="", flush=True)
138+
if chunk.data:
139+
try:
140+
# Parse the raw JSON data
141+
data = json.loads(chunk.data)
142+
143+
# Unmarshal to structured model for type safety
144+
try:
145+
structured_chunk = CreateChatCompletionStreamResponse.model_validate(data)
146+
147+
# Use the structured model for better type safety and IDE support
148+
if structured_chunk.choices and len(structured_chunk.choices) > 0:
149+
choice = structured_chunk.choices[0]
150+
if hasattr(choice.delta, 'content') and choice.delta.content:
151+
print(choice.delta.content, end="", flush=True)
152+
153+
except ValidationError:
154+
# Fallback to manual parsing for non-standard chunks
155+
if "choices" in data and len(data["choices"]) > 0:
156+
delta = data["choices"][0].get("delta", {})
157+
if "content" in delta and delta["content"]:
158+
print(delta["content"], end="", flush=True)
159+
160+
except json.JSONDecodeError:
161+
pass
146162
```
147163

148164
### Proxy Requests

examples/chat/README.md

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ print(response.choices[0].message.content)
4545

4646
```python
4747
from inference_gateway import InferenceGatewayClient, Message
48-
from inference_gateway.models import SSEvent
48+
from inference_gateway.models import SSEvent, CreateChatCompletionStreamResponse
49+
from pydantic import ValidationError
4950
import json
5051

5152
client = InferenceGatewayClient("http://localhost:8080/v1")
@@ -60,13 +61,28 @@ stream = client.create_chat_completion_stream(
6061
)
6162

6263
for chunk in stream:
63-
if isinstance(chunk, SSEvent) and chunk.data:
64+
if chunk.data:
6465
try:
66+
# Parse the raw JSON data
6567
data = json.loads(chunk.data)
66-
if "choices" in data and len(data["choices"]) > 0:
67-
delta = data["choices"][0].get("delta", {})
68-
if "content" in delta and delta["content"]:
69-
print(delta["content"], end="", flush=True)
68+
69+
# Unmarshal to structured model for type safety
70+
try:
71+
structured_chunk = CreateChatCompletionStreamResponse.model_validate(data)
72+
73+
# Use the structured model for better type safety and IDE support
74+
if structured_chunk.choices and len(structured_chunk.choices) > 0:
75+
choice = structured_chunk.choices[0]
76+
if hasattr(choice.delta, 'content') and choice.delta.content:
77+
print(choice.delta.content, end="", flush=True)
78+
79+
except ValidationError:
80+
# Fallback to manual parsing for non-standard chunks
81+
if "choices" in data and len(data["choices"]) > 0:
82+
delta = data["choices"][0].get("delta", {})
83+
if "content" in delta and delta["content"]:
84+
print(delta["content"], end="", flush=True)
85+
7086
except json.JSONDecodeError:
7187
pass
7288
```

examples/chat/main.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import json
22
import os
33

4+
from pydantic import ValidationError
5+
46
from inference_gateway import InferenceGatewayClient, Message
57
from inference_gateway.client import InferenceGatewayAPIError, InferenceGatewayError
6-
from inference_gateway.models import SSEvent
8+
from inference_gateway.models import CreateChatCompletionStreamResponse, SSEvent
79

810

911
def main() -> None:
@@ -57,23 +59,36 @@ def main() -> None:
5759
)
5860

5961
for chunk in stream:
60-
if isinstance(chunk, SSEvent):
61-
# Handle Server-Sent Events format
62-
if chunk.data:
62+
# All chunks are now SSEvent objects
63+
if chunk.data:
64+
try:
65+
# Parse the raw JSON data
66+
data = json.loads(chunk.data)
67+
68+
# Try to unmarshal to structured model for type safety
6369
try:
64-
data = json.loads(chunk.data)
70+
structured_chunk = CreateChatCompletionStreamResponse.model_validate(data)
71+
72+
# Use the structured model for better type safety and IDE support
73+
if structured_chunk.choices and len(structured_chunk.choices) > 0:
74+
choice = structured_chunk.choices[0]
75+
if hasattr(choice.delta, "content") and choice.delta.content:
76+
print(choice.delta.content, end="", flush=True)
77+
78+
# Optionally show other information
79+
if choice.finish_reason and choice.finish_reason != "null":
80+
print(f"\n[Finished: {choice.finish_reason}]")
81+
82+
except ValidationError:
83+
# Fallback to manual parsing for non-standard chunks
6584
if "choices" in data and len(data["choices"]) > 0:
6685
delta = data["choices"][0].get("delta", {})
6786
if "content" in delta and delta["content"]:
6887
print(delta["content"], end="", flush=True)
69-
except json.JSONDecodeError:
70-
pass
71-
elif isinstance(chunk, dict):
72-
# Handle JSON format
73-
if "choices" in chunk and len(chunk["choices"]) > 0:
74-
delta = chunk["choices"][0].get("delta", {})
75-
if "content" in delta and delta["content"]:
76-
print(delta["content"], end="", flush=True)
88+
89+
except json.JSONDecodeError:
90+
# Handle non-JSON SSE data
91+
print(f"[Non-JSON chunk: {chunk.data}]", end="", flush=True)
7792

7893
print("\n")
7994

inference_gateway/client.py

Lines changed: 37 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -206,37 +206,6 @@ def list_tools(self) -> ListToolsResponse:
206206
except ValidationError as e:
207207
raise InferenceGatewayValidationError(f"Response validation failed: {e}")
208208

209-
def _parse_sse_chunk(self, chunk: bytes) -> SSEvent:
210-
"""Parse an SSE message chunk into structured event data.
211-
212-
Args:
213-
chunk: Raw SSE message chunk in bytes format
214-
215-
Returns:
216-
SSEvent: Parsed SSE message with event type and data fields
217-
218-
Raises:
219-
InferenceGatewayValidationError: If chunk format or content is invalid
220-
"""
221-
if not isinstance(chunk, bytes):
222-
raise TypeError(f"Expected bytes, got {type(chunk)}")
223-
224-
try:
225-
decoded = chunk.decode("utf-8")
226-
event_type = None
227-
data = None
228-
229-
for line in (l.strip() for l in decoded.split("\n") if l.strip()):
230-
if line.startswith("event:"):
231-
event_type = line.removeprefix("event:").strip()
232-
elif line.startswith("data:"):
233-
data = line.removeprefix("data:").strip()
234-
235-
return SSEvent(event=event_type, data=data, retry=None)
236-
237-
except UnicodeDecodeError as e:
238-
raise InferenceGatewayValidationError(f"Invalid UTF-8 encoding in SSE chunk: {chunk!r}")
239-
240209
def _parse_json_line(self, line: bytes) -> Dict[str, Any]:
241210
"""Parse a single JSON line into a dictionary.
242211
@@ -325,9 +294,8 @@ def create_chat_completion_stream(
325294
provider: Optional[Union[Provider, str]] = None,
326295
max_tokens: Optional[int] = None,
327296
tools: Optional[List[ChatCompletionTool]] = None,
328-
use_sse: bool = True,
329297
**kwargs: Any,
330-
) -> Generator[Union[Dict[str, Any], SSEvent], None, None]:
298+
) -> Generator[SSEvent, None, None]:
331299
"""Stream a chat completion.
332300
333301
Args:
@@ -336,11 +304,10 @@ def create_chat_completion_stream(
336304
provider: Optional provider specification
337305
max_tokens: Maximum number of tokens to generate
338306
tools: List of tools the model may call (using ChatCompletionTool models)
339-
use_sse: Whether to use Server-Sent Events format
340307
**kwargs: Additional parameters to pass to the API
341308
342309
Yields:
343-
Union[Dict[str, Any], SSEvent]: Stream chunks
310+
SSEvent: Stream chunks in SSEvent format
344311
345312
Raises:
346313
InferenceGatewayAPIError: If the API request fails
@@ -377,7 +344,7 @@ def create_chat_completion_stream(
377344
response.raise_for_status()
378345
except httpx.HTTPStatusError as e:
379346
raise InferenceGatewayAPIError(f"Request failed: {str(e)}")
380-
yield from self._process_stream_response(response, use_sse)
347+
yield from self._process_stream_response(response)
381348
else:
382349
requests_response = self.session.post(
383350
url, params=params, json=request.model_dump(exclude_none=True), stream=True
@@ -386,49 +353,45 @@ def create_chat_completion_stream(
386353
requests_response.raise_for_status()
387354
except (requests.exceptions.HTTPError, Exception) as e:
388355
raise InferenceGatewayAPIError(f"Request failed: {str(e)}")
389-
yield from self._process_stream_response(requests_response, use_sse)
356+
yield from self._process_stream_response(requests_response)
390357

391358
except ValidationError as e:
392359
raise InferenceGatewayValidationError(f"Request validation failed: {e}")
393360

394361
def _process_stream_response(
395-
self, response: Union[requests.Response, httpx.Response], use_sse: bool
396-
) -> Generator[Union[Dict[str, Any], SSEvent], None, None]:
397-
"""Process streaming response data."""
398-
if use_sse:
399-
buffer: List[bytes] = []
400-
401-
for line in response.iter_lines():
402-
if not line:
403-
if buffer:
404-
chunk = b"\n".join(buffer)
405-
yield self._parse_sse_chunk(chunk)
406-
buffer = []
407-
continue
408-
409-
if isinstance(line, str):
410-
line_bytes = line.encode("utf-8")
411-
else:
412-
line_bytes = line
413-
buffer.append(line_bytes)
414-
else:
415-
for line in response.iter_lines():
416-
if not line:
417-
continue
418-
419-
if isinstance(line, str):
420-
line_bytes = line.encode("utf-8")
421-
else:
422-
line_bytes = line
423-
424-
if line_bytes.strip() == b"data: [DONE]":
425-
continue
426-
if line_bytes.startswith(b"data: "):
427-
json_str = line_bytes[6:].decode("utf-8")
428-
data = json.loads(json_str)
429-
yield data
430-
else:
431-
yield self._parse_json_line(line_bytes)
362+
self, response: Union[requests.Response, httpx.Response]
363+
) -> Generator[SSEvent, None, None]:
364+
"""Process streaming response data in SSEvent format."""
365+
current_event = None
366+
367+
for line in response.iter_lines():
368+
if not line:
369+
continue
370+
371+
if isinstance(line, str):
372+
line_bytes = line.encode("utf-8")
373+
else:
374+
line_bytes = line
375+
376+
if line_bytes.strip() == b"data: [DONE]":
377+
continue
378+
379+
if line_bytes.startswith(b"event: "):
380+
current_event = line_bytes[7:].decode("utf-8").strip()
381+
continue
382+
elif line_bytes.startswith(b"data: "):
383+
json_str = line_bytes[6:].decode("utf-8")
384+
event_type = current_event if current_event else "content-delta"
385+
yield SSEvent(event=event_type, data=json_str)
386+
current_event = None
387+
elif line_bytes.strip() == b"":
388+
continue
389+
else:
390+
try:
391+
parsed_data = self._parse_json_line(line_bytes)
392+
yield SSEvent(event="content-delta", data=json.dumps(parsed_data))
393+
except Exception:
394+
yield SSEvent(event="content-delta", data=line_bytes.decode("utf-8"))
432395

433396
def proxy_request(
434397
self,

inference_gateway/models.py

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)