Skip to content

feat: make it possible to abort AI requests #1806

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 51 additions & 3 deletions packages/xl-ai/src/AIExtension.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {
suggestChanges,
} from "@blocknote/prosemirror-suggest-changes";
import { APICallError, LanguageModel, RetryError } from "ai";
import { Fragment, Slice } from "prosemirror-model";
import { Plugin, PluginKey } from "prosemirror-state";
import { fixTablesKey } from "prosemirror-tables";
import { createStore, StoreApi } from "zustand/vanilla";
Expand All @@ -17,7 +18,6 @@ import { LLMResponse } from "./api/LLMResponse.js";
import { PromptBuilder } from "./api/formats/PromptBuilder.js";
import { LLMFormat, llmFormats } from "./api/index.js";
import { createAgentCursorPlugin } from "./plugins/AgentCursorPlugin.js";
import { Fragment, Slice } from "prosemirror-model";

type MakeOptional<T, K extends keyof T> = Omit<T, K> & Partial<Pick<T, K>>;

Expand Down Expand Up @@ -87,6 +87,7 @@ const PLUGIN_KEY = new PluginKey(`blocknote-ai-plugin`);

export class AIExtension extends BlockNoteExtension {
private previousRequestOptions: LLMRequestOptions | undefined;
private currentAbortController?: AbortController;

public static key(): string {
return "ai";
Expand Down Expand Up @@ -297,6 +298,24 @@ export class AIExtension extends BlockNoteExtension {
}
}

/**
* Abort the current LLM request.
*
* Only valid when executing an LLM request (status is "thinking" or "ai-writing")
*/
public async abort() {
const state = this.store.getState().aiMenuState;
if (state === "closed") {
throw new Error("abort() is only valid during LLM execution");
}
if (state.status !== "thinking" && state.status !== "ai-writing") {
throw new Error("abort() is only valid during LLM execution");
}

// Abort the current request
this.currentAbortController?.abort();
}

/**
* Update the status of a call to an LLM
*
Expand Down Expand Up @@ -349,15 +368,20 @@ export class AIExtension extends BlockNoteExtension {
* Execute a call to an LLM and apply the result to the editor
*/
public async callLLM(opts: MakeOptional<LLMRequestOptions, "model">) {
const startState = this.store.getState().aiMenuState;
this.setAIResponseStatus("thinking");
this.editor.forkYDocPlugin?.fork();

// Create abort controller for this request
this.currentAbortController = new AbortController();

let ret: LLMResponse | undefined;
try {
const requestOptions = {
...this.options.getState(),
...opts,
previousResponse: this.store.getState().llmResponse,
abortSignal: this.currentAbortController.signal,
};
this.previousRequestOptions = requestOptions;

Expand All @@ -383,18 +407,42 @@ export class AIExtension extends BlockNoteExtension {
llmResponse: ret,
});

await ret.execute();
await ret.execute(this.currentAbortController?.signal);

this.setAIResponseStatus("user-reviewing");
} catch (e) {
// TODO in error state, should we discard the forked document?
// Handle abort errors gracefully
if (e instanceof Error && e.name === "AbortError") {
// Request was aborted, don't set error status as abort() handles cleanup
const state = this.store.getState().aiMenuState;
if (state === "closed" || startState === "closed") {
throw new Error(
"Unexpected: AbortError occurred while the AI menu was closed",
);
}
if (state.status === "ai-writing") {
// we were already writing. Set to reviewing to show the user the partial result
this.setAIResponseStatus("user-reviewing");
} else {
// we were not writing yet. Set to the previous state
if (startState.status === "error") {
this.setAIResponseStatus({ status: startState.status, error: e });
} else {
this.setAIResponseStatus(startState.status);
}
}
return ret;
}

this.setAIResponseStatus({
status: "error",
error: e,
});
// eslint-disable-next-line no-console
console.warn("Error calling LLM", e);
} finally {
// Clean up abort controller
this.currentAbortController = undefined;
}
return ret;
}
Expand Down
7 changes: 7 additions & 0 deletions packages/xl-ai/src/api/LLMRequest.ts
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ export type LLMRequestOptions = {
* @default true
*/
withDelays?: boolean;
/**
* AbortSignal to cancel the LLM request
*/
abortSignal?: AbortSignal;
/**
* Additional options to pass to the AI SDK `generateObject` function
* (only used when `stream` is `false`)
Expand Down Expand Up @@ -134,6 +138,7 @@ export async function doLLMRequest(
withDelays,
dataFormat,
previousResponse,
abortSignal,
...rest
} = {
maxRetries: 2,
Expand Down Expand Up @@ -226,6 +231,7 @@ export async function doLLMRequest(
streamTools,
{
messages,
abortSignal,
...rest,
},
() => {
Expand All @@ -238,6 +244,7 @@ export async function doLLMRequest(
} else {
response = await generateOperations(streamTools, {
messages,
abortSignal,
...rest,
});
if (deleteCursorBlock) {
Expand Down
8 changes: 4 additions & 4 deletions packages/xl-ai/src/api/LLMResponse.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ export class LLMResponse {
*
* (this method consumes underlying streams in `llmResult`)
*/
async *applyToolCalls() {
async *applyToolCalls(abortSignal?: AbortSignal) {
let currentStream: AsyncIterable<{
operation: StreamToolCall<StreamTool<any>[]>;
isUpdateToPreviousOperation: boolean;
isPossiblyPartial: boolean;
}> = this.llmResult.operationsSource;
for (const tool of this.streamTools) {
currentStream = tool.execute(currentStream);
currentStream = tool.execute(currentStream, abortSignal);
}
yield* currentStream;
}
Expand All @@ -45,9 +45,9 @@ export class LLMResponse {
*
* (this method consumes underlying streams in `llmResult`)
*/
public async execute() {
public async execute(abortSignal?: AbortSignal) {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
for await (const _result of this.applyToolCalls()) {
for await (const _result of this.applyToolCalls(abortSignal)) {
// no op
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ export function createAddBlocksTool<T>(config: {
},
// Note: functionality mostly tested in jsontools.test.ts
// would be nicer to add a direct unit test
execute: async function* (operationsStream) {
execute: async function* (operationsStream, abortSignal?: AbortSignal) {
// An add operation has some complexity:
// - it can add multiple blocks in 1 operation
// (this is needed because you need an id as reference block - and if you want to insert multiple blocks you can only use an existing block as reference id)
Expand Down Expand Up @@ -266,6 +266,11 @@ export function createAddBlocksTool<T>(config: {
}

for (const step of agentSteps) {
if (abortSignal?.aborted) {
const error = new Error("Operation was aborted");
error.name = "AbortError";
throw error;
}
if (options.withDelays) {
await delayAgentStep(step);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ export function createUpdateBlockTool<T>(config: {
isUpdateToPreviousOperation: boolean;
isPossiblyPartial: boolean;
}>,
abortSignal?: AbortSignal,
) {
const STEP_SIZE = 50;
let minSize = STEP_SIZE;
Expand Down Expand Up @@ -254,6 +255,11 @@ export function createUpdateBlockTool<T>(config: {
const agentSteps = getStepsAsAgent(tr);

for (const step of agentSteps) {
if (abortSignal?.aborted) {
const error = new Error("Operation was aborted");
error.name = "AbortError";
throw error;
}
if (options.withDelays) {
await delayAgentStep(step);
}
Expand Down
7 changes: 6 additions & 1 deletion packages/xl-ai/src/api/formats/base-tools/delete.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ export const deleteBlockTool = (
},
// Note: functionality mostly tested in jsontools.test.ts
// would be nicer to add a direct unit test
execute: async function* (operationsStream) {
execute: async function* (operationsStream, abortSignal?: AbortSignal) {
for await (const chunk of operationsStream) {
if (chunk.operation.type !== "delete") {
// pass through non-delete operations
Expand All @@ -93,6 +93,11 @@ export const deleteBlockTool = (
const agentSteps = getStepsAsAgent(tr);

for (const step of agentSteps) {
if (abortSignal?.aborted) {
const error = new Error("Operation was aborted");
error.name = "AbortError";
throw error;
}
if (options.withDelays) {
await delayAgentStep(step);
}
Expand Down
7 changes: 7 additions & 0 deletions packages/xl-ai/src/blocknoteAIClient/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,20 @@ const fetchViaBlockNoteAIServer =
body: init?.body || request.body,
method: request.method,
duplex: "half",
signal: request.signal,
} as any,
);
try {
const resp = await fetch(newRequest);
return resp;
} catch (e) {
// Temp fix for https://github.com/vercel/ai/issues/6370
if (
e instanceof Error &&
(e.name === "AbortError" || e.name === "TimeoutError")
) {
throw e;
}
throw new TypeError("fetch failed", {
cause: e,
});
Expand Down
11 changes: 7 additions & 4 deletions packages/xl-ai/src/components/AIMenu/AIMenu.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,12 @@ export const AIMenu = (props: AIMenuProps) => {
const rightSection = useMemo(() => {
if (aiResponseStatus === "thinking" || aiResponseStatus === "ai-writing") {
return (
<Components.SuggestionMenu.Loader
className={"bn-suggestion-menu-loader bn-combobox-right-section"}
/>
// TODO
<div onClick={() => ai.abort()}>
<Components.SuggestionMenu.Loader
className={"bn-suggestion-menu-loader bn-combobox-right-section"}
/>
</div>
);
} else if (aiResponseStatus === "error") {
return (
Expand All @@ -117,7 +120,7 @@ export const AIMenu = (props: AIMenuProps) => {
}

return undefined;
}, [Components, aiResponseStatus]);
}, [Components, aiResponseStatus, ai]);

return (
<PromptSuggestionMenu
Expand Down
21 changes: 18 additions & 3 deletions packages/xl-ai/src/streamTool/callLLMWithStreamTools.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import {
AsyncIterableStream,
createAsyncIterableStream,
createAsyncIterableStreamFromAsyncIterable,
withAbort,
} from "../util/stream.js";
import { filterNewOrUpdatedOperations } from "./filterNewOrUpdatedOperations.js";
import {
Expand All @@ -27,6 +28,7 @@ type LLMRequestOptions = {
model: LanguageModel;
messages: CoreMessage[];
maxRetries: number;
abortSignal?: AbortSignal;
};

/**
Expand Down Expand Up @@ -117,6 +119,13 @@ export async function generateOperations<T extends StreamTool<any>[]>(

const ret = await generateObject<{ operations: any }>(options);

if (opts.abortSignal?.aborted) {
// throw abort error before stream processing starts and `onStart` is called
const error = new Error("Operation was aborted");
error.name = "AbortError";
throw error;
}

// because the rest of the codebase always expects a stream, we convert the object to a stream here
const stream = operationsToStream(ret.object);

Expand All @@ -132,7 +141,10 @@ export async function generateOperations<T extends StreamTool<any>[]>(
get operationsSource() {
if (!_operationsSource) {
_operationsSource = createAsyncIterableStreamFromAsyncIterable(
preprocessOperationsNonStreaming(stream.value, streamTools),
withAbort(
preprocessOperationsNonStreaming(stream.value, streamTools),
opts.abortSignal,
),
);
}
return _operationsSource;
Expand Down Expand Up @@ -258,8 +270,11 @@ export async function streamOperations<T extends StreamTool<any>[]>(
preprocessOperationsStreaming(
filterNewOrUpdatedOperations(
streamOnStartCallback(
partialObjectStreamThrowError(
createAsyncIterableStream(fullStream1),
withAbort(
partialObjectStreamThrowError(
createAsyncIterableStream(fullStream1),
),
opts.abortSignal,
),
onStart,
),
Expand Down
1 change: 1 addition & 0 deletions packages/xl-ai/src/streamTool/streamTool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ export type StreamTool<T extends { type: string }> = {
isUpdateToPreviousOperation: boolean;
isPossiblyPartial: boolean;
}>,
abortSignal?: AbortSignal,
) => AsyncIterable<{
operation: StreamToolCall<StreamTool<{ type: string }>[]>;
isUpdateToPreviousOperation: boolean;
Expand Down
25 changes: 21 additions & 4 deletions packages/xl-ai/src/util/stream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
* Converts an AsyncIterable to a ReadableStream
*/
export function asyncIterableToStream<T>(
iterable: AsyncIterable<T>
iterable: AsyncIterable<T>,
): ReadableStream<T> {
return new ReadableStream({
async start(controller) {
Expand All @@ -29,11 +29,11 @@ export type AsyncIterableStream<T> = AsyncIterable<T> & ReadableStream<T>;
* Creates an AsyncIterableStream from a ReadableStream
*/
export function createAsyncIterableStream<T>(
source: ReadableStream<T>
source: ReadableStream<T>,
): AsyncIterableStream<T> {
if (source.locked) {
throw new Error(
"Stream (source) is already locked and cannot be iterated."
"Stream (source) is already locked and cannot be iterated.",
);
}

Expand All @@ -60,7 +60,24 @@ export function createAsyncIterableStream<T>(
* Creates an AsyncIterableStream from an AsyncGenerator
*/
export function createAsyncIterableStreamFromAsyncIterable<T>(
source: AsyncIterable<T>
source: AsyncIterable<T>,
): AsyncIterableStream<T> {
return createAsyncIterableStream(asyncIterableToStream(source));
}

/**
* Helper to wrap an async iterable and throw if the abort signal is triggered.
*/
export async function* withAbort<T>(
iterable: AsyncIterable<T>,
signal?: AbortSignal,
) {
for await (const item of iterable) {
if (signal?.aborted) {
const error = new Error("Operation was aborted");
error.name = "AbortError";
throw error;
}
yield item;
}
}
Loading