diff --git a/packages/benchmarks/src/nlPromptResponse/metrics.ts b/packages/benchmarks/src/nlPromptResponse/metrics.ts index 139cb8863..7eff31e6a 100644 --- a/packages/benchmarks/src/nlPromptResponse/metrics.ts +++ b/packages/benchmarks/src/nlPromptResponse/metrics.ts @@ -2,11 +2,14 @@ import { NlPromptResponseEvalScorer } from "./NlQuestionAnswerEval"; import { Factuality, Score, AnswerCorrectness } from "autoevals"; import { strict as assert } from "assert"; import { LlmOptions } from "mongodb-rag-core/executeCode"; +import { OpenAI } from "mongodb-rag-core/openai"; +import { openAi } from "mongodb-rag-core/langchain"; export const makeReferenceAlignment: ( + openAiClient: OpenAI, llmOptions: LlmOptions, name_postfix?: string -) => NlPromptResponseEvalScorer = (llmOptions, name_postfix) => +) => NlPromptResponseEvalScorer = (openAiClient, llmOptions, name_postfix) => async function ({ input, output, expected }) { const { response } = output; const { reference } = expected; @@ -29,7 +32,7 @@ export const makeReferenceAlignment: ( // Note: need to do the funky casting here // b/c of different `OpenAI` client typing // that is not relevant here. - client: llmOptions.openAiClient as unknown as Parameters< + client: openAiClient as unknown as Parameters< typeof Factuality >[0]["client"], model: llmOptions.model, @@ -62,9 +65,10 @@ function inflateFactualityScore(score: number | null | undefined) { } export const makeAnswerCorrectness: ( + openAiClient: OpenAI, llmOptions: LlmOptions, name_postfix?: string -) => NlPromptResponseEvalScorer = (llmOptions, name_postfix) => +) => NlPromptResponseEvalScorer = (openAiClient, llmOptions, name_postfix) => async function ({ input, output, expected }) { const { response } = output; const { reference } = expected; @@ -89,7 +93,7 @@ export const makeAnswerCorrectness: ( // b/c of different `OpenAI` client typing // that is not relevant here. - client: llmOptions.openAiClient as unknown as Parameters< + client: openAiClient as unknown as Parameters< typeof Factuality >[0]["client"], model: llmOptions.model, @@ -101,11 +105,12 @@ export const makeAnswerCorrectness: ( }; export const makeReferenceAlignmentCouncil: ( + openAiClient: OpenAI, llmOptions: LlmOptions[] -) => NlPromptResponseEvalScorer = (llmOptions) => { +) => NlPromptResponseEvalScorer = (openAiClient, llmOptions) => { assert(llmOptions.length > 0, "At least one LLM must be provided"); const factualityMetrics = llmOptions.map((llmOption) => - makeReferenceAlignment(llmOption) + makeReferenceAlignment(openAiClient, llmOption) ); return async function ({ input, output, expected }) { const name = "ReferenceAlignmentCouncil"; diff --git a/packages/benchmarks/src/nlPromptResponse/nlPromptCompletionTask.ts b/packages/benchmarks/src/nlPromptResponse/nlPromptCompletionTask.ts index 5886b9707..9395edc25 100644 --- a/packages/benchmarks/src/nlPromptResponse/nlPromptCompletionTask.ts +++ b/packages/benchmarks/src/nlPromptResponse/nlPromptCompletionTask.ts @@ -4,16 +4,18 @@ import { NlPromptResponseEvalTask } from "./NlQuestionAnswerEval"; import { OpenAI } from "mongodb-rag-core/openai"; interface MakeNlPromptCompletionTaskParams { + openAiClient: OpenAI; llmOptions: LlmOptions; initialMessages?: OpenAI.Chat.ChatCompletionMessageParam[]; } export function makeNlPromptCompletionTask({ + openAiClient, llmOptions, initialMessages, }: MakeNlPromptCompletionTaskParams): NlPromptResponseEvalTask { return async function (input) { - const { openAiClient, ...llmConfig } = llmOptions; + const { ...llmConfig } = llmOptions; const res = await openAiClient.chat.completions.create({ messages: [...(initialMessages ?? []), ...input.messages], stream: false, diff --git a/packages/benchmarks/src/nlPromptResponse/runNlPromptResponseBenchmark.ts b/packages/benchmarks/src/nlPromptResponse/runNlPromptResponseBenchmark.ts index a5f8eddb1..f54fb96d9 100644 --- a/packages/benchmarks/src/nlPromptResponse/runNlPromptResponseBenchmark.ts +++ b/packages/benchmarks/src/nlPromptResponse/runNlPromptResponseBenchmark.ts @@ -60,21 +60,24 @@ export async function runNlPromptResponseBenchmark({ "At least one judge model must be configured in 'judgeModelsConfig'. Check your model labels in 'globalConfig.ts'." ); - const judgeClients = await Promise.all( + const judgeConfigs = await Promise.all( judgeModelsConfig.map(async (m) => { const endpointAndKey = await getOpenAiEndpointAndApiKey(m); console.log(`Judge model: ${m.label}`); return { openAiClient: new OpenAI(endpointAndKey), - model: m.deployment, - temperature: 0, - label: m.label, + config: { + model: m.deployment, + temperature: 0, + label: m.label, + }, }; }) ); - const judgeMetrics = judgeClients.map(({ label, ...config }) => - makeReferenceAlignment(config, label) - ); + const judgeMetrics = judgeConfigs.map(({ openAiClient, config }) => { + const { label, ...llmOptions } = config; + return makeReferenceAlignment(openAiClient, llmOptions, label); + }); await PromisePool.for(models) .withConcurrency(maxConcurrentExperiments) @@ -121,6 +124,7 @@ export async function runNlPromptResponseBenchmark({ ...staticLlmOptions, }, task: makeNlPromptCompletionTask({ + openAiClient, llmOptions, initialMessages: [systemMessage], }), diff --git a/packages/benchmarks/src/test/makeSampleLlmOptions.ts b/packages/benchmarks/src/test/makeSampleLlmOptions.ts index c607b9c95..5b494a5a7 100644 --- a/packages/benchmarks/src/test/makeSampleLlmOptions.ts +++ b/packages/benchmarks/src/test/makeSampleLlmOptions.ts @@ -2,14 +2,17 @@ import { assertEnvVars, BRAINTRUST_ENV_VARS } from "mongodb-rag-core"; import { LlmOptions } from "mongodb-rag-core/executeCode"; import { OpenAI } from "mongodb-rag-core/openai"; -export const makeSampleLlmOptions = () => { +export const makeSampleLlm = () => { const { BRAINTRUST_API_KEY, BRAINTRUST_ENDPOINT } = assertEnvVars(BRAINTRUST_ENV_VARS); + return new OpenAI({ + apiKey: BRAINTRUST_API_KEY, + baseURL: BRAINTRUST_ENDPOINT, + }); +}; + +export const makeSampleLlmOptions = () => { return { - openAiClient: new OpenAI({ - apiKey: BRAINTRUST_API_KEY, - baseURL: BRAINTRUST_ENDPOINT, - }), model: "gpt-4o", temperature: 0, } satisfies LlmOptions; diff --git a/packages/benchmarks/src/textToDriver/bin/mongoshBenchmarks/claudeGenerated/config.ts b/packages/benchmarks/src/textToDriver/bin/mongoshBenchmarks/claudeGenerated/config.ts new file mode 100644 index 000000000..3d3a29bd0 --- /dev/null +++ b/packages/benchmarks/src/textToDriver/bin/mongoshBenchmarks/claudeGenerated/config.ts @@ -0,0 +1,91 @@ +import { assertEnvVars, BRAINTRUST_ENV_VARS } from "mongodb-rag-core"; +import { TEXT_TO_DRIVER_ENV_VARS } from "../../../TextToDriverEnvVars"; +import { ModelConfig, models } from "mongodb-rag-core/models"; +import { strict as assert } from "assert"; +import { LlmOptions } from "mongodb-rag-core/executeCode"; +import { + SchemaStrategy, + SystemPromptStrategy, +} from "../../../generateDriverCode/languagePrompts/PromptStrategies"; + +export const DATASET_NAME = "atlas-sample-dataset-claude"; + +export const PROJECT_NAME = "natural-language-to-mongosh"; + +export const EXPERIMENT_BASE_NAME = "mongosh-benchmark-claude"; + +export const { + BRAINTRUST_API_KEY, + BRAINTRUST_ENDPOINT, + MONGODB_TEXT_TO_DRIVER_CONNECTION_URI, +} = assertEnvVars({ + ...TEXT_TO_DRIVER_ENV_VARS, + ...BRAINTRUST_ENV_VARS, +}); + +export const schemaStrategies: SchemaStrategy[] = [ + "annotated", + "interpreted", + "none", +] as const; +export const systemPromptStrategies: SystemPromptStrategy[] = [ + "default", + "chainOfThought", + "lazy", +] as const; + +export const experimentTypes = [ + "toolCall", + "agentic", + "promptCompletion", +] as const; + +export const fewShot = [true, false]; + +export interface Experiment { + model: (typeof MODELS)[number]; + schemaStrategy: SchemaStrategy; + systemPromptStrategy: SystemPromptStrategy; + type: (typeof experimentTypes)[number]; + fewShot?: boolean; +} + +export const MAX_CONCURRENT_EXPERIMENTS = 2; + +export const MAX_CONCURRENT_MODELS = 2; + +export const MODELS: ModelConfig[] = ( + [ + // benchmark models + // "gpt-4.1-mini", + // "gemini-2.5-flash-preview-05-20", + // "gemini-2.5-pro-preview-05-06", + // "gpt-4.1-nano", + // "gpt-4.1", + // "o3", + // "o4-mini", + "anthropic/claude-opus-4", + "anthropic/claude-sonnet-4", + // "claude-opus-4", + // "claude-sonnet-4", + ] satisfies (typeof models)[number]["label"][] +).map((label) => { + const model = models.find((m) => m.label === label); + assert(model, `Model ${label} not found`); + return model; +}); + +export function makeLlmOptions(model: ModelConfig): LlmOptions { + // Different because o3-mini has slightly different options + if (model.label === "o3-mini") { + return { + model: model.deployment, + max_completion_tokens: 3000, + }; + } + return { + model: model.deployment, + temperature: 0, + max_tokens: 3000, + }; +} diff --git a/packages/benchmarks/src/textToDriver/bin/mongoshBenchmarks/claudeGenerated/promptCompletionAnnotatedSchema.ts b/packages/benchmarks/src/textToDriver/bin/mongoshBenchmarks/claudeGenerated/promptCompletionAnnotatedSchema.ts new file mode 100644 index 000000000..8367c3685 --- /dev/null +++ b/packages/benchmarks/src/textToDriver/bin/mongoshBenchmarks/claudeGenerated/promptCompletionAnnotatedSchema.ts @@ -0,0 +1,86 @@ +import { makeTextToDriverEval } from "../../../TextToDriverEval"; +import { loadTextToDriverBraintrustEvalCases } from "../../../loadBraintrustDatasets"; +import { + ReasonableOutput, + SuccessfulExecution, +} from "../../../evaluationMetrics"; +import { annotatedDbSchemas } from "../../../generateDriverCode/annotatedDbSchemas"; +import { createOpenAI } from "@ai-sdk/openai"; +import { wrapAISDKModel } from "mongodb-rag-core/braintrust"; +import { + BRAINTRUST_API_KEY, + DATASET_NAME, + PROJECT_NAME, + MONGODB_TEXT_TO_DRIVER_CONNECTION_URI, + makeLlmOptions, + MAX_CONCURRENT_EXPERIMENTS, + MODELS, + EXPERIMENT_BASE_NAME, + Experiment, +} from "./config"; +import PromisePool from "@supercharge/promise-pool"; +import { makeGenerateMongoshCodePromptCompletionTask } from "../../../generateDriverCode/generateMongoshCodePromptCompletion"; +import { getOpenAiEndpointAndApiKey } from "mongodb-rag-core/models"; +import { makeExperimentName } from "../../../makeExperimentName"; + +async function main() { + await PromisePool.for(MODELS) + .withConcurrency(MAX_CONCURRENT_EXPERIMENTS) + .handleError((error, model) => { + console.error( + `Error running experiment for model ${model.label}:`, + JSON.stringify(error) + ); + }) + .process(async (model) => { + const llmOptions = makeLlmOptions(model); + const experiment: Experiment = { + model, + schemaStrategy: "annotated", + systemPromptStrategy: "default", + type: "promptCompletion", + }; + const experimentName = makeExperimentName({ + baseName: EXPERIMENT_BASE_NAME, + experimentType: experiment.type, + model: model.label, + systemPromptStrategy: experiment.systemPromptStrategy, + schemaStrategy: experiment.schemaStrategy, + }); + console.log(`Running experiment: ${experimentName}`); + + await makeTextToDriverEval({ + apiKey: BRAINTRUST_API_KEY, + projectName: PROJECT_NAME, + experimentName, + data: loadTextToDriverBraintrustEvalCases({ + apiKey: BRAINTRUST_API_KEY, + projectName: PROJECT_NAME, + datasetName: DATASET_NAME, + }), + maxConcurrency: model.maxConcurrency, + + task: makeGenerateMongoshCodePromptCompletionTask({ + uri: MONGODB_TEXT_TO_DRIVER_CONNECTION_URI, + databaseInfos: annotatedDbSchemas, + llmOptions, + systemPromptStrategy: experiment.systemPromptStrategy, + openai: wrapAISDKModel( + createOpenAI({ + ...(await getOpenAiEndpointAndApiKey(model)), + }).chat(llmOptions.model, { + structuredOutputs: true, + }) + ), + schemaStrategy: experiment.schemaStrategy, + }), + metadata: { + llmOptions, + ...experiment, + }, + scores: [SuccessfulExecution, ReasonableOutput], + }); + }); +} + +main(); diff --git a/packages/benchmarks/src/textToDriver/loadBraintrustDatasets.ts b/packages/benchmarks/src/textToDriver/loadBraintrustDatasets.ts index b0a335cba..0cee6f201 100644 --- a/packages/benchmarks/src/textToDriver/loadBraintrustDatasets.ts +++ b/packages/benchmarks/src/textToDriver/loadBraintrustDatasets.ts @@ -179,7 +179,12 @@ export async function loadTextToDriverBraintrustEvalCases({ apiKey, dataset: datasetName, }).fetchedData() - ).map((d) => TextToDriverEvalCaseSchema.parse(d)); + ).map((d) => { + if (d.tags === null) { + d.tags = undefined; + } + return TextToDriverEvalCaseSchema.parse(d); + }); return dataset; } diff --git a/packages/datasets/src/bin/cleanClaudeDatasets.ts b/packages/datasets/src/bin/cleanClaudeDatasets.ts new file mode 100644 index 000000000..0d18ea495 --- /dev/null +++ b/packages/datasets/src/bin/cleanClaudeDatasets.ts @@ -0,0 +1,118 @@ +import { + DatabaseNlQueryDatasetEntry, + convertDatabaseNlQueryDatasetEntryToBraintrust, + DatabaseNlQueryDatasetEntryBraintrust, + convertBraintrustDatabaseNlQueryDatasetEntryToFlat, +} from "../treeGeneration/databaseNlQueries/DatabaseNlQueryDatasetEntry"; +import * as fs from "fs"; +import * as path from "path"; +import * as readline from "readline"; +import { countAndLogUsage } from "../treeGeneration/databaseNlQueries/analyzeDataset"; + +async function readJsonlFile( + filePath: string +): Promise { + const entries: DatabaseNlQueryDatasetEntry[] = []; + const fileStream = fs.createReadStream(filePath); + const rl = readline.createInterface({ + input: fileStream, + crlfDelay: Infinity, + }); + + for await (const line of rl) { + if (line.trim()) { + try { + const entry = JSON.parse(line) as DatabaseNlQueryDatasetEntry; + entries.push(entry); + } catch (error) { + console.error(`Error parsing line in ${filePath}:`, error); + } + } + } + + return entries; +} + +async function main() { + // 1. Get dataset from dataOut/anthropic-big-jsonl + const dataOutDir = path.join(__dirname, "../../dataOut"); + const anthropicDir = path.join(dataOutDir, "anthropic-big-jsonl"); + + if (!fs.existsSync(anthropicDir)) { + console.error(`Directory not found: ${anthropicDir}`); + process.exit(1); + } + + // Get all .jsonl files + const jsonlFiles = fs + .readdirSync(anthropicDir) + .filter((file) => file.endsWith(".jsonl")) + .map((file) => path.join(anthropicDir, file)); + + console.log(`Found ${jsonlFiles.length} JSONL files to process`); + + // 2. For each file, load the array of DatabaseNlQueryDatasetEntry into memory + const allBraintrustEntries: DatabaseNlQueryDatasetEntryBraintrust[] = []; + + for (const filePath of jsonlFiles) { + const fileName = path.basename(filePath); + console.log(`Processing ${fileName}...`); + + const entries = await readJsonlFile(filePath); + console.log(` Loaded ${entries.length} entries`); + + // 3. Convert to Braintrust format and append filename to generationUuid + for (const entry of entries) { + const braintrustEntry = + convertDatabaseNlQueryDatasetEntryToBraintrust(entry); + + // Append the filename to the generationUuid field + if (braintrustEntry.metadata && braintrustEntry.metadata.generationUuid) { + braintrustEntry.metadata.generationUuid = `${braintrustEntry.metadata.generationUuid}_${fileName}`; + } + + allBraintrustEntries.push(braintrustEntry); + } + } + + console.log( + `\nTotal entries before deduplication: ${allBraintrustEntries.length}` + ); + + // 5. De-duplicate the array based on the NL query + const seenQueries = new Set(); + const deduplicatedEntries = allBraintrustEntries.filter((entry) => { + const nlQuery = entry.input?.nlQuery ?? ""; + if (seenQueries.has(nlQuery)) { + return false; + } + seenQueries.add(nlQuery); + return true; + }); + + console.log( + `Total entries after deduplication: ${deduplicatedEntries.length}` + ); + console.log( + `Removed ${ + allBraintrustEntries.length - deduplicatedEntries.length + } duplicate entries` + ); + + // 6. Write the array of braintrust entries to a new file + const outputPath = path.join(dataOutDir, "anthropic-big-braintrust.json"); + fs.writeFileSync(outputPath, JSON.stringify(deduplicatedEntries)); + + console.log( + `\nSuccessfully wrote ${deduplicatedEntries.length} entries to ${outputPath}` + ); + countAndLogUsage( + deduplicatedEntries.map(convertBraintrustDatabaseNlQueryDatasetEntryToFlat) + ); +} + +// Run the main function +main().catch((error) => { + console.error("Error in main:", error); + process.exit(1); +}); diff --git a/packages/datasets/src/bin/fixRewrittenDatasetFormatting.ts b/packages/datasets/src/bin/fixRewrittenDatasetFormatting.ts new file mode 100644 index 000000000..280b1411e --- /dev/null +++ b/packages/datasets/src/bin/fixRewrittenDatasetFormatting.ts @@ -0,0 +1,26 @@ +import path from "path"; +import fs from "fs"; +import { DatabaseNlQueryDatasetEntryBraintrust } from "../treeGeneration/databaseNlQueries/DatabaseNlQueryDatasetEntry"; +import { RewriteClassification } from "../treeGeneration/databaseNlQueries/rewriteNlQuery/rewriteNlQuery"; +async function main() { + const dataOutDir = path.join(__dirname, "..", "..", "dataOut"); + const datasetInPath = path.join( + dataOutDir, + "atlas-sample-dataset-claude-rewritten.1752073270894.json" + ); + + console.log("Reading dataset from", datasetInPath); + + const dataset = JSON.parse(fs.readFileSync(datasetInPath, "utf-8")) as { + classification: RewriteClassification; + datasetEntry: DatabaseNlQueryDatasetEntryBraintrust; + }[]; + const fixedDataset = dataset.map((d) => d.datasetEntry); + const pathOut = path.join( + dataOutDir, + "atlas-sample-dataset-claude-rewritten.json" + ); + console.log("Writing fixed dataset to", pathOut); + fs.writeFileSync(pathOut, JSON.stringify(fixedDataset, null, 2)); +} +main(); diff --git a/packages/datasets/src/bin/generateDbSchemas.ts b/packages/datasets/src/bin/generateDbSchemas.ts index 8af38d590..07114bb05 100644 --- a/packages/datasets/src/bin/generateDbSchemas.ts +++ b/packages/datasets/src/bin/generateDbSchemas.ts @@ -3,13 +3,12 @@ import yaml from "yaml"; import path from "path"; import fs from "fs"; import { MongoClient } from "mongodb-rag-core/mongodb"; -import { OpenAI } from "mongodb-rag-core/openai"; import { assertEnvVars } from "mongodb-rag-core"; -import { BRAINTRUST_ENV_VARS } from "mongodb-rag-core"; import { DATABASE_NL_QUERIES } from "../EnvVars"; import { generateAnnotatedDatabaseInfoNode } from "../treeGeneration/databaseNlQueries/databaseNodes/generateAnnotatedDatabaseInfo"; import { datasetDatabases } from "../treeGeneration/databaseNlQueries/datasetDatabases"; import { prettyPrintMongoDbDocument } from "mongodb-rag-core/executeCode"; +import { openAiClient } from "../openAi"; const dataOutDir = path.resolve(__dirname, "..", "..", "dataOut"); @@ -19,13 +18,8 @@ if (!fs.existsSync(dataOutDir)) { console.log(`Created directory: ${dataOutDir}`); } -const { - MONGODB_TEXT_TO_CODE_CONNECTION_URI, - BRAINTRUST_API_KEY, - BRAINTRUST_ENDPOINT, -} = assertEnvVars({ +const { MONGODB_TEXT_TO_CODE_CONNECTION_URI } = assertEnvVars({ ...DATABASE_NL_QUERIES, - ...BRAINTRUST_ENV_VARS, }); async function generateDbSchema(mongoClient: MongoClient, dbName: string) { @@ -40,15 +34,12 @@ async function generateDbSchema(mongoClient: MongoClient, dbName: string) { databaseName: dbName, numSamplesPerCollection: 2, }, - llm: { + llmOptions: { model, - openAiClient: new OpenAI({ - baseURL: BRAINTRUST_ENDPOINT, - apiKey: BRAINTRUST_API_KEY, - }), temperature: 0, max_completion_tokens: 2000, }, + openAiClient, }); fs.writeFileSync( diff --git a/packages/datasets/src/bin/generateMongoshDatasets.ts b/packages/datasets/src/bin/generateMongoshDatasets.ts index 89aeaa248..8ae7610f8 100644 --- a/packages/datasets/src/bin/generateMongoshDatasets.ts +++ b/packages/datasets/src/bin/generateMongoshDatasets.ts @@ -3,12 +3,13 @@ import { MongoClient } from "mongodb-rag-core/mongodb"; import { executeMongoshQuery, isReasonableResult, + LlmOptions, } from "mongodb-rag-core/executeCode"; import * as fs from "fs"; import * as path from "path"; import PromisePool from "@supercharge/promise-pool"; -import { OpenAI } from "mongodb-rag-core/openai"; -import { BRAINTRUST_ENV_VARS, assertEnvVars } from "mongodb-rag-core"; +import { openAiClient } from "../openAi"; +import { assertEnvVars } from "mongodb-rag-core"; import { DATABASE_NL_QUERIES } from "../EnvVars"; import { generateAnnotatedDatabaseInfoNode } from "../treeGeneration/databaseNlQueries/databaseNodes/generateAnnotatedDatabaseInfo"; import { generateDatabaseExecutionResult } from "../treeGeneration/databaseNlQueries/databaseNodes/generateDatabaseExecutionResult"; @@ -17,7 +18,6 @@ import { generateMongoshCode } from "../treeGeneration/databaseNlQueries/databas import { generateNaturalLanguageQueries } from "../treeGeneration/databaseNlQueries/databaseNodes/generateNaturalLanguageQueries"; import { generateDatabaseUseCases } from "../treeGeneration/databaseNlQueries/databaseNodes/generateUseCases"; import { makeMongoDbNodeStore } from "../treeGeneration/MongoDbNodeStore"; -import { LlmOptions } from "../treeGeneration/databaseNlQueries/databaseNodes/LlmOptions"; import { datasetDatabases } from "../treeGeneration/databaseNlQueries/datasetDatabases"; import { findMostFrequentAndPerformantDatabaseExecutionResult } from "../treeGeneration/databaseNlQueries/findMostFrequentAndPerformantDatabaseExecutionResult"; import { generateDatabaseNlQueryDatasetEntry } from "../treeGeneration/databaseNlQueries/DatabaseNlQueryDatasetEntry"; @@ -111,8 +111,9 @@ async function generateMongoshDataset({ console.log(`Generating database info for database ${dataset.databaseName}`); const databaseInfoNode = await generateAnnotatedDatabaseInfoNode({ mongoDb: dataset, - llm: llmConfigs.database.llmConfig, + llmOptions: llmConfigs.database.llmConfig, latestDate: dataset.latestDate, + openAiClient, }); await nodeStore.storeNodes({ nodes: [databaseInfoNode] }); @@ -139,7 +140,7 @@ async function generateMongoshDataset({ ); await nodeStore.storeNodes({ nodes: useCases }); console.log( - `Generated ${useCases.length} use cases for ${userNode.data.name}, ${userNode.data.jobTitle}` + `Generated ${useCases.length} use cases for ${userNode.data.name}, ${userNode.data.role}` ); return useCases; }); @@ -271,12 +272,7 @@ async function generateMongoshDataset({ async function main() { // Set up - const { - BRAINTRUST_API_KEY, - BRAINTRUST_ENDPOINT, - MONGODB_TEXT_TO_CODE_CONNECTION_URI, - } = assertEnvVars({ - ...BRAINTRUST_ENV_VARS, + const { MONGODB_TEXT_TO_CODE_CONNECTION_URI } = assertEnvVars({ ...DATABASE_NL_QUERIES, }); const mongoClient = new MongoClient(MONGODB_TEXT_TO_CODE_CONNECTION_URI); @@ -290,10 +286,6 @@ async function main() { } const defaultLlmConfig: LlmOptions = { - openAiClient: new OpenAI({ - apiKey: BRAINTRUST_API_KEY, - baseURL: BRAINTRUST_ENDPOINT, - }), model: "gpt-4o", temperature: 0.7, seed: 42, diff --git a/packages/datasets/src/bin/generateMongoshDatasetsClaude.ts b/packages/datasets/src/bin/generateMongoshDatasetsClaude.ts new file mode 100644 index 000000000..b58c251ac --- /dev/null +++ b/packages/datasets/src/bin/generateMongoshDatasetsClaude.ts @@ -0,0 +1,390 @@ +import "dotenv/config"; +import { MongoClient } from "mongodb-rag-core/mongodb"; +import { + executeMongoshQuery, + isReasonableResult, +} from "mongodb-rag-core/executeCode"; +import * as fs from "fs"; +import * as path from "path"; +import PromisePool from "@supercharge/promise-pool"; +import { BRAINTRUST_ENV_VARS, assertEnvVars } from "mongodb-rag-core"; +import { DATABASE_NL_QUERIES } from "../EnvVars"; +import { generateAnnotatedDatabaseInfoNode } from "../treeGeneration/databaseNlQueries/databaseNodes/generateAnnotatedDatabaseInfo"; +import { generateDatabaseExecutionResult } from "../treeGeneration/databaseNlQueries/databaseNodes/generateDatabaseExecutionResult"; +import { generateDatabaseUsers } from "../treeGeneration/databaseNlQueries/databaseNodes/generateDatabaseUsers"; +import { generateMongoshCode } from "../treeGeneration/databaseNlQueries/databaseNodes/generateMongoshCode"; +import { generateNaturalLanguageQueries } from "../treeGeneration/databaseNlQueries/databaseNodes/generateNaturalLanguageQueries"; +import { generateDatabaseUseCases } from "../treeGeneration/databaseNlQueries/databaseNodes/generateUseCases"; +import { makeMongoDbNodeStore } from "../treeGeneration/MongoDbNodeStore"; +import { datasetDatabases } from "../treeGeneration/databaseNlQueries/datasetDatabases"; +import { findMostFrequentAndPerformantDatabaseExecutionResult } from "../treeGeneration/databaseNlQueries/findMostFrequentAndPerformantDatabaseExecutionResult"; +import { generateDatabaseNlQueryDatasetEntry } from "../treeGeneration/databaseNlQueries/DatabaseNlQueryDatasetEntry"; +import { initLogger } from "mongodb-rag-core/braintrust"; +import { openAiClient } from "../openAi"; +import { GenerateChildrenLlmOptions } from "../treeGeneration/generateChildren"; + +const DEFAULT_CONCURRENCY = 16; + +/** + Magic number to specify the max results array size to evaluate. + */ +const MAX_RESULT_ARRAY_SIZE = 20; + +type LlmGenerationConfig = { + database: { + llmConfig: GenerateChildrenLlmOptions; + }; + users: { + llmConfig: GenerateChildrenLlmOptions; + numGenerations: number; + concurrency: number; + }; + useCases: { + llmConfig: GenerateChildrenLlmOptions; + numGenerations: number; + concurrency: number; + }; + nlQueries: { + llmConfig: GenerateChildrenLlmOptions; + numGenerations: number; + concurrency: number; + }; + dbQueries: { + llmConfig: GenerateChildrenLlmOptions; + numGenerations: number; + concurrency: number; + }; + dbExecutions: { + concurrency: number; + }; +}; +interface GenerateMongoshDatasetParams { + persistence: { + mongoClient: MongoClient; + databaseName: string; + collectionName: string; + }; + dataset: { + databaseName: string; + numSamplesPerCollection: number; + connectionUri: string; + latestDate: Date; + mongoClient: MongoClient; + }; + llmConfigs: LlmGenerationConfig; + datasetUuid: string; + writeToFile: { + dataOutDir: string; + }; + maxResultsArraySize?: number; + minClusterSize?: number; +} + +async function generateMongoshDataset({ + persistence, + dataset, + llmConfigs, + datasetUuid, + writeToFile, + maxResultsArraySize = MAX_RESULT_ARRAY_SIZE, + minClusterSize, +}: GenerateMongoshDatasetParams) { + console.log(`Generating dataset for database ${dataset.databaseName}`); + const datasetOutDir = path.resolve(writeToFile.dataOutDir, datasetUuid); + if (!fs.existsSync(datasetOutDir)) { + fs.mkdirSync(datasetOutDir, { recursive: true }); + console.log(`Created directory: ${datasetOutDir}`); + } + const referenceAnswersOutputPath = path.resolve( + datasetOutDir, + `referenceAnswers.dataset_${datasetUuid}.jsonl` + ); + // Write out each DB's dataset to a separate file + const textToMqlOutputPath = path.resolve( + datasetOutDir, + `text_to_mongosh.dataset_${datasetUuid}.${dataset.databaseName}.jsonl` + ); + + console.log( + `Writing data out to DB ${persistence.databaseName}.${persistence.collectionName}` + ); + + const nodeStore = makeMongoDbNodeStore(persistence); + + console.log(`Generating database info for database ${dataset.databaseName}`); + const databaseInfoNode = await generateAnnotatedDatabaseInfoNode({ + mongoDb: dataset, + llmOptions: llmConfigs.database.llmConfig, + openAiClient, + latestDate: dataset.latestDate, + }); + await nodeStore.storeNodes({ nodes: [databaseInfoNode] }); + + // Generate database users + console.log("Generating database users..."); + const userNodes = await generateDatabaseUsers( + databaseInfoNode, + llmConfigs.users.llmConfig, + llmConfigs.users.numGenerations + ); + await nodeStore.storeNodes({ nodes: userNodes }); + + console.log(`Generated ${userNodes.length} database users:`); + console.log( + JSON.stringify( + userNodes.map(({ data }) => data), + null, + 2 + ) + ); + + // Generate use cases for each user + console.log("Generating use cases for each user..."); + const { results: useCaseNodesByUser } = await PromisePool.for(userNodes) + .withConcurrency(llmConfigs.useCases.concurrency ?? 5) + .process(async (userNode) => { + const useCases = await generateDatabaseUseCases( + userNode, + llmConfigs.useCases.llmConfig, + llmConfigs.useCases.numGenerations + ); + await nodeStore.storeNodes({ nodes: useCases }); + console.log( + `Generated ${useCases.length} use cases for ${userNode.data.name}, ${userNode.data.role}` + ); + console.log( + useCases.map( + ({ data }, i) => + `${i + 1}: ${data.title}: ${data.description.slice(0, 20)}...` + ) + ); + return useCases; + }); + + const useCaseNodes = useCaseNodesByUser.flat(); + console.log(`Created ${useCaseNodes.length} use cases.`); + + console.log("Generating natural language queries for each use case..."); + + // Process use cases in parallel with limited concurrency + const { results: nlQueryNodesByUseCase } = await PromisePool.withConcurrency( + llmConfigs.nlQueries.concurrency ?? DEFAULT_CONCURRENCY + ) + .for(useCaseNodes) + .handleError((err) => { + console.error(err); + }) + .process(async (useCaseNode) => { + const nlQueries = await generateNaturalLanguageQueries( + useCaseNode, + llmConfigs.nlQueries.llmConfig, + llmConfigs.nlQueries.numGenerations + ); + await nodeStore.storeNodes({ + nodes: nlQueries, + }); + console.log( + `Generated ${nlQueries.length} NL queries for use case: ${useCaseNode.data.title}` + ); + + return nlQueries; + }); + const nlQueryNodes = nlQueryNodesByUseCase.flat(); + + // Generate triplets for the NL queries + console.log("Generating query nodes for the NL queries..."); + const { results: dbQCodeNodesByNlQuery } = await PromisePool.for(nlQueryNodes) + .withConcurrency(llmConfigs.dbQueries.concurrency ?? DEFAULT_CONCURRENCY) + .process(async (nlQueryNode) => { + const dbCodeNodes = await generateMongoshCode( + nlQueryNode, + llmConfigs.dbQueries.llmConfig, + llmConfigs.dbQueries.numGenerations + ); + await nodeStore.storeNodes({ nodes: dbCodeNodes }); + + console.log( + `Generated ${dbCodeNodes.length} DB queries for NL query: ${nlQueryNode.data.query}` + ); + return dbCodeNodes; + }); + for (const dbCodeNodes of dbQCodeNodesByNlQuery) { + const { results: dbExecutions } = await PromisePool.for(dbCodeNodes) + .withConcurrency( + llmConfigs.dbExecutions.concurrency ?? DEFAULT_CONCURRENCY + ) + .process(async (dbCodeNode) => { + const dbExecution = await generateDatabaseExecutionResult({ + database: { + name: dataset.databaseName, + uri: dataset.connectionUri, + }, + generatedQuery: dbCodeNode, + executor: executeMongoshQuery, + }); + if ( + Array.isArray(dbExecution.data.result) && + dbExecution.data.result?.length > maxResultsArraySize + ) { + throw new Error("Result array is too large to process."); + } + console.log( + `Generated DB execution: ${dbExecution.data.result + ?.toString() + .slice(0, 20)} ...` + ); + const { success, reason } = isReasonableResult(dbExecution.data.result); + if (!success) { + throw new Error("Result is not reasonable. Reason: " + reason); + } + return dbExecution; + }); + console.log(`Generated ${dbExecutions.length} DB executions.`); + + try { + // Find the most frequent and performant database execution result + const { fastestMostFrequentIndex } = + findMostFrequentAndPerformantDatabaseExecutionResult( + dbExecutions.map((node) => node.data), + minClusterSize + ); + if ( + fastestMostFrequentIndex !== null && + dbExecutions[fastestMostFrequentIndex].data.result !== null + ) { + const dbResult = dbExecutions[fastestMostFrequentIndex].data.result; + if ( + (Array.isArray(dbResult) && dbResult.length > 0) || + !Array.isArray(dbResult) + ) { + dbExecutions[fastestMostFrequentIndex].data.isReferenceAnswer = true; + } + } + + await nodeStore.storeNodes({ nodes: dbExecutions }); + console.log(`Writing data out to ${textToMqlOutputPath}`); + for (const dbExecution of dbExecutions) { + const textToMqlDatasetEntry = + generateDatabaseNlQueryDatasetEntry(dbExecution); + fs.appendFileSync( + textToMqlOutputPath, + JSON.stringify(textToMqlDatasetEntry) + "\n" + ); + if (dbExecution.data.isReferenceAnswer) { + fs.appendFileSync( + referenceAnswersOutputPath, + JSON.stringify(textToMqlDatasetEntry) + "\n" + ); + } + } + } catch (error) { + console.error(error); + } + + console.log( + `Successfully wrote ${dbCodeNodes.length} nodes to ${textToMqlOutputPath}` + ); + } +} + +async function main() { + // Set up + const { BRAINTRUST_API_KEY, MONGODB_TEXT_TO_CODE_CONNECTION_URI } = + assertEnvVars({ + ...BRAINTRUST_ENV_VARS, + ...DATABASE_NL_QUERIES, + }); + + initLogger({ + projectName: "generate-mongosh-dataset-claude", + apiKey: BRAINTRUST_API_KEY, + }); + const mongoClient = new MongoClient(MONGODB_TEXT_TO_CODE_CONNECTION_URI); + + const dataOutDir = path.resolve(__dirname, "..", "..", "dataOut"); + + // Validate that dataOutDir exists. Create if it doesn't + if (!fs.existsSync(dataOutDir)) { + fs.mkdirSync(dataOutDir, { recursive: true }); + console.log(`Created directory: ${dataOutDir}`); + } + + const defaultLlmConfig: GenerateChildrenLlmOptions = { + model: "claude-4-opus-20250514", + temperature: 0.7, + seed: 42, + }; + + // One point to control generations at each level. + // Useful for debugging. + const config = { + database: { + llmConfig: { + ...defaultLlmConfig, + temperature: 0, // Make this one deterministic + }, + }, + users: { + numGenerations: 8, + llmConfig: defaultLlmConfig, + concurrency: DEFAULT_CONCURRENCY, + }, + useCases: { + numGenerations: 8, + llmConfig: defaultLlmConfig, + concurrency: DEFAULT_CONCURRENCY, + }, + nlQueries: { + numGenerations: 8, + llmConfig: defaultLlmConfig, + concurrency: DEFAULT_CONCURRENCY, + }, + dbQueries: { + numGenerations: 8, + llmConfig: { + ...defaultLlmConfig, + __claudeMaxConcurrency: 1, + __claudeTemperatureVariation: 0.01, + }, + concurrency: DEFAULT_CONCURRENCY, + }, + dbExecutions: { + concurrency: 20, + }, + } as const satisfies LlmGenerationConfig; + + // Runnit + try { + const now = Date.now(); + await mongoClient.connect(); + // TODO: 2nd dataset, weather data throwing err. need to investigate + for (const db of datasetDatabases) { + await generateMongoshDataset({ + persistence: { + mongoClient, + databaseName: "mongosh_datasets", + collectionName: db.name, + }, + dataset: { + databaseName: db.name, + numSamplesPerCollection: 2, + mongoClient, + latestDate: db.latestDate, + connectionUri: MONGODB_TEXT_TO_CODE_CONNECTION_URI, + }, + llmConfigs: config, + datasetUuid: `${defaultLlmConfig.model}_temp_${ + defaultLlmConfig.temperature + }_${now.toString()}`, + writeToFile: { + dataOutDir, + }, + // If we have 10 generations, we want to find the most frequent and performant result that has at least 4 generations + minClusterSize: Math.ceil(config.dbQueries.numGenerations / 3), + }); + } + } finally { + await mongoClient.close(); + } +} +main(); diff --git a/packages/datasets/src/bin/rewriteNlToMongoshDatasetQueries.ts b/packages/datasets/src/bin/rewriteNlToMongoshDatasetQueries.ts new file mode 100644 index 000000000..107a7f8aa --- /dev/null +++ b/packages/datasets/src/bin/rewriteNlToMongoshDatasetQueries.ts @@ -0,0 +1,108 @@ +import { makeRewriteNlQueryPrompt } from "../treeGeneration/databaseNlQueries/rewriteNlQuery/rewriteNlQuery"; +import { openAiProvider } from "../openAi"; +import { wrapAISDKModel } from "mongodb-rag-core/braintrust"; +import { models } from "mongodb-rag-core/models"; +import { DatabaseNlQueryDatasetEntryBraintrustSchema } from "../treeGeneration/databaseNlQueries/DatabaseNlQueryDatasetEntry"; +import yaml from "yaml"; +import PromisePool from "@supercharge/promise-pool"; +import fs from "fs"; +import path from "path"; + +async function runRewriteNlQueryToMongoshDataset(config: { + modelDeployment: (typeof models)[number]["deployment"]; + maxConcurrency: number; + generationId: string; + shuffle?: boolean; + limit?: number; +}) { + // Note: maybe switch to Opus for the real run + const modelDeployment = config.modelDeployment; + + const model = wrapAISDKModel(openAiProvider(modelDeployment)); + + const rewriteNlQueryPrompt = makeRewriteNlQueryPrompt(model); + + const dataOutDir = path.join(__dirname, "..", "..", "dataOut"); + // rewriting the latest rewrite + const datasetInPath = path.join( + dataOutDir, + "atlas-sample-dataset-claude-rewritten.v1.json" + ); + + const intermediateDatasetOutPath = path.join( + dataOutDir, + `atlas-sample-dataset-claude-rewritten.${config.generationId}.yaml` + ); + const datasetOutPath = path.join( + dataOutDir, + `atlas-sample-dataset-claude-rewritten.${config.generationId}.json` + ); + + let datasetEntries = config.shuffle + ? shuffle( + DatabaseNlQueryDatasetEntryBraintrustSchema.array().parse( + JSON.parse(fs.readFileSync(datasetInPath, "utf-8")) + ) + ) + : DatabaseNlQueryDatasetEntryBraintrustSchema.array().parse( + JSON.parse(fs.readFileSync(datasetInPath, "utf-8")) + ); + if (config.limit) { + datasetEntries = datasetEntries.slice(0, config.limit); + } + + console.log("Processing", datasetEntries.length, "dataset entries"); + let start = 1; + + console.log("Writing intermediate results to", intermediateDatasetOutPath); + const { results } = await PromisePool.for(datasetEntries) + .withConcurrency(config.maxConcurrency) + .handleError((error) => { + console.error(error); + }) + .process(async (entry) => { + console.log(`Processing entry ${start++}/${datasetEntries.length}. +Entry NL query: ${entry.input.nlQuery}`); + const result = await rewriteNlQueryPrompt(entry); + fs.appendFileSync( + intermediateDatasetOutPath, + yaml.stringify([result.datasetEntry]) + ); + return result; + }); + + console.log(results.length, "total result(s)"); + console.log( + results.filter( + (result) => result.classification.classification === "ambiguous" + ).length, + "ambiguous result(s)" + ); + + console.log("Writing full dataset to", datasetOutPath); + fs.writeFileSync( + datasetOutPath, + JSON.stringify( + results.map((r) => r.datasetEntry), + null, + 2 + ) + ); +} + +function shuffle(items: T[]) { + return items.sort(() => Math.random() - 0.5); +} + +async function main() { + const config = { + modelDeployment: "claude-opus-4-20250514", + maxConcurrency: 15, + generationId: Date.now().toString(), + // shuffle: true, + // limit: 200, + } satisfies Parameters[0]; + await runRewriteNlQueryToMongoshDataset(config); +} + +main(); diff --git a/packages/datasets/src/openAi.ts b/packages/datasets/src/openAi.ts index c99c650a4..8b5e02359 100644 --- a/packages/datasets/src/openAi.ts +++ b/packages/datasets/src/openAi.ts @@ -1,18 +1,21 @@ -import { assertEnvVars } from "mongodb-rag-core"; -import { OPENAI_ENV_VARS } from "./EnvVars"; -import { AzureOpenAI } from "mongodb-rag-core/openai"; +import { assertEnvVars, BRAINTRUST_ENV_VARS } from "mongodb-rag-core"; +import { OpenAI } from "mongodb-rag-core/openai"; +import { wrapOpenAI, wrapAISDKModel } from "mongodb-rag-core/braintrust"; +import { createOpenAI } from "mongodb-rag-core/aiSdk"; -const { - OPENAI_API_KEY, - OPENAI_API_VERSION, - OPENAI_CHAT_COMPLETION_DEPLOYMENT, - OPENAI_ENDPOINT, -} = assertEnvVars(OPENAI_ENV_VARS); +const { BRAINTRUST_API_KEY, BRAINTRUST_ENDPOINT } = + assertEnvVars(BRAINTRUST_ENV_VARS); -export const openAiClient = new AzureOpenAI({ - apiKey: OPENAI_API_KEY, - apiVersion: OPENAI_API_VERSION, - endpoint: OPENAI_ENDPOINT, +export const openAiClient = wrapOpenAI( + new OpenAI({ + apiKey: BRAINTRUST_API_KEY, + baseURL: BRAINTRUST_ENDPOINT, + }) +); + +export const openAiProvider = createOpenAI({ + apiKey: BRAINTRUST_API_KEY, + baseURL: BRAINTRUST_ENDPOINT, }); -export const model = OPENAI_CHAT_COMPLETION_DEPLOYMENT; +export const model = "gpt-4o"; diff --git a/packages/datasets/src/treeGeneration/databaseNlQueries/DatabaseNlQueryDatasetEntry.ts b/packages/datasets/src/treeGeneration/databaseNlQueries/DatabaseNlQueryDatasetEntry.ts index 1c8cf5b70..20049941f 100644 --- a/packages/datasets/src/treeGeneration/databaseNlQueries/DatabaseNlQueryDatasetEntry.ts +++ b/packages/datasets/src/treeGeneration/databaseNlQueries/DatabaseNlQueryDatasetEntry.ts @@ -78,7 +78,7 @@ export const DatabaseNlQueryDatasetEntryBraintrustSchema = z.object({ executionTimeMs: z.number().nullable(), }), - tags: z.array(z.string()), + tags: z.array(z.string()).nullable(), metadata: z.object({ methods: z.array(z.string()).optional(), queryOperators: z.array(z.string()).optional(), diff --git a/packages/datasets/src/treeGeneration/databaseNlQueries/databaseMetadata/generateAnnotatedCollectionSchema.ts b/packages/datasets/src/treeGeneration/databaseNlQueries/databaseMetadata/generateAnnotatedCollectionSchema.ts index c529f1bcc..aa94c7c1b 100644 --- a/packages/datasets/src/treeGeneration/databaseNlQueries/databaseMetadata/generateAnnotatedCollectionSchema.ts +++ b/packages/datasets/src/treeGeneration/databaseNlQueries/databaseMetadata/generateAnnotatedCollectionSchema.ts @@ -1,9 +1,10 @@ import { OpenAI } from "mongodb-rag-core/openai"; import { CollectionInfo, DatabaseMetadata } from "./getDatabaseMetadata"; -import { LlmOptions } from "../databaseNodes/LlmOptions"; +import { LlmOptions } from "mongodb-rag-core/executeCode"; import { z } from "zod"; import { getOpenAiFunctionResponse } from "./getOpenAiFunctionResponse"; import { prettyPrintMongoDbDocument } from "./prettyPrintMongoDbDocument"; +import { wrapTraced } from "mongodb-rag-core/braintrust"; const systemPrompt = `You are an expert MongoDB database architect. Your task is to analyze the provided database metadata and generate clear, concise descriptions and an annotated schema for the specified collection. The descriptions that you generate will be used in the prompt of a LLM for performing database-related tasks. @@ -57,36 +58,41 @@ interface GenerateAnnotatedCollectionSchemaParams { llm: LlmOptions; } -export async function generateAnnotatedCollectionSchema({ - collectionMetadata, - databaseMetadata, - llm: llmOptions, -}: GenerateAnnotatedCollectionSchemaParams): Promise { - const messages = [ - { - role: "system", - content: systemPrompt, - }, - { - role: "user", - content: `Analyze the following collection: '${ - collectionMetadata.collectionName - }'. +export const makeGenerateAnnotatedCollectionSchema = (openAiClient: OpenAI) => + wrapTraced( + async function generateAnnotatedCollectionSchema({ + collectionMetadata, + databaseMetadata, + llm: llmOptions, + }: GenerateAnnotatedCollectionSchemaParams): Promise { + const messages = [ + { + role: "system", + content: systemPrompt, + }, + { + role: "user", + content: `Analyze the following collection: '${ + collectionMetadata.collectionName + }'. Database metadata: ${prettyPrintMongoDbDocument(databaseMetadata)} Again, analyze the collection named '${collectionMetadata.collectionName}'.`, - }, - ] satisfies OpenAI.ChatCompletionMessageParam[]; + }, + ] satisfies OpenAI.ChatCompletionMessageParam[]; - const result = await getOpenAiFunctionResponse({ - messages, - llmOptions, - schema: makeDetailedCollectionDescriptionSchema(collectionMetadata), - functionName, - functionDescription, - }); + const result = await getOpenAiFunctionResponse({ + messages, + llmOptions, + schema: makeDetailedCollectionDescriptionSchema(collectionMetadata), + functionName, + functionDescription, + openAiClient, + }); - return result; -} + return result; + }, + { name: "generateAnnotatedCollectionSchema" } + ); diff --git a/packages/datasets/src/treeGeneration/databaseNlQueries/databaseMetadata/generateHighLevelDbDescriptions.ts b/packages/datasets/src/treeGeneration/databaseNlQueries/databaseMetadata/generateHighLevelDbDescriptions.ts index 6544d1cca..87515ab88 100644 --- a/packages/datasets/src/treeGeneration/databaseNlQueries/databaseMetadata/generateHighLevelDbDescriptions.ts +++ b/packages/datasets/src/treeGeneration/databaseNlQueries/databaseMetadata/generateHighLevelDbDescriptions.ts @@ -1,8 +1,8 @@ import { z } from "zod"; import { DatabaseMetadata } from "./getDatabaseMetadata"; -import { LlmOptions } from "../databaseNodes/LlmOptions"; import { getOpenAiFunctionResponse } from "./getOpenAiFunctionResponse"; import { prettyPrintMongoDbDocument } from "./prettyPrintMongoDbDocument"; +import { LlmOptions } from "mongodb-rag-core/executeCode"; import { OpenAI } from "mongodb-rag-core/openai"; const systemPrompt = `You are an expert MongoDB database architect. Your task is to analyze the provided database metadata and generate clear, concise descriptions. @@ -53,7 +53,8 @@ function createHighLevelDbDescriptionsSchema( */ export async function generateHighLevelDbDescriptions( databaseMetadata: DatabaseMetadata, - llmOptions: LlmOptions + llmOptions: LlmOptions, + openAiClient: OpenAI ) { const schema = createHighLevelDbDescriptionsSchema(databaseMetadata); @@ -69,6 +70,7 @@ ${prettyPrintMongoDbDocument(databaseMetadata)}`, }, ] satisfies OpenAI.ChatCompletionMessageParam[]; return await getOpenAiFunctionResponse({ + openAiClient, messages, llmOptions, schema, diff --git a/packages/datasets/src/treeGeneration/databaseNlQueries/databaseMetadata/getOpenAiFunctionResponse.ts b/packages/datasets/src/treeGeneration/databaseNlQueries/databaseMetadata/getOpenAiFunctionResponse.ts index f04d158c5..21c76c8e8 100644 --- a/packages/datasets/src/treeGeneration/databaseNlQueries/databaseMetadata/getOpenAiFunctionResponse.ts +++ b/packages/datasets/src/treeGeneration/databaseNlQueries/databaseMetadata/getOpenAiFunctionResponse.ts @@ -1,6 +1,6 @@ import { OpenAI } from "mongodb-rag-core/openai"; import { ZodSchema, z } from "zod"; -import { LlmOptions } from "../databaseNodes/LlmOptions"; +import { LlmOptions } from "mongodb-rag-core/executeCode"; import zodToJsonSchema from "zod-to-json-schema"; interface GetOpenAiFunctionResponseParams { @@ -9,6 +9,7 @@ interface GetOpenAiFunctionResponseParams { schema: Schema; functionName: string; functionDescription?: string; + openAiClient: OpenAI; } export async function getOpenAiFunctionResponse({ @@ -17,11 +18,12 @@ export async function getOpenAiFunctionResponse({ schema, functionName, functionDescription, + openAiClient, }: GetOpenAiFunctionResponseParams): Promise> { const parameters = zodToJsonSchema(schema, { $refStrategy: "none", }); - const { openAiClient, ...createChatCompletionParams } = llmOptions; + const { ...createChatCompletionParams } = llmOptions; const res = await openAiClient.chat.completions.create({ messages, ...createChatCompletionParams, diff --git a/packages/datasets/src/treeGeneration/databaseNlQueries/databaseNodes/LlmOptions.ts b/packages/datasets/src/treeGeneration/databaseNlQueries/databaseNodes/LlmOptions.ts deleted file mode 100644 index 878c6103e..000000000 --- a/packages/datasets/src/treeGeneration/databaseNlQueries/databaseNodes/LlmOptions.ts +++ /dev/null @@ -1,8 +0,0 @@ -import { OpenAI } from "mongodb-rag-core/openai"; - -export type LlmOptions = { - openAiClient: OpenAI; -} & Omit< - OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming, - "messages" | "tools" | "tool_choice" ->; diff --git a/packages/datasets/src/treeGeneration/databaseNlQueries/databaseNodes/generateAnnotatedDatabaseInfo.ts b/packages/datasets/src/treeGeneration/databaseNlQueries/databaseNodes/generateAnnotatedDatabaseInfo.ts index 79b34f43b..498ee6f01 100644 --- a/packages/datasets/src/treeGeneration/databaseNlQueries/databaseNodes/generateAnnotatedDatabaseInfo.ts +++ b/packages/datasets/src/treeGeneration/databaseNlQueries/databaseNodes/generateAnnotatedDatabaseInfo.ts @@ -1,7 +1,8 @@ import { MongoClient, ObjectId } from "mongodb-rag-core/mongodb"; import { DatabaseInfoNode } from "./nodeTypes"; -import { LlmOptions } from "./LlmOptions"; +import { LlmOptions } from "mongodb-rag-core/executeCode"; import { generateAnnotatedDatabaseInfo } from "mongodb-rag-core/executeCode"; +import { OpenAI } from "mongodb-rag-core/openai"; export interface GenerateAnnotatedDatabaseInfoParams { mongoDb: { @@ -10,18 +11,21 @@ export interface GenerateAnnotatedDatabaseInfoParams { numSamplesPerCollection?: number; }; latestDate?: Date; - llm: LlmOptions; + llmOptions: LlmOptions; + openAiClient: OpenAI; } export async function generateAnnotatedDatabaseInfoNode({ mongoDb: { mongoClient, databaseName, numSamplesPerCollection = 2 }, latestDate = new Date(), - llm, + llmOptions, + openAiClient, }: GenerateAnnotatedDatabaseInfoParams): Promise { const annotatedDatabaseInfo = await generateAnnotatedDatabaseInfo({ mongoDb: { mongoClient, databaseName, numSamplesPerCollection }, latestDate, - llm, + llmOptions, + openAiClient, }); // Return the final DatabaseInfoNode return { diff --git a/packages/datasets/src/treeGeneration/databaseNlQueries/databaseNodes/generateDatabaseExecutionResult.ts b/packages/datasets/src/treeGeneration/databaseNlQueries/databaseNodes/generateDatabaseExecutionResult.ts index f56ab18d5..903070bfa 100644 --- a/packages/datasets/src/treeGeneration/databaseNlQueries/databaseNodes/generateDatabaseExecutionResult.ts +++ b/packages/datasets/src/treeGeneration/databaseNlQueries/databaseNodes/generateDatabaseExecutionResult.ts @@ -5,6 +5,7 @@ import { extractMongoDbMethods, extractMongoDbQueryOperators, } from "mongodb-rag-core/executeCode"; +import { wrapTraced } from "mongodb-rag-core/braintrust"; export interface ExecuteGeneratedQueryParams { generatedQuery: DatabaseCodeNode; @@ -18,31 +19,36 @@ export interface ExecuteGeneratedQueryParams { /** Execute the generated query. */ -export async function generateDatabaseExecutionResult({ - generatedQuery, - database, - executor, -}: ExecuteGeneratedQueryParams): Promise { - const query = generatedQuery.data.code; - const { uri, name: databaseName } = database; +export const generateDatabaseExecutionResult = wrapTraced( + async function ({ + generatedQuery, + database, + executor, + }: ExecuteGeneratedQueryParams): Promise { + const query = generatedQuery.data.code; + const { uri, name: databaseName } = database; - const executionResult = await executor({ - query, - uri, - databaseName, - }); + const executionResult = await executor({ + query, + uri, + databaseName, + }); - const metadata = { - queryOperators: extractMongoDbQueryOperators(query), - methods: extractMongoDbMethods(query), - }; + const metadata = { + queryOperators: extractMongoDbQueryOperators(query), + methods: extractMongoDbMethods(query), + }; - const data = { ...executionResult, ...metadata }; - return { - _id: new ObjectId(), - parent: generatedQuery, - type: "database_execution_result", - updated: new Date(), - data, - }; -} + const data = { ...executionResult, ...metadata }; + return { + _id: new ObjectId(), + parent: generatedQuery, + type: "database_execution_result", + updated: new Date(), + data, + }; + }, + { + name: "generateDatabaseExecutionResult", + } +); diff --git a/packages/datasets/src/treeGeneration/databaseNlQueries/databaseNodes/generateDatabaseUsers.ts b/packages/datasets/src/treeGeneration/databaseNlQueries/databaseNodes/generateDatabaseUsers.ts index 7ff7194ba..258d5c57f 100644 --- a/packages/datasets/src/treeGeneration/databaseNlQueries/databaseNodes/generateDatabaseUsers.ts +++ b/packages/datasets/src/treeGeneration/databaseNlQueries/databaseNodes/generateDatabaseUsers.ts @@ -1,37 +1,114 @@ +import { wrapTraced } from "mongodb-rag-core/braintrust"; import { makeGenerateChildrenWithOpenAi } from "../../generateChildren"; import { DatabaseInfoNode, DatabaseUserNode, DatabaseUserSchema, } from "./nodeTypes"; +import { openAiClient } from "../../../openAi"; -export const generateDatabaseUsers = makeGenerateChildrenWithOpenAi< - DatabaseInfoNode, - DatabaseUserNode ->({ - childType: "database_user", - makePromptMessages: async (parent, numChildren) => [ - { - role: "system", - content: `You are an experienced database administrator and organizational psychologist who specializes in modeling realistic user personas for enterprise systems. Given the database context, create diverse users who might interact with applications built on this database. Take into account specific user personas, jobs, roles, and backgrounds of the industry for a given database. -- Be sure to include industry-specifics to enrich the quality of the dataset. -Also include general jobs that would still be relevant for the database, with the balance toward industry-specific roles. -- Each user should have varying roles, expertise levels, and backgrounds that make sense for the database's domain. -- You can include both internal and external users, depending on the context. - -Focus on creating realistic and diverse personas that would actually use this type of database. +export const generateDatabaseUsers = wrapTraced( + makeGenerateChildrenWithOpenAi({ + openAiClient, + childType: "database_user", + makePromptMessages: async (parent, numChildren) => [ + { + role: "system", + content: `You are an experienced database administrator and organizational psychologist who specializes in modeling realistic user personas for enterprise systems. Given the database context, create diverse users who might interact with this database in ANY form. + + +Users can interact with the database in many ways: +- Direct database queries (developers, engineers, data scientists) +- Through internal business applications (employees, managers, analysts) +- Through consumer-facing applications (end users, customers, enthusiasts) +- Through analytical tools (business analysts, researchers, executives) +- Through third-party integrations (partners, vendors, external systems) + + + +Create a diverse mix of users that reflects real-world usage: +- Industry professionals specific to the database domain (e.g., for e-commerce: store managers, inventory specialists, suppliers) +- End consumers/enthusiasts (e.g., for e-commerce: online shoppers, bargain hunters, frequent buyers) +- Business/operational roles (managers, analysts, support staff) +- Technical roles (developers, data engineers, system administrators) +- External stakeholders (partners, vendors, researchers) +- Academic/educational users (students, teachers, researchers) + + +- Use names from diverse cultural backgrounds +- Include a mix of traditional and modern names +- Ensure gender balance across all role types +- Avoid stereotypical name-role associations + + + + + +- Aim for roughly 30% domain-specific professionals +- 30% end users/consumers +- 20% business/analytical roles +- 20% technical/development roles +- Include users from different regions, age groups, and backgrounds +- Consider both frequent power users and occasional casual users + + + +- Vary the experience levels (entry-level to senior/expert) +- Include some users with specific constraints or special needs +- Mix internal company users with external users +- Include users from different time zones if relevant +- Consider users with different access patterns (daily, weekly, seasonal) +- Add users who might misuse or stress the system (within ethical bounds) +- Include users from smaller/niche segments, not just mainstream + + + + +Some examples of users types for different types of databases: + + +For an e-commerce database, users might include: +- Store managers, inventory specialists, suppliers, merchandisers (domain professionals) +- Online shoppers, bargain hunters, gift buyers, wholesale customers (consumers) +- Sales analysts, customer service reps, marketing managers (business roles) +- E-commerce developers, data engineers, API integrators (technical roles) + + + +For a healthcare database, users might include: +- Doctors, nurses, pharmacists, lab technicians (domain professionals) +- Patients, caregivers, health-conscious individuals seeking medical info (consumers) +- Hospital administrators, insurance analysts, compliance officers (business roles) +- Health IT specialists, medical data scientists, EMR developers (technical roles) + + + +For a financial services database, users might include: +- Investment advisors, loan officers, branch managers, auditors (domain professionals) +- Bank customers, investors, loan applicants, small business owners (consumers) +- Risk analysts, fraud investigators, regulatory compliance managers (business roles) +- FinTech developers, blockchain engineers, payment system integrators (technical roles) + + + + +Focus on creating realistic personas that represent the full spectrum of database interactions. Each user should have a clear reason for accessing the database data. Generate exactly ${numChildren} user(s).`, - }, - { - role: "user", - content: `Generate users for database with the following info: + }, + { + role: "user", + content: `Generate users for database with the following info: ${JSON.stringify(parent.data, null, 2)}`, + }, + ], + response: { + schema: DatabaseUserSchema, + name: "generate_database_users", + description: "Generate a list of realistic database users", }, - ], - response: { - schema: DatabaseUserSchema, - name: "generate_database_users", - description: "Generate a list of realistic database users", - }, -}); + }), + { + name: "generateDatabaseUsers", + } +); diff --git a/packages/datasets/src/treeGeneration/databaseNlQueries/databaseNodes/generateMongoshCode.ts b/packages/datasets/src/treeGeneration/databaseNlQueries/databaseNodes/generateMongoshCode.ts index 13b7731c8..91702bae6 100644 --- a/packages/datasets/src/treeGeneration/databaseNlQueries/databaseNodes/generateMongoshCode.ts +++ b/packages/datasets/src/treeGeneration/databaseNlQueries/databaseNodes/generateMongoshCode.ts @@ -9,6 +9,8 @@ import { makePromptDbInfo, makePromptNaturalLanguageQueryInfo, } from "./makePromptComponents"; +import { wrapTraced } from "mongodb-rag-core/braintrust"; +import { openAiClient } from "../../../openAi"; const abstractOutputExample: DatabaseCode = { queryPlan: "", @@ -16,62 +18,96 @@ const abstractOutputExample: DatabaseCode = { language: "mongosh", }; -const rules = [ +// Taken from the optimized mongosh query authoring prompt +export const mongoshQueryAuthoringTips = [ "Ensure proper use of MongoDB operators ($eq, $gt, $lt, etc.) and data types (ObjectId, ISODate)", - "Include appropriate projections to return only necessary fields", "For complex queries, use aggregation pipeline with proper stages ($match, $group, $lookup, etc.)", - "Consider performance by utilizing available indexes and optimizing query patterns", + "Consider performance by utilizing available indexes, avoiding $where and full collection scans, and using covered queries where possible", "Include sorting (.sort()) and limiting (.limit()) when appropriate for result set management", - "Handle null values and existence checks explicitly.", - "For date operations, use proper MongoDB date operators ($dateToString, $dateToParts, etc.)", + "Handle null values and existence checks explicitly with $exists and $type operators to differentiate between missing fields, null values, and empty arrays", + "Do not include `null` in results objects in aggregation, e.g. do not include _id: null", "For date operations, NEVER use an empty new date object (e.g. `new Date()`). ALWAYS specify the date, such as `new Date(\"2024-10-24\")`. Use the provided 'Latest Date' field to inform dates in queries.", - "Keep the output clean. Do not include value `null` in results objects in aggregation. Also don't include empty string values.", + "For Decimal128 operations, prefer range queries over exact equality", + "When querying arrays, use appropriate operators like $elemMatch for complex matching, $all to match multiple elements, or $size for array length checks", +]; + +// Taken from the optimized mongosh query authoring prompt +const chainOfThoughtTopics = [ + "Which collections are relevant to the query.", + "Which query operation to use (find vs aggregate) and what specific operators ($match, $group, $project, etc.) are needed", + "What fields are relevant to the query.", + "Which indexes you can use to improve performance.", + "Any specific transformations or projections.", + "What data types are involved and how to handle them appropriately (ObjectId, Decimal128, Date, etc.)", + "What edge cases to consider (empty results, null values, missing fields)", + "How to handle any array fields that require special operators ($elemMatch, $all, $size)", + "Any other relevant considerations.", ]; const nlQuerySystemPrompt = `You are an expert data analyst experienced at using MongoDB. Your job is to take information about a MongoDB database plus a natural language query and generate a MongoDB shell (mongosh) query to execute to retrieve the information needed to answer the natural language query. In your response include a query plan, where you think step-by-step about how to best execute the query before providing the final mongosh output. + Format the mongosh query in the following structure: \`db..find({/* query */})\` or \`db..aggregate({/* query */})\` + + Some general query-authoring tips: -${rules.map((rule, i) => `${i + 1}. ${rule}`).join("\n")} -Before writing the mongosh query, think step-by-step about what the query should do in the "queryPlan" field. -For the language field, always put 'mongosh'. For example the output should look like: +${markdownList(mongoshQueryAuthoringTips)} + -${JSON.stringify(abstractOutputExample)}`; + +Before writing the mongosh query, think step-by-step about what the query should do in the "queryPlan" field. In your thoughts consider: -export const generateMongoshCode = makeGenerateNChoiceChildrenWithOpenAi< - DatabaseNlQueryNode, - DatabaseCodeNode ->({ - childType: "database_code", - makePromptMessages: async ({ - data: naturalLanguageQuery, - parent: { +${markdownList(chainOfThoughtTopics)} + + + +For the language field, always put 'mongosh'. + + + +For example the output should look like: +${JSON.stringify(abstractOutputExample)} +`; +function markdownList(items: string[]) { + return items.map((item, index) => `${index + 1}. ${item}`).join("\n"); +} +export const generateMongoshCode = wrapTraced( + makeGenerateNChoiceChildrenWithOpenAi({ + openAiClient, + childType: "database_code", + makePromptMessages: async ({ + data: naturalLanguageQuery, parent: { - parent: { data: databaseInfo }, + parent: { + parent: { data: databaseInfo }, + }, }, - }, - }) => { - const message = `Generate MongoDB Shell (mongosh) queries for the following database and natural language query: + }) => { + const message = `Generate MongoDB Shell (mongosh) queries for the following database and natural language query: ${makePromptDbInfo(databaseInfo)} ${makePromptNaturalLanguageQueryInfo(naturalLanguageQuery)} `; - return [ - { role: "system", content: nlQuerySystemPrompt }, - { role: "user", content: message }, - ]; - }, - response: { - schema: DatabaseCodeSchema, - name: "generate_db_code", - description: "A MongoDB Shell (mongosh) query for the database use case", - }, -}); + return [ + { role: "system", content: nlQuerySystemPrompt }, + { role: "user", content: message }, + ]; + }, + response: { + schema: DatabaseCodeSchema, + name: "generate_db_code", + description: "A MongoDB Shell (mongosh) query for the database use case", + }, + }), + { + name: "generateMongoshCode", + } +); diff --git a/packages/datasets/src/treeGeneration/databaseNlQueries/databaseNodes/generateNaturalLanguageQueries.eval.ts b/packages/datasets/src/treeGeneration/databaseNlQueries/databaseNodes/generateNaturalLanguageQueries.eval.ts index c452bb965..2140580db 100644 --- a/packages/datasets/src/treeGeneration/databaseNlQueries/databaseNodes/generateNaturalLanguageQueries.eval.ts +++ b/packages/datasets/src/treeGeneration/databaseNlQueries/databaseNodes/generateNaturalLanguageQueries.eval.ts @@ -56,6 +56,7 @@ async function main() { }), functionName: nlQueryResponseSchema.name, functionDescription: nlQueryResponseSchema.description, + openAiClient, }); return results; }, diff --git a/packages/datasets/src/treeGeneration/databaseNlQueries/databaseNodes/generateNaturalLanguageQueries.ts b/packages/datasets/src/treeGeneration/databaseNlQueries/databaseNodes/generateNaturalLanguageQueries.ts index f775d8814..c3f3403dd 100644 --- a/packages/datasets/src/treeGeneration/databaseNlQueries/databaseNodes/generateNaturalLanguageQueries.ts +++ b/packages/datasets/src/treeGeneration/databaseNlQueries/databaseNodes/generateNaturalLanguageQueries.ts @@ -13,6 +13,8 @@ import { } from "./makePromptComponents"; import { DatabaseInfo } from "mongodb-rag-core/executeCode"; import { OpenAI } from "mongodb-rag-core/openai"; +import { wrapTraced } from "mongodb-rag-core/braintrust"; +import { openAiClient } from "../../../openAi"; export interface MakeGenerateNaturalLanguageQueryPromptParams { numChildren: number; @@ -20,6 +22,7 @@ export interface MakeGenerateNaturalLanguageQueryPromptParams { user: DatabaseUser; databaseInfo: DatabaseInfo; } + export function makeGenerateNaturalLanguageQueryPrompt({ numChildren, useCase, @@ -30,10 +33,7 @@ export function makeGenerateNaturalLanguageQueryPrompt({ These queries will be used in a benchmarking dataset that measures large language models' ability to generate MongoDB database queries from natural language. The queries should be of similar type to those used in modern natural language-to-SQL benchmarks like Spider-SQL or BirdBench. -It is important that the natural language queries not be overly vague, such that there is a high degree of ambiguity in the query. For each natural language query, there should be a clear, unambiguous intent. This is so we can measure the accuracy of the generated queries against the intent and a reference answer. - -Be sure to include varying levels of complexity among the queries. - + For each natural language query: - Create a query that sounds natural and conversational, as if a real user was asking it - Make sure the query aligns with the user's intent and information needs described in the use case @@ -42,19 +42,117 @@ For each natural language query: - Provide variations that express the same information need in different ways - Ensure the query would retrieve the information needed to satisfy the use case - The query should sound conversational, as if the user were asking them to an AI chatbot. +- For most users, avoid using technical database terminology (e.g., "documents", "collections") - use domain-specific terms instead + -Limiting queries: -- IMPORTANT: Design queries that will return a specific small number of results (ideally fewer than 20 documents). - - For example, instead of saying "Find the most popular movies from 1997", say "Find 10 most popular movies from 1997" -- Include specific filtering conditions and constraints to narrow down the result set -- Avoid overly broad queries that might return the entire collection or a large subset of it + +Regarding the specificity of the natural language queries: +- It is important that the natural language queries not be overly vague. +- For each natural language query, there should be a clear, unambiguous intent. There should be a low degree of ambiguity in the query. This is so we can measure the accuracy of the generated MongoDB queries against the natural language query and a reference answer, which is the database output. To create more specific queries: - Use multiple filtering conditions when appropriate - For example, "Find total number of action movies from 2010 with a rating above 8.5" - Include specific values for fields (exact names, precise dates, specific ratings, etc.) - Use comparisons and ranges with specific thresholds +- Avoid wildcards or overly broad terms like "all", "any", "everything" without constraints + + + + +Be sure to include varying levels of complexity among the queries. The complexity levels are: + +- simple +- moderate +- complex + +Below is some specific guidance for each complexity level: + + +Simple queries should follow these criteria: + +Requirements: +- Single collection queries + +Guidelines: +- Basic filtering +- Use up to a few conditions (e.g., equality, comparison) +- Correspond to operations like: find, count, simple aggregation + +Examples: +- "Find all movies released in 2020" +- "Show me users with age greater than 25" +- "Get the product with SKU 'ABC123'" +- "Count how many orders were placed today" + + + + +Moderate queries should follow these criteria: + +Requirements: +- Single collection queries with multiple conditions +- Multiple filter conditions (3-5) with AND/OR logic +- Can correspond to find/aggregation/count, etc. +Guidelines: +- Aggregations (count, sum, average, min, max, etc.) +- Sorting and limiting results +- May access nested fields or arrays +- Simple grouping operations + +Examples: +- "Find the top 10 highest-rated movies from 2020 that are longer than 2 hours" +- "Show me all orders from last month with their customer details and total amount" +- "Get the average rating for each movie genre in the database" +- "List products that are low in stock (less than 10 units) and have been ordered more than 50 times" + + + + +Complex queries should follow these criteria: + +Requirements: +- The query has many parts (at least 3-4 distinct operations or conditions) +- Requires multiple stages of data transformation or analysis +- Cannot be solved with a simple find() or basic aggregation +- Often involves combining data from multiple sources or performing calculations across grouped data +- May require conditional logic or branching based on data values +- Results depend on intermediate calculations or derived fields + +Guidelines: +- Can include multiple collections (joins/lookups) +- Complex aggregation pipelines +- Advanced operations (window functions, array manipulations, text search) +- Multiple grouping levels or conditional aggregations +- Geospatial queries or text search with scoring +- Subqueries or complex conditional logic +- Time-series analysis or running calculations + +Examples: +- "Find movies where the lead actor has won an Oscar and appeared in at least 5 other movies rated above 8.0 in our database" +- "Calculate the month-over-month growth rate of sales for each product category in the last year" +- "Find all restaurants within 2km of midtown Manhattan that serve Italian cuisine, are open now, and have an average rating above 4.5 from users who have reviewed more than 10 restaurants" +- "Show me customers who spent more than the average customer spending in their region and have made purchases in at least 3 different categories this year" + + + + +- When generating the natural language queries, ensure a balanced distribution across all complexity levels to properly test the model's capabilities. +-You should generate a similar number of queries for each complexity level. + + + + + +Limiting queries: +- IMPORTANT: Design queries that will return a specific small number of results (ideally fewer than 20 documents). + - For example, instead of saying "Find the most popular movies from 1997", say "Find 10 most popular movies from 1997" +- Include specific filtering conditions and constraints to narrow down the result set +- Avoid overly broad queries that might return the entire collection or a large subset of it + + + If the use case relates to time windows or temporal analysis: - Include a variety of different time expressions (last 30 days, year 1998, past quarter, last summer, etc.) - Mix absolute time references (e.g., "in 2020", "during the 1990s") with relative ones (e.g., "in the past month", "over the last year") @@ -62,9 +160,15 @@ If the use case relates to time windows or temporal analysis: - Use both specific dates and general time periods - Make sure to keep the date provided in the 'Latest Date' field in mind when crafting the query. -Make use of the indexes available in the database when designing the query, if relevant to the use case. -For example if there are geo or text based indexes, consider writing queries that use them. However, only do this if relevant to the use case. +Note: not all databases will have temporal data. If the database does not have temporal data, ignore the temporal guidelines. + + +- Make use of the indexes available in the database when designing the query, if relevant to the use case. +- For example if there are geo or text based indexes, consider writing queries that use them. However, only use the indexes if relevant to the use case. + + + For the output \`resultsSchema\`, include the actual type definition, for instance given the query "Find the titles and ratings of the 10 most popular movies from 2010 with a rating above 8.5", the output should be: \`\`\`typescript /** @@ -72,8 +176,11 @@ For the output \`resultsSchema\`, include the actual type definition, for instan */ type QueryResults = {title: string, rating: number}[]; \`\`\` + -Generate ${numChildren} natural language queries with multiple variations for each use case, maintaining the same complexity level as specified in the use case.`; + +Generate ${numChildren} natural language queries with multiple variations for each use case, maintaining the same complexity level as specified in the use case. +`; const message = `Generate natural language queries for the following database use case: @@ -94,26 +201,29 @@ export const nlQueryResponseSchema = { description: "A natural language query for the database use case", }; -export const generateNaturalLanguageQueries = makeGenerateChildrenWithOpenAi< - UseCaseNode, - DatabaseNlQueryNode ->({ - makePromptMessages: async ( - { - data: useCase, - parent: { - data: user, - parent: { data: databaseInfo }, +export const generateNaturalLanguageQueries = wrapTraced( + makeGenerateChildrenWithOpenAi({ + openAiClient, + makePromptMessages: async ( + { + data: useCase, + parent: { + data: user, + parent: { data: databaseInfo }, + }, }, + numChildren + ) => { + return makeGenerateNaturalLanguageQueryPrompt({ + numChildren, + useCase, + user, + databaseInfo, + }); }, - numChildren - ) => { - return makeGenerateNaturalLanguageQueryPrompt({ - numChildren, - useCase, - user, - databaseInfo, - }); - }, - response: nlQueryResponseSchema, -}); + response: nlQueryResponseSchema, + }), + { + name: "generateNaturalLanguageQueries", + } +); diff --git a/packages/datasets/src/treeGeneration/databaseNlQueries/databaseNodes/generateUseCases.ts b/packages/datasets/src/treeGeneration/databaseNlQueries/databaseNodes/generateUseCases.ts index d327eaf0a..83449e483 100644 --- a/packages/datasets/src/treeGeneration/databaseNlQueries/databaseNodes/generateUseCases.ts +++ b/packages/datasets/src/treeGeneration/databaseNlQueries/databaseNodes/generateUseCases.ts @@ -5,16 +5,17 @@ import { UseCaseNode, } from "./nodeTypes"; import { makePromptDbInfo, makePromptDbUserInfo } from "./makePromptComponents"; +import { wrapTraced } from "mongodb-rag-core/braintrust"; +import { openAiClient } from "../../../openAi"; -export const generateDatabaseUseCases = makeGenerateChildrenWithOpenAi< - DatabaseUserNode, - UseCaseNode ->({ - makePromptMessages: async ( - { data: user, parent: { data: databaseInfo } }, - numMessages - ) => { - const systemPrompt = `You are an expert user researcher who understands how different professionals use database information in their roles. Given a database user profile, generate realistic use cases that describe what information they need to retrieve from a database and why. +export const generateDatabaseUseCases = wrapTraced( + makeGenerateChildrenWithOpenAi({ + openAiClient, + makePromptMessages: async ( + { data: user, parent: { data: databaseInfo } }, + numMessages + ) => { + const systemPrompt = `You are an expert user researcher who understands how different professionals use database information in their roles. Given a database user profile, generate realistic use cases that describe what information they need to retrieve from a database and why. For each use case: - Focus ONLY on information needs, not on specific queries or technical implementation @@ -37,7 +38,7 @@ For each use case: - ONLY make use cases that are supported by information in the given database. Generate ${numMessages} use case(s).`; - const message = `Generate information retrieval use cases for the following user: + const message = `Generate information retrieval use cases for the following user: ${makePromptDbUserInfo(user)} @@ -47,15 +48,19 @@ ${makePromptDbInfo(databaseInfo)} Based on this profile, what are the most realistic and specific information needs this person would have when working with the database?`; - return [ - { role: "system", content: systemPrompt }, - { role: "user", content: message }, - ]; - }, - response: { - schema: DatabaseUseCaseSchema, - name: "generate_use_cases", - description: "An array of information retrieval use cases for the user", - }, - childType: "database_use_case", -}); + return [ + { role: "system", content: systemPrompt }, + { role: "user", content: message }, + ]; + }, + response: { + schema: DatabaseUseCaseSchema, + name: "generate_use_cases", + description: "An array of information retrieval use cases for the user", + }, + childType: "database_use_case", + }), + { + name: "generateDatabaseUseCases", + } +); diff --git a/packages/datasets/src/treeGeneration/databaseNlQueries/databaseNodes/makePromptComponents.ts b/packages/datasets/src/treeGeneration/databaseNlQueries/databaseNodes/makePromptComponents.ts index 76935abde..692e063a1 100644 --- a/packages/datasets/src/treeGeneration/databaseNlQueries/databaseNodes/makePromptComponents.ts +++ b/packages/datasets/src/treeGeneration/databaseNlQueries/databaseNodes/makePromptComponents.ts @@ -35,7 +35,7 @@ export function makePromptDbUserInfo(user: DatabaseUser) { return `## User Information Name: ${user.name} -Job Title: ${user.jobTitle} +Role: ${user.role} Description: ${user.description}`; } diff --git a/packages/datasets/src/treeGeneration/databaseNlQueries/databaseNodes/nodeTypes.ts b/packages/datasets/src/treeGeneration/databaseNlQueries/databaseNodes/nodeTypes.ts index 85bd674ad..f9aee707a 100644 --- a/packages/datasets/src/treeGeneration/databaseNlQueries/databaseNodes/nodeTypes.ts +++ b/packages/datasets/src/treeGeneration/databaseNlQueries/databaseNodes/nodeTypes.ts @@ -12,11 +12,11 @@ export type DatabaseInfoNode = WithParentNode< export const DatabaseUserSchema = z.object({ name: z.string().describe("Full name of the database user"), - jobTitle: z.string().describe("Current professional role or position"), + role: z.string().describe("Current professional role or position."), description: z .string() .describe( - "Background of the user. Consider including years of experience and relevant areas of expertise." + "Background of the user. Consider including years of experience and relevant areas of expertise/interest." ), }); diff --git a/packages/datasets/src/treeGeneration/databaseNlQueries/databaseNodes/sampleData.ts b/packages/datasets/src/treeGeneration/databaseNlQueries/databaseNodes/sampleData.ts index 5c2e0b352..1d77a9b5f 100644 --- a/packages/datasets/src/treeGeneration/databaseNlQueries/databaseNodes/sampleData.ts +++ b/packages/datasets/src/treeGeneration/databaseNlQueries/databaseNodes/sampleData.ts @@ -6,19 +6,10 @@ import { DatabaseUserNode, NaturalLanguageQuery, } from "./nodeTypes"; -import { OpenAI } from "mongodb-rag-core/openai"; -import { LlmOptions } from "./LlmOptions"; -import { BRAINTRUST_ENV_VARS, assertEnvVars } from "mongodb-rag-core"; -import { DatabaseInfo } from "mongodb-rag-core/executeCode"; +import { DatabaseInfo, LlmOptions } from "mongodb-rag-core/executeCode"; export const makeSampleLlmOptions = () => { - const { BRAINTRUST_API_KEY, BRAINTRUST_ENDPOINT } = - assertEnvVars(BRAINTRUST_ENV_VARS); return { - openAiClient: new OpenAI({ - apiKey: BRAINTRUST_API_KEY, - baseURL: BRAINTRUST_ENDPOINT, - }), model: "gpt-4o-mini", temperature: 0.5, seed: 42, @@ -147,25 +138,25 @@ export const sampleMovieDbInfo: DatabaseInfo = { export const sampleDatabaseUsers = [ { name: "Alice Chen", - jobTitle: "Data Analyst", + role: "Data Analyst", description: "Alice analyzes user interaction data to derive meaningful insights and create reports on movie trends. In Analytics Department. Expertise areas: Data Analysis, SQL, Python, Data Visualization. 5 years of experience.", }, { name: "Daniel Garcia", - jobTitle: "Film Critic", + role: "Film Critic", description: "Daniel writes and publishes reviews and provides ratings for the latest movie releases on MovieDB. In Content Department. Expertise areas: Film Critique, Writing, Cinema Studies. 7 years of experience.", }, { name: "Samuel King", - jobTitle: "Film Festival Organizer", + role: "Film Festival Organizer", description: "Samuel organizes film festivals and uses MovieDB to discover films and directors to feature in his events. In Events Department. Expertise areas: Event Planning, Film Selection, Networking. 15 years of experience.", }, { name: "Tina Lewis", - jobTitle: "Talent Agent", + role: "Talent Agent", description: "Tina represents actors and helps them find roles in films. She uses MovieDB to stay updated on industry trends and opportunities. In Talent Management Department. Expertise areas: Talent Representation, Negotiation, Industry Knowledge. 10 years of experience.", }, diff --git a/packages/datasets/src/treeGeneration/databaseNlQueries/findMostFrequentAndPerformantDatabaseExecutionResult.ts b/packages/datasets/src/treeGeneration/databaseNlQueries/findMostFrequentAndPerformantDatabaseExecutionResult.ts index cbf4babed..007f992d7 100644 --- a/packages/datasets/src/treeGeneration/databaseNlQueries/findMostFrequentAndPerformantDatabaseExecutionResult.ts +++ b/packages/datasets/src/treeGeneration/databaseNlQueries/findMostFrequentAndPerformantDatabaseExecutionResult.ts @@ -54,10 +54,13 @@ export function getClusterFastestExecutionTimeExecutionResultIndex( } export function findMostFrequentAndPerformantDatabaseExecutionResult( - items: DatabaseExecutionResultNode["data"][] + items: DatabaseExecutionResultNode["data"][], + minClusterSize = 1 ) { const clusters = fuzzyClusterDatabaseExecutionResults(items); const largestCluster = findLargestCluster(clusters); + if (largestCluster.length < minClusterSize) + return { clusters, fastestMostFrequentIndex: null }; const clusterItems = getClusterElementsFromIndexes(items, largestCluster); const fastestMostFrequentIndex = getClusterFastestExecutionTimeExecutionResultIndex(clusterItems); diff --git a/packages/datasets/src/treeGeneration/databaseNlQueries/rewriteNlQuery/rewriteNlQuery.eval.ts b/packages/datasets/src/treeGeneration/databaseNlQueries/rewriteNlQuery/rewriteNlQuery.eval.ts new file mode 100644 index 000000000..7c94481dc --- /dev/null +++ b/packages/datasets/src/treeGeneration/databaseNlQueries/rewriteNlQuery/rewriteNlQuery.eval.ts @@ -0,0 +1,419 @@ +import { DatabaseNlQueryDatasetEntryBraintrust } from "../DatabaseNlQueryDatasetEntry"; +import { Eval, EvalScorer, wrapAISDKModel } from "mongodb-rag-core/braintrust"; +import { + makeRewriteNlQueryPrompt, + RewriteClassification, +} from "./rewriteNlQuery"; +import { LanguageModel } from "mongodb-rag-core/aiSdk"; +import { openAiProvider } from "../../../openAi"; +import { models } from "mongodb-rag-core/models"; + +type RewriteNlQueryInput = DatabaseNlQueryDatasetEntryBraintrust; + +type RewriteNlQueryExpected = { + classification: RewriteClassification["classification"]; +}; + +type RewriteNlQueryOutput = { + classification: RewriteClassification; + datasetEntry: DatabaseNlQueryDatasetEntryBraintrust; +}; + +const correctRewriteClassification: EvalScorer< + RewriteNlQueryInput, + RewriteNlQueryOutput, + RewriteNlQueryExpected +> = function ({ output, expected }) { + return { + name: "CorrectRewriteClassification", + score: + output.classification.classification === expected.classification ? 1 : 0, + metadata: { + classification: output.classification, + expectedClassification: expected.classification, + }, + }; +}; + +interface RewriteNlQueryDatasetEntry { + input: RewriteNlQueryInput; + expected: RewriteNlQueryExpected; + tags?: string[]; +} + +function runRewriteNlQueryEval({ + model, + dataset, +}: { + model: LanguageModel; + dataset: RewriteNlQueryDatasetEntry[]; +}) { + const rewriteNlQuery = makeRewriteNlQueryPrompt(model); + return Eval< + RewriteNlQueryInput, + RewriteNlQueryOutput, + RewriteNlQueryExpected + >("rewrite-nl-query", { + data: dataset, + task: rewriteNlQuery, + scores: [correctRewriteClassification], + experimentName: `rewrite-nl-query-${model.modelId}`, + metadata: { + model: model.modelId, + }, + }); +} + +const dataset: RewriteNlQueryDatasetEntry[] = [ + // 5 Ambiguous examples + { + input: { + input: { + nlQuery: + "Find all shipwrecks within 5 kilometers of coordinates 45.483, -86.665.", + databaseName: "sample_geospatial", + }, + expected: { + dbQuery: + 'db.shipwrecks.find({coordinates: {$near: {$geometry: {type: "Point", coordinates: [-86.665, 45.483]}, $maxDistance: 5000}}}, {_id: 1, feature_type: 1, latdec: 1, londec: 1, depth: 1, history: 1, coordinates: 1})', + executionTimeMs: 5238, + result: [ + { + _id: "578f6fa3df35c7fbdbaf0030", + coordinates: [-86.6652573, 45.4831299], + depth: 0, + feature_type: "Wrecks - Submerged, dangerous", + history: "reported 1995", + latdec: 45.4831299, + londec: -86.6652573, + }, + ], + }, + tags: ["simple", "mongosh", "sample_geospatial"], + metadata: { + methods: ["find"], + queryOperators: ["$near", "$geometry", "$maxDistance"], + language: "mongosh", + complexity: "simple", + }, + }, + expected: { + classification: "ambiguous", + }, + }, + { + input: { + input: { + nlQuery: + "Show me the 15 largest sell transactions for Apple (AAPL) stock in 2016 with transaction amounts over 500 shares.", + databaseName: "sample_analytics", + }, + expected: { + dbQuery: + 'db.transactions.aggregate([{$unwind: "$transactions"}, {$match: {"transactions.date": {$gte: new Date("2016-01-01"), $lt: new Date("2017-01-01")}, "transactions.symbol": "aapl", "transactions.transaction_code": "sell", "transactions.amount": {$gt: 500}}}, {$sort: {"transactions.amount": -1}}, {$limit: 15}, {$project: {_id: 0, account_id: 1, date: "$transactions.date", amount: "$transactions.amount", price: "$transactions.price", total: "$transactions.total"}}])', + executionTimeMs: 1689, + result: [ + { + account_id: 418570, + amount: 9959, + date: "2016-12-19T00:00:00Z", + price: "113.9603089288650750177112058736383914947509765625", + total: "1134930.716622567282101385899", + }, + ], + }, + tags: ["moderate", "mongosh", "sample_analytics"], + metadata: { + methods: ["aggregate"], + queryOperators: [ + "$unwind", + "$match", + "$gte", + "$lt", + "$gt", + "$sort", + "$limit", + "$project", + ], + language: "mongosh", + complexity: "moderate", + }, + }, + expected: { + classification: "ambiguous", + }, + }, + { + input: { + input: { + nlQuery: "Show me the 5 most recent buy transactions for EBAY.", + databaseName: "sample_analytics", + }, + expected: { + dbQuery: + 'db.transactions.aggregate([{$unwind: "$transactions"}, {$match: {"transactions.symbol": "ebay", "transactions.transaction_code": "buy"}}, {$sort: {"transactions.date": -1}}, {$limit: 5}, {$project: {_id: 0, date: "$transactions.date", amount: "$transactions.amount", price: "$transactions.price", total: "$transactions.total", symbol: "$transactions.symbol", transaction_code: "$transactions.transaction_code"}}])', + executionTimeMs: 5032, + result: [ + { + amount: 7866, + date: "2017-01-06T00:00:00Z", + price: "30.7483446996547655771792051382362842559814453125", + symbol: "ebay", + total: "241866.4794074843860300916276", + transaction_code: "buy", + }, + ], + }, + tags: ["simple", "mongosh", "sample_analytics"], + metadata: { + methods: ["aggregate"], + queryOperators: ["$unwind", "$match", "$sort", "$limit", "$project"], + language: "mongosh", + complexity: "simple", + }, + }, + expected: { + classification: "ambiguous", + }, + }, + { + input: { + input: { + nlQuery: + "Get titles, directors, writers, and main cast for the top 15 movies with IMDb rating above 8.0 and more than 10000 votes.", + databaseName: "sample_mflix", + }, + expected: { + dbQuery: + 'db.movies.aggregate([{$match: {"imdb.rating": {$gt: 8.0}, "imdb.votes": {$gt: 10000}}}, {$sort: {"imdb.rating": -1}}, {$limit: 15}, {$project: {_id: 0, title: 1, directors: 1, writers: 1, cast: 1}}])', + executionTimeMs: 1728, + result: [ + { + cast: [ + "Scott Grimes", + "Matthew Leitch", + "Damian Lewis", + "Ron Livingston", + ], + title: "Band of Brothers", + }, + ], + }, + tags: ["complex", "mongosh", "sample_mflix"], + metadata: { + methods: ["aggregate"], + queryOperators: ["$match", "$gt", "$sort", "$limit", "$project"], + language: "mongosh", + complexity: "complex", + }, + }, + expected: { + classification: "ambiguous", + }, + }, + { + input: { + input: { + nlQuery: + "List 8 shared rooms under $35 that mention bus stops in their transit information with a location score of 9 or higher.", + databaseName: "sample_airbnb", + }, + expected: { + dbQuery: + 'db.listingsAndReviews.find({room_type: "Shared room", price: {$lt: Decimal128("35")}, transit: {$regex: /bus/i}, "review_scores.review_scores_location": {$gte: 9}}, {name: 1, price: 1, room_type: 1, transit: 1, "review_scores.review_scores_location": 1, _id: 0}).limit(8)', + executionTimeMs: 4061, + result: [ + { + name: "GuestHouse Antero de Quental -Triple Bedroom 1Bed", + price: { $numberDecimal: "25.00" }, + review_scores: { review_scores_location: 9 }, + room_type: "Shared room", + transit: "Walking distance from all type of transports...", + }, + ], + }, + tags: ["simple", "mongosh", "sample_airbnb"], + metadata: { + methods: ["find", "limit"], + queryOperators: ["$lt", "$regex", "$gte"], + language: "mongosh", + complexity: "simple", + }, + }, + expected: { + classification: "ambiguous", + }, + }, + // 5 Unambiguous examples + { + input: { + input: { + nlQuery: + "Show me the average customer satisfaction rating for online purchases versus in-store purchases.", + databaseName: "sample_supplies", + }, + expected: { + dbQuery: + 'db.sales.aggregate([{$group: {_id: "$purchaseMethod", averageSatisfaction: {$avg: "$customer.satisfaction"}}}])', + executionTimeMs: 1708, + result: [ + { _id: "Online", averageSatisfaction: 3.7709779179810727 }, + { _id: "Phone", averageSatisfaction: 3.837248322147651 }, + { _id: "In store", averageSatisfaction: 3.79638169563675 }, + ], + }, + tags: ["simple", "mongosh", "sample_supplies"], + metadata: { + methods: ["aggregate"], + queryOperators: ["$group", "$avg"], + language: "mongosh", + complexity: "simple", + }, + }, + expected: { + classification: "unambiguous", + }, + }, + { + input: { + input: { + nlQuery: "Which boroughs have fewer than 1000 restaurants?", + databaseName: "sample_restaurants", + }, + expected: { + dbQuery: + 'db.restaurants.aggregate([{$group: {_id: "$borough", count: {$sum: 1}}}, {$match: {count: {$lt: 1000}}}, {$project: {_id: 1, count: 1}}])', + executionTimeMs: 1903, + result: [ + { _id: "Staten Island", count: 969 }, + { _id: "Missing", count: 51 }, + ], + }, + tags: ["simple", "mongosh", "sample_restaurants"], + metadata: { + methods: ["aggregate"], + queryOperators: ["$group", "$sum", "$match", "$lt", "$project"], + language: "mongosh", + complexity: "simple", + }, + }, + expected: { + classification: "unambiguous", + }, + }, + { + input: { + input: { + nlQuery: + "Count the total number of transactions for each store location in the first quarter of 2017.", + databaseName: "sample_supplies", + }, + expected: { + dbQuery: + 'db.sales.aggregate([{$match: {saleDate: {$gte: ISODate("2017-01-01T00:00:00.000Z"), $lte: ISODate("2017-03-31T23:59:59.999Z")}}}, {$group: {_id: "$storeLocation", transactionCount: {$sum: 1}}}, {$project: {_id: 0, storeLocation: "$_id", transactionCount: 1}}, {$sort: {storeLocation: 1}}])', + executionTimeMs: 1318, + result: [ + { storeLocation: "Austin", transactionCount: 27 }, + { storeLocation: "Denver", transactionCount: 73 }, + ], + }, + tags: ["simple", "mongosh", "sample_supplies"], + metadata: { + methods: ["aggregate"], + queryOperators: [ + "$match", + "$gte", + "$lte", + "$group", + "$sum", + "$project", + "$sort", + ], + language: "mongosh", + complexity: "simple", + }, + }, + expected: { + classification: "unambiguous", + }, + }, + { + input: { + input: { + nlQuery: + "Show me how many award-winning movies (with at least 3 wins) we have in each genre, sorted by count descending, limit to top 12 genres.", + databaseName: "sample_mflix", + }, + expected: { + dbQuery: + 'db.movies.aggregate([{$match: {"awards.wins": {$gte: 3}}}, {$unwind: "$genres"}, {$group: {_id: "$genres", count: {$sum: 1}}}, {$sort: {count: -1}}, {$limit: 12}])', + executionTimeMs: 1257, + result: [ + { _id: "Drama", count: 5542 }, + { _id: "Comedy", count: 2122 }, + { _id: "Romance", count: 1369 }, + ], + }, + tags: ["moderate", "mongosh", "sample_mflix"], + metadata: { + methods: ["aggregate"], + queryOperators: [ + "$match", + "$gte", + "$unwind", + "$group", + "$sum", + "$sort", + "$limit", + ], + language: "mongosh", + complexity: "moderate", + }, + }, + expected: { + classification: "unambiguous", + }, + }, + { + input: { + input: { + nlQuery: "Which planets in our solar system have rings?", + databaseName: "sample_guides", + }, + expected: { + dbQuery: + "db.planets.find({hasRings: true}, {name: 1, hasRings: 1, _id: 0})", + executionTimeMs: 1270, + result: [ + { hasRings: true, name: "Neptune" }, + { hasRings: true, name: "Uranus" }, + { hasRings: true, name: "Jupiter" }, + { hasRings: true, name: "Saturn" }, + ], + }, + tags: ["simple", "mongosh", "sample_guides"], + metadata: { + methods: ["find"], + queryOperators: [], + language: "mongosh", + complexity: "simple", + }, + }, + expected: { + classification: "unambiguous", + }, + }, +].map( + (entry) => + ({ + ...entry, + tags: [entry.expected.classification, ...entry.input.tags], + } as RewriteNlQueryDatasetEntry) +); + +const modelDeployment: (typeof models)[number]["deployment"] = + "claude-opus-4-20250514"; +runRewriteNlQueryEval({ + model: wrapAISDKModel(openAiProvider(modelDeployment)), + dataset, +}); diff --git a/packages/datasets/src/treeGeneration/databaseNlQueries/rewriteNlQuery/rewriteNlQuery.ts b/packages/datasets/src/treeGeneration/databaseNlQueries/rewriteNlQuery/rewriteNlQuery.ts new file mode 100644 index 000000000..925607a92 --- /dev/null +++ b/packages/datasets/src/treeGeneration/databaseNlQueries/rewriteNlQuery/rewriteNlQuery.ts @@ -0,0 +1,364 @@ +import { z } from "zod"; +import { DatabaseNlQueryDatasetEntryBraintrust } from "../DatabaseNlQueryDatasetEntry"; +import { generateObject, LanguageModel } from "mongodb-rag-core/aiSdk"; + +const RewriteClassificationSchema = z.object({ + think: z.string(), + classification: z.enum(["unambiguous", "ambiguous"]), + rewrittenNaturalLanguageQuery: z + .string() + .optional() + .describe( + "Rewritten natural language query to be less ambiguous. Only include this field if the classification is 'ambiguous'" + ), +}); + +export type RewriteClassification = z.infer; + +const fewShotExamples: { + input: DatabaseNlQueryDatasetEntryBraintrust; + output: RewriteClassification; +}[] = [ + { + input: { + input: { + databaseName: "sample_restaurants", + nlQuery: + "Find restaurants that have received only C grades in their last 4 inspections with scores above 30", + }, + expected: { + dbQuery: + 'db.restaurants.aggregate([\n {\n $addFields: {\n sortedGrades: {\n $slice: [\n {\n $sortArray: {\n input: "$grades",\n sortBy: { date: -1 }\n }\n },\n 4\n ]\n }\n }\n },\n {\n $match: {\n $and: [\n { "sortedGrades.3": { $exists: true } }, // Ensure at least 4 inspections\n {\n sortedGrades: {\n $not: {\n $elemMatch: {\n $or: [\n { grade: { $ne: "C" } },\n { score: { $lte: 30 } }\n ]\n }\n }\n }\n }\n ]\n }\n },\n {\n $project: {\n name: 1,\n restaurant_id: 1,\n borough: 1,\n cuisine: 1,\n consecutive_c_grades: 4,\n latest_score: { $arrayElemAt: ["$sortedGrades.score", 0] }\n }\n }\n])', + executionTimeMs: 2020, + result: [ + { + _id: "5eb3d669b31de5d588f45ff8", + borough: "Queens", + cuisine: "Chinese", + latest_score: 32, + name: "Red Chopstick", + restaurant_id: "41602559", + }, + ], + }, + tags: [], + metadata: { + complexity: "moderate", + language: "mongosh", + methods: ["aggregate"], + }, + }, + output: { + think: + "1. **Analyze the natural language query**: The user asks to 'Find restaurants that have received only C grades in their last 4 inspections with scores above 30.' Key ambiguous terms: 'last 4 inspections' and 'scores above 30'.\n\n2. **Examine the MongoDB query structure**: The query uses $addFields to sort grades by date descending and slice to get 4 most recent, then $match with complex conditions ensuring exactly 4 inspections exist and all have grade C with scores > 30, finally $project returns name, restaurant_id, borough, cuisine, consecutive_c_grades, and latest_score.\n\n3. **Compare natural language with MongoDB behavior**: The natural language doesn't specify that 'last 4' means chronologically recent (requires date sorting), nor that 'scores above 30' applies to ALL 4 inspections (not just some).\n\n4. **Identify output expectations**: The natural language doesn't specify what fields should be returned. The MongoDB query returns 6 specific fields, but this isn't clear from the natural language.\n\n5. **Look for specific ambiguities**: 'last 4 inspections' (temporal reference unclear), 'scores above 30' (filtering logic unclear - all vs some), missing output field specifications.\n\n6. **Consider alternative interpretations**: A user might interpret 'last 4' as array position rather than chronological order, or 'scores above 30' as applying to any/some scores rather than all.\n\n7. **Make your determination**: This is ambiguous because multiple interpretations exist for temporal ordering, filtering logic, and output structure.", + classification: "ambiguous", + rewrittenNaturalLanguageQuery: + "Find restaurants where the 4 most recent inspections (sorted by date) all received C grades and all had scores greater than 30. Return the restaurant name, id, borough, cuisine, the count of consecutive C grades (4), and the latest inspection score.", + } satisfies RewriteClassification, + }, + { + input: { + input: { + databaseName: "sample_restaurants", + nlQuery: + "For Chinese restaurants, show me the distribution of health grades (A, B, C) and count how many inspections received each grade.", + }, + expected: { + dbQuery: + 'db.restaurants.aggregate([\n {\n $match: {\n cuisine: "Chinese"\n }\n },\n {\n $unwind: "$grades"\n },\n {\n $group: {\n _id: "$grades.grade",\n count: { $sum: 1 }\n }\n },\n {\n $project: {\n _id: 0,\n grade: "$_id",\n count: 1\n }\n },\n {\n $sort: {\n grade: 1\n }\n }\n])', + executionTimeMs: 1255, + result: [ + { + count: 6587, + grade: "A", + }, + { + count: 1871, + grade: "B", + }, + { + count: 452, + grade: "C", + }, + { + count: 59, + grade: "Not Yet Graded", + }, + { + count: 212, + grade: "P", + }, + { + count: 168, + grade: "Z", + }, + ], + }, + tags: [], + metadata: { + complexity: "moderate", + language: "mongosh", + methods: ["aggregate"], + }, + }, + output: { + think: + "1. **Analyze the natural language query**: The user asks 'For Chinese restaurants, show me the distribution of health grades (A, B, C) and count how many inspections received each grade.' This is very clear: 'Chinese restaurants' (cuisine filter), 'distribution of health grades' (grouping), and 'count how many inspections received each grade' (aggregation).\n\n2. **Examine the MongoDB query structure**: The query uses $match to filter for Chinese cuisine, $unwind to separate individual grades, $group by grade to count occurrences, $project to format output with grade and count fields, and $sort by grade.\n\n3. **Compare natural language with MongoDB behavior**: The natural language directly specifies what the MongoDB query does - filter for Chinese restaurants, group by grade, and count inspections for each grade.\n\n4. **Identify output expectations**: The natural language clearly indicates the output should show each grade with its count. The MongoDB query returns exactly this: grade and count for each grade type.\n\n5. **Look for specific ambiguities**: No significant ambiguities - the query is straightforward about cuisine type, what to group by, and what to count.\n\n6. **Consider alternative interpretations**: The query is very direct with minimal room for alternative interpretations. A user would expect exactly what the MongoDB query produces: a list of grades with their counts.\n\n7. **Make your determination**: This is unambiguous because the natural language clearly specifies the filtering criteria (Chinese restaurants), the grouping method (by grade), the aggregation (count), and the expected output structure (grade and count).", + classification: "unambiguous", + }, + }, + { + input: { + input: { + databaseName: "sample_restaurants", + nlQuery: + "Show me the top 10 restaurants in Manhattan that have received 3 or more C grades in their inspection history", + }, + expected: { + dbQuery: + 'db.restaurants.aggregate([\n {\n $match: {\n borough: "Manhattan"\n }\n },\n {\n $addFields: {\n totalCGrades: {\n $size: {\n $filter: {\n input: "$grades",\n cond: { $eq: ["$$this.grade", "C"] }\n }\n }\n }\n }\n },\n {\n $match: {\n totalCGrades: { $gte: 3 }\n }\n },\n {\n $sort: {\n totalCGrades: -1\n }\n },\n {\n $limit: 10\n },\n {\n $project: {\n name: 1,\n address: {\n street: "$address.street",\n building: "$address.building"\n },\n totalCGrades: 1,\n cuisine: 1,\n _id: 0\n }\n }\n])', + executionTimeMs: 1454, + result: [ + { + address: { + building: "519", + street: "8 Avenue", + }, + cuisine: "American", + name: "Amici 36", + totalCGrades: 4, + }, + { + address: { + building: "158", + street: "West 58 Street", + }, + cuisine: "Pizza", + name: "Bella Vita", + totalCGrades: 4, + }, + { + address: { + building: "146", + street: "East 46 Street", + }, + cuisine: "Italian", + name: "Nanni Restaurant", + totalCGrades: 4, + }, + { + address: { + building: "45", + street: "East 45 Street", + }, + cuisine: "American", + name: "The Roosevelt Hotel", + totalCGrades: 3, + }, + { + address: { + building: "108", + street: "St Nicholas Avenue", + }, + cuisine: "American", + name: "Hai Cheng Restaurant", + totalCGrades: 3, + }, + ], + }, + tags: [], + metadata: { + complexity: "moderate", + language: "mongosh", + methods: ["aggregate"], + }, + }, + output: { + think: + "1. **Analyze the natural language query**: The user asks to 'Show me the top 10 restaurants in Manhattan that have received 3 or more C grades in their inspection history.' Key ambiguous terms: 'top 10' (ranking criteria unclear).\n\n2. **Examine the MongoDB query structure**: The query uses $match to filter for Manhattan, $addFields to count C grades, another $match for >= 3 C grades, $sort by totalCGrades descending, $limit 10, and $project to return name, address (street/building), totalCGrades, cuisine with _id excluded.\n\n3. **Compare natural language with MongoDB behavior**: The natural language says 'top 10' but doesn't specify the ranking criteria. The MongoDB query sorts by totalCGrades descending, meaning 'top' refers to highest number of C grades, but this isn't explicit in the natural language.\n\n4. **Identify output expectations**: The natural language doesn't specify what fields should be returned. The MongoDB query returns 4 specific fields (name, address structure, totalCGrades, cuisine) but this isn't clear from the natural language.\n\n5. **Look for specific ambiguities**: Vague ranking criteria ('top 10' without specifying metric), missing output field specifications.\n\n6. **Consider alternative interpretations**: A user might interpret 'top 10' as highest-rated restaurants, most popular, or other criteria rather than most C grades. They might also expect different output fields.\n\n7. **Make your determination**: This is ambiguous because 'top 10' doesn't specify the ranking criteria (though the MongoDB query clarifies it's by C grade count), and the output structure isn't specified in the natural language.", + classification: "ambiguous", + rewrittenNaturalLanguageQuery: + "Show me the 10 restaurants in Manhattan with the highest number of C grades in their inspection history (at least 3 C grades), sorted by total C grades descending. Return the restaurant name, address (street and building), total count of C grades, and cuisine type.", + }, + }, +]; + +export const systemPrompt = `You are an expert data labeler employed by MognoDB. You are given a candidate item for a dataset. The dataset will be use to evaluate how well large language models (LLMs) generate MongoDB code and also perform reinforcement learning (RL) to improve the quality of the generated code. The generated code will be evaluated against the expected output using a fuzzy matcher. + + +Each item in the dataset contains the following: + +- Natural language query +- MongoDB shell (mongosh) query +- Database output for the query, formatted in EJSON +- Metadata about the dataset item. + + + + +Your task is to perform the following: +1. Determine if the natural language query is appropriate for the dataset. Each natural language query must be unambiguously answerable by the relevant MongoDB code. (More on what it means for a query to be 'ambiguous' below in the section) +2. If the natural language query is _not_ unambiguously answerable by the relevant MongoDB code, propose an alternative natural language query that is more appropriate for the dataset. + + + + + +Things to look out for that make a natural language query ambiguous: + +1. It's not clear what exactly the values in the output should be. +2. It's not clear what exactly the structure of the output should be. +3. Vague ordering or ranking criteria (e.g., "top 10", "best", "highest" without specifying the ranking metric). +4. Ambiguous time-based references. +5. Unclear filtering conditions. +6. Missing specification of what fields/data should be returned in the final output. **This is very common** - queries like "Find the top 10 restaurants" don't specify what information about those restaurants should be returned. +7. Ambiguous aggregation scope (e.g., unclear whether calculations should be per-restaurant, per-borough, etc.). + + + +If you need to rewrite a natural language query, keep the following in mind: + +1. The rewritten natural language query must be unambiguously answerable by the provided mongosh query. +2. The rewritten natural language query must not directly mention the mongosh query. +3. **Do not need to specify key for aggregation queries**: For aggregation queries ( \`db.collection.aggregate(...)\`), the rewritten natural language query does not need to use the same keys in the output documents. The fuzzy matcher that is used to assess equality handles this case, as long as the values are the same. +4. **Explicitly specify what data should be returned**: Include clear descriptions of which fields should be in the output (e.g., "Return the restaurant name, address, and cuisine type"). +5. **Clarify sorting and ranking criteria**: Replace vague terms like "top 10" with specific criteria (e.g., "sorted by total C grades descending"). +6. **Resolve temporal ambiguities**: Clarify time-based references (e.g., "4 most recent inspections sorted by date"). +7. **Specify filtering logic**: Make clear whether conditions apply to all items or just some (e.g., "all 4 grades must be C" vs "at least 3 grades are C"). + +**Quality standards for rewrites**: +- The rewritten query should be natural and readable, not overly technical +- The rewritten query should be fully specified for the given Mongosh query and output. Include all necessary details but avoid unnecessary verbosity. +- Use domain-appropriate terminology (e.g., "health inspection scores" rather than just "scores") +- Ensure the rewrite would lead any reasonable person to write the same MongoDB query. Query specifificity is important for this, especially with respect to the \`$project\` stage. + + + +If you modify the original natural language query, maintain the writing style of the original query, but make it more clear and unambiguous. + + + +When evaluating whether a natural language query is unambiguous: + +1. **Compare the natural language query with the MongoDB query output structure**: Does the natural language clearly indicate what fields should be returned based on the \`$project\` stage? +2. **Check if sorting/ordering is clear**: If the MongoDB query includes \`$sort\`, does the natural language specify the sorting criteria? +3. **Verify filtering clarity**: Are all \`$match\` conditions clearly implied by the natural language? +4. **Assess temporal references**: Are time-based operations (like "most recent" or "last N") unambiguous? +5. **Consider domain context**: Use your knowledge of the relevant domain to determine if terms would be clear to a typical user. + +A query is **unambiguous** if someone reading it would produce the same MongoDB query and expect the same output structure. A query is **ambiguous** if multiple reasonable interpretations exist. + +**For borderline cases**: If you're unsure whether a query is ambiguous, err on the side of classifying it as ambiguous. It's better to have a clear, explicit query than to risk multiple interpretations. + +**Critical reminder**: Just because you can reasonably infer what the output should contain doesn't make a query unambiguous. If the natural language doesn't explicitly specify output fields, it should generally be classified as ambiguous. + +**Common patterns that should be classified as AMBIGUOUS:** +- "Find the top 10 restaurants" (doesn't specify what info about restaurants to return) +- "Show me neighborhoods with the most restaurants" (doesn't specify what fields to include) +- "Which planet has the highest temperature?" (doesn't specify if you want just the name, or name + temperature, or additional details) +- "Count customers by tier" (doesn't specify if you want just counts, or tier names + counts, or additional fields) + +**Patterns that are typically UNAMBIGUOUS:** +- "Return the restaurant name and address for the top 10 restaurants by rating" (explicit output specification) +- "Show the planet name and temperature for the planet with highest temperature" (clear output fields) +- "Count customers by tier, showing tier name and count" (explicit about what to return) + + + + +In the "think" field, you should systematically analyze the query by addressing these key points: + +1. **Analyze the natural language query**: What is the user asking for? Are there any vague or ambiguous terms? +2. **Examine the MongoDB query structure**: What does the MongoDB query actually do? Pay attention to: + - Filtering conditions (\`$match\` stages) + - Sorting criteria (\`$sort\` stages) + - Aggregation operations (\`$group\`, \`$addFields\`, etc.) + - Output structure (\`$project\` stage) + +3. **Compare natural language with MongoDB behavior**: Does the natural language clearly specify what the MongoDB query does? Are there gaps or ambiguities? +4. **Identify output expectations**: What fields should be returned? Does the natural language make this clear, or would a user be unsure about the output structure? **Be strict here**. If the natural language doesn't explicitly mention what data to return, it's likely ambiguous even if the output seems "obvious" from context. +5. **Look for specific ambiguities**: Check for: + - Vague ranking criteria ("top 10", "best") + - Unclear temporal references ("last 4", "recent") + - Ambiguous filtering logic ("scores above 30" - all or some?) + - Missing output field specifications +6. **Consider alternative interpretations**: Could a reasonable user interpret this query differently? Would they expect different output fields or structure? +7. **Make your determination**: Based on the above analysis, classify as "ambiguous" if there are multiple reasonable interpretations, or "unambiguous" if the query clearly specifies the expected MongoDB behavior and output. +8. **Style check**: (only for ambiguous queries to modify) If you modify the original natural language query, think about how to maintain the writing style of the original query, but make it more clear and unambiguous. +**Your reasoning should be specific and reference concrete aspects of both the natural language query and the MongoDB code.** + + + + + +Here are some examples to help you understand the task better: + +${fewShotExamples.map((example, idx) => { + return ` + + + +${JSON.stringify(example.input, null, 2)} + + + +${JSON.stringify(example.output, null, 2)} + + + +`; +})} + + + + +- Focus on whether the natural language query unambiguously specifies what the MongoDB query actually does +- Pay special attention to the \`$project\` stage or \`.project()\` method. Does the natural language make it clear what fields should be returned? +- Consider edge cases and alternative interpretations a user might have in the "think" field. +- Remember that the goal is to create a dataset where LLMs can unambiguously generate the correct MongoDB code +- Only include "rewrittenNaturalLanguageQuery" when classification is "ambiguous" +- **Be strict about output specifications**: Don't assume that "obvious" outputs make a query unambiguous - explicit is better than implicit + +`; + +export function makeRewriteNlQueryPrompt(model: LanguageModel) { + return async ( + datasetEntry: DatabaseNlQueryDatasetEntryBraintrust + ): Promise<{ + classification: RewriteClassification; + datasetEntry: DatabaseNlQueryDatasetEntryBraintrust; + }> => { + const { object: classification } = await generateObject({ + model, + schema: RewriteClassificationSchema, + messages: [ + { + role: "system", + content: systemPrompt, + }, + { + role: "user", + content: `Here is a dataset entry: + + +${JSON.stringify(datasetEntry, null, 2)} +`, + }, + ], + }); + const datasetEntryOut = { + ...datasetEntry, + input: { + ...datasetEntry.input, + nlQuery: + classification.rewrittenNaturalLanguageQuery ?? + datasetEntry.input.nlQuery, + }, + } satisfies DatabaseNlQueryDatasetEntryBraintrust; + + return { + classification, + datasetEntry: datasetEntryOut, + }; + }; +} diff --git a/packages/datasets/src/treeGeneration/generateChildren.test.ts b/packages/datasets/src/treeGeneration/generateChildren.test.ts index 7299668c0..39ad91bdf 100644 --- a/packages/datasets/src/treeGeneration/generateChildren.test.ts +++ b/packages/datasets/src/treeGeneration/generateChildren.test.ts @@ -5,7 +5,7 @@ import { makeGenerateChildrenWithOpenAi, makeGenerateNChoiceChildrenWithOpenAi, } from "./generateChildren"; -import { LlmOptions } from "./databaseNlQueries/databaseNodes/LlmOptions"; +import { LlmOptions } from "mongodb-rag-core/executeCode"; // Mock OpenAI client const mockOpenAIClient = { @@ -46,7 +46,6 @@ function createParentNode(): ParentNode { function createLlmOptions(): LlmOptions { return { - openAiClient: mockOpenAIClient, model: "gpt-4", temperature: 0, max_tokens: 1000, @@ -80,6 +79,7 @@ describe("makeGenerateChildrenWithOpenAi", () => { ParentNode, ChildNode >({ + openAiClient: mockOpenAIClient, makePromptMessages: async (parent) => [ { role: "system", content: "You are a test assistant" }, { @@ -114,6 +114,7 @@ describe("makeGenerateChildrenWithOpenAi", () => { ParentNode, ChildNode >({ + openAiClient: mockOpenAIClient, makePromptMessages: async (parent) => [ { role: "system", content: "You are a test assistant" }, { @@ -229,6 +230,7 @@ describe("makeGenerateNChoiceChildrenWithOpenAi", () => { ParentNode, ChildNode >({ + openAiClient: mockOpenAIClient, makePromptMessages: async (parent) => [ { role: "system", content: "You are a test assistant" }, { @@ -260,6 +262,7 @@ describe("makeGenerateNChoiceChildrenWithOpenAi", () => { ParentNode, ChildNode >({ + openAiClient: mockOpenAIClient, makePromptMessages: () => Promise.resolve([ { role: "system", content: "You are a test assistant" }, @@ -281,4 +284,8 @@ describe("makeGenerateNChoiceChildrenWithOpenAi", () => { expect(children).toHaveLength(1); expect(children[0].data.value).toBe(10); }); + + it("should use claude temperature variation for models with 'claude' in the name", () => { + // TODO: add + }); }); diff --git a/packages/datasets/src/treeGeneration/generateChildren.ts b/packages/datasets/src/treeGeneration/generateChildren.ts index ead70bd1a..236c6161b 100644 --- a/packages/datasets/src/treeGeneration/generateChildren.ts +++ b/packages/datasets/src/treeGeneration/generateChildren.ts @@ -4,7 +4,12 @@ import { zodToJsonSchema } from "zod-to-json-schema"; import { PromisePool } from "@supercharge/promise-pool"; import { ObjectId } from "mongodb-rag-core/mongodb"; import { GenerationNode, WithParentNode } from "./GenerationNode"; -import { LlmOptions } from "./databaseNlQueries/databaseNodes/LlmOptions"; +import { LlmOptions } from "mongodb-rag-core/executeCode"; + +export type GenerateChildrenLlmOptions = LlmOptions & { + __claudeMaxConcurrency?: number; + __claudeTemperatureVariation?: number; +}; export type GenerateChildren< ParentNode extends GenerationNode | null, @@ -14,7 +19,7 @@ export type GenerateChildren< > > = ( parent: ParentNode, - llmOptions: LlmOptions, + llmOptions: GenerateChildrenLlmOptions, numChildren: number ) => Promise; @@ -59,6 +64,7 @@ export interface MakeGenerateChildrenWithOpenAiParams< */ concurrency?: number; }; + openAiClient: OpenAI; /** Name for child type. @@ -77,6 +83,7 @@ export function makeGenerateChildrenWithOpenAi< response, filterNodes, childType, + openAiClient, }: MakeGenerateChildrenWithOpenAiParams< ParentNode, ChildNode @@ -91,7 +98,7 @@ export function makeGenerateChildrenWithOpenAi< const responseSchema = z.object({ items: z.array(response.schema), }); - const { openAiClient, ...clientConfig } = llmOptions; + const { ...clientConfig } = llmOptions; // Loop running tool calls to generate // exactly the correct amount of children. @@ -152,6 +159,7 @@ export function makeGenerateNChoiceChildrenWithOpenAi< response, filterNodes, childType, + openAiClient, }: Omit< MakeGenerateChildrenWithOpenAiParams, "response" @@ -159,20 +167,75 @@ export function makeGenerateNChoiceChildrenWithOpenAi< response: ResponseFunction; }): GenerateChildren { return async function generateNChoiceChildrenWithOpenAI( - parent: ParentNode, - llmOptions: LlmOptions, - numChildren: number + parent, + llmOptions, + numChildren ): Promise { const messages = await makePromptMessages(parent, numChildren); - const { openAiClient, ...clientConfig } = llmOptions; - - const completion = await getCompletions({ - openAiClient, - ...clientConfig, - messages, - response, - numCompletions: numChildren, - }); + const { ...clientConfig } = llmOptions; + let completion: OpenAI.Chat.Completions.ChatCompletion & { + _request_id?: string | null; + } = { + choices: [], + id: "", + created: 0, + model: "", + object: "chat.completion", + }; + if (llmOptions.model.includes("claude")) { + const defaultTemperatureVariation = 0.01; + const defaultMaxConcurrency = 1; + const defaultTemperature = llmOptions.temperature ?? 0.5; + const variedTemperatures = Array.from({ length: numChildren }).map( + (_, index) => { + // Offset by the temperature variation in a symmetrical manner + // E.g. if temperature is .5 and variation is .01, + // then the temperatures should be .5, .49, .51, .48, .52, etc. + // This is so we can still use the Braintrust cache. + const isEven = index % 2 === 0; + const offset = Math.floor(index / 2) + (isEven ? 0 : 1); + const claudeVariationOffset = + llmOptions.__claudeTemperatureVariation ?? + defaultTemperatureVariation; + const variation = isEven + ? -offset * claudeVariationOffset + : offset * claudeVariationOffset; + return Math.round((variation + defaultTemperature) * 100) / 100; + } + ); + const { results: choices } = await PromisePool.for(variedTemperatures) + .withConcurrency( + llmOptions.__claudeMaxConcurrency ?? defaultMaxConcurrency + ) + .handleError((error) => { + console.error("Error generating children", error); + }) + .process(async (variedTemperature) => { + clientConfig.temperature = variedTemperature; + const { + choices: [choice], + } = await getCompletions({ + openAiClient, + ...clientConfig, + messages, + response, + numCompletions: numChildren, + }); + return choice; + }); + completion.choices = choices; + } + // For other models where we can generate all completions at once + // using the N candidates generation pattern (see https://community.openai.com/t/how-does-n-parameter-work-in-chat-completions/288725/2) + else { + completion = await getCompletions({ + openAiClient, + ...clientConfig, + messages, + response, + numCompletions: numChildren, + }); + } let children: ChildNode["data"][] = completion.choices .map((choice) => choice.message.tool_calls?.[0].function.arguments) @@ -192,11 +255,14 @@ async function getCompletions({ messages, response, numCompletions, + __claudeMaxConcurrency, + __claudeTemperatureVariation, ...clientConfig -}: LlmOptions & { +}: GenerateChildrenLlmOptions & { messages: OpenAI.ChatCompletionMessageParam[]; numCompletions?: number; response: ResponseFunction; + openAiClient: OpenAI; }) { const completion = await openAiClient.chat.completions.create({ ...clientConfig, diff --git a/packages/datasets/src/treeGeneration/mongoDbDatabases/README.md b/packages/datasets/src/treeGeneration/mongoDbDatabases/README.md new file mode 100644 index 000000000..edc041321 --- /dev/null +++ b/packages/datasets/src/treeGeneration/mongoDbDatabases/README.md @@ -0,0 +1,16 @@ +# Generate MongoDB Databases + +This directory contains the logic for generating MongoDB databases. + +The generation pipeline should use the following flow: + +1. create a business + its use case + - maybe we help out by providing sector or something in prompt + - can get inspiration from https://www.mongodb.com/solutions/customer-case-studies or other docs. doing napkin math i see ~300 case studies +2. Generate schema + indexes + - prompt includes condensed version of schema design docs to push for best practices +3. Generate zod types to validate collection document schemas + - want to have this to guarantee consistency of any documents we input +4. have AI write script to input data into schema + - i think this is the hardest part + - i think using the Claude Code TypeScript SDK (https://docs.anthropic.com/en/docs/claude-code/sdk#typescript) could be promising here \ No newline at end of file diff --git a/packages/mongodb-rag-core/src/executeCode/databaseMetadata/LlmOptions.ts b/packages/mongodb-rag-core/src/executeCode/databaseMetadata/LlmOptions.ts index 61d03ae4e..e86269ca5 100644 --- a/packages/mongodb-rag-core/src/executeCode/databaseMetadata/LlmOptions.ts +++ b/packages/mongodb-rag-core/src/executeCode/databaseMetadata/LlmOptions.ts @@ -1,8 +1,6 @@ import { OpenAI } from "openai"; -export type LlmOptions = { - openAiClient: OpenAI; -} & Omit< +export type LlmOptions = Omit< OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming, "messages" | "tools" | "tool_choice" >; diff --git a/packages/mongodb-rag-core/src/executeCode/databaseMetadata/generateAnnotatedCollectionSchema.ts b/packages/mongodb-rag-core/src/executeCode/databaseMetadata/generateAnnotatedCollectionSchema.ts index 6f3435ad5..6bd85334f 100644 --- a/packages/mongodb-rag-core/src/executeCode/databaseMetadata/generateAnnotatedCollectionSchema.ts +++ b/packages/mongodb-rag-core/src/executeCode/databaseMetadata/generateAnnotatedCollectionSchema.ts @@ -4,6 +4,7 @@ import { z } from "zod"; import { getOpenAiFunctionResponse } from "./getOpenAiFunctionResponse"; import { prettyPrintMongoDbDocument } from "../prettyPrintMongoDbDocument"; import { LlmOptions } from "./LlmOptions"; +import { wrapTraced } from "braintrust"; const systemPrompt = `You are an expert MongoDB database architect. Your task is to analyze the provided database metadata and generate clear, concise descriptions and an annotated schema for the specified collection. The descriptions that you generate will be used in the prompt of a LLM for performing database-related tasks. @@ -57,36 +58,43 @@ interface GenerateAnnotatedCollectionSchemaParams { llm: LlmOptions; } -export async function generateAnnotatedCollectionSchema({ - collectionMetadata, - databaseMetadata, - llm: llmOptions, -}: GenerateAnnotatedCollectionSchemaParams): Promise { - const messages = [ - { - role: "system", - content: systemPrompt, - }, - { - role: "user", - content: `Analyze the following collection: '${ - collectionMetadata.collectionName - }'. +export const makeGenerateAnnotatedCollectionSchema = (openAiClient: OpenAI) => + wrapTraced( + async function generateAnnotatedCollectionSchema({ + collectionMetadata, + databaseMetadata, + llm: llmOptions, + }: GenerateAnnotatedCollectionSchemaParams): Promise { + const messages = [ + { + role: "system", + content: systemPrompt, + }, + { + role: "user", + content: `Analyze the following collection: '${ + collectionMetadata.collectionName + }'. Database metadata: ${prettyPrintMongoDbDocument(databaseMetadata)} Again, analyze the collection named '${collectionMetadata.collectionName}'.`, - }, - ] satisfies OpenAI.ChatCompletionMessageParam[]; + }, + ] satisfies OpenAI.ChatCompletionMessageParam[]; - const result = await getOpenAiFunctionResponse({ - messages, - llmOptions, - schema: makeDetailedCollectionDescriptionSchema(collectionMetadata), - functionName, - functionDescription, - }); + const result = await getOpenAiFunctionResponse({ + messages, + llmOptions, + schema: makeDetailedCollectionDescriptionSchema(collectionMetadata), + functionName, + functionDescription, + openAiClient, + }); - return result; -} + return result; + }, + { + name: "generateAnnotatedCollectionSchema", + } + ); diff --git a/packages/mongodb-rag-core/src/executeCode/databaseMetadata/generateAnnotatedDatabaseInfo.ts b/packages/mongodb-rag-core/src/executeCode/databaseMetadata/generateAnnotatedDatabaseInfo.ts index 6bd8573d0..3edbec668 100644 --- a/packages/mongodb-rag-core/src/executeCode/databaseMetadata/generateAnnotatedDatabaseInfo.ts +++ b/packages/mongodb-rag-core/src/executeCode/databaseMetadata/generateAnnotatedDatabaseInfo.ts @@ -1,9 +1,11 @@ import { MongoClient } from "mongodb"; import { z } from "zod"; import { LlmOptions } from "./LlmOptions"; -import { generateAnnotatedCollectionSchema } from "./generateAnnotatedCollectionSchema"; -import { generateHighLevelDbDescriptions } from "./generateHighLevelDbDescriptions"; +import { makeGenerateAnnotatedCollectionSchema } from "./generateAnnotatedCollectionSchema"; +import { makeGenerateHighLevelDbDescriptions } from "./generateHighLevelDbDescriptions"; import { getDatabaseMetadata } from "./getDatabaseMetadata"; +import { traced } from "braintrust"; +import OpenAI from "openai"; export const DatabaseInfoSchema = z.object({ name: z.string().describe("Name of the database"), @@ -41,7 +43,8 @@ export interface GenerateAnnotatedDatabaseInfoParams { numSamplesPerCollection?: number; }; latestDate?: Date; - llm: LlmOptions; + llmOptions: LlmOptions; + openAiClient: OpenAI; } /** @@ -50,7 +53,8 @@ export interface GenerateAnnotatedDatabaseInfoParams { export async function generateAnnotatedDatabaseInfo({ mongoDb: { mongoClient, databaseName, numSamplesPerCollection = 2 }, latestDate = new Date(), - llm, + llmOptions, + openAiClient, }: GenerateAnnotatedDatabaseInfoParams): Promise { // Get raw database metadata const databaseMetadata = await getDatabaseMetadata({ @@ -59,61 +63,70 @@ export async function generateAnnotatedDatabaseInfo({ numSamplesPerCollection, latestDate, }); + const generateHighLevelDbDescriptions = + makeGenerateHighLevelDbDescriptions(openAiClient); + const generateAnnotatedCollectionSchema = + makeGenerateAnnotatedCollectionSchema(openAiClient); + return traced( + async () => { + // Generate high-level database descriptions + const highLevelDescriptions = await generateHighLevelDbDescriptions( + databaseMetadata, + llmOptions + ); - // Generate high-level database descriptions - const highLevelDescriptions = await generateHighLevelDbDescriptions( - databaseMetadata, - llm - ); + // Create initial annotated database info + const annotatedDatabaseInfo: DatabaseInfo = { + name: databaseName, + description: highLevelDescriptions.databaseDescription, + latestDate, + collections: databaseMetadata.collections.map((collection, i) => ({ + name: collection.collectionName, + description: + highLevelDescriptions.collectionDescriptions[i].description, + schema: collection.schema, + examples: collection.exampleDocuments, + indexes: collection.indexes, + })), + }; - // Create initial annotated database info - const annotatedDatabaseInfo: DatabaseInfo = { - name: databaseName, - description: highLevelDescriptions.databaseDescription, - latestDate, - collections: databaseMetadata.collections.map((collection, i) => ({ - name: collection.collectionName, - description: highLevelDescriptions.collectionDescriptions[i].description, - schema: collection.schema, - examples: collection.exampleDocuments, - indexes: collection.indexes, - })), - }; + // Generate detailed schema descriptions for each collection + for (let i = 0; i < annotatedDatabaseInfo.collections.length; i++) { + const annotatedCollection = annotatedDatabaseInfo.collections[i]; + const collection = databaseMetadata.collections.find( + (c) => c.collectionName === annotatedCollection.name + ); + if (!collection) { + continue; + } - // Generate detailed schema descriptions for each collection - for (let i = 0; i < annotatedDatabaseInfo.collections.length; i++) { - const annotatedCollection = annotatedDatabaseInfo.collections[i]; - const collection = databaseMetadata.collections.find( - (c) => c.collectionName === annotatedCollection.name - ); - if (!collection) { - continue; - } + const { typeScriptSchema, indexDescriptions } = + await generateAnnotatedCollectionSchema({ + collectionMetadata: collection, + databaseMetadata, + llm: llmOptions, + }); - const { typeScriptSchema, indexDescriptions } = - await generateAnnotatedCollectionSchema({ - collectionMetadata: collection, - databaseMetadata, - llm, - }); + // Update the collection's schema with the annotated version + annotatedCollection.schema = typeScriptSchema; - // Update the collection's schema with the annotated version - annotatedCollection.schema = typeScriptSchema; + // Update the collection's indexes with the annotated version + for (let j = 0; j < indexDescriptions.length; j++) { + const indexDescription = indexDescriptions[j]; - // Update the collection's indexes with the annotated version - for (let j = 0; j < indexDescriptions.length; j++) { - const indexDescription = indexDescriptions[j]; + const collectionIndexDescription = annotatedCollection.indexes.find( + (index) => index.name === indexDescription.name + ); - const collectionIndexDescription = annotatedCollection.indexes.find( - (index) => index.name === indexDescription.name - ); + if (!collectionIndexDescription) { + continue; + } - if (!collectionIndexDescription) { - continue; + collectionIndexDescription.description = indexDescription.description; + } } - - collectionIndexDescription.description = indexDescription.description; - } - } - return annotatedDatabaseInfo; + return annotatedDatabaseInfo; + }, + { name: "generateAnnotatedDatabaseInfo" } + ); } diff --git a/packages/mongodb-rag-core/src/executeCode/databaseMetadata/generateHighLevelDbDescriptions.ts b/packages/mongodb-rag-core/src/executeCode/databaseMetadata/generateHighLevelDbDescriptions.ts index c10e7f8e0..a88711a4b 100644 --- a/packages/mongodb-rag-core/src/executeCode/databaseMetadata/generateHighLevelDbDescriptions.ts +++ b/packages/mongodb-rag-core/src/executeCode/databaseMetadata/generateHighLevelDbDescriptions.ts @@ -4,6 +4,7 @@ import { LlmOptions } from "./LlmOptions"; import { getOpenAiFunctionResponse } from "./getOpenAiFunctionResponse"; import { prettyPrintMongoDbDocument } from "../prettyPrintMongoDbDocument"; import { OpenAI } from "openai"; +import { wrapTraced } from "braintrust"; const systemPrompt = `You are an expert MongoDB database architect. Your task is to analyze the provided database metadata and generate clear, concise descriptions. @@ -51,29 +52,36 @@ function createHighLevelDbDescriptionsSchema( /** Get high-level descriptions of the database and its collections. */ -export async function generateHighLevelDbDescriptions( - databaseMetadata: DatabaseMetadata, - llmOptions: LlmOptions -) { - const schema = createHighLevelDbDescriptionsSchema(databaseMetadata); +export const makeGenerateHighLevelDbDescriptions = (openAiClient: OpenAI) => + wrapTraced( + async function ( + databaseMetadata: DatabaseMetadata, + llmOptions: LlmOptions + ) { + const schema = createHighLevelDbDescriptionsSchema(databaseMetadata); - const messages = [ - { - role: "system", - content: systemPrompt, - }, - { - role: "user", - content: `Database information: + const messages = [ + { + role: "system", + content: systemPrompt, + }, + { + role: "user", + content: `Database information: ${prettyPrintMongoDbDocument(databaseMetadata)}`, + }, + ] satisfies OpenAI.ChatCompletionMessageParam[]; + return await getOpenAiFunctionResponse({ + messages, + llmOptions, + schema, + functionName, + functionDescription: + "Generate high-level descriptions of the database and its collections based on the provided metadata", + openAiClient, + }); }, - ] satisfies OpenAI.ChatCompletionMessageParam[]; - return await getOpenAiFunctionResponse({ - messages, - llmOptions, - schema, - functionName, - functionDescription: - "Generate high-level descriptions of the database and its collections based on the provided metadata", - }); -} + { + name: "generateHighLevelDbDescriptions", + } + ); diff --git a/packages/mongodb-rag-core/src/executeCode/databaseMetadata/getOpenAiFunctionResponse.ts b/packages/mongodb-rag-core/src/executeCode/databaseMetadata/getOpenAiFunctionResponse.ts index eebe7b2f3..59f089ebb 100644 --- a/packages/mongodb-rag-core/src/executeCode/databaseMetadata/getOpenAiFunctionResponse.ts +++ b/packages/mongodb-rag-core/src/executeCode/databaseMetadata/getOpenAiFunctionResponse.ts @@ -9,6 +9,7 @@ interface GetOpenAiFunctionResponseParams { schema: Schema; functionName: string; functionDescription?: string; + openAiClient: OpenAI; } export async function getOpenAiFunctionResponse({ @@ -17,11 +18,12 @@ export async function getOpenAiFunctionResponse({ schema, functionName, functionDescription, + openAiClient, }: GetOpenAiFunctionResponseParams): Promise> { const parameters = zodToJsonSchema(schema, { $refStrategy: "none", }); - const { openAiClient, ...createChatCompletionParams } = llmOptions; + const { ...createChatCompletionParams } = llmOptions; const res = await openAiClient.chat.completions.create({ messages, ...createChatCompletionParams, diff --git a/packages/mongodb-rag-core/src/models/models.ts b/packages/mongodb-rag-core/src/models/models.ts index 5b4af5dc9..630da1d6e 100644 --- a/packages/mongodb-rag-core/src/models/models.ts +++ b/packages/mongodb-rag-core/src/models/models.ts @@ -235,7 +235,7 @@ const allModels = [ deployment: "us.anthropic.claude-sonnet-4-20250514-v1:0", provider: "braintrust", developer: "Anthropic", - maxConcurrency: 10, + maxConcurrency: 5, authorized: true, }, { @@ -243,7 +243,23 @@ const allModels = [ deployment: "us.anthropic.claude-opus-4-20250514-v1:0", provider: "braintrust", developer: "Anthropic", - maxConcurrency: 10, + maxConcurrency: 5, + authorized: true, + }, + { + label: "anthropic/claude-sonnet-4", + deployment: "claude-sonnet-4-20250514", + provider: "braintrust", + developer: "Anthropic", + maxConcurrency: 5, + authorized: true, + }, + { + label: "anthropic/claude-opus-4", + deployment: "claude-opus-4-20250514", + provider: "braintrust", + developer: "Anthropic", + maxConcurrency: 5, authorized: true, }, { diff --git a/packages/scripts/src/profound/getAndProcessAnswers.ts b/packages/scripts/src/profound/getAndProcessAnswers.ts index 5cccc08c5..7ad60e57e 100644 --- a/packages/scripts/src/profound/getAndProcessAnswers.ts +++ b/packages/scripts/src/profound/getAndProcessAnswers.ts @@ -149,13 +149,13 @@ export const main = async (startDateArg?: string, endDateArg?: string) => { // get reference alignment scores for answers const endpointAndKey = await getOpenAiEndpointAndApiKey(model); + const openAiClient = new OpenAI(endpointAndKey); const config = { - openAiClient: new OpenAI(endpointAndKey), model: model.deployment, temperature: 0, label: model.label, }; - const referenceAlignmentFn = makeReferenceAlignment(config); + const referenceAlignmentFn = makeReferenceAlignment(openAiClient, config); const answerRecords: any[] = []; const { results, errors } = await PromisePool.for(answers) .withConcurrency(model.maxConcurrency ?? 5)