From 914e5489b7c3a588b776e9e7161ca6c61a0d8848 Mon Sep 17 00:00:00 2001 From: Eden Reich Date: Sun, 1 Jun 2025 20:12:36 +0000 Subject: [PATCH] refactor: Enhance stream processing with abort signal support and increase default timeout Signed-off-by: Eden Reich --- src/client.ts | 63 +++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 49 insertions(+), 14 deletions(-) diff --git a/src/client.ts b/src/client.ts index 26f113d..661b3c7 100644 --- a/src/client.ts +++ b/src/client.ts @@ -51,13 +51,20 @@ class StreamProcessor { this.clientProvidedTools = clientProvidedTools; } - async processStream(body: ReadableStream): Promise { + async processStream( + body: ReadableStream, + abortSignal?: AbortSignal + ): Promise { const reader = body.getReader(); const decoder = new TextDecoder(); let buffer = ''; try { while (true) { + if (abortSignal?.aborted) { + throw new Error('Stream processing was aborted'); + } + const { done, value } = await reader.read(); if (done) break; @@ -73,6 +80,11 @@ class StreamProcessor { } } } catch (error) { + if (abortSignal?.aborted || (error as Error).name === 'AbortError') { + console.log('Stream processing was cancelled'); + return; + } + const apiError: SchemaError = { error: (error as Error).message || 'Unknown error', }; @@ -215,10 +227,10 @@ class StreamProcessor { } private finalizeIncompleteToolCalls(): void { - for (const [, toolCall] of this.incompleteToolCalls.entries()) { + this.incompleteToolCalls.forEach((toolCall) => { if (!toolCall.id || !toolCall.function.name) { globalThis.console.warn('Incomplete tool call detected:', toolCall); - continue; + return; } const completedToolCall = { @@ -237,15 +249,26 @@ class StreamProcessor { } this.callbacks.onMCPTool?.(completedToolCall); } catch (argError) { - globalThis.console.warn( - `Invalid MCP tool arguments for ${toolCall.function.name}:`, - argError - ); + const isIncompleteJSON = + toolCall.function.arguments && + !toolCall.function.arguments.trim().endsWith('}'); + + if (isIncompleteJSON) { + globalThis.console.warn( + `Incomplete MCP tool arguments for ${toolCall.function.name} (stream was likely interrupted):`, + toolCall.function.arguments + ); + } else { + globalThis.console.warn( + `Invalid MCP tool arguments for ${toolCall.function.name}:`, + argError + ); + } } } else { this.callbacks.onTool?.(completedToolCall); } - } + }); this.incompleteToolCalls.clear(); } @@ -280,7 +303,7 @@ export class InferenceGatewayClient { this.apiKey = options.apiKey; this.defaultHeaders = options.defaultHeaders || {}; this.defaultQuery = options.defaultQuery || {}; - this.timeout = options.timeout || 30000; + this.timeout = options.timeout || 60000; // Increased default timeout to 60 seconds this.fetchFn = options.fetch || globalThis.fetch; } @@ -404,6 +427,7 @@ export class InferenceGatewayClient { * @param request - Chat completion request (must include at least model and messages) * @param callbacks - Callbacks for handling streaming events * @param provider - Optional provider to use for this request + * @param abortSignal - Optional AbortSignal to cancel the request */ async streamChatCompletion( request: Omit< @@ -411,10 +435,15 @@ export class InferenceGatewayClient { 'stream' | 'stream_options' >, callbacks: ChatCompletionStreamCallbacks, - provider?: Provider + provider?: Provider, + abortSignal?: AbortSignal ): Promise { try { - const response = await this.initiateStreamingRequest(request, provider); + const response = await this.initiateStreamingRequest( + request, + provider, + abortSignal + ); if (!response.body) { const error: SchemaError = { @@ -440,7 +469,7 @@ export class InferenceGatewayClient { callbacks, clientProvidedTools ); - await streamProcessor.processStream(response.body); + await streamProcessor.processStream(response.body, abortSignal); } catch (error) { const apiError: SchemaError = { error: (error as Error).message || 'Unknown error occurred', @@ -458,7 +487,8 @@ export class InferenceGatewayClient { SchemaCreateChatCompletionRequest, 'stream' | 'stream_options' >, - provider?: Provider + provider?: Provider, + abortSignal?: AbortSignal ): Promise { const query: Record = {}; if (provider) { @@ -485,6 +515,11 @@ export class InferenceGatewayClient { } const controller = new AbortController(); + + const combinedSignal = abortSignal + ? AbortSignal.any([abortSignal, controller.signal]) + : controller.signal; + const timeoutId = globalThis.setTimeout( () => controller.abort(), this.timeout @@ -501,7 +536,7 @@ export class InferenceGatewayClient { include_usage: true, }, }), - signal: controller.signal, + signal: combinedSignal, }); if (!response.ok) {