Skip to content

refactor: Enhance stream processing with abort signal support and increase default timeout #18

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 49 additions & 14 deletions src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,20 @@ class StreamProcessor {
this.clientProvidedTools = clientProvidedTools;
}

async processStream(body: ReadableStream<Uint8Array>): Promise<void> {
async processStream(
body: ReadableStream<Uint8Array>,
abortSignal?: AbortSignal
): Promise<void> {
const reader = body.getReader();
const decoder = new TextDecoder();
let buffer = '';

try {
while (true) {
if (abortSignal?.aborted) {
throw new Error('Stream processing was aborted');
}

const { done, value } = await reader.read();
if (done) break;

Expand All @@ -73,6 +80,11 @@ class StreamProcessor {
}
}
} catch (error) {
if (abortSignal?.aborted || (error as Error).name === 'AbortError') {
console.log('Stream processing was cancelled');
return;
}

const apiError: SchemaError = {
error: (error as Error).message || 'Unknown error',
};
Expand Down Expand Up @@ -215,10 +227,10 @@ class StreamProcessor {
}

private finalizeIncompleteToolCalls(): void {
for (const [, toolCall] of this.incompleteToolCalls.entries()) {
this.incompleteToolCalls.forEach((toolCall) => {
if (!toolCall.id || !toolCall.function.name) {
globalThis.console.warn('Incomplete tool call detected:', toolCall);
continue;
return;
}

const completedToolCall = {
Expand All @@ -237,15 +249,26 @@ class StreamProcessor {
}
this.callbacks.onMCPTool?.(completedToolCall);
} catch (argError) {
globalThis.console.warn(
`Invalid MCP tool arguments for ${toolCall.function.name}:`,
argError
);
const isIncompleteJSON =
toolCall.function.arguments &&
!toolCall.function.arguments.trim().endsWith('}');

if (isIncompleteJSON) {
globalThis.console.warn(
`Incomplete MCP tool arguments for ${toolCall.function.name} (stream was likely interrupted):`,
toolCall.function.arguments
);
} else {
globalThis.console.warn(
`Invalid MCP tool arguments for ${toolCall.function.name}:`,
argError
);
}
}
} else {
this.callbacks.onTool?.(completedToolCall);
}
}
});
this.incompleteToolCalls.clear();
}

Expand Down Expand Up @@ -280,7 +303,7 @@ export class InferenceGatewayClient {
this.apiKey = options.apiKey;
this.defaultHeaders = options.defaultHeaders || {};
this.defaultQuery = options.defaultQuery || {};
this.timeout = options.timeout || 30000;
this.timeout = options.timeout || 60000; // Increased default timeout to 60 seconds
this.fetchFn = options.fetch || globalThis.fetch;
}

Expand Down Expand Up @@ -404,17 +427,23 @@ export class InferenceGatewayClient {
* @param request - Chat completion request (must include at least model and messages)
* @param callbacks - Callbacks for handling streaming events
* @param provider - Optional provider to use for this request
* @param abortSignal - Optional AbortSignal to cancel the request
*/
async streamChatCompletion(
request: Omit<
SchemaCreateChatCompletionRequest,
'stream' | 'stream_options'
>,
callbacks: ChatCompletionStreamCallbacks,
provider?: Provider
provider?: Provider,
abortSignal?: AbortSignal
): Promise<void> {
try {
const response = await this.initiateStreamingRequest(request, provider);
const response = await this.initiateStreamingRequest(
request,
provider,
abortSignal
);

if (!response.body) {
const error: SchemaError = {
Expand All @@ -440,7 +469,7 @@ export class InferenceGatewayClient {
callbacks,
clientProvidedTools
);
await streamProcessor.processStream(response.body);
await streamProcessor.processStream(response.body, abortSignal);
} catch (error) {
const apiError: SchemaError = {
error: (error as Error).message || 'Unknown error occurred',
Expand All @@ -458,7 +487,8 @@ export class InferenceGatewayClient {
SchemaCreateChatCompletionRequest,
'stream' | 'stream_options'
>,
provider?: Provider
provider?: Provider,
abortSignal?: AbortSignal
): Promise<Response> {
const query: Record<string, string> = {};
if (provider) {
Expand All @@ -485,6 +515,11 @@ export class InferenceGatewayClient {
}

const controller = new AbortController();

const combinedSignal = abortSignal
? AbortSignal.any([abortSignal, controller.signal])
: controller.signal;

const timeoutId = globalThis.setTimeout(
() => controller.abort(),
this.timeout
Expand All @@ -501,7 +536,7 @@ export class InferenceGatewayClient {
include_usage: true,
},
}),
signal: controller.signal,
signal: combinedSignal,
});

if (!response.ok) {
Expand Down