@@ -206,37 +206,6 @@ def list_tools(self) -> ListToolsResponse:
206
206
except ValidationError as e :
207
207
raise InferenceGatewayValidationError (f"Response validation failed: { e } " )
208
208
209
- def _parse_sse_chunk (self , chunk : bytes ) -> SSEvent :
210
- """Parse an SSE message chunk into structured event data.
211
-
212
- Args:
213
- chunk: Raw SSE message chunk in bytes format
214
-
215
- Returns:
216
- SSEvent: Parsed SSE message with event type and data fields
217
-
218
- Raises:
219
- InferenceGatewayValidationError: If chunk format or content is invalid
220
- """
221
- if not isinstance (chunk , bytes ):
222
- raise TypeError (f"Expected bytes, got { type (chunk )} " )
223
-
224
- try :
225
- decoded = chunk .decode ("utf-8" )
226
- event_type = None
227
- data = None
228
-
229
- for line in (l .strip () for l in decoded .split ("\n " ) if l .strip ()):
230
- if line .startswith ("event:" ):
231
- event_type = line .removeprefix ("event:" ).strip ()
232
- elif line .startswith ("data:" ):
233
- data = line .removeprefix ("data:" ).strip ()
234
-
235
- return SSEvent (event = event_type , data = data , retry = None )
236
-
237
- except UnicodeDecodeError as e :
238
- raise InferenceGatewayValidationError (f"Invalid UTF-8 encoding in SSE chunk: { chunk !r} " )
239
-
240
209
def _parse_json_line (self , line : bytes ) -> Dict [str , Any ]:
241
210
"""Parse a single JSON line into a dictionary.
242
211
@@ -325,9 +294,8 @@ def create_chat_completion_stream(
325
294
provider : Optional [Union [Provider , str ]] = None ,
326
295
max_tokens : Optional [int ] = None ,
327
296
tools : Optional [List [ChatCompletionTool ]] = None ,
328
- use_sse : bool = True ,
329
297
** kwargs : Any ,
330
- ) -> Generator [Union [ Dict [ str , Any ], SSEvent ] , None , None ]:
298
+ ) -> Generator [SSEvent , None , None ]:
331
299
"""Stream a chat completion.
332
300
333
301
Args:
@@ -336,11 +304,10 @@ def create_chat_completion_stream(
336
304
provider: Optional provider specification
337
305
max_tokens: Maximum number of tokens to generate
338
306
tools: List of tools the model may call (using ChatCompletionTool models)
339
- use_sse: Whether to use Server-Sent Events format
340
307
**kwargs: Additional parameters to pass to the API
341
308
342
309
Yields:
343
- Union[Dict[str, Any], SSEvent] : Stream chunks
310
+ SSEvent: Stream chunks in SSEvent format
344
311
345
312
Raises:
346
313
InferenceGatewayAPIError: If the API request fails
@@ -377,7 +344,7 @@ def create_chat_completion_stream(
377
344
response .raise_for_status ()
378
345
except httpx .HTTPStatusError as e :
379
346
raise InferenceGatewayAPIError (f"Request failed: { str (e )} " )
380
- yield from self ._process_stream_response (response , use_sse )
347
+ yield from self ._process_stream_response (response )
381
348
else :
382
349
requests_response = self .session .post (
383
350
url , params = params , json = request .model_dump (exclude_none = True ), stream = True
@@ -386,49 +353,45 @@ def create_chat_completion_stream(
386
353
requests_response .raise_for_status ()
387
354
except (requests .exceptions .HTTPError , Exception ) as e :
388
355
raise InferenceGatewayAPIError (f"Request failed: { str (e )} " )
389
- yield from self ._process_stream_response (requests_response , use_sse )
356
+ yield from self ._process_stream_response (requests_response )
390
357
391
358
except ValidationError as e :
392
359
raise InferenceGatewayValidationError (f"Request validation failed: { e } " )
393
360
394
361
def _process_stream_response (
395
- self , response : Union [requests .Response , httpx .Response ], use_sse : bool
396
- ) -> Generator [Union [Dict [str , Any ], SSEvent ], None , None ]:
397
- """Process streaming response data."""
398
- if use_sse :
399
- buffer : List [bytes ] = []
400
-
401
- for line in response .iter_lines ():
402
- if not line :
403
- if buffer :
404
- chunk = b"\n " .join (buffer )
405
- yield self ._parse_sse_chunk (chunk )
406
- buffer = []
407
- continue
408
-
409
- if isinstance (line , str ):
410
- line_bytes = line .encode ("utf-8" )
411
- else :
412
- line_bytes = line
413
- buffer .append (line_bytes )
414
- else :
415
- for line in response .iter_lines ():
416
- if not line :
417
- continue
418
-
419
- if isinstance (line , str ):
420
- line_bytes = line .encode ("utf-8" )
421
- else :
422
- line_bytes = line
423
-
424
- if line_bytes .strip () == b"data: [DONE]" :
425
- continue
426
- if line_bytes .startswith (b"data: " ):
427
- json_str = line_bytes [6 :].decode ("utf-8" )
428
- data = json .loads (json_str )
429
- yield data
430
- else :
431
- yield self ._parse_json_line (line_bytes )
362
+ self , response : Union [requests .Response , httpx .Response ]
363
+ ) -> Generator [SSEvent , None , None ]:
364
+ """Process streaming response data in SSEvent format."""
365
+ current_event = None
366
+
367
+ for line in response .iter_lines ():
368
+ if not line :
369
+ continue
370
+
371
+ if isinstance (line , str ):
372
+ line_bytes = line .encode ("utf-8" )
373
+ else :
374
+ line_bytes = line
375
+
376
+ if line_bytes .strip () == b"data: [DONE]" :
377
+ continue
378
+
379
+ if line_bytes .startswith (b"event: " ):
380
+ current_event = line_bytes [7 :].decode ("utf-8" ).strip ()
381
+ continue
382
+ elif line_bytes .startswith (b"data: " ):
383
+ json_str = line_bytes [6 :].decode ("utf-8" )
384
+ event_type = current_event if current_event else "content-delta"
385
+ yield SSEvent (event = event_type , data = json_str )
386
+ current_event = None
387
+ elif line_bytes .strip () == b"" :
388
+ continue
389
+ else :
390
+ try :
391
+ parsed_data = self ._parse_json_line (line_bytes )
392
+ yield SSEvent (event = "content-delta" , data = json .dumps (parsed_data ))
393
+ except Exception :
394
+ yield SSEvent (event = "content-delta" , data = line_bytes .decode ("utf-8" ))
432
395
433
396
def proxy_request (
434
397
self ,
0 commit comments