diff --git a/packages/xl-ai/src/AIExtension.ts b/packages/xl-ai/src/AIExtension.ts index e4ba5d684..0f07dd5da 100644 --- a/packages/xl-ai/src/AIExtension.ts +++ b/packages/xl-ai/src/AIExtension.ts @@ -9,6 +9,7 @@ import { suggestChanges, } from "@blocknote/prosemirror-suggest-changes"; import { APICallError, LanguageModel, RetryError } from "ai"; +import { Fragment, Slice } from "prosemirror-model"; import { Plugin, PluginKey } from "prosemirror-state"; import { fixTablesKey } from "prosemirror-tables"; import { createStore, StoreApi } from "zustand/vanilla"; @@ -17,7 +18,6 @@ import { LLMResponse } from "./api/LLMResponse.js"; import { PromptBuilder } from "./api/formats/PromptBuilder.js"; import { LLMFormat, llmFormats } from "./api/index.js"; import { createAgentCursorPlugin } from "./plugins/AgentCursorPlugin.js"; -import { Fragment, Slice } from "prosemirror-model"; type MakeOptional = Omit & Partial>; @@ -87,6 +87,7 @@ const PLUGIN_KEY = new PluginKey(`blocknote-ai-plugin`); export class AIExtension extends BlockNoteExtension { private previousRequestOptions: LLMRequestOptions | undefined; + private currentAbortController?: AbortController; public static key(): string { return "ai"; @@ -297,6 +298,24 @@ export class AIExtension extends BlockNoteExtension { } } + /** + * Abort the current LLM request. + * + * Only valid when executing an LLM request (status is "thinking" or "ai-writing") + */ + public async abort() { + const state = this.store.getState().aiMenuState; + if (state === "closed") { + throw new Error("abort() is only valid during LLM execution"); + } + if (state.status !== "thinking" && state.status !== "ai-writing") { + throw new Error("abort() is only valid during LLM execution"); + } + + // Abort the current request + this.currentAbortController?.abort(); + } + /** * Update the status of a call to an LLM * @@ -349,15 +368,20 @@ export class AIExtension extends BlockNoteExtension { * Execute a call to an LLM and apply the result to the editor */ public async callLLM(opts: MakeOptional) { + const startState = this.store.getState().aiMenuState; this.setAIResponseStatus("thinking"); this.editor.forkYDocPlugin?.fork(); + // Create abort controller for this request + this.currentAbortController = new AbortController(); + let ret: LLMResponse | undefined; try { const requestOptions = { ...this.options.getState(), ...opts, previousResponse: this.store.getState().llmResponse, + abortSignal: this.currentAbortController.signal, }; this.previousRequestOptions = requestOptions; @@ -383,11 +407,32 @@ export class AIExtension extends BlockNoteExtension { llmResponse: ret, }); - await ret.execute(); + await ret.execute(this.currentAbortController?.signal); this.setAIResponseStatus("user-reviewing"); } catch (e) { - // TODO in error state, should we discard the forked document? + // Handle abort errors gracefully + if (e instanceof Error && e.name === "AbortError") { + // Request was aborted, don't set error status as abort() handles cleanup + const state = this.store.getState().aiMenuState; + if (state === "closed" || startState === "closed") { + throw new Error( + "Unexpected: AbortError occurred while the AI menu was closed", + ); + } + if (state.status === "ai-writing") { + // we were already writing. Set to reviewing to show the user the partial result + this.setAIResponseStatus("user-reviewing"); + } else { + // we were not writing yet. Set to the previous state + if (startState.status === "error") { + this.setAIResponseStatus({ status: startState.status, error: e }); + } else { + this.setAIResponseStatus(startState.status); + } + } + return ret; + } this.setAIResponseStatus({ status: "error", @@ -395,6 +440,9 @@ export class AIExtension extends BlockNoteExtension { }); // eslint-disable-next-line no-console console.warn("Error calling LLM", e); + } finally { + // Clean up abort controller + this.currentAbortController = undefined; } return ret; } diff --git a/packages/xl-ai/src/api/LLMRequest.ts b/packages/xl-ai/src/api/LLMRequest.ts index c11a19d62..9407c05b9 100644 --- a/packages/xl-ai/src/api/LLMRequest.ts +++ b/packages/xl-ai/src/api/LLMRequest.ts @@ -102,6 +102,10 @@ export type LLMRequestOptions = { * @default true */ withDelays?: boolean; + /** + * AbortSignal to cancel the LLM request + */ + abortSignal?: AbortSignal; /** * Additional options to pass to the AI SDK `generateObject` function * (only used when `stream` is `false`) @@ -134,6 +138,7 @@ export async function doLLMRequest( withDelays, dataFormat, previousResponse, + abortSignal, ...rest } = { maxRetries: 2, @@ -226,6 +231,7 @@ export async function doLLMRequest( streamTools, { messages, + abortSignal, ...rest, }, () => { @@ -238,6 +244,7 @@ export async function doLLMRequest( } else { response = await generateOperations(streamTools, { messages, + abortSignal, ...rest, }); if (deleteCursorBlock) { diff --git a/packages/xl-ai/src/api/LLMResponse.ts b/packages/xl-ai/src/api/LLMResponse.ts index 1321ab5b9..b3f77dc71 100644 --- a/packages/xl-ai/src/api/LLMResponse.ts +++ b/packages/xl-ai/src/api/LLMResponse.ts @@ -28,14 +28,14 @@ export class LLMResponse { * * (this method consumes underlying streams in `llmResult`) */ - async *applyToolCalls() { + async *applyToolCalls(abortSignal?: AbortSignal) { let currentStream: AsyncIterable<{ operation: StreamToolCall[]>; isUpdateToPreviousOperation: boolean; isPossiblyPartial: boolean; }> = this.llmResult.operationsSource; for (const tool of this.streamTools) { - currentStream = tool.execute(currentStream); + currentStream = tool.execute(currentStream, abortSignal); } yield* currentStream; } @@ -45,9 +45,9 @@ export class LLMResponse { * * (this method consumes underlying streams in `llmResult`) */ - public async execute() { + public async execute(abortSignal?: AbortSignal) { // eslint-disable-next-line @typescript-eslint/no-unused-vars - for await (const _result of this.applyToolCalls()) { + for await (const _result of this.applyToolCalls(abortSignal)) { // no op } } diff --git a/packages/xl-ai/src/api/formats/base-tools/createAddBlocksTool.ts b/packages/xl-ai/src/api/formats/base-tools/createAddBlocksTool.ts index 6c1eca51d..8d3ed43a0 100644 --- a/packages/xl-ai/src/api/formats/base-tools/createAddBlocksTool.ts +++ b/packages/xl-ai/src/api/formats/base-tools/createAddBlocksTool.ts @@ -169,7 +169,7 @@ export function createAddBlocksTool(config: { }, // Note: functionality mostly tested in jsontools.test.ts // would be nicer to add a direct unit test - execute: async function* (operationsStream) { + execute: async function* (operationsStream, abortSignal?: AbortSignal) { // An add operation has some complexity: // - it can add multiple blocks in 1 operation // (this is needed because you need an id as reference block - and if you want to insert multiple blocks you can only use an existing block as reference id) @@ -266,6 +266,11 @@ export function createAddBlocksTool(config: { } for (const step of agentSteps) { + if (abortSignal?.aborted) { + const error = new Error("Operation was aborted"); + error.name = "AbortError"; + throw error; + } if (options.withDelays) { await delayAgentStep(step); } diff --git a/packages/xl-ai/src/api/formats/base-tools/createUpdateBlockTool.ts b/packages/xl-ai/src/api/formats/base-tools/createUpdateBlockTool.ts index 6a81d0d46..071256d44 100644 --- a/packages/xl-ai/src/api/formats/base-tools/createUpdateBlockTool.ts +++ b/packages/xl-ai/src/api/formats/base-tools/createUpdateBlockTool.ts @@ -178,6 +178,7 @@ export function createUpdateBlockTool(config: { isUpdateToPreviousOperation: boolean; isPossiblyPartial: boolean; }>, + abortSignal?: AbortSignal, ) { const STEP_SIZE = 50; let minSize = STEP_SIZE; @@ -254,6 +255,11 @@ export function createUpdateBlockTool(config: { const agentSteps = getStepsAsAgent(tr); for (const step of agentSteps) { + if (abortSignal?.aborted) { + const error = new Error("Operation was aborted"); + error.name = "AbortError"; + throw error; + } if (options.withDelays) { await delayAgentStep(step); } diff --git a/packages/xl-ai/src/api/formats/base-tools/delete.ts b/packages/xl-ai/src/api/formats/base-tools/delete.ts index 3f96642c6..7fd5728d5 100644 --- a/packages/xl-ai/src/api/formats/base-tools/delete.ts +++ b/packages/xl-ai/src/api/formats/base-tools/delete.ts @@ -76,7 +76,7 @@ export const deleteBlockTool = ( }, // Note: functionality mostly tested in jsontools.test.ts // would be nicer to add a direct unit test - execute: async function* (operationsStream) { + execute: async function* (operationsStream, abortSignal?: AbortSignal) { for await (const chunk of operationsStream) { if (chunk.operation.type !== "delete") { // pass through non-delete operations @@ -93,6 +93,11 @@ export const deleteBlockTool = ( const agentSteps = getStepsAsAgent(tr); for (const step of agentSteps) { + if (abortSignal?.aborted) { + const error = new Error("Operation was aborted"); + error.name = "AbortError"; + throw error; + } if (options.withDelays) { await delayAgentStep(step); } diff --git a/packages/xl-ai/src/blocknoteAIClient/client.ts b/packages/xl-ai/src/blocknoteAIClient/client.ts index c64137972..b17cd8f25 100644 --- a/packages/xl-ai/src/blocknoteAIClient/client.ts +++ b/packages/xl-ai/src/blocknoteAIClient/client.ts @@ -18,6 +18,7 @@ const fetchViaBlockNoteAIServer = body: init?.body || request.body, method: request.method, duplex: "half", + signal: request.signal, } as any, ); try { @@ -25,6 +26,12 @@ const fetchViaBlockNoteAIServer = return resp; } catch (e) { // Temp fix for https://github.com/vercel/ai/issues/6370 + if ( + e instanceof Error && + (e.name === "AbortError" || e.name === "TimeoutError") + ) { + throw e; + } throw new TypeError("fetch failed", { cause: e, }); diff --git a/packages/xl-ai/src/components/AIMenu/AIMenu.tsx b/packages/xl-ai/src/components/AIMenu/AIMenu.tsx index 555325ef5..3fc4db478 100644 --- a/packages/xl-ai/src/components/AIMenu/AIMenu.tsx +++ b/packages/xl-ai/src/components/AIMenu/AIMenu.tsx @@ -94,9 +94,12 @@ export const AIMenu = (props: AIMenuProps) => { const rightSection = useMemo(() => { if (aiResponseStatus === "thinking" || aiResponseStatus === "ai-writing") { return ( - + // TODO +
ai.abort()}> + +
); } else if (aiResponseStatus === "error") { return ( @@ -117,7 +120,7 @@ export const AIMenu = (props: AIMenuProps) => { } return undefined; - }, [Components, aiResponseStatus]); + }, [Components, aiResponseStatus, ai]); return ( []>( const ret = await generateObject<{ operations: any }>(options); + if (opts.abortSignal?.aborted) { + // throw abort error before stream processing starts and `onStart` is called + const error = new Error("Operation was aborted"); + error.name = "AbortError"; + throw error; + } + // because the rest of the codebase always expects a stream, we convert the object to a stream here const stream = operationsToStream(ret.object); @@ -132,7 +141,10 @@ export async function generateOperations[]>( get operationsSource() { if (!_operationsSource) { _operationsSource = createAsyncIterableStreamFromAsyncIterable( - preprocessOperationsNonStreaming(stream.value, streamTools), + withAbort( + preprocessOperationsNonStreaming(stream.value, streamTools), + opts.abortSignal, + ), ); } return _operationsSource; @@ -258,8 +270,11 @@ export async function streamOperations[]>( preprocessOperationsStreaming( filterNewOrUpdatedOperations( streamOnStartCallback( - partialObjectStreamThrowError( - createAsyncIterableStream(fullStream1), + withAbort( + partialObjectStreamThrowError( + createAsyncIterableStream(fullStream1), + ), + opts.abortSignal, ), onStart, ), diff --git a/packages/xl-ai/src/streamTool/streamTool.ts b/packages/xl-ai/src/streamTool/streamTool.ts index d907d1315..2165d7564 100644 --- a/packages/xl-ai/src/streamTool/streamTool.ts +++ b/packages/xl-ai/src/streamTool/streamTool.ts @@ -51,6 +51,7 @@ export type StreamTool = { isUpdateToPreviousOperation: boolean; isPossiblyPartial: boolean; }>, + abortSignal?: AbortSignal, ) => AsyncIterable<{ operation: StreamToolCall[]>; isUpdateToPreviousOperation: boolean; diff --git a/packages/xl-ai/src/util/stream.ts b/packages/xl-ai/src/util/stream.ts index 03ac85c23..9fbf70b1b 100644 --- a/packages/xl-ai/src/util/stream.ts +++ b/packages/xl-ai/src/util/stream.ts @@ -2,7 +2,7 @@ * Converts an AsyncIterable to a ReadableStream */ export function asyncIterableToStream( - iterable: AsyncIterable + iterable: AsyncIterable, ): ReadableStream { return new ReadableStream({ async start(controller) { @@ -29,11 +29,11 @@ export type AsyncIterableStream = AsyncIterable & ReadableStream; * Creates an AsyncIterableStream from a ReadableStream */ export function createAsyncIterableStream( - source: ReadableStream + source: ReadableStream, ): AsyncIterableStream { if (source.locked) { throw new Error( - "Stream (source) is already locked and cannot be iterated." + "Stream (source) is already locked and cannot be iterated.", ); } @@ -60,7 +60,24 @@ export function createAsyncIterableStream( * Creates an AsyncIterableStream from an AsyncGenerator */ export function createAsyncIterableStreamFromAsyncIterable( - source: AsyncIterable + source: AsyncIterable, ): AsyncIterableStream { return createAsyncIterableStream(asyncIterableToStream(source)); } + +/** + * Helper to wrap an async iterable and throw if the abort signal is triggered. + */ +export async function* withAbort( + iterable: AsyncIterable, + signal?: AbortSignal, +) { + for await (const item of iterable) { + if (signal?.aborted) { + const error = new Error("Operation was aborted"); + error.name = "AbortError"; + throw error; + } + yield item; + } +}