diff --git a/packages/chatbot-server-mongodb-public/src/config.ts b/packages/chatbot-server-mongodb-public/src/config.ts index 5bb7fa705..3ee981c03 100644 --- a/packages/chatbot-server-mongodb-public/src/config.ts +++ b/packages/chatbot-server-mongodb-public/src/config.ts @@ -31,7 +31,8 @@ import { import { redactConnectionUri } from "./middleware/redactConnectionUri"; import path from "path"; import express from "express"; -import { logger } from "mongodb-rag-core"; +import { logger, makeMongoDbSearchResultsStore } from "mongodb-rag-core"; +import { createAzure } from "mongodb-rag-core/aiSdk"; import { wrapOpenAI, wrapTraced, @@ -40,7 +41,6 @@ import { import { AzureOpenAI } from "mongodb-rag-core/openai"; import { MongoClient } from "mongodb-rag-core/mongodb"; import { - ANALYZER_ENV_VARS, AZURE_OPENAI_ENV_VARS, PREPROCESSOR_ENV_VARS, TRACING_ENV_VARS, @@ -57,7 +57,7 @@ import { makeGenerateResponseWithSearchTool } from "./processors/generateRespons import { makeBraintrustLogger } from "mongodb-rag-core/braintrust"; import { makeMongoDbScrubbedMessageStore } from "./tracing/scrubbedMessages/MongoDbScrubbedMessageStore"; import { MessageAnalysis } from "./tracing/scrubbedMessages/analyzeMessage"; -import { createAzure } from "mongodb-rag-core/aiSdk"; +import { makeFindContentWithMongoDbMetadata } from "./processors/findContentWithMongoDbMetadata"; export const { MONGODB_CONNECTION_URI, @@ -120,6 +120,11 @@ export const embeddedContentStore = makeMongoDbEmbeddedContentStore({ }, }); +export const searchResultsStore = makeMongoDbSearchResultsStore({ + connectionUri: MONGODB_CONNECTION_URI, + databaseName: MONGODB_DATABASE_NAME, +}); + export const verifiedAnswerConfig = { embeddingModel: OPENAI_VERIFIED_ANSWER_EMBEDDING_DEPLOYMENT, findNearestNeighborsOptions: { @@ -307,6 +312,13 @@ export async function closeDbConnections() { logger.info(`Segment logging is ${segmentConfig ? "enabled" : "disabled"}`); export const config: AppConfig = { + contentRouterConfig: { + findContent: makeFindContentWithMongoDbMetadata({ + findContent, + classifierModel: languageModel, + }), + searchResultsStore, + }, conversationsRouterConfig: { middleware: [ blockGetRequests, diff --git a/packages/chatbot-server-mongodb-public/src/processors/findContentWithMongoDbMetadata.test.ts b/packages/chatbot-server-mongodb-public/src/processors/findContentWithMongoDbMetadata.test.ts new file mode 100644 index 000000000..547254a6f --- /dev/null +++ b/packages/chatbot-server-mongodb-public/src/processors/findContentWithMongoDbMetadata.test.ts @@ -0,0 +1,81 @@ +// Mocks +jest.mock("mongodb-rag-core/mongoDbMetadata", () => { + const actual = jest.requireActual("mongodb-rag-core/mongoDbMetadata"); + return { + ...actual, + classifyMongoDbProgrammingLanguageAndProduct: jest.fn(), + }; +}); + +jest.mock("mongodb-rag-core", () => { + const actual = jest.requireActual("mongodb-rag-core"); + return { + ...actual, + updateFrontMatter: jest.fn(), + }; +}); + +import { FindContentFunc, updateFrontMatter } from "mongodb-rag-core"; +import { + makeFindContentWithMongoDbMetadata, +} from "./findContentWithMongoDbMetadata"; +import { classifyMongoDbProgrammingLanguageAndProduct } from "mongodb-rag-core/mongoDbMetadata"; + + +const mockedClassify = + classifyMongoDbProgrammingLanguageAndProduct as jest.Mock; +const mockedUpdateFrontMatter = updateFrontMatter as jest.Mock; + +function makeMockFindContent(result: string[]): FindContentFunc { + return jest.fn().mockResolvedValue(result); +} + +afterEach(() => { + jest.resetAllMocks(); +}); + +describe("makeFindContentWithMongoDbMetadata", () => { + test("enhances query with front matter and classification", async () => { + const inputQuery = "How do I use MongoDB with TypeScript?"; + const expectedQuery = `--- +product: driver +programmingLanguage: typescript +--- +How do I use MongoDB with TypeScript?`; + const fakeResult = ["doc1", "doc2"]; + + mockedClassify.mockResolvedValue({ + product: "driver", + programmingLanguage: "typescript", + }); + mockedUpdateFrontMatter.mockReturnValue(expectedQuery); + + const findContentMock = makeMockFindContent(fakeResult); + + const wrappedFindContent = makeFindContentWithMongoDbMetadata({ + findContent: findContentMock, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + classifierModel: {} as any, + }); + + const result = await wrappedFindContent({ + query: inputQuery, + filters: { sourceName: ["docs"] }, + limit: 3, + }); + + expect(mockedClassify).toHaveBeenCalledWith(expect.anything(), inputQuery); + expect(mockedUpdateFrontMatter).toHaveBeenCalledWith(inputQuery, { + product: "driver", + programmingLanguage: "typescript", + }); + + expect(findContentMock).toHaveBeenCalledWith({ + query: expectedQuery, + filters: { sourceName: ["docs"] }, + limit: 3, + }); + + expect(result).toEqual(fakeResult); + }); +}); diff --git a/packages/chatbot-server-mongodb-public/src/processors/findContentWithMongoDbMetadata.ts b/packages/chatbot-server-mongodb-public/src/processors/findContentWithMongoDbMetadata.ts new file mode 100644 index 000000000..a08d9f21f --- /dev/null +++ b/packages/chatbot-server-mongodb-public/src/processors/findContentWithMongoDbMetadata.ts @@ -0,0 +1,38 @@ +import { FindContentFunc, updateFrontMatter } from "mongodb-rag-core"; +import { LanguageModel } from "mongodb-rag-core/aiSdk"; +import { wrapTraced } from "mongodb-rag-core/braintrust"; +import { classifyMongoDbProgrammingLanguageAndProduct } from "mongodb-rag-core/mongoDbMetadata"; + +export const makeFindContentWithMongoDbMetadata = ({ + findContent, + classifierModel, +}: { + findContent: FindContentFunc; + classifierModel: LanguageModel; +}) => { + const wrappedFindContent: FindContentFunc = wrapTraced( + async ({ query, filters, limit }) => { + const { product, programmingLanguage } = + await classifyMongoDbProgrammingLanguageAndProduct( + classifierModel, + query + ); + + const preProcessedQuery = updateFrontMatter(query, { + ...(product ? { product } : {}), + ...(programmingLanguage ? { programmingLanguage } : {}), + }); + + const res = await findContent({ + query: preProcessedQuery, + filters, + limit, + }); + return res; + }, + { + name: "makeFindContentWithMongoDbMetadata", + } + ); + return wrappedFindContent; +}; diff --git a/packages/mongodb-chatbot-server/src/app.ts b/packages/mongodb-chatbot-server/src/app.ts index c2ae01a5e..af36afc80 100644 --- a/packages/mongodb-chatbot-server/src/app.ts +++ b/packages/mongodb-chatbot-server/src/app.ts @@ -17,6 +17,7 @@ import { ObjectId } from "mongodb-rag-core/mongodb"; import { getRequestId, logRequest, sendErrorResponse } from "./utils"; import { CorsOptions } from "cors"; import cloneDeep from "lodash.clonedeep"; +import { makeContentRouter, MakeContentRouterParams } from "./routes"; /** Configuration for the server Express.js app. @@ -27,6 +28,11 @@ export interface AppConfig { */ conversationsRouterConfig: ConversationsRouterParams; + /** + Configuration for the content router. + */ + contentRouterConfig?: MakeContentRouterParams; + /** Maximum time in milliseconds for a request to complete before timing out. Defaults to 60000 (1 minute). @@ -119,6 +125,7 @@ export const makeApp = async (config: AppConfig): Promise => { corsOptions, apiPrefix = DEFAULT_API_PREFIX, expressAppConfig, + contentRouterConfig, } = config; logger.info("Server has the following configuration:"); logger.info( @@ -141,6 +148,10 @@ export const makeApp = async (config: AppConfig): Promise => { makeConversationsRouter(conversationsRouterConfig) ); + if (contentRouterConfig) { + app.use(`${apiPrefix}/content`, makeContentRouter(contentRouterConfig)); + } + app.get("/health", (_req, res) => { const data = { uptime: process.uptime(), diff --git a/packages/mongodb-chatbot-server/src/routes/content/contentRouter.test.ts b/packages/mongodb-chatbot-server/src/routes/content/contentRouter.test.ts new file mode 100644 index 000000000..306b6bdca --- /dev/null +++ b/packages/mongodb-chatbot-server/src/routes/content/contentRouter.test.ts @@ -0,0 +1,54 @@ +import request from "supertest"; +import { makeTestApp } from "../../test/testHelpers"; +import type { MakeContentRouterParams } from "./contentRouter"; +import type { + FindContentFunc, + MongoDbSearchResultsStore, +} from "mongodb-rag-core"; + +// Minimal in-memory mock for SearchResultsStore for testing purposes +const mockSearchResultsStore: MongoDbSearchResultsStore = { + drop: jest.fn(), + close: jest.fn(), + metadata: { + databaseName: "mock", + collectionName: "mock", + }, + saveSearchResult: jest.fn(), + init: jest.fn(), +}; + +const findContentMock = jest.fn().mockResolvedValue({ + content: [], + queryEmbedding: [], +}) satisfies FindContentFunc; + +// Helper to build contentRouterConfig for the test app +function makeContentRouterConfig( + overrides: Partial = {} +) { + return { + findContent: findContentMock, + searchResultsStore: mockSearchResultsStore, + ...overrides, + } satisfies MakeContentRouterParams; +} + +describe("contentRouter", () => { + const searchEndpoint = "/api/v1/content/search"; + + it("should call custom middleware if provided", async () => { + const mockMiddleware = jest.fn((_req, _res, next) => next()); + const { app, origin } = await makeTestApp({ + contentRouterConfig: makeContentRouterConfig({ + middleware: [mockMiddleware], + }), + }); + await request(app) + .post(searchEndpoint) + .set("req-id", "test-req-id") + .set("Origin", origin) + .send({ query: "mongodb" }); + expect(mockMiddleware).toHaveBeenCalled(); + }); +}); diff --git a/packages/mongodb-chatbot-server/src/routes/content/contentRouter.ts b/packages/mongodb-chatbot-server/src/routes/content/contentRouter.ts new file mode 100644 index 000000000..31fd333ac --- /dev/null +++ b/packages/mongodb-chatbot-server/src/routes/content/contentRouter.ts @@ -0,0 +1,61 @@ +import { RequestHandler, Router } from "express"; +import { ParamsDictionary } from "express-serve-static-core"; +import { FindContentFunc, MongoDbSearchResultsStore } from "mongodb-rag-core"; + +import validateRequestSchema from "../../middleware/validateRequestSchema"; +import { SearchContentRequest, makeSearchContentRoute } from "./searchContent"; + +/** + Middleware to put in front of all the routes in the contentRouter. + Useful for authentication, data validation, logging, etc. + It exposes the app's {@link ContentRouterLocals} via {@link Response.locals} + ([docs](https://expressjs.com/en/api.html#res.locals)). + You can use or modify `res.locals.customData` in your middleware, and this data + will be available to subsequent middleware and route handlers. + */ +export type SearchContentMiddleware = RequestHandler< + ParamsDictionary, + unknown, + unknown, + unknown, + SearchContentRouterLocals +>; + +/** + Local variables provided by Express.js for single request-response cycle + + Keeps track of data for authentication or dynamic data validation. + */ +export interface SearchContentRouterLocals { + customData: Record; +} + +export interface MakeContentRouterParams { + findContent: FindContentFunc; + searchResultsStore: MongoDbSearchResultsStore; + // TODO: Add default middleware along with customData as in conversationsRouter + middleware?: SearchContentMiddleware[]; +} + +export function makeContentRouter({ + findContent, + searchResultsStore, + middleware = [], +}: MakeContentRouterParams) { + const contentRouter = Router(); + + // Add middleware to the conversationsRouter. + middleware?.forEach((middleware) => contentRouter.use(middleware)); + + // Create new conversation. + contentRouter.post( + "/search", + validateRequestSchema(SearchContentRequest), + makeSearchContentRoute({ + findContent, + searchResultsStore, + }) + ); + + return contentRouter; +} diff --git a/packages/mongodb-chatbot-server/src/routes/content/index.ts b/packages/mongodb-chatbot-server/src/routes/content/index.ts new file mode 100644 index 000000000..692a707f5 --- /dev/null +++ b/packages/mongodb-chatbot-server/src/routes/content/index.ts @@ -0,0 +1,2 @@ +export * from "./contentRouter"; +export * from "./searchContent"; diff --git a/packages/mongodb-chatbot-server/src/routes/content/searchContent.test.ts b/packages/mongodb-chatbot-server/src/routes/content/searchContent.test.ts new file mode 100644 index 000000000..b2ca79963 --- /dev/null +++ b/packages/mongodb-chatbot-server/src/routes/content/searchContent.test.ts @@ -0,0 +1,155 @@ +import { makeSearchContentRoute } from "./searchContent"; +import type { FindContentFunc, FindContentResult } from "mongodb-rag-core"; +import type { MongoDbSearchResultsStore } from "mongodb-rag-core"; +import { createRequest, createResponse } from "node-mocks-http"; + +// Helper to create a mock FindContentFunc +function makeMockFindContent(result: FindContentResult) { + return jest.fn().mockResolvedValue(result) satisfies FindContentFunc; +} + +// Helper to create a mock MongoDbSearchResultsStore +function makeMockMongoDbSearchResultsStore() { + return { + drop: jest.fn(), + close: jest.fn(), + metadata: { databaseName: "mock", collectionName: "mock" }, + saveSearchResult: jest.fn().mockResolvedValue(undefined), + init: jest.fn(), + } satisfies MongoDbSearchResultsStore; +} + +describe("makeSearchContentRoute", () => { + const baseReqBody = { + query: "What is aggregation?", + limit: 2, + dataSources: [{ name: "source1", type: "docs", versionLabel: "v1" }], + }; + // Add all required EmbeddedContent fields for the mock result + const baseFindContentResult: FindContentResult = { + queryEmbedding: [0.1, 0.2, 0.3], + content: [ + { + url: "https://www.mongodb.com/docs/manual/aggregation", + text: "Look at all this aggregation", + metadata: { pageTitle: "Aggregation Operations" }, + sourceName: "source1", + tokenCount: 8, + embeddings: { test: [0.1, 0.2, 0.3] }, + updated: new Date(), + score: 0.8, + }, + { + url: "https://mongodb.com/docs", + text: "MongoDB Docs", + metadata: { pageTitle: "MongoDB" }, + sourceName: "source1", + tokenCount: 10, + embeddings: { test: [0.1, 0.2, 0.3] }, + updated: new Date(), + score: 0.6, + }, + ], + }; + + it("should return search results for a valid request", async () => { + const findContent = makeMockFindContent(baseFindContentResult); + const searchResultsStore = makeMockMongoDbSearchResultsStore(); + const handler = makeSearchContentRoute({ findContent, searchResultsStore }); + + const req = createRequest({ + body: baseReqBody, + headers: { "req-id": "test-req-id" }, + }); + const res = createResponse(); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + await handler(req, res as any); + + const data = res._getJSONData(); + expect(data).toHaveProperty("results"); + expect(Array.isArray(data.results)).toBe(true); + expect(data.results.length).toBe(2); + expect(data.results[0].url).toBe( + "https://www.mongodb.com/docs/manual/aggregation" + ); + }); + + it("should call findContent with correct arguments", async () => { + const findContent = jest.fn().mockResolvedValue(baseFindContentResult); + const searchResultsStore = makeMockMongoDbSearchResultsStore(); + const handler = makeSearchContentRoute({ findContent, searchResultsStore }); + const req = createRequest({ + body: baseReqBody, + headers: { "req-id": "test-req-id" }, + }); + const res = createResponse(); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + await handler(req, res as any); + + expect(findContent).toHaveBeenCalledWith({ + query: baseReqBody.query, + filters: expect.any(Object), + limit: baseReqBody.limit, + }); + }); + + it("should call searchResultsStore.saveSearchResult", async () => { + const findContent = makeMockFindContent(baseFindContentResult); + const searchResultsStore = makeMockMongoDbSearchResultsStore(); + const handler = makeSearchContentRoute({ findContent, searchResultsStore }); + const req = createRequest({ + body: baseReqBody, + headers: { "req-id": "test-req-id" }, + }); + const res = createResponse(); + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + await handler(req, res as any); + expect(searchResultsStore.saveSearchResult).toHaveBeenCalledWith( + expect.objectContaining({ + query: baseReqBody.query, + results: baseFindContentResult.content, + dataSources: baseReqBody.dataSources, + limit: baseReqBody.limit, + }) + ); + }); + + it("should handle errors from findContent and throw", async () => { + const findContent = jest.fn().mockRejectedValue(new Error("fail")); + const searchResultsStore = makeMockMongoDbSearchResultsStore(); + const handler = makeSearchContentRoute({ findContent, searchResultsStore }); + const req = createRequest({ + body: baseReqBody, + headers: { "req-id": "test-req-id" }, + }); + const res = createResponse(); + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + await expect(handler(req, res as any)).rejects.toMatchObject({ + message: "Unable to query search database", + httpStatus: 500, + name: "RequestError", + }); + }); + + it("should respect `limit` and `dataSources` parameters", async () => { + const findContent = jest.fn().mockResolvedValue(baseFindContentResult); + const searchResultsStore = makeMockMongoDbSearchResultsStore(); + const handler = makeSearchContentRoute({ findContent, searchResultsStore }); + const req = createRequest({ + body: { ...baseReqBody, limit: 1, dataSources: [{ name: "source2" }] }, + headers: { "req-id": "test-req-id" }, + }); + const res = createResponse(); + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + await handler(req, res as any); + expect(findContent).toHaveBeenCalledWith( + expect.objectContaining({ + limit: 1, + filters: expect.objectContaining({ sourceName: ["source2"] }), + }) + ); + }); +}); diff --git a/packages/mongodb-chatbot-server/src/routes/content/searchContent.ts b/packages/mongodb-chatbot-server/src/routes/content/searchContent.ts new file mode 100644 index 000000000..7de4a2a89 --- /dev/null +++ b/packages/mongodb-chatbot-server/src/routes/content/searchContent.ts @@ -0,0 +1,139 @@ +import { + Request as ExpressRequest, + Response as ExpressResponse, +} from "express"; +import { + FindContentFunc, + FindContentResult, + MongoDbSearchResultsStore, + QueryFilters, + SearchRecordDataSource, + SearchRecordDataSourceSchema, +} from "mongodb-rag-core"; +import { z } from "zod"; + +import { SomeExpressRequest } from "../../middleware"; +import { makeRequestError } from "../conversations/utils"; +import { SearchContentRouterLocals } from "./contentRouter"; + +export const SearchContentRequestBody = z.object({ + query: z.string(), + dataSources: z.array(SearchRecordDataSourceSchema).optional(), + limit: z.number().int().min(1).max(500).optional().default(5), +}); + +export const SearchContentRequest = SomeExpressRequest.merge( + z.object({ + headers: z.object({ + "req-id": z.string(), + }), + body: SearchContentRequestBody, + }) +); + +export type SearchContentRequest = z.infer; +export type SearchContentRequestBody = z.infer; + +export interface MakeSearchContentRouteParams { + findContent: FindContentFunc; + searchResultsStore: MongoDbSearchResultsStore; +} + +interface SearchContentResponseChunk { + url: string; + title: string; + text: string; + metadata?: { + sourceName?: string; + sourceType?: string; + sourceVersionLabel?: string; + tags?: string[]; + [k: string]: unknown; + }; +} +interface SearchContentResponseBody { + results: SearchContentResponseChunk[]; +} + +export function makeSearchContentRoute({ + findContent, + searchResultsStore, +}: MakeSearchContentRouteParams) { + return async ( + req: ExpressRequest, + res: ExpressResponse + ) => { + try { + const { query, dataSources, limit } = req.body; + const results = await findContent({ + query, + filters: mapDataSourcesToFilters(dataSources), + limit, + }); + res.json(mapFindContentResultToSearchContentResponseChunk(results)); + await persistSearchResultsToDatabase({ + query, + results, + dataSources, + limit, + searchResultsStore, + }); + } catch (error) { + throw makeRequestError({ + httpStatus: 500, + message: "Unable to query search database", + }); + } + }; +} + +function mapFindContentResultToSearchContentResponseChunk( + result: FindContentResult +): SearchContentResponseBody { + return { + results: result.content.map(({ url, metadata, text }) => ({ + url, + title: metadata?.pageTitle ?? "", + text, + metadata, + })), + }; +} + +function mapDataSourcesToFilters( + dataSources?: SearchRecordDataSource[] +): QueryFilters { + if (!dataSources || dataSources.length === 0) { + return {}; + } + + const sourceNames = dataSources.map((ds) => ds.name); + const sourceTypes = dataSources + .map((ds) => ds.type) + .filter((t): t is string => !!t); + const versionLabels = dataSources + .map((ds) => ds.versionLabel) + .filter((v): v is string => !!v); + + return { + ...(sourceNames.length && { sourceName: sourceNames }), + ...(sourceTypes.length && { sourceType: sourceTypes }), + ...(versionLabels.length && { version: { label: versionLabels } }), + }; +} + +async function persistSearchResultsToDatabase(params: { + query: string; + results: FindContentResult; + dataSources: SearchRecordDataSource[]; + limit: number; + searchResultsStore: MongoDbSearchResultsStore; +}) { + params.searchResultsStore.saveSearchResult({ + query: params.query, + results: params.results.content, + dataSources: params.dataSources, + limit: params.limit, + createdAt: new Date(), + }); +} diff --git a/packages/mongodb-chatbot-server/src/routes/index.ts b/packages/mongodb-chatbot-server/src/routes/index.ts index b9f9da7be..cff392f1a 100644 --- a/packages/mongodb-chatbot-server/src/routes/index.ts +++ b/packages/mongodb-chatbot-server/src/routes/index.ts @@ -1 +1,2 @@ export * from "./conversations"; +export * from "./content"; diff --git a/packages/mongodb-rag-core/src/contentStore/EmbeddedContent.ts b/packages/mongodb-rag-core/src/contentStore/EmbeddedContent.ts index 4bbe948b1..18cc19c9d 100644 --- a/packages/mongodb-rag-core/src/contentStore/EmbeddedContent.ts +++ b/packages/mongodb-rag-core/src/contentStore/EmbeddedContent.ts @@ -96,12 +96,9 @@ export interface GetSourcesMatchParams { Filters for querying the embedded content vector store. */ export type QueryFilters = { - sourceName?: string; - version?: { - current?: boolean; - label?: string; - }; - sourceType?: Page["sourceType"]; + sourceName?: string | string[]; + version?: { current?: boolean; label?: string | string[] }; + sourceType?: Page["sourceType"] | string[]; }; /** diff --git a/packages/mongodb-rag-core/src/contentStore/MongoDbEmbeddedContentStore.ts b/packages/mongodb-rag-core/src/contentStore/MongoDbEmbeddedContentStore.ts index 147da6ebc..fdecea7ab 100644 --- a/packages/mongodb-rag-core/src/contentStore/MongoDbEmbeddedContentStore.ts +++ b/packages/mongodb-rag-core/src/contentStore/MongoDbEmbeddedContentStore.ts @@ -294,10 +294,10 @@ export function makeMongoDbEmbeddedContentStore({ } type MongoDbAtlasVectorSearchFilter = { - sourceName?: string; - "metadata.version.label"?: string; + sourceName?: string | { $in: string[] }; + "metadata.version.label"?: string | { $in: string[] }; "metadata.version.isCurrent"?: boolean | { $ne: boolean }; - sourceType?: string; + sourceType?: string | { $in: string[] }; }; const handleFilters = ( @@ -305,15 +305,21 @@ const handleFilters = ( ): MongoDbAtlasVectorSearchFilter => { const vectorSearchFilter: MongoDbAtlasVectorSearchFilter = {}; if (filter.sourceName) { - vectorSearchFilter["sourceName"] = filter.sourceName; + vectorSearchFilter["sourceName"] = Array.isArray(filter.sourceName) + ? { $in: filter.sourceName } + : filter.sourceName; } if (filter.sourceType) { - vectorSearchFilter["sourceType"] = filter.sourceType; + vectorSearchFilter["sourceType"] = Array.isArray(filter.sourceType) + ? { $in: filter.sourceType } + : filter.sourceType; } // Handle version filter. Note: unversioned embeddings (isCurrent: null) are treated as current const { current, label } = filter.version ?? {}; if (label) { - vectorSearchFilter["metadata.version.label"] = label; + vectorSearchFilter["metadata.version.label"] = Array.isArray(label) + ? { $in: label } + : label; } // Return current embeddings if either: // 1. current=true was explicitly requested, or diff --git a/packages/mongodb-rag-core/src/contentStore/MongoDbSearchResultsStore.test.ts b/packages/mongodb-rag-core/src/contentStore/MongoDbSearchResultsStore.test.ts new file mode 100644 index 000000000..525fd44ea --- /dev/null +++ b/packages/mongodb-rag-core/src/contentStore/MongoDbSearchResultsStore.test.ts @@ -0,0 +1,91 @@ +import { strict as assert } from "assert"; +import "dotenv/config"; +import { MongoClient } from "mongodb"; +import { MONGO_MEMORY_SERVER_URI } from "../test/constants"; +import { + makeMongoDbSearchResultsStore, + MongoDbSearchResultsStore, + SearchResultRecord, +} from "./MongoDbSearchResultsStore"; + +const searchResultRecord: SearchResultRecord = { + query: "What is MongoDB Atlas?", + results: [ + { + url: "foo", + title: "bar", + text: "baz", + metadata: { + sourceName: "source", + }, + }, + ], + dataSources: [{ name: "source1", type: "docs" }], + createdAt: new Date(), +}; +const uri = MONGO_MEMORY_SERVER_URI; + +describe("MongoDbSearchResultsStore", () => { + let store: MongoDbSearchResultsStore | undefined; + + beforeAll(async () => { + store = makeMongoDbSearchResultsStore({ + connectionUri: uri, + databaseName: "test-search-content-database", + }); + }); + + afterEach(async () => { + await store?.drop(); + }); + afterAll(async () => { + await store?.close(); + }); + + it("has an overridable default collection name", async () => { + assert(store); + + expect(store.metadata.collectionName).toBe("search_results"); + + const storeWithCustomCollectionName = makeMongoDbSearchResultsStore({ + connectionUri: uri, + databaseName: store.metadata.databaseName, + collectionName: "custom-search_results", + }); + + expect(storeWithCustomCollectionName.metadata.collectionName).toBe( + "custom-search_results" + ); + }); + + it("creates indexes", async () => { + assert(store); + await store.init(); + + const mongoClient = new MongoClient(uri); + const coll = mongoClient + ?.db(store.metadata.databaseName) + .collection(store.metadata.collectionName); + const indexes = await coll?.listIndexes().toArray(); + + expect(indexes?.some((el) => el.name === "createdAt_-1")).toBe(true); + await mongoClient.close(); + }); + + it("saves search result records to db", async () => { + assert(store); + await store.saveSearchResult(searchResultRecord); + + // Check for record in db + const client = new MongoClient(uri); + await client.connect(); + const db = client.db(store.metadata.databaseName); + const collection = db.collection("search_results"); + const found = await collection.findOne(searchResultRecord); + + expect(found).toBeTruthy(); + expect(found).toMatchObject(searchResultRecord); + + await client.close(); + }); +}); diff --git a/packages/mongodb-rag-core/src/contentStore/MongoDbSearchResultsStore.ts b/packages/mongodb-rag-core/src/contentStore/MongoDbSearchResultsStore.ts new file mode 100644 index 000000000..f9fa77c79 --- /dev/null +++ b/packages/mongodb-rag-core/src/contentStore/MongoDbSearchResultsStore.ts @@ -0,0 +1,110 @@ +import { z } from "zod"; +import { DatabaseConnection } from "../DatabaseConnection"; +import { + MakeMongoDbDatabaseConnectionParams, + makeMongoDbDatabaseConnection, +} from "../MongoDbDatabaseConnection"; +import { Document } from "mongodb"; + +export const SearchRecordDataSourceSchema = z.object({ + name: z.string(), + type: z.string().optional(), + versionLabel: z.string().optional(), +}); + +export type SearchRecordDataSource = z.infer< + typeof SearchRecordDataSourceSchema +>; + +export interface ResultChunk { + url: string; + title: string; + text: string; + metadata: { + sourceName: string; + sourceType?: string; + tags?: string[]; + [key: string]: unknown; // Accept additional unknown properties + }; +} + +export const ResultChunkSchema = z.object({ + url: z.string(), + title: z.string(), + text: z.string(), + metadata: z + .object({ + sourceName: z.string(), + sourceType: z.string().optional(), + tags: z.array(z.string()).optional(), + }) + .passthrough(), +}); + +export const SearchResultRecordSchema = z.object({ + query: z.string(), + results: z.array(ResultChunkSchema), + dataSources: z.array(SearchRecordDataSourceSchema).optional(), + limit: z.number().optional(), + createdAt: z.date(), +}); + +export interface SearchResultRecord { + query: string; + results: Document[]; + dataSources?: SearchRecordDataSource[]; + limit?: number; + createdAt: Date; +} + +export type MongoDbSearchResultsStore = DatabaseConnection & { + metadata: { + databaseName: string; + collectionName: string; + }; + saveSearchResult(record: SearchResultRecord): Promise; + init(): Promise; +}; + +export type MakeMongoDbSearchResultsStoreParams = + MakeMongoDbDatabaseConnectionParams & { + collectionName?: string; + }; + +export type ContentCustomData = Record | undefined; + +export function makeMongoDbSearchResultsStore({ + connectionUri, + databaseName, + collectionName = "search_results", +}: MakeMongoDbSearchResultsStoreParams): MongoDbSearchResultsStore { + const { db, drop, close } = makeMongoDbDatabaseConnection({ + connectionUri, + databaseName, + }); + const searchResultsCollection = + db.collection(collectionName); + return { + drop, + close, + metadata: { + databaseName, + collectionName, + }, + async saveSearchResult(record: SearchResultRecord) { + const insertResult = await searchResultsCollection.insertOne(record); + + if (!insertResult.acknowledged) { + throw new Error("Insert was not acknowledged by MongoDB"); + } + if (!insertResult.insertedId) { + throw new Error( + "No insertedId returned from MongoDbSearchResultsStore.saveSearchResult insertOne" + ); + } + }, + async init() { + await searchResultsCollection.createIndex({ createdAt: -1 }); + }, + }; +} diff --git a/packages/mongodb-rag-core/src/contentStore/index.ts b/packages/mongodb-rag-core/src/contentStore/index.ts index a0be7b876..3de61850a 100644 --- a/packages/mongodb-rag-core/src/contentStore/index.ts +++ b/packages/mongodb-rag-core/src/contentStore/index.ts @@ -2,6 +2,7 @@ export * from "./EmbeddedContent"; export * from "./getChangedPages"; export * from "./MongoDbEmbeddedContentStore"; export * from "./MongoDbPageStore"; +export * from "./MongoDbSearchResultsStore"; export * from "./MongoDbTransformedContentStore"; export * from "./Page"; export * from "./PageFormat"; diff --git a/packages/mongodb-rag-core/src/findContent/DefaultFindContent.test.ts b/packages/mongodb-rag-core/src/findContent/DefaultFindContent.test.ts index 365b08560..8e66a4e15 100644 --- a/packages/mongodb-rag-core/src/findContent/DefaultFindContent.test.ts +++ b/packages/mongodb-rag-core/src/findContent/DefaultFindContent.test.ts @@ -95,4 +95,20 @@ describe("makeDefaultFindContent()", () => { expect(content.length).toBeGreaterThan(0); expect(embeddingModelName).toBe(OPENAI_RETRIEVAL_EMBEDDING_DEPLOYMENT); }); + test("should limit results", async () => { + const findContent = makeDefaultFindContent({ + embedder, + store: embeddedContentStore, + findNearestNeighborsOptions: { + minScore: 0.1, // low min, should return at least one result + }, + }); + const query = "MongoDB"; + const { content } = await findContent({ + query, + limit: 1, // limit to 1, should return 1 result + }); + expect(content).toBeDefined(); + expect(content.length).toBe(1); + }); }); diff --git a/packages/mongodb-rag-core/src/findContent/DefaultFindContent.ts b/packages/mongodb-rag-core/src/findContent/DefaultFindContent.ts index 43661187c..572fe10da 100644 --- a/packages/mongodb-rag-core/src/findContent/DefaultFindContent.ts +++ b/packages/mongodb-rag-core/src/findContent/DefaultFindContent.ts @@ -20,7 +20,7 @@ export const makeDefaultFindContent = ({ findNearestNeighborsOptions, searchBoosters, }: MakeDefaultFindContentFuncArgs): FindContentFunc => { - return async ({ query, filters = {} }) => { + return async ({ query, filters = {}, limit }) => { const { embedding } = await embedder.embed({ text: query, }); @@ -28,6 +28,7 @@ export const makeDefaultFindContent = ({ let content = await store.findNearestNeighbors(embedding, { ...findNearestNeighborsOptions, filter: filters, + ...(limit ? { k: limit }: {}) }); for (const booster of searchBoosters ?? []) { diff --git a/packages/mongodb-rag-core/src/findContent/FindContentFunc.ts b/packages/mongodb-rag-core/src/findContent/FindContentFunc.ts index da01c3d54..ddb4d0f21 100644 --- a/packages/mongodb-rag-core/src/findContent/FindContentFunc.ts +++ b/packages/mongodb-rag-core/src/findContent/FindContentFunc.ts @@ -4,6 +4,7 @@ import { WithScore } from "../VectorStore"; export type FindContentFuncArgs = { query: string; filters?: QueryFilters; + limit?: number; }; export type FindContentFunc = ( diff --git a/packages/mongodb-rag-core/src/mongoDbMetadata/classifyMetadata.ts b/packages/mongodb-rag-core/src/mongoDbMetadata/classifyMetadata.ts index 3ba058b3b..067595d09 100644 --- a/packages/mongodb-rag-core/src/mongoDbMetadata/classifyMetadata.ts +++ b/packages/mongodb-rag-core/src/mongoDbMetadata/classifyMetadata.ts @@ -135,6 +135,20 @@ ${mongoDbTopics function nullOnErr() { return null; } + +export const classifyMongoDbProgrammingLanguageAndProduct = wrapTraced( + async (model: LanguageModel, data: string, maxRetries?: number) => { + const [programmingLanguage, product] = await Promise.all([ + classifyMongoDbProgrammingLanguage(model, data, maxRetries).catch( + nullOnErr + ), + classifyMongoDbProduct(model, data, maxRetries).catch(nullOnErr), + ]); + return { programmingLanguage, product }; + }, + { name: "classifyMongoDbProgrammingLanguageAndProduct" } +); + export const classifyMongoDbMetadata = wrapTraced( async (model: LanguageModel, data: string, maxRetries?: number) => { const [programmingLanguage, product, topic] = await Promise.all([