diff --git a/packages/chatbot-server-mongodb-public/src/config.ts b/packages/chatbot-server-mongodb-public/src/config.ts index 7f9db1944..0b0b67808 100644 --- a/packages/chatbot-server-mongodb-public/src/config.ts +++ b/packages/chatbot-server-mongodb-public/src/config.ts @@ -19,6 +19,7 @@ import { defaultCreateConversationCustomData, defaultAddMessageToConversationCustomData, makeVerifiedAnswerGenerateResponse, + addMessageToConversationVerifiedAnswerStream, } from "mongodb-chatbot-server"; import cookieParser from "cookie-parser"; import { blockGetRequests } from "./middleware/blockGetRequests"; @@ -40,7 +41,6 @@ import { import { AzureOpenAI } from "mongodb-rag-core/openai"; import { MongoClient } from "mongodb-rag-core/mongodb"; import { - ANALYZER_ENV_VARS, AZURE_OPENAI_ENV_VARS, PREPROCESSOR_ENV_VARS, TRACING_ENV_VARS, @@ -53,7 +53,10 @@ import { import { useSegmentIds } from "./middleware/useSegmentIds"; import { makeSearchTool } from "./tools/search"; import { makeMongoDbInputGuardrail } from "./processors/mongoDbInputGuardrail"; -import { makeGenerateResponseWithSearchTool } from "./processors/generateResponseWithSearchTool"; +import { + addMessageToConversationStream, + makeGenerateResponseWithSearchTool, +} from "./processors/generateResponseWithSearchTool"; import { makeBraintrustLogger } from "mongodb-rag-core/braintrust"; import { makeMongoDbScrubbedMessageStore } from "./tracing/scrubbedMessages/MongoDbScrubbedMessageStore"; import { MessageAnalysis } from "./tracing/scrubbedMessages/analyzeMessage"; @@ -231,6 +234,7 @@ export const generateResponse = wrapTraced( references: verifiedAnswer.references.map(addReferenceSourceType), }; }, + stream: addMessageToConversationVerifiedAnswerStream, onNoVerifiedAnswerFound: wrapTraced( makeGenerateResponseWithSearchTool({ languageModel, @@ -253,6 +257,7 @@ export const generateResponse = wrapTraced( searchTool: makeSearchTool(findContent), toolChoice: "auto", maxSteps: 5, + stream: addMessageToConversationStream, }), { name: "generateResponseWithSearchTool" } ), diff --git a/packages/chatbot-server-mongodb-public/src/processors/generateResponseWithSearchTool.test.ts b/packages/chatbot-server-mongodb-public/src/processors/generateResponseWithSearchTool.test.ts index 3951d8141..723998986 100644 --- a/packages/chatbot-server-mongodb-public/src/processors/generateResponseWithSearchTool.test.ts +++ b/packages/chatbot-server-mongodb-public/src/processors/generateResponseWithSearchTool.test.ts @@ -351,18 +351,21 @@ describe("generateResponseWithSearchTool", () => { describe("streaming mode", () => { // Create a mock DataStreamer implementation const makeMockDataStreamer = () => { - const mockStreamData = jest.fn(); const mockConnect = jest.fn(); const mockDisconnect = jest.fn(); + const mockStreamData = jest.fn(); + const mockStreamResponses = jest.fn(); const mockStream = jest.fn().mockImplementation(async () => { // Process the stream and return a string result return "Hello"; }); + const dataStreamer = { connected: false, connect: mockConnect, disconnect: mockDisconnect, streamData: mockStreamData, + streamResponses: mockStreamResponses, stream: mockStream, } as DataStreamer; diff --git a/packages/chatbot-server-mongodb-public/src/processors/generateResponseWithSearchTool.ts b/packages/chatbot-server-mongodb-public/src/processors/generateResponseWithSearchTool.ts index 074184d5d..692ba6e37 100644 --- a/packages/chatbot-server-mongodb-public/src/processors/generateResponseWithSearchTool.ts +++ b/packages/chatbot-server-mongodb-public/src/processors/generateResponseWithSearchTool.ts @@ -6,7 +6,6 @@ import { AssistantMessage, ToolMessage, } from "mongodb-rag-core"; - import { CoreAssistantMessage, CoreMessage, @@ -28,6 +27,7 @@ import { GenerateResponse, GenerateResponseReturnValue, InputGuardrailResult, + type StreamFunction, } from "mongodb-chatbot-server"; import { MongoDbSearchToolArgs, @@ -52,8 +52,59 @@ export interface GenerateResponseWithSearchToolParams { search_content: SearchTool; }>; searchTool: SearchTool; + stream?: { + onLlmNotWorking: StreamFunction<{ notWorkingMessage: string }>; + onLlmRefusal: StreamFunction<{ refusalMessage: string }>; + onReferenceLinks: StreamFunction<{ references: References }>; + onTextDelta: StreamFunction<{ delta: string }>; + }; } +export const addMessageToConversationStream: GenerateResponseWithSearchToolParams["stream"] = + { + onLlmNotWorking({ dataStreamer, notWorkingMessage }) { + dataStreamer?.streamData({ + type: "delta", + data: notWorkingMessage, + }); + }, + onLlmRefusal({ dataStreamer, refusalMessage }) { + dataStreamer?.streamData({ + type: "delta", + data: refusalMessage, + }); + }, + onReferenceLinks({ dataStreamer, references }) { + dataStreamer?.streamData({ + type: "references", + data: references, + }); + }, + onTextDelta({ dataStreamer, delta }) { + dataStreamer?.streamData({ + type: "delta", + data: delta, + }); + }, + }; + +// TODO: implement this +export const responsesApiStream: GenerateResponseWithSearchToolParams["stream"] = + { + onLlmNotWorking() { + throw new Error("not yet implemented"); + }, + onLlmRefusal() { + throw new Error("not yet implemented"); + }, + onReferenceLinks() { + throw new Error("not yet implemented"); + }, + onTextDelta() { + throw new Error("not yet implemented"); + }, + }; + /** Generate chatbot response using RAG and a search tool named {@link SEARCH_TOOL_NAME}. */ @@ -69,6 +120,7 @@ export function makeGenerateResponseWithSearchTool({ maxSteps = 2, searchTool, toolChoice, + stream, }: GenerateResponseWithSearchToolParams): GenerateResponse { return async function generateResponseWithSearchTool({ conversation, @@ -80,9 +132,11 @@ export function makeGenerateResponseWithSearchTool({ dataStreamer, request, }) { - if (shouldStream) { - assert(dataStreamer, "dataStreamer is required for streaming"); - } + const streamingModeActive = + shouldStream === true && + dataStreamer !== undefined && + stream !== undefined; + const userMessage: UserMessage = { role: "user", content: latestMessageText, @@ -179,10 +233,10 @@ export function makeGenerateResponseWithSearchTool({ switch (chunk.type) { case "text-delta": - if (shouldStream) { - dataStreamer?.streamData({ - data: chunk.textDelta, - type: "delta", + if (streamingModeActive) { + stream.onTextDelta({ + dataStreamer, + delta: chunk.textDelta, }); } break; @@ -202,10 +256,10 @@ export function makeGenerateResponseWithSearchTool({ // Stream references if we have any and weren't aborted if (references.length > 0 && !generationController.signal.aborted) { - if (shouldStream) { - dataStreamer?.streamData({ - data: references, - type: "references", + if (streamingModeActive) { + stream.onReferenceLinks({ + dataStreamer, + references, }); } } @@ -238,10 +292,10 @@ export function makeGenerateResponseWithSearchTool({ ...userMessageCustomData, ...guardrailResult, }; - if (shouldStream) { - dataStreamer?.streamData({ - type: "delta", - data: llmRefusalMessage, + if (streamingModeActive) { + stream.onLlmRefusal({ + dataStreamer, + refusalMessage: llmRefusalMessage, }); } return handleReturnGeneration({ @@ -293,10 +347,10 @@ export function makeGenerateResponseWithSearchTool({ }); } } catch (error: unknown) { - if (shouldStream) { - dataStreamer?.streamData({ - type: "delta", - data: llmNotWorkingMessage, + if (streamingModeActive) { + stream.onLlmNotWorking({ + dataStreamer, + notWorkingMessage: llmNotWorkingMessage, }); } diff --git a/packages/mongodb-chatbot-server/src/processors/makeVerifiedAnswerGenerateResponse.test.ts b/packages/mongodb-chatbot-server/src/processors/makeVerifiedAnswerGenerateResponse.test.ts index c5618c9d2..90d005c1f 100644 --- a/packages/mongodb-chatbot-server/src/processors/makeVerifiedAnswerGenerateResponse.test.ts +++ b/packages/mongodb-chatbot-server/src/processors/makeVerifiedAnswerGenerateResponse.test.ts @@ -1,5 +1,8 @@ import { ObjectId } from "mongodb-rag-core/mongodb"; -import { makeVerifiedAnswerGenerateResponse } from "./makeVerifiedAnswerGenerateResponse"; +import { + makeVerifiedAnswerGenerateResponse, + type StreamFunction, +} from "./makeVerifiedAnswerGenerateResponse"; import { VerifiedAnswer, WithScore, DataStreamer } from "mongodb-rag-core"; import { GenerateResponseReturnValue } from "./GenerateResponse"; @@ -24,6 +27,29 @@ describe("makeVerifiedAnswerGenerateResponse", () => { }, ] satisfies GenerateResponseReturnValue["messages"]; + const streamVerifiedAnswer: StreamFunction<{ + verifiedAnswer: VerifiedAnswer; + }> = async ({ dataStreamer, verifiedAnswer }) => { + dataStreamer.streamData({ + type: "metadata", + data: { + verifiedAnswer: { + _id: verifiedAnswer._id, + created: verifiedAnswer.created, + updated: verifiedAnswer.updated, + }, + }, + }); + dataStreamer.streamData({ + type: "delta", + data: verifiedAnswer.answer, + }); + dataStreamer.streamData({ + type: "references", + data: verifiedAnswer.references, + }); + }; + // Create a mock verified answer const createMockVerifiedAnswer = (): WithScore => ({ answer: verifiedAnswerContent, @@ -55,6 +81,7 @@ describe("makeVerifiedAnswerGenerateResponse", () => { connect: jest.fn(), disconnect: jest.fn(), stream: jest.fn(), + streamResponses: jest.fn(), }); // Create base request parameters @@ -79,6 +106,9 @@ describe("makeVerifiedAnswerGenerateResponse", () => { onNoVerifiedAnswerFound: async () => ({ messages: noVerifiedAnswerFoundMessages, }), + stream: { + onVerifiedAnswerFound: streamVerifiedAnswer, + }, }); it("uses onNoVerifiedAnswerFound if no verified answer is found", async () => { diff --git a/packages/mongodb-chatbot-server/src/processors/makeVerifiedAnswerGenerateResponse.ts b/packages/mongodb-chatbot-server/src/processors/makeVerifiedAnswerGenerateResponse.ts index 01d3be4f6..d8df30147 100644 --- a/packages/mongodb-chatbot-server/src/processors/makeVerifiedAnswerGenerateResponse.ts +++ b/packages/mongodb-chatbot-server/src/processors/makeVerifiedAnswerGenerateResponse.ts @@ -1,4 +1,8 @@ -import { VerifiedAnswer, FindVerifiedAnswerFunc } from "mongodb-rag-core"; +import { + VerifiedAnswer, + FindVerifiedAnswerFunc, + DataStreamer, +} from "mongodb-rag-core"; import { strict as assert } from "assert"; import { GenerateResponse, @@ -17,8 +21,40 @@ export interface MakeVerifiedAnswerGenerateResponseParams { onVerifiedAnswerFound?: (verifiedAnswer: VerifiedAnswer) => VerifiedAnswer; onNoVerifiedAnswerFound: GenerateResponse; + + stream?: { + onVerifiedAnswerFound: StreamFunction<{ verifiedAnswer: VerifiedAnswer }>; + }; } +export type StreamFunction = ( + params: { dataStreamer: DataStreamer } & Params +) => void; + +export const addMessageToConversationVerifiedAnswerStream: MakeVerifiedAnswerGenerateResponseParams["stream"] = + { + onVerifiedAnswerFound: ({ verifiedAnswer, dataStreamer }) => { + dataStreamer.streamData({ + type: "metadata", + data: { + verifiedAnswer: { + _id: verifiedAnswer._id, + created: verifiedAnswer.created, + updated: verifiedAnswer.updated, + }, + }, + }); + dataStreamer.streamData({ + type: "delta", + data: verifiedAnswer.answer, + }); + dataStreamer.streamData({ + type: "references", + data: verifiedAnswer.references, + }); + }, + }; + /** Searches for verified answers for the user query. If no verified answer can be found for the given query, the @@ -28,6 +64,7 @@ export const makeVerifiedAnswerGenerateResponse = ({ findVerifiedAnswer, onVerifiedAnswerFound, onNoVerifiedAnswerFound, + stream, }: MakeVerifiedAnswerGenerateResponseParams): GenerateResponse => { return async (args) => { const { latestMessageText, shouldStream, dataStreamer } = args; @@ -54,17 +91,11 @@ export const makeVerifiedAnswerGenerateResponse = ({ if (shouldStream) { assert(dataStreamer, "Must have dataStreamer if shouldStream=true"); - dataStreamer.streamData({ - type: "metadata", - data: metadata, - }); - dataStreamer.streamData({ - type: "delta", - data: answer, - }); - dataStreamer.streamData({ - type: "references", - data: references, + assert(stream, "Must have stream if shouldStream=true"); + + stream.onVerifiedAnswerFound({ + dataStreamer, + verifiedAnswer, }); } diff --git a/packages/mongodb-chatbot-server/src/routes/responses/createResponse.test.ts b/packages/mongodb-chatbot-server/src/routes/responses/createResponse.test.ts index 6d00cfc21..4690a6225 100644 --- a/packages/mongodb-chatbot-server/src/routes/responses/createResponse.test.ts +++ b/packages/mongodb-chatbot-server/src/routes/responses/createResponse.test.ts @@ -1,154 +1,178 @@ import "dotenv/config"; -import request from "supertest"; -import type { Express } from "express"; -import type { Conversation, SomeMessage } from "mongodb-rag-core"; -import { DEFAULT_API_PREFIX, type AppConfig } from "../../app"; -import { makeTestApp } from "../../test/testHelpers"; -import { basicResponsesRequestBody } from "../../test/testConfig"; -import { ERROR_TYPE, ERROR_CODE } from "./errors"; +import type { Server } from "http"; +import { ObjectId } from "mongodb"; +import type { + Conversation, + ConversationsService, + SomeMessage, +} from "mongodb-rag-core"; +import { type AppConfig } from "../../app"; +import { + makeTestLocalServer, + makeOpenAiClient, + makeCreateResponseRequestStream, + type Stream, +} from "../../test/testHelpers"; +import { makeDefaultConfig } from "../../test/testConfig"; import { ERR_MSG, type CreateResponseRequest } from "./createResponse"; +import { ERROR_CODE, ERROR_TYPE } from "./errors"; jest.setTimeout(100000); describe("POST /responses", () => { - const endpointUrl = `${DEFAULT_API_PREFIX}/responses`; - let app: Express; let appConfig: AppConfig; + let server: Server; let ipAddress: string; let origin: string; + let conversations: ConversationsService; beforeEach(async () => { - ({ app, ipAddress, origin, appConfig } = await makeTestApp()); + appConfig = await makeDefaultConfig(); + + ({ conversations } = appConfig.responsesRouterConfig.createResponse); + + // use a unique port so this doesn't collide with other test suites + const testPort = 5200; + ({ server, ipAddress, origin } = await makeTestLocalServer( + appConfig, + testPort + )); }); - afterEach(() => { + afterEach(async () => { + server?.listening && server?.close(); jest.restoreAllMocks(); }); - const makeCreateResponseRequest = ( - body?: Partial, - appOverride?: Express + const makeClientAndRequest = ( + body?: Partial ) => { - // TODO: update this to use the openai client - return request(appOverride ?? app) - .post(endpointUrl) - .set("X-Forwarded-For", ipAddress) - .set("Origin", origin) - .send({ ...basicResponsesRequestBody, ...body }); + const openAiClient = makeOpenAiClient(origin, ipAddress); + return makeCreateResponseRequestStream(openAiClient, body); }; describe("Valid requests", () => { - it("Should return 200 given a string input", async () => { - const response = await makeCreateResponseRequest(); + it("Should return responses given a string input", async () => { + const stream = await makeClientAndRequest(); - expect(response.statusCode).toBe(200); + await expectValidResponses({ requestBody: {}, stream }); }); - it("Should return 200 given a message array input", async () => { - const response = await makeCreateResponseRequest({ + it("Should return responses given a message array input", async () => { + const requestBody: Partial = { input: [ { role: "system", content: "You are a helpful assistant." }, { role: "user", content: "What is MongoDB?" }, { role: "assistant", content: "MongoDB is a document database." }, { role: "user", content: "What is a document database?" }, ], - }); + }; + const stream = await makeClientAndRequest(requestBody); - expect(response.statusCode).toBe(200); + await expectValidResponses({ requestBody, stream }); }); - it("Should return 200 given a valid request with instructions", async () => { - const response = await makeCreateResponseRequest({ + it("Should return responses given a valid request with instructions", async () => { + const requestBody: Partial = { instructions: "You are a helpful chatbot.", - }); + }; + const stream = await makeClientAndRequest(requestBody); - expect(response.statusCode).toBe(200); + await expectValidResponses({ requestBody, stream }); }); - it("Should return 200 with valid max_output_tokens", async () => { - const response = await makeCreateResponseRequest({ + it("Should return responses with valid max_output_tokens", async () => { + const requestBody: Partial = { max_output_tokens: 4000, - }); + }; + const stream = await makeClientAndRequest(requestBody); - expect(response.statusCode).toBe(200); + await expectValidResponses({ requestBody, stream }); }); - it("Should return 200 with valid metadata", async () => { - const response = await makeCreateResponseRequest({ + it("Should return responses with valid metadata", async () => { + const requestBody: Partial = { metadata: { key1: "value1", key2: "value2" }, - }); + }; + const stream = await makeClientAndRequest(requestBody); - expect(response.statusCode).toBe(200); + await expectValidResponses({ requestBody, stream }); }); - it("Should return 200 with valid temperature", async () => { - const response = await makeCreateResponseRequest({ + it("Should return responses with valid temperature", async () => { + const requestBody: Partial = { temperature: 0, - }); + }; + const stream = await makeClientAndRequest(requestBody); - expect(response.statusCode).toBe(200); + await expectValidResponses({ requestBody, stream }); }); - it("Should return 200 with previous_response_id", async () => { - const conversation = - await appConfig.conversationsRouterConfig.conversations.create({ - initialMessages: [{ role: "user", content: "What is MongoDB?" }], - }); + it("Should return responses with previous_response_id", async () => { + const initialMessages: Array = [ + { role: "user", content: "Initial message!" }, + ]; + const { messages } = await conversations.create({ initialMessages }); - const previousResponseId = conversation.messages[0].id; - const response = await makeCreateResponseRequest({ - previous_response_id: previousResponseId.toString(), - }); + const previous_response_id = messages.at(-1)?.id.toString(); + const requestBody: Partial = { + previous_response_id, + }; + const stream = await makeClientAndRequest(requestBody); - expect(response.statusCode).toBe(200); + await expectValidResponses({ requestBody, stream }); }); - it("Should return 200 if previous_response_id is the latest message", async () => { - const conversation = - await appConfig.conversationsRouterConfig.conversations.create({ - initialMessages: [ - { role: "user", content: "What is MongoDB?" }, - { role: "assistant", content: "MongoDB is a document database." }, - { role: "user", content: "What is a document database?" }, - ], - }); + it("Should return responses if previous_response_id is the latest message", async () => { + const initialMessages: Array = [ + { role: "user", content: "Initial message!" }, + { role: "assistant", content: "Initial response!" }, + { role: "user", content: "Another message!" }, + ]; + const { messages } = await conversations.create({ initialMessages }); - const previousResponseId = conversation.messages[2].id; - const response = await makeCreateResponseRequest({ - previous_response_id: previousResponseId.toString(), - }); + const previous_response_id = messages.at(-1)?.id.toString(); + const requestBody: Partial = { + previous_response_id, + }; + const stream = await makeClientAndRequest(requestBody); - expect(response.statusCode).toBe(200); + await expectValidResponses({ requestBody, stream }); }); - it("Should return 200 with user", async () => { - const response = await makeCreateResponseRequest({ + it("Should return responses with user", async () => { + const requestBody: Partial = { user: "some-user-id", - }); + }; + const stream = await makeClientAndRequest(requestBody); - expect(response.statusCode).toBe(200); + await expectValidResponses({ requestBody, stream }); }); - it("Should return 200 with store=false", async () => { - const response = await makeCreateResponseRequest({ + it("Should return responses with store=false", async () => { + const requestBody: Partial = { store: false, - }); + }; + const stream = await makeClientAndRequest(requestBody); - expect(response.statusCode).toBe(200); + await expectValidResponses({ requestBody, stream }); }); - it("Should return 200 with store=true", async () => { - const response = await makeCreateResponseRequest({ + it("Should return responses with store=true", async () => { + const requestBody: Partial = { store: true, - }); + }; + const stream = await makeClientAndRequest(requestBody); - expect(response.statusCode).toBe(200); + await expectValidResponses({ requestBody, stream }); }); - it("Should return 200 with tools and tool_choice", async () => { - const response = await makeCreateResponseRequest({ + it("Should return responses with tools and tool_choice", async () => { + const requestBody: Partial = { tools: [ { + type: "function", + strict: true, name: "test-tool", description: "A tool for testing.", parameters: { @@ -161,15 +185,18 @@ describe("POST /responses", () => { }, ], tool_choice: "auto", - }); + }; + const stream = await makeClientAndRequest(requestBody); - expect(response.statusCode).toBe(200); + await expectValidResponses({ requestBody, stream }); }); - it("Should return 200 with a specific function tool_choice", async () => { - const response = await makeCreateResponseRequest({ + it("Should return responses with a specific function tool_choice", async () => { + const requestBody: Partial = { tools: [ { + type: "function", + strict: true, name: "test-tool", description: "A tool for testing.", parameters: { @@ -185,30 +212,32 @@ describe("POST /responses", () => { type: "function", name: "test-tool", }, - }); + }; + const stream = await makeClientAndRequest(requestBody); - expect(response.statusCode).toBe(200); + await expectValidResponses({ requestBody, stream }); }); - it("Should return 200 given a message array with function_call", async () => { - const response = await makeCreateResponseRequest({ + it("Should return responses given a message array with function_call", async () => { + const requestBody: Partial = { input: [ { role: "user", content: "What is MongoDB?" }, { type: "function_call", - id: "call123", + call_id: "call123", name: "my_function", arguments: `{"query": "value"}`, status: "in_progress", }, ], - }); + }; + const stream = await makeClientAndRequest(requestBody); - expect(response.statusCode).toBe(200); + await expectValidResponses({ requestBody, stream }); }); - it("Should return 200 given a message array with function_call_output", async () => { - const response = await makeCreateResponseRequest({ + it("Should return responses given a message array with function_call_output", async () => { + const requestBody: Partial = { input: [ { role: "user", content: "What is MongoDB?" }, { @@ -218,103 +247,91 @@ describe("POST /responses", () => { status: "completed", }, ], - }); + }; + const stream = await makeClientAndRequest(requestBody); - expect(response.statusCode).toBe(200); + await expectValidResponses({ requestBody, stream }); }); - it("Should return 200 with tool_choice 'none'", async () => { - const response = await makeCreateResponseRequest({ + it("Should return responses with a valid tool_choice", async () => { + const requestBody: Partial = { tool_choice: "none", - }); - - expect(response.statusCode).toBe(200); - }); - - it("Should return 200 with tool_choice 'only'", async () => { - const response = await makeCreateResponseRequest({ - tool_choice: "only", - }); + }; + const stream = await makeClientAndRequest(requestBody); - expect(response.statusCode).toBe(200); + await expectValidResponses({ requestBody, stream }); }); - it("Should return 200 with an empty tools array", async () => { - const response = await makeCreateResponseRequest({ + it("Should return responses with an empty tools array", async () => { + const requestBody: Partial = { tools: [], - }); + }; + const stream = await makeClientAndRequest(requestBody); - expect(response.statusCode).toBe(200); + await expectValidResponses({ requestBody, stream }); }); it("Should store conversation messages if `storeMessageContent: undefined` and `store: true`", async () => { - const createSpy = jest.spyOn( - appConfig.conversationsRouterConfig.conversations, - "create" - ); - const addMessagesSpy = jest.spyOn( - appConfig.conversationsRouterConfig.conversations, - "addManyConversationMessages" - ); - const storeMessageContent = undefined; - const conversation = - await appConfig.conversationsRouterConfig.conversations.create({ - storeMessageContent, - initialMessages: [{ role: "user", content: "What is MongoDB?" }], - }); + const initialMessages: Array = [ + { role: "user", content: "Initial message!" }, + ]; + const { _id, messages } = await conversations.create({ + storeMessageContent, + initialMessages, + }); const store = true; - const previousResponseId = conversation.messages[0].id.toString(); - const response = await makeCreateResponseRequest({ - previous_response_id: previousResponseId, + const previous_response_id = messages.at(-1)?.id.toString(); + const requestBody: Partial = { + previous_response_id, store, - }); + }; + const stream = await makeClientAndRequest(requestBody); + + const updatedConversation = await conversations.findById({ _id }); + if (!updatedConversation) { + return expect(updatedConversation).not.toBeNull(); + } - const createdConversation = await createSpy.mock.results[0].value; - const addedMessages = await addMessagesSpy.mock.results[0].value; + await expectValidResponses({ requestBody, stream }); - expect(response.statusCode).toBe(200); - expect(createdConversation.storeMessageContent).toEqual( + expect(updatedConversation?.storeMessageContent).toEqual( storeMessageContent ); - testDefaultMessageContent({ - createdConversation, - addedMessages, + expectDefaultMessageContent({ + initialMessages, + updatedConversation, store, }); }); it("Should store conversation messages when `store: true`", async () => { - const createSpy = jest.spyOn( - appConfig.conversationsRouterConfig.conversations, - "create" - ); - const addMessagesSpy = jest.spyOn( - appConfig.conversationsRouterConfig.conversations, - "addManyConversationMessages" - ); - const store = true; const userId = "customUserId"; const metadata = { customMessage1: "customMessage1", customMessage2: "customMessage2", }; - const response = await makeCreateResponseRequest({ + const requestBody: Partial = { store, metadata, user: userId, - }); + }; + const stream = await makeClientAndRequest(requestBody); - const createdConversation = await createSpy.mock.results[0].value; - const addedMessages = await addMessagesSpy.mock.results[0].value; + const results = await expectValidResponses({ requestBody, stream }); + + const updatedConversation = await conversations.findByMessageId({ + messageId: getMessageIdFromResults(results), + }); + if (!updatedConversation) { + return expect(updatedConversation).not.toBeNull(); + } - expect(response.statusCode).toBe(200); - expect(createdConversation.storeMessageContent).toEqual(store); - testDefaultMessageContent({ - createdConversation, - addedMessages, + expect(updatedConversation.storeMessageContent).toEqual(store); + expectDefaultMessageContent({ + updatedConversation, userId, store, metadata, @@ -322,35 +339,31 @@ describe("POST /responses", () => { }); it("Should not store conversation messages when `store: false`", async () => { - const createSpy = jest.spyOn( - appConfig.conversationsRouterConfig.conversations, - "create" - ); - const addMessagesSpy = jest.spyOn( - appConfig.conversationsRouterConfig.conversations, - "addManyConversationMessages" - ); - const store = false; const userId = "customUserId"; const metadata = { customMessage1: "customMessage1", customMessage2: "customMessage2", }; - const response = await makeCreateResponseRequest({ + const requestBody: Partial = { store, metadata, user: userId, - }); + }; + const stream = await makeClientAndRequest(requestBody); - const createdConversation = await createSpy.mock.results[0].value; - const addedMessages = await addMessagesSpy.mock.results[0].value; + const results = await expectValidResponses({ requestBody, stream }); - expect(response.statusCode).toBe(200); - expect(createdConversation.storeMessageContent).toEqual(store); - testDefaultMessageContent({ - createdConversation, - addedMessages, + const updatedConversation = await conversations.findByMessageId({ + messageId: getMessageIdFromResults(results), + }); + if (!updatedConversation) { + return expect(updatedConversation).not.toBeNull(); + } + + expect(updatedConversation.storeMessageContent).toEqual(store); + expectDefaultMessageContent({ + updatedConversation, userId, store, metadata, @@ -358,24 +371,15 @@ describe("POST /responses", () => { }); it("Should store function_call messages when `store: true`", async () => { - const createSpy = jest.spyOn( - appConfig.conversationsRouterConfig.conversations, - "create" - ); - const addMessagesSpy = jest.spyOn( - appConfig.conversationsRouterConfig.conversations, - "addManyConversationMessages" - ); - const store = true; const functionCallType = "function_call"; const functionCallOutputType = "function_call_output"; - const response = await makeCreateResponseRequest({ + const requestBody: Partial = { store, input: [ { type: functionCallType, - id: "call123", + call_id: "call123", name: "my_function", arguments: `{"query": "value"}`, status: "in_progress", @@ -387,139 +391,138 @@ describe("POST /responses", () => { status: "completed", }, ], - }); + }; + const stream = await makeClientAndRequest(requestBody); - const createdConversation = await createSpy.mock.results[0].value; - const addedMessages = await addMessagesSpy.mock.results[0].value; + const results = await expectValidResponses({ requestBody, stream }); - expect(response.statusCode).toBe(200); - expect(createdConversation.storeMessageContent).toEqual(store); + const updatedConversation = await conversations.findByMessageId({ + messageId: getMessageIdFromResults(results), + }); + if (!updatedConversation) { + return expect(updatedConversation).not.toBeNull(); + } + + expect(updatedConversation.storeMessageContent).toEqual(store); - expect(addedMessages[0].role).toEqual("system"); - expect(addedMessages[1].role).toEqual("system"); + expect(updatedConversation.messages[0].role).toEqual("system"); + expect(updatedConversation.messages[0].content).toEqual(functionCallType); - expect(addedMessages[0].content).toEqual(functionCallType); - expect(addedMessages[1].content).toEqual(functionCallOutputType); + expect(updatedConversation.messages[1].role).toEqual("system"); + expect(updatedConversation.messages[1].content).toEqual( + functionCallOutputType + ); }); }); describe("Invalid requests", () => { - it("Should return 400 with an empty input string", async () => { - const response = await makeCreateResponseRequest({ + it("Should return error responses if empty input string", async () => { + const stream = await makeClientAndRequest({ input: "", }); - expect(response.statusCode).toBe(400); - expect(response.body.error).toEqual( - badRequestError(`Path: body.input - ${ERR_MSG.INPUT_STRING}`) - ); + await expectInvalidResponses({ + stream, + message: `Path: body.input - ${ERR_MSG.INPUT_STRING}`, + }); }); - it("Should return 400 with an empty message array", async () => { - const response = await makeCreateResponseRequest({ + it("Should return error responses if empty message array", async () => { + const stream = await makeClientAndRequest({ input: [], }); - expect(response.statusCode).toBe(400); - expect(response.body.error).toEqual( - badRequestError(`Path: body.input - ${ERR_MSG.INPUT_ARRAY}`) - ); - }); - - it("Should return 400 if model is not mongodb-chat-latest", async () => { - const response = await makeCreateResponseRequest({ - model: "gpt-4o-mini", + await expectInvalidResponses({ + stream, + message: `Path: body.input - ${ERR_MSG.INPUT_ARRAY}`, }); - - expect(response.statusCode).toBe(400); - expect(response.body.error).toEqual( - badRequestError(ERR_MSG.MODEL_NOT_SUPPORTED("gpt-4o-mini")) - ); }); - it("Should return 400 if stream is not true", async () => { - const response = await makeCreateResponseRequest({ - stream: false, + it("Should return error responses if model is not supported via config", async () => { + const invalidModel = "invalid-model"; + const stream = await makeClientAndRequest({ + model: invalidModel, }); - expect(response.statusCode).toBe(400); - expect(response.body.error).toEqual( - badRequestError(`Path: body.stream - ${ERR_MSG.STREAM}`) - ); + await expectInvalidResponses({ + stream, + message: ERR_MSG.MODEL_NOT_SUPPORTED(invalidModel), + }); }); - it("Should return 400 if max_output_tokens is > allowed limit", async () => { + it("Should return error responses if max_output_tokens is > allowed limit", async () => { const max_output_tokens = 4001; - - const response = await makeCreateResponseRequest({ + const stream = await makeClientAndRequest({ max_output_tokens, }); - expect(response.statusCode).toBe(400); - expect(response.body.error).toEqual( - badRequestError(ERR_MSG.MAX_OUTPUT_TOKENS(max_output_tokens, 4000)) - ); + await expectInvalidResponses({ + stream, + message: ERR_MSG.MAX_OUTPUT_TOKENS(max_output_tokens, 4000), + }); }); - it("Should return 400 if metadata has too many fields", async () => { + it("Should return error responses if metadata has too many fields", async () => { const metadata: Record = {}; for (let i = 0; i < 17; i++) { metadata[`key${i}`] = "value"; } - const response = await makeCreateResponseRequest({ + const stream = await makeClientAndRequest({ metadata, }); - expect(response.statusCode).toBe(400); - expect(response.body.error).toEqual( - badRequestError(`Path: body.metadata - ${ERR_MSG.METADATA_LENGTH}`) - ); + await expectInvalidResponses({ + stream, + message: `Path: body.metadata - ${ERR_MSG.METADATA_LENGTH}`, + }); }); - it("Should return 400 if metadata value is too long", async () => { - const response = await makeCreateResponseRequest({ + it("Should return error responses if metadata value is too long", async () => { + const stream = await makeClientAndRequest({ metadata: { key1: "a".repeat(513) }, }); - expect(response.statusCode).toBe(400); - expect(response.body.error).toEqual( - badRequestError( - "Path: body.metadata.key1 - String must contain at most 512 character(s)" - ) - ); + await expectInvalidResponses({ + stream, + message: + "Path: body.metadata.key1 - String must contain at most 512 character(s)", + }); }); - it("Should return 400 if temperature is not 0", async () => { - const response = await makeCreateResponseRequest({ + it("Should return error responses if temperature is not 0", async () => { + const stream = await makeClientAndRequest({ temperature: 0.5 as any, }); - expect(response.statusCode).toBe(400); - expect(response.body.error).toEqual( - badRequestError(`Path: body.temperature - ${ERR_MSG.TEMPERATURE}`) - ); + await expectInvalidResponses({ + stream, + message: `Path: body.temperature - ${ERR_MSG.TEMPERATURE}`, + }); }); - it("Should return 400 if messages contain an invalid role", async () => { - const response = await makeCreateResponseRequest({ + it("Should return error responses if messages contain an invalid role", async () => { + const stream = await makeClientAndRequest({ input: [ { role: "user", content: "What is MongoDB?" }, - { role: "invalid-role" as any, content: "This is an invalid role." }, + { + role: "invalid-role" as any, + content: "This is an invalid role.", + }, ], }); - expect(response.statusCode).toBe(400); - expect(response.body.error).toEqual( - badRequestError("Path: body.input - Invalid input") - ); + await expectInvalidResponses({ + stream, + message: "Path: body.input - Invalid input", + }); }); - it("Should return 400 if function_call has an invalid status", async () => { - const response = await makeCreateResponseRequest({ + it("Should return error responses if function_call has an invalid status", async () => { + const stream = await makeClientAndRequest({ input: [ { type: "function_call", - id: "call123", + call_id: "call123", name: "my_function", arguments: `{"query": "value"}`, status: "invalid_status" as any, @@ -527,14 +530,14 @@ describe("POST /responses", () => { ], }); - expect(response.statusCode).toBe(400); - expect(response.body.error).toEqual( - badRequestError("Path: body.input - Invalid input") - ); + await expectInvalidResponses({ + stream, + message: "Path: body.input - Invalid input", + }); }); - it("Should return 400 if function_call_output has an invalid status", async () => { - const response = await makeCreateResponseRequest({ + it("Should return error responses if function_call_output has an invalid status", async () => { + const stream = await makeClientAndRequest({ input: [ { type: "function_call_output", @@ -545,196 +548,328 @@ describe("POST /responses", () => { ], }); - expect(response.statusCode).toBe(400); - expect(response.body.error).toEqual( - badRequestError("Path: body.input - Invalid input") - ); + await expectInvalidResponses({ + stream, + message: "Path: body.input - Invalid input", + }); }); - it("Should return 400 with an invalid tool_choice string", async () => { - const response = await makeCreateResponseRequest({ + it("Should return error responses with an invalid tool_choice string", async () => { + const stream = await makeClientAndRequest({ tool_choice: "invalid_choice" as any, }); - expect(response.statusCode).toBe(400); - expect(response.body.error).toEqual( - badRequestError("Path: body.tool_choice - Invalid input") - ); + await expectInvalidResponses({ + stream, + message: "Path: body.tool_choice - Invalid input", + }); }); - it("Should return 400 if max_output_tokens is negative", async () => { - const response = await makeCreateResponseRequest({ + it("Should return error responses if max_output_tokens is negative", async () => { + const stream = await makeClientAndRequest({ max_output_tokens: -1, }); - expect(response.statusCode).toBe(400); - expect(response.body.error).toEqual( - badRequestError( - "Path: body.max_output_tokens - Number must be greater than or equal to 0" - ) - ); + await expectInvalidResponses({ + stream, + message: + "Path: body.max_output_tokens - Number must be greater than or equal to 0", + }); }); - it("Should return 400 if previous_response_id is not a valid ObjectId", async () => { - const messageId = "some-id"; - - const response = await makeCreateResponseRequest({ - previous_response_id: messageId, + it("Should return error responses if previous_response_id is not a valid ObjectId", async () => { + const previous_response_id = "some-id"; + const stream = await makeClientAndRequest({ + previous_response_id, }); - expect(response.statusCode).toBe(400); - expect(response.body.error).toEqual( - badRequestError(ERR_MSG.INVALID_OBJECT_ID(messageId)) - ); + await expectInvalidResponses({ + stream, + message: ERR_MSG.INVALID_OBJECT_ID(previous_response_id), + }); }); - it("Should return 400 if previous_response_id is not found", async () => { - const messageId = "123456789012123456789012"; - - const response = await makeCreateResponseRequest({ - previous_response_id: messageId, + it("Should return error responses if previous_response_id is not found", async () => { + const previous_response_id = "123456789012123456789012"; + const stream = await makeClientAndRequest({ + previous_response_id, }); - expect(response.statusCode).toBe(400); - expect(response.body.error).toEqual( - badRequestError(ERR_MSG.MESSAGE_NOT_FOUND(messageId)) - ); + await expectInvalidResponses({ + stream, + message: ERR_MSG.MESSAGE_NOT_FOUND(previous_response_id), + }); }); - it("Should return 400 if previous_response_id is not the latest message", async () => { - const conversation = - await appConfig.conversationsRouterConfig.conversations.create({ - initialMessages: [ - { role: "user", content: "What is MongoDB?" }, - { role: "assistant", content: "MongoDB is a document database." }, - { role: "user", content: "What is a document database?" }, - ], - }); + it("Should return error responses if previous_response_id is not the latest message", async () => { + const initialMessages: Array = [ + { role: "user", content: "Initial message!" }, + { role: "assistant", content: "Initial response!" }, + { role: "user", content: "Another message!" }, + ]; + const { messages } = await conversations.create({ initialMessages }); - const previousResponseId = conversation.messages[0].id; - const response = await makeCreateResponseRequest({ - previous_response_id: previousResponseId.toString(), + const previous_response_id = messages[0].id.toString(); + const stream = await makeClientAndRequest({ + previous_response_id, }); - expect(response.statusCode).toBe(400); - expect(response.body.error).toEqual( - badRequestError( - ERR_MSG.MESSAGE_NOT_LATEST(previousResponseId.toString()) - ) - ); + await expectInvalidResponses({ + stream, + message: ERR_MSG.MESSAGE_NOT_LATEST(previous_response_id), + }); }); - it("Should return 400 if there are too many messages in the conversation", async () => { - const maxUserMessagesInConversation = 0; - const newApp = await makeTestApp({ - responsesRouterConfig: { - ...appConfig.responsesRouterConfig, - createResponse: { - ...appConfig.responsesRouterConfig.createResponse, - maxUserMessagesInConversation, - }, - }, + it("Should return error responses if there are too many messages in the conversation", async () => { + const { maxUserMessagesInConversation } = + appConfig.responsesRouterConfig.createResponse; + + const initialMessages = Array(maxUserMessagesInConversation).fill({ + role: "user", + content: "Initial message!", }); + const { messages } = await conversations.create({ initialMessages }); - const response = await makeCreateResponseRequest({}, newApp.app); + const previous_response_id = messages.at(-1)?.id.toString(); + const stream = await makeClientAndRequest({ + previous_response_id, + }); - expect(response.statusCode).toBe(400); - expect(response.body.error).toEqual( - badRequestError( - ERR_MSG.TOO_MANY_MESSAGES(maxUserMessagesInConversation) - ) - ); + await expectInvalidResponses({ + stream, + message: ERR_MSG.TOO_MANY_MESSAGES(maxUserMessagesInConversation), + }); }); - }); - it("Should return 400 if user id has changed since the conversation was created", async () => { - const userId1 = "user1"; - const userId2 = "user2"; - const conversation = - await appConfig.conversationsRouterConfig.conversations.create({ - userId: userId1, - initialMessages: [{ role: "user", content: "What is MongoDB?" }], + it("Should return error responses if user id has changed since the conversation was created", async () => { + const userId = "user1"; + const badUserId = "user2"; + + const initialMessages: Array = [ + { role: "user", content: "Initial message!" }, + ]; + const { messages } = await conversations.create({ + userId, + initialMessages, + }); + + const previous_response_id = messages.at(-1)?.id.toString(); + const stream = await makeClientAndRequest({ + previous_response_id, + user: badUserId, }); - const previousResponseId = conversation.messages[0].id.toString(); - const response = await makeCreateResponseRequest({ - previous_response_id: previousResponseId, - user: userId2, + await expectInvalidResponses({ + stream, + message: ERR_MSG.CONVERSATION_USER_ID_CHANGED, + }); }); - expect(response.statusCode).toBe(400); - expect(response.body.error).toEqual( - badRequestError(ERR_MSG.CONVERSATION_USER_ID_CHANGED) - ); - }); + it("Should return error responses if `store: false` and `previous_response_id` is provided", async () => { + const stream = await makeClientAndRequest({ + previous_response_id: "123456789012123456789012", + store: false, + }); - it("Should return 400 if `store: false` and `previous_response_id` is provided", async () => { - const response = await makeCreateResponseRequest({ - previous_response_id: "123456789012123456789012", - store: false, + await expectInvalidResponses({ + stream, + message: ERR_MSG.STORE_NOT_SUPPORTED, + }); }); - expect(response.statusCode).toBe(400); - expect(response.body.error).toEqual( - badRequestError(ERR_MSG.STORE_NOT_SUPPORTED) - ); - }); - - it("Should return 400 if `store: true` and `storeMessageContent: false`", async () => { - const conversation = - await appConfig.conversationsRouterConfig.conversations.create({ + it("Should return error responses if `store: true` and `storeMessageContent: false`", async () => { + const initialMessages: Array = [ + { role: "user", content: "Initial message!" }, + ]; + const { messages } = await conversations.create({ storeMessageContent: false, - initialMessages: [{ role: "user", content: "" }], + initialMessages, }); - const previousResponseId = conversation.messages[0].id.toString(); - const response = await makeCreateResponseRequest({ - previous_response_id: previousResponseId, - store: true, - }); + const previous_response_id = messages.at(-1)?.id.toString(); + const stream = await makeClientAndRequest({ + previous_response_id, + store: true, + }); - expect(response.statusCode).toBe(400); - expect(response.body.error).toEqual( - badRequestError(ERR_MSG.CONVERSATION_STORE_MISMATCH) - ); + await expectInvalidResponses({ + stream, + message: ERR_MSG.CONVERSATION_STORE_MISMATCH, + }); + }); }); }); // --- HELPERS --- -const badRequestError = (message: string) => ({ - type: ERROR_TYPE, - code: ERROR_CODE.INVALID_REQUEST_ERROR, +const getMessageIdFromResults = (results?: Array) => { + if (!results?.length) throw new Error("No results found"); + + const messageId = results.at(-1)?.response?.id; + + if (typeof messageId !== "string") throw new Error("Message ID not found"); + + return new ObjectId(messageId); +}; + +interface ExpectInvalidResponsesParams { + stream: Stream; + message: string; +} + +const expectInvalidResponses = async ({ + stream, message, -}); +}: ExpectInvalidResponsesParams) => { + const responses: any[] = []; + try { + for await (const event of stream) { + responses.push(event); + } + + fail("expected error"); + } catch (err: any) { + expect(err.type).toBe(ERROR_TYPE); + expect(err.code).toBe(ERROR_CODE.INVALID_REQUEST_ERROR); + expect(err.error.type).toBe(ERROR_TYPE); + expect(err.error.code).toBe(ERROR_CODE.INVALID_REQUEST_ERROR); + expect(err.error.message).toBe(message); + } + + expect(Array.isArray(responses)).toBe(true); + expect(responses.length).toBe(0); +}; -interface TestDefaultMessageContentParams { - createdConversation: Conversation; - addedMessages: SomeMessage[]; +interface ExpectValidResponsesParams { + stream: Stream; + requestBody: Partial; +} + +const expectValidResponses = async ({ + stream, + requestBody, +}: ExpectValidResponsesParams) => { + const responses: any[] = []; + for await (const event of stream) { + responses.push(event); + } + + expect(Array.isArray(responses)).toBe(true); + expect(responses.length).toBe(3); + + expect(responses[0].type).toBe("response.created"); + expect(responses[1].type).toBe("response.in_progress"); + expect(responses[2].type).toBe("response.completed"); + + responses.forEach(({ response, sequence_number }, index) => { + // basic response properties + expect(sequence_number).toBe(index); + expect(typeof response.id).toBe("string"); + expect(typeof response.created_at).toBe("number"); + expect(response.object).toBe("response"); + expect(response.error).toBeNull(); + expect(response.incomplete_details).toBeNull(); + expect(response.model).toBe("mongodb-chat-latest"); + expect(response.output_text).toBe(""); + expect(response.output).toEqual([]); + expect(response.parallel_tool_calls).toBe(true); + expect(response.temperature).toBe(0); + expect(response.stream).toBe(true); + expect(response.top_p).toBeNull(); + + // conditional upon request body properties + if (requestBody.instructions) { + expect(response.instructions).toBe(requestBody.instructions); + } else { + expect(response.instructions).toBeNull(); + } + if (requestBody.max_output_tokens) { + expect(response.max_output_tokens).toBe(requestBody.max_output_tokens); + } else { + expect(response.max_output_tokens).toBe(1000); + } + if (requestBody.previous_response_id) { + expect(response.previous_response_id).toBe( + requestBody.previous_response_id + ); + } else { + expect(response.previous_response_id).toBeNull(); + } + if (typeof requestBody.store === "boolean") { + expect(response.store).toBe(requestBody.store); + } else { + expect(response.store).toBe(true); + } + if (requestBody.tool_choice) { + expect(response.tool_choice).toEqual(requestBody.tool_choice); + } else { + expect(response.tool_choice).toBe("auto"); + } + if (requestBody.tools) { + expect(response.tools).toEqual(requestBody.tools); + } else { + expect(response.tools).toEqual([]); + } + if (requestBody.user) { + expect(response.user).toBe(requestBody.user); + } else { + expect(response.user).toBeUndefined(); + } + if (requestBody.metadata) { + expect(response.metadata).toEqual(requestBody.metadata); + } else { + expect(response.metadata).toBeNull(); + } + }); + + return responses; +}; + +interface ExpectDefaultMessageContentParams { + initialMessages?: Array; + updatedConversation: Conversation; store: boolean; userId?: string; - metadata?: Record; + metadata?: Record | null; } -const testDefaultMessageContent = ({ - createdConversation, - addedMessages, +const expectDefaultMessageContent = ({ + initialMessages, + updatedConversation, store, userId, - metadata, -}: TestDefaultMessageContentParams) => { - expect(createdConversation.userId).toEqual(userId); - - expect(addedMessages[0].role).toBe("user"); - expect(addedMessages[1].role).toEqual("user"); - expect(addedMessages[2].role).toEqual("assistant"); - - expect(addedMessages[0].content).toBe(store ? "What is MongoDB?" : ""); - expect(addedMessages[1].content).toBeFalsy(); - expect(addedMessages[2].content).toEqual(store ? "some content" : ""); - - expect(addedMessages[0].metadata).toEqual(metadata); - expect(addedMessages[1].metadata).toEqual(metadata); - expect(addedMessages[2].metadata).toEqual(metadata); - if (metadata) expect(createdConversation.customData).toEqual({ metadata }); + metadata = null, +}: ExpectDefaultMessageContentParams) => { + expect(updatedConversation.userId).toEqual(userId); + if (metadata) expect(updatedConversation.customData).toEqual({ metadata }); + + const defaultMessagesLength = 3; + const initialMessagesLength = initialMessages?.length ?? 0; + const totalMessagesLength = defaultMessagesLength + initialMessagesLength; + + const { messages } = updatedConversation; + expect(messages.length).toEqual(totalMessagesLength); + + initialMessages?.forEach((initialMessage, index) => { + expect(messages[index].role).toEqual(initialMessage.role); + expect(messages[index].content).toEqual(initialMessage.content); + expect(messages[index].metadata).toEqual(initialMessage.metadata); + expect(messages[index].customData).toEqual(initialMessage.customData); + }); + + const firstMessage = messages[initialMessagesLength]; + const secondMessage = messages[initialMessagesLength + 1]; + const thirdMessage = messages[initialMessagesLength + 2]; + + expect(firstMessage.role).toBe("user"); + expect(firstMessage.content).toBe(store ? "What is MongoDB?" : ""); + expect(firstMessage.metadata).toEqual(metadata); + + expect(secondMessage.role).toEqual("user"); + expect(secondMessage.content).toBeFalsy(); + expect(secondMessage.metadata).toEqual(metadata); + + expect(thirdMessage.role).toEqual("assistant"); + expect(thirdMessage.content).toEqual(store ? "some content" : ""); + expect(thirdMessage.metadata).toEqual(metadata); }; diff --git a/packages/mongodb-chatbot-server/src/routes/responses/createResponse.ts b/packages/mongodb-chatbot-server/src/routes/responses/createResponse.ts index aa0fec8c5..7d2654c43 100644 --- a/packages/mongodb-chatbot-server/src/routes/responses/createResponse.ts +++ b/packages/mongodb-chatbot-server/src/routes/responses/createResponse.ts @@ -4,11 +4,11 @@ import type { Response as ExpressResponse, } from "express"; import { ObjectId } from "mongodb"; -import type { APIError } from "mongodb-rag-core/openai"; -import type { - ConversationsService, - Conversation, - SomeMessage, +import type { OpenAI } from "mongodb-rag-core/openai"; +import { + type ConversationsService, + type Conversation, + makeDataStreamer, } from "mongodb-rag-core"; import { SomeExpressRequest } from "../../middleware"; import { getRequestId } from "../../utils"; @@ -19,8 +19,22 @@ import { generateZodErrorMessage, sendErrorResponse, ERROR_TYPE, + type SomeOpenAIAPIError, } from "./errors"; +type StreamCreatedMessage = Omit< + OpenAI.Responses.ResponseCreatedEvent, + "sequence_number" +>; +type StreamInProgressMessage = Omit< + OpenAI.Responses.ResponseInProgressEvent, + "sequence_number" +>; +type StreamCompletedMessage = Omit< + OpenAI.Responses.ResponseCompletedEvent, + "sequence_number" +>; + export const ERR_MSG = { INPUT_STRING: "Input must be a non-empty string", INPUT_ARRAY: @@ -64,10 +78,7 @@ const CreateResponseRequestBodySchema = z.object({ // function tool call z.object({ type: z.literal("function_call"), - id: z - .string() - .optional() - .describe("Unique ID of the function tool call"), + call_id: z.string().describe("Unique ID of the function tool call"), name: z.string().describe("Name of the function tool to call"), arguments: z .string() @@ -123,11 +134,11 @@ const CreateResponseRequestBodySchema = z.object({ .default(0), tool_choice: z .union([ - z.enum(["none", "only", "auto"]), + z.enum(["none", "auto", "required"]), z .object({ - name: z.string(), type: z.literal("function"), + name: z.string(), }) .describe("Function tool choice"), ]) @@ -137,6 +148,8 @@ const CreateResponseRequestBodySchema = z.object({ tools: z .array( z.object({ + type: z.literal("function"), + strict: z.boolean(), name: z.string(), description: z.string().optional(), parameters: z @@ -184,8 +197,11 @@ export function makeCreateResponseRoute({ ) => { const reqId = getRequestId(req); const headers = req.headers as Record; + const dataStreamer = makeDataStreamer(); try { + dataStreamer.connect(res); + // --- INPUT VALIDATION --- const { error, data } = CreateResponseRequestSchema.safeParse(req); if (error) { @@ -266,6 +282,31 @@ export function makeCreateResponseRoute({ }); } + // generate responseId to use in conversation DB AND Responses API stream + const responseId = new ObjectId(); + const baseResponse = makeBaseResponseData({ + responseId, + data: data.body, + }); + + const createdMessage: StreamCreatedMessage = { + type: "response.created", + response: { + ...baseResponse, + created_at: Date.now(), + }, + }; + dataStreamer.streamResponses(createdMessage); + + const inProgressMessage: StreamInProgressMessage = { + type: "response.in_progress", + response: { + ...baseResponse, + created_at: Date.now(), + }, + }; + dataStreamer.streamResponses(inProgressMessage); + // TODO: actually implement this call const { messages } = await generateResponse({} as any); @@ -277,20 +318,39 @@ export function makeCreateResponseRoute({ metadata, input, messages, + responseId, }); - return res.status(200).send({ status: "ok" }); + const completedMessage: StreamCompletedMessage = { + type: "response.completed", + response: { + ...baseResponse, + created_at: Date.now(), + }, + }; + dataStreamer.streamResponses(completedMessage); } catch (error) { const standardError = - (error as APIError)?.type === ERROR_TYPE - ? (error as APIError) + (error as SomeOpenAIAPIError)?.type === ERROR_TYPE + ? (error as SomeOpenAIAPIError) : makeInternalServerError({ error: error as Error, headers }); - sendErrorResponse({ - res, - reqId, - error: standardError, - }); + if (dataStreamer.connected) { + dataStreamer.streamResponses({ + ...standardError, + type: ERROR_TYPE, + }); + } else { + sendErrorResponse({ + res, + reqId, + error: standardError, + }); + } + } finally { + if (dataStreamer.connected) { + dataStreamer.disconnect(); + } } }; } @@ -385,13 +445,18 @@ const hasConversationUserIdChanged = ( return conversation.userId !== userId; }; +type MessagesParam = Parameters< + ConversationsService["addManyConversationMessages"] +>[0]["messages"]; + interface AddMessagesToConversationParams { conversations: ConversationsService; conversation: Conversation; store: boolean; metadata?: Record; input: CreateResponseRequest["body"]["input"]; - messages: Array; + messages: MessagesParam; + responseId: ObjectId; } const saveMessagesToConversation = async ({ @@ -401,13 +466,19 @@ const saveMessagesToConversation = async ({ metadata, input, messages, + responseId, }: AddMessagesToConversationParams) => { const messagesToAdd = [ ...convertInputToDBMessages(input, store, metadata), ...messages.map((message) => formatMessage(message, store, metadata)), ]; + // handle setting the response id for the last message + // this corresponds to the response id in the response stream + if (messagesToAdd.length > 0) { + messagesToAdd[messagesToAdd.length - 1].id = responseId; + } - await conversations.addManyConversationMessages({ + return await conversations.addManyConversationMessages({ conversationId: conversation._id, messages: messagesToAdd, }); @@ -417,7 +488,7 @@ const convertInputToDBMessages = ( input: CreateResponseRequest["body"]["input"], store: boolean, metadata?: Record -): Array => { +): MessagesParam => { if (typeof input === "string") { return [formatMessage({ role: "user", content: input }, store, metadata)]; } @@ -433,14 +504,52 @@ const convertInputToDBMessages = ( }; const formatMessage = ( - message: SomeMessage, + message: MessagesParam[number], store: boolean, metadata?: Record -): SomeMessage => { +): MessagesParam[number] => { + // store a placeholder string if we're not storing message data + const content = store ? message.content : ""; + // handle cleaning custom data if we're not storing message data + const customData = { + ...message.customData, + query: store ? message.customData?.query : "", + reason: store ? message.customData?.reason : "", + }; + return { ...message, - // store a placeholder string if we're not storing message data - content: store ? message.content : "", + content, metadata, + customData, + }; +}; + +interface BaseResponseData { + responseId: ObjectId; + data: CreateResponseRequest["body"]; +} + +const makeBaseResponseData = ({ responseId, data }: BaseResponseData) => { + return { + id: responseId.toString(), + object: "response" as const, + error: null, + incomplete_details: null, + instructions: data.instructions ?? null, + max_output_tokens: data.max_output_tokens ?? null, + model: data.model, + output_text: "", + output: [], + parallel_tool_calls: true, + previous_response_id: data.previous_response_id ?? null, + store: data.store, + temperature: data.temperature, + stream: data.stream, + tool_choice: data.tool_choice, + tools: data.tools ?? [], + top_p: null, + user: data.user, + metadata: data.metadata ?? null, }; }; diff --git a/packages/mongodb-chatbot-server/src/routes/responses/errors.ts b/packages/mongodb-chatbot-server/src/routes/responses/errors.ts index f5b6822e9..e4fd783c4 100644 --- a/packages/mongodb-chatbot-server/src/routes/responses/errors.ts +++ b/packages/mongodb-chatbot-server/src/routes/responses/errors.ts @@ -1,5 +1,5 @@ import { - APIError, + type APIError, BadRequestError, InternalServerError, NotFoundError, @@ -43,6 +43,13 @@ export enum ERROR_CODE { } // --- OPENAI ERROR WRAPPERS --- +export type SomeOpenAIAPIError = + | APIError + | BadRequestError + | NotFoundError + | RateLimitError + | InternalServerError; + interface MakeOpenAIErrorParams { error: Error; headers: Record; @@ -51,7 +58,7 @@ interface MakeOpenAIErrorParams { export const makeInternalServerError = ({ error, headers, -}: MakeOpenAIErrorParams): APIError => { +}: MakeOpenAIErrorParams) => { const message = error.message ?? "Internal server error"; const _error = { ...error, @@ -65,7 +72,7 @@ export const makeInternalServerError = ({ export const makeBadRequestError = ({ error, headers, -}: MakeOpenAIErrorParams): APIError => { +}: MakeOpenAIErrorParams) => { const message = error.message ?? "Bad request"; const _error = { ...error, @@ -79,7 +86,7 @@ export const makeBadRequestError = ({ export const makeNotFoundError = ({ error, headers, -}: MakeOpenAIErrorParams): APIError => { +}: MakeOpenAIErrorParams) => { const message = error.message ?? "Not found"; const _error = { ...error, @@ -93,7 +100,7 @@ export const makeNotFoundError = ({ export const makeRateLimitError = ({ error, headers, -}: MakeOpenAIErrorParams): APIError => { +}: MakeOpenAIErrorParams) => { const message = error.message ?? "Rate limit exceeded"; const _error = { ...error, diff --git a/packages/mongodb-chatbot-server/src/routes/responses/responsesRouter.test.ts b/packages/mongodb-chatbot-server/src/routes/responses/responsesRouter.test.ts index 9bfb9b29e..38c735e8b 100644 --- a/packages/mongodb-chatbot-server/src/routes/responses/responsesRouter.test.ts +++ b/packages/mongodb-chatbot-server/src/routes/responses/responsesRouter.test.ts @@ -1,131 +1,189 @@ -import type { Express } from "express"; -import request from "supertest"; -import { AppConfig } from "../../app"; -import { DEFAULT_API_PREFIX } from "../../app"; -import { makeTestApp } from "../../test/testHelpers"; -import { makeTestAppConfig } from "../../test/testHelpers"; -import { basicResponsesRequestBody } from "../../test/testConfig"; -import { ERROR_TYPE, ERROR_CODE, makeBadRequestError } from "./errors"; -import { CreateResponseRequest } from "./createResponse"; +import type { Server } from "http"; +import { + makeTestLocalServer, + makeOpenAiClient, + makeCreateResponseRequestStream, + type Stream, +} from "../../test/testHelpers"; +import { makeDefaultConfig } from "../../test/testConfig"; +import { + ERROR_CODE, + ERROR_TYPE, + makeBadRequestError, + type SomeOpenAIAPIError, +} from "./errors"; jest.setTimeout(60000); describe("Responses Router", () => { - const ipAddress = "127.0.0.1"; - const responsesEndpoint = DEFAULT_API_PREFIX + "/responses"; - let appConfig: AppConfig; - - beforeAll(async () => { - ({ appConfig } = await makeTestAppConfig()); + let server: Server; + let ipAddress: string; + let origin: string; + + afterEach(async () => { + if (server?.listening) { + await new Promise((resolve) => { + server.close(() => resolve()); + }); + } + jest.clearAllMocks(); }); - const makeCreateResponseRequest = ( - app: Express, - origin: string, - body?: Partial - ) => { - return request(app) - .post(responsesEndpoint) - .set("X-Forwarded-For", ipAddress) - .set("Origin", origin) - .send({ ...basicResponsesRequestBody, ...body }); - }; - - it("should return 200 given a valid request", async () => { - const { app, origin } = await makeTestApp(appConfig); + it("should return responses given a valid request", async () => { + ({ server, ipAddress, origin } = await makeTestLocalServer()); - const res = await makeCreateResponseRequest(app, origin); + const openAiClient = makeOpenAiClient(origin, ipAddress); + const stream = await makeCreateResponseRequestStream(openAiClient); - expect(res.status).toBe(200); + await expectValidResponses({ stream }); }); - it("should return 500 when handling an unknown error", async () => { + it("should return an OpenAI error when handling an unknown error", async () => { const errorMessage = "Unknown error"; - const { app, origin } = await makeTestApp({ - ...appConfig, - responsesRouterConfig: { - ...appConfig.responsesRouterConfig, - createResponse: { - ...appConfig.responsesRouterConfig.createResponse, - generateResponse: () => Promise.reject(new Error(errorMessage)), - }, - }, - }); - const res = await makeCreateResponseRequest(app, origin); + const appConfig = await makeDefaultConfig(); + appConfig.responsesRouterConfig.createResponse.generateResponse = () => { + throw new Error(errorMessage); + }; - expect(res.status).toBe(500); - expect(res.body.type).toBe(ERROR_TYPE); - expect(res.body.code).toBe(ERROR_CODE.SERVER_ERROR); - expect(res.body.error).toEqual({ - type: ERROR_TYPE, - code: ERROR_CODE.SERVER_ERROR, - message: errorMessage, - }); - }); + ({ server, ipAddress, origin } = await makeTestLocalServer(appConfig)); - it("should return the openai error when service throws an openai error", async () => { - const errorMessage = "Bad request input"; - const { app, origin } = await makeTestApp({ - ...appConfig, - responsesRouterConfig: { - ...appConfig.responsesRouterConfig, - createResponse: { - ...appConfig.responsesRouterConfig.createResponse, - generateResponse: () => - Promise.reject( - makeBadRequestError({ - error: new Error(errorMessage), - headers: {}, - }) - ), - }, + const openAiClient = makeOpenAiClient(origin, ipAddress); + const stream = await makeCreateResponseRequestStream(openAiClient); + + await expectInvalidResponses({ + stream, + error: { + type: ERROR_TYPE, + code: ERROR_CODE.SERVER_ERROR, + message: errorMessage, }, }); + }); - const res = await makeCreateResponseRequest(app, origin); + it("should return the OpenAI error when service throws an OpenAI error", async () => { + const errorMessage = "Bad request input"; - expect(res.status).toBe(400); - expect(res.body.type).toBe(ERROR_TYPE); - expect(res.body.code).toBe(ERROR_CODE.INVALID_REQUEST_ERROR); - expect(res.body.error).toEqual({ - type: ERROR_TYPE, - code: ERROR_CODE.INVALID_REQUEST_ERROR, - message: errorMessage, + const appConfig = await makeDefaultConfig(); + appConfig.responsesRouterConfig.createResponse.generateResponse = () => + Promise.reject( + makeBadRequestError({ + error: new Error(errorMessage), + headers: {}, + }) + ); + + ({ server, ipAddress, origin } = await makeTestLocalServer(appConfig)); + + const openAiClient = makeOpenAiClient(origin, ipAddress); + const stream = await makeCreateResponseRequestStream(openAiClient); + + await expectInvalidResponses({ + stream, + error: { + type: ERROR_TYPE, + code: ERROR_CODE.INVALID_REQUEST_ERROR, + message: errorMessage, + }, }); }); - test("Should apply responses router rate limit and return an openai error", async () => { + it("Should return an OpenAI error when rate limit is hit", async () => { const rateLimitErrorMessage = "Error: rate limit exceeded!"; - const { app, origin } = await makeTestApp({ - responsesRouterConfig: { - rateLimitConfig: { - routerRateLimitConfig: { - windowMs: 50000, // Big window to cover test duration - max: 1, // Only one request should be allowed - message: rateLimitErrorMessage, - }, - }, + const appConfig = await makeDefaultConfig(); + appConfig.responsesRouterConfig.rateLimitConfig = { + routerRateLimitConfig: { + windowMs: 500000, // Big window to cover test duration + max: 1, // Only one request should be allowed + message: rateLimitErrorMessage, }, - }); + }; - const successRes = await makeCreateResponseRequest(app, origin); - const rateLimitedRes = await makeCreateResponseRequest(app, origin); + ({ server, ipAddress, origin } = await makeTestLocalServer(appConfig)); - expect(successRes.status).toBe(200); - expect(successRes.error).toBeFalsy(); + const openAiClient = makeOpenAiClient(origin, ipAddress); + const stream = await makeCreateResponseRequestStream(openAiClient); - expect(rateLimitedRes.status).toBe(429); - expect(rateLimitedRes.error).toBeTruthy(); - expect(rateLimitedRes.body.type).toBe(ERROR_TYPE); - expect(rateLimitedRes.body.code).toBe(ERROR_CODE.RATE_LIMIT_ERROR); - expect(rateLimitedRes.body.error).toEqual({ - type: ERROR_TYPE, - code: ERROR_CODE.RATE_LIMIT_ERROR, - message: rateLimitErrorMessage, - }); - expect(rateLimitedRes.body.headers["x-forwarded-for"]).toBe(ipAddress); - expect(rateLimitedRes.body.headers["origin"]).toBe(origin); + try { + await makeCreateResponseRequestStream(openAiClient); + + fail("expected rate limit error"); + } catch (error) { + expect((error as SomeOpenAIAPIError).status).toBe(429); + expect((error as SomeOpenAIAPIError).error).toEqual({ + type: ERROR_TYPE, + code: ERROR_CODE.RATE_LIMIT_ERROR, + message: rateLimitErrorMessage, + }); + } + + await expectValidResponses({ stream }); }); }); + +// --- HELPERS --- + +interface ExpectValidResponsesParams { + stream: Stream; +} + +const expectValidResponses = async ({ stream }: ExpectValidResponsesParams) => { + const responses: any[] = []; + for await (const event of stream) { + responses.push(event); + } + + expect(Array.isArray(responses)).toBe(true); + expect(responses.length).toBe(3); + + expect(responses[0].type).toBe("response.created"); + expect(responses[1].type).toBe("response.in_progress"); + expect(responses[2].type).toBe("response.completed"); + + responses.forEach(({ sequence_number, response }, index) => { + expect(sequence_number).toBe(index); + expect(typeof response.id).toBe("string"); + expect(response.object).toBe("response"); + expect(response.error).toBeNull(); + expect(response.model).toBe("mongodb-chat-latest"); + }); +}; + +interface ExpectInvalidResponsesParams { + stream: Stream; + error: { + type: string; + code: string; + message: string; + }; +} + +const expectInvalidResponses = async ({ + stream, + error, +}: ExpectInvalidResponsesParams) => { + const responses: any[] = []; + try { + for await (const event of stream) { + responses.push(event); + } + + fail("expected error"); + } catch (err: any) { + expect(err.type).toBe(error.type); + expect(err.code).toBe(error.code); + expect(err.error.type).toBe(error.type); + expect(err.error.code).toBe(error.code); + expect(err.error.message).toBe(error.message); + } + + expect(Array.isArray(responses)).toBe(true); + expect(responses.length).toBe(2); + + expect(responses[0].type).toBe("response.created"); + expect(responses[1].type).toBe("response.in_progress"); + + expect(responses[0].sequence_number).toBe(0); + expect(responses[1].sequence_number).toBe(1); +}; diff --git a/packages/mongodb-chatbot-server/src/test/testConfig.ts b/packages/mongodb-chatbot-server/src/test/testConfig.ts index 42064bcb8..b826cf63f 100644 --- a/packages/mongodb-chatbot-server/src/test/testConfig.ts +++ b/packages/mongodb-chatbot-server/src/test/testConfig.ts @@ -175,7 +175,6 @@ export const MONGO_CHAT_MODEL = "mongodb-chat-latest"; export const basicResponsesRequestBody = { model: MONGO_CHAT_MODEL, - stream: true, input: "What is MongoDB?", }; diff --git a/packages/mongodb-chatbot-server/src/test/testHelpers.ts b/packages/mongodb-chatbot-server/src/test/testHelpers.ts index 5f68b1aff..156a9af43 100644 --- a/packages/mongodb-chatbot-server/src/test/testHelpers.ts +++ b/packages/mongodb-chatbot-server/src/test/testHelpers.ts @@ -1,6 +1,13 @@ import { strict as assert } from "assert"; -import { AppConfig, makeApp } from "../app"; -import { makeDefaultConfig, memoryDb, systemPrompt } from "./testConfig"; +import { OpenAI } from "mongodb-rag-core/openai"; +import { AppConfig, DEFAULT_API_PREFIX, makeApp } from "../app"; +import { + makeDefaultConfig, + memoryDb, + systemPrompt, + basicResponsesRequestBody, +} from "./testConfig"; +import type { CreateResponseRequest } from "../routes/responses/createResponse"; export async function makeTestAppConfig( defaultConfigOverrides?: PartialAppConfig @@ -33,9 +40,11 @@ export type PartialAppConfig = Omit< > & { conversationsRouterConfig?: Partial; responsesRouterConfig?: Partial; + port?: number; }; -export const TEST_ORIGIN = "http://localhost:5173"; +export const TEST_PORT = 5173; +export const TEST_ORIGIN = `http://localhost:`; /** Helper function to quickly make an app for testing purposes. Can't be called @@ -45,7 +54,7 @@ export const TEST_ORIGIN = "http://localhost:5173"; export async function makeTestApp(defaultConfigOverrides?: PartialAppConfig) { // ip address for local host const ipAddress = "127.0.0.1"; - const origin = TEST_ORIGIN; + const origin = TEST_ORIGIN + (defaultConfigOverrides?.port ?? TEST_PORT); const { appConfig, systemPrompt, mongodb } = await makeTestAppConfig( defaultConfigOverrides @@ -63,6 +72,53 @@ export async function makeTestApp(defaultConfigOverrides?: PartialAppConfig) { }; } +export const TEST_OPENAI_API_KEY = "test-api-key"; + +/** + Helper function to quickly make a local server for testing purposes. + Builds on the other helpers for app/config stuff. + @param defaultConfigOverrides - optional overrides for default app config + */ +export const makeTestLocalServer = async ( + defaultConfigOverrides?: PartialAppConfig, + port?: number +) => { + const testAppResult = await makeTestApp({ + ...defaultConfigOverrides, + port, + }); + + const server = testAppResult.app.listen(port ?? TEST_PORT); + + return { ...testAppResult, server }; +}; + +export const makeOpenAiClient = (origin: string, ipAddress: string) => { + return new OpenAI({ + baseURL: origin + DEFAULT_API_PREFIX, + apiKey: TEST_OPENAI_API_KEY, + defaultHeaders: { + Origin: origin, + "X-Forwarded-For": ipAddress, + }, + }); +}; + +export type Stream = Awaited< + ReturnType +>; + +export const makeCreateResponseRequestStream = ( + openAiClient: OpenAI, + body?: Omit, "stream"> +) => { + return openAiClient.responses.create({ + ...basicResponsesRequestBody, + ...body, + stream: true, + }); +}; + /** Create a URL to represent a client-side route on the test origin. @param path - path to append to the origin base URL. diff --git a/packages/mongodb-rag-core/package.json b/packages/mongodb-rag-core/package.json index 720bfbcba..b6892a1ff 100644 --- a/packages/mongodb-rag-core/package.json +++ b/packages/mongodb-rag-core/package.json @@ -101,7 +101,7 @@ "ignore": "^5.3.2", "langchain": "^0.3.5", "mongodb": "^6.3.0", - "openai": "^4.95.0", + "openai": "^5.9.1", "rimraf": "^6.0.1", "simple-git": "^3.27.0", "toml": "^3.0.0", diff --git a/packages/mongodb-rag-core/src/DataStreamer.test.ts b/packages/mongodb-rag-core/src/DataStreamer.test.ts index b38b97a3d..a661cdbd2 100644 --- a/packages/mongodb-rag-core/src/DataStreamer.test.ts +++ b/packages/mongodb-rag-core/src/DataStreamer.test.ts @@ -1,16 +1,23 @@ -import { DataStreamer, makeDataStreamer } from "./DataStreamer"; +import { + DataStreamer, + makeDataStreamer, + type ResponsesStreamParams, +} from "./DataStreamer"; import { OpenAI } from "openai"; import { createResponse } from "node-mocks-http"; import { EventEmitter } from "events"; import { Response } from "express"; -let res: ReturnType & Response; -const dataStreamer = makeDataStreamer(); describe("Data Streaming", () => { + let dataStreamer: DataStreamer; + let res: ReturnType & Response; + + beforeAll(() => { + dataStreamer = makeDataStreamer(); + }); + beforeEach(() => { - res = createResponse({ - eventEmitter: EventEmitter, - }); + res = createResponse({ eventEmitter: EventEmitter }); dataStreamer.connect(res); }); @@ -79,6 +86,30 @@ describe("Data Streaming", () => { `data: {"type":"delta","data":"Once upon"}\n\ndata: {"type":"delta","data":" a time there was a"}\n\ndata: {"type":"delta","data":" very long string."}\n\n` ); }); + + it("Streams Responses API events as valid SSE events to the client", () => { + dataStreamer.streamResponses({ + type: "response.created", + id: "test1", + } as ResponsesStreamParams); + dataStreamer.streamResponses({ + type: "response.in_progress", + id: "test2", + } as ResponsesStreamParams); + dataStreamer.streamResponses({ + type: "response.output_text.delta", + id: "test3", + } as ResponsesStreamParams); + dataStreamer.streamResponses({ + type: "response.completed", + id: "test4", + } as ResponsesStreamParams); + + const data = res._getData(); + expect(data).toBe( + `event: response.created\ndata: {"type":"response.created","id":"test1","sequence_number":0}\n\nevent: response.in_progress\ndata: {"type":"response.in_progress","id":"test2","sequence_number":1}\n\nevent: response.output_text.delta\ndata: {"type":"response.output_text.delta","id":"test3","sequence_number":2}\n\nevent: response.completed\ndata: {"type":"response.completed","id":"test4","sequence_number":3}\n\n` + ); + }); }); function createChatCompletionWithDelta( diff --git a/packages/mongodb-rag-core/src/DataStreamer.ts b/packages/mongodb-rag-core/src/DataStreamer.ts index 423e6ec21..12d56b2bf 100644 --- a/packages/mongodb-rag-core/src/DataStreamer.ts +++ b/packages/mongodb-rag-core/src/DataStreamer.ts @@ -16,6 +16,7 @@ interface ServerSentEventDispatcher { disconnect(): void; sendData(data: Data): void; sendEvent(eventType: string, data: Data): void; + sendResponsesEvent(data: OpenAI.Responses.ResponseStreamEvent): void; } type ServerSentEventData = object | string; @@ -43,6 +44,10 @@ function makeServerSentEventDispatcher< res.write(`event: ${eventType}\n`); res.write(`data: ${JSON.stringify(data)}\n\n`); }, + sendResponsesEvent(data) { + res.write(`event: ${data.type}\n`); + res.write(`data: ${JSON.stringify(data)}\n\n`); + }, }; } @@ -53,6 +58,10 @@ interface StreamParams { type StreamEvent = { type: string; data: unknown }; +export type ResponsesStreamParams = + | Omit + | Omit; + /** Event when server streams additional message response to the client. */ @@ -122,6 +131,7 @@ export interface DataStreamer { disconnect(): void; streamData(data: SomeStreamEvent): void; stream(params: StreamParams): Promise; + streamResponses(data: ResponsesStreamParams): void; } /** @@ -130,6 +140,7 @@ export interface DataStreamer { export function makeDataStreamer(): DataStreamer { let connected = false; let sse: ServerSentEventDispatcher | undefined; + let responseSequenceNumber = 0; return { get connected() { @@ -161,7 +172,7 @@ export function makeDataStreamer(): DataStreamer { /** Streams single item of data in an event stream. */ - streamData(data: SomeStreamEvent) { + streamData(data) { if (!this.connected) { throw new Error( `Tried to stream data, but there's no SSE connection. Call DataStreamer.connect() first.` @@ -173,7 +184,7 @@ export function makeDataStreamer(): DataStreamer { /** Streams all message events in an event stream. */ - async stream({ stream }: StreamParams) { + async stream({ stream }) { if (!this.connected) { throw new Error( `Tried to stream data, but there's no SSE connection. Call DataStreamer.connect() first.` @@ -197,5 +208,19 @@ export function makeDataStreamer(): DataStreamer { } return streamedData; }, + + async streamResponses(data) { + if (!this.connected) { + throw new Error( + `Tried to stream data, but there's no SSE connection. Call DataStreamer.connect() first.` + ); + } + sse?.sendResponsesEvent({ + ...data, + sequence_number: responseSequenceNumber, + } as OpenAI.Responses.ResponseStreamEvent); + + responseSequenceNumber++; + }, }; }