Skip to content

Claude nl to mql #809

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
17 changes: 11 additions & 6 deletions packages/benchmarks/src/nlPromptResponse/metrics.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@ import { NlPromptResponseEvalScorer } from "./NlQuestionAnswerEval";
import { Factuality, Score, AnswerCorrectness } from "autoevals";
import { strict as assert } from "assert";
import { LlmOptions } from "mongodb-rag-core/executeCode";
import { OpenAI } from "mongodb-rag-core/openai";
import { openAi } from "mongodb-rag-core/langchain";

export const makeReferenceAlignment: (
openAiClient: OpenAI,
llmOptions: LlmOptions,
name_postfix?: string
) => NlPromptResponseEvalScorer = (llmOptions, name_postfix) =>
) => NlPromptResponseEvalScorer = (openAiClient, llmOptions, name_postfix) =>
async function ({ input, output, expected }) {
const { response } = output;
const { reference } = expected;
Expand All @@ -29,7 +32,7 @@ export const makeReferenceAlignment: (
// Note: need to do the funky casting here
// b/c of different `OpenAI` client typing
// that is not relevant here.
client: llmOptions.openAiClient as unknown as Parameters<
client: openAiClient as unknown as Parameters<
typeof Factuality
>[0]["client"],
model: llmOptions.model,
Expand Down Expand Up @@ -62,9 +65,10 @@ function inflateFactualityScore(score: number | null | undefined) {
}

export const makeAnswerCorrectness: (
openAiClient: OpenAI,
llmOptions: LlmOptions,
name_postfix?: string
) => NlPromptResponseEvalScorer = (llmOptions, name_postfix) =>
) => NlPromptResponseEvalScorer = (openAiClient, llmOptions, name_postfix) =>
async function ({ input, output, expected }) {
const { response } = output;
const { reference } = expected;
Expand All @@ -89,7 +93,7 @@ export const makeAnswerCorrectness: (
// b/c of different `OpenAI` client typing
// that is not relevant here.

client: llmOptions.openAiClient as unknown as Parameters<
client: openAiClient as unknown as Parameters<
typeof Factuality
>[0]["client"],
model: llmOptions.model,
Expand All @@ -101,11 +105,12 @@ export const makeAnswerCorrectness: (
};

export const makeReferenceAlignmentCouncil: (
openAiClient: OpenAI,
llmOptions: LlmOptions[]
) => NlPromptResponseEvalScorer = (llmOptions) => {
) => NlPromptResponseEvalScorer = (openAiClient, llmOptions) => {
assert(llmOptions.length > 0, "At least one LLM must be provided");
const factualityMetrics = llmOptions.map((llmOption) =>
makeReferenceAlignment(llmOption)
makeReferenceAlignment(openAiClient, llmOption)
);
return async function ({ input, output, expected }) {
const name = "ReferenceAlignmentCouncil";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,18 @@ import { NlPromptResponseEvalTask } from "./NlQuestionAnswerEval";
import { OpenAI } from "mongodb-rag-core/openai";

interface MakeNlPromptCompletionTaskParams {
openAiClient: OpenAI;
llmOptions: LlmOptions;
initialMessages?: OpenAI.Chat.ChatCompletionMessageParam[];
}

export function makeNlPromptCompletionTask({
openAiClient,
llmOptions,
initialMessages,
}: MakeNlPromptCompletionTaskParams): NlPromptResponseEvalTask {
return async function (input) {
const { openAiClient, ...llmConfig } = llmOptions;
const { ...llmConfig } = llmOptions;
const res = await openAiClient.chat.completions.create({
messages: [...(initialMessages ?? []), ...input.messages],
stream: false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,21 +60,24 @@ export async function runNlPromptResponseBenchmark({
"At least one judge model must be configured in 'judgeModelsConfig'. Check your model labels in 'globalConfig.ts'."
);

const judgeClients = await Promise.all(
const judgeConfigs = await Promise.all(
judgeModelsConfig.map(async (m) => {
const endpointAndKey = await getOpenAiEndpointAndApiKey(m);
console.log(`Judge model: ${m.label}`);
return {
openAiClient: new OpenAI(endpointAndKey),
model: m.deployment,
temperature: 0,
label: m.label,
config: {
model: m.deployment,
temperature: 0,
label: m.label,
},
};
})
);
const judgeMetrics = judgeClients.map(({ label, ...config }) =>
makeReferenceAlignment(config, label)
);
const judgeMetrics = judgeConfigs.map(({ openAiClient, config }) => {
const { label, ...llmOptions } = config;
return makeReferenceAlignment(openAiClient, llmOptions, label);
});

await PromisePool.for(models)
.withConcurrency(maxConcurrentExperiments)
Expand Down Expand Up @@ -121,6 +124,7 @@ export async function runNlPromptResponseBenchmark({
...staticLlmOptions,
},
task: makeNlPromptCompletionTask({
openAiClient,
llmOptions,
initialMessages: [systemMessage],
}),
Expand Down
13 changes: 8 additions & 5 deletions packages/benchmarks/src/test/makeSampleLlmOptions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@ import { assertEnvVars, BRAINTRUST_ENV_VARS } from "mongodb-rag-core";
import { LlmOptions } from "mongodb-rag-core/executeCode";
import { OpenAI } from "mongodb-rag-core/openai";

export const makeSampleLlmOptions = () => {
export const makeSampleLlm = () => {
const { BRAINTRUST_API_KEY, BRAINTRUST_ENDPOINT } =
assertEnvVars(BRAINTRUST_ENV_VARS);
return new OpenAI({
apiKey: BRAINTRUST_API_KEY,
baseURL: BRAINTRUST_ENDPOINT,
});
};

export const makeSampleLlmOptions = () => {
return {
openAiClient: new OpenAI({
apiKey: BRAINTRUST_API_KEY,
baseURL: BRAINTRUST_ENDPOINT,
}),
model: "gpt-4o",
temperature: 0,
} satisfies LlmOptions;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import { assertEnvVars, BRAINTRUST_ENV_VARS } from "mongodb-rag-core";
import { TEXT_TO_DRIVER_ENV_VARS } from "../../../TextToDriverEnvVars";
import { ModelConfig, models } from "mongodb-rag-core/models";
import { strict as assert } from "assert";
import { LlmOptions } from "mongodb-rag-core/executeCode";
import {
SchemaStrategy,
SystemPromptStrategy,
} from "../../../generateDriverCode/languagePrompts/PromptStrategies";

export const DATASET_NAME = "atlas-sample-dataset-claude";

export const PROJECT_NAME = "natural-language-to-mongosh";

export const EXPERIMENT_BASE_NAME = "mongosh-benchmark-claude";

export const {
BRAINTRUST_API_KEY,
BRAINTRUST_ENDPOINT,
MONGODB_TEXT_TO_DRIVER_CONNECTION_URI,
} = assertEnvVars({
...TEXT_TO_DRIVER_ENV_VARS,
...BRAINTRUST_ENV_VARS,
});

export const schemaStrategies: SchemaStrategy[] = [
"annotated",
"interpreted",
"none",
] as const;
export const systemPromptStrategies: SystemPromptStrategy[] = [
"default",
"chainOfThought",
"lazy",
] as const;

export const experimentTypes = [
"toolCall",
"agentic",
"promptCompletion",
] as const;

export const fewShot = [true, false];

export interface Experiment {
model: (typeof MODELS)[number];
schemaStrategy: SchemaStrategy;
systemPromptStrategy: SystemPromptStrategy;
type: (typeof experimentTypes)[number];
fewShot?: boolean;
}

export const MAX_CONCURRENT_EXPERIMENTS = 2;

export const MAX_CONCURRENT_MODELS = 2;

export const MODELS: ModelConfig[] = (
[
// benchmark models
// "gpt-4.1-mini",
// "gemini-2.5-flash-preview-05-20",
// "gemini-2.5-pro-preview-05-06",
// "gpt-4.1-nano",
// "gpt-4.1",
// "o3",
// "o4-mini",
"anthropic/claude-opus-4",
"anthropic/claude-sonnet-4",
// "claude-opus-4",
// "claude-sonnet-4",
] satisfies (typeof models)[number]["label"][]
).map((label) => {
const model = models.find((m) => m.label === label);
assert(model, `Model ${label} not found`);
return model;
});

export function makeLlmOptions(model: ModelConfig): LlmOptions {
// Different because o3-mini has slightly different options
if (model.label === "o3-mini") {
return {
model: model.deployment,
max_completion_tokens: 3000,
};
}
return {
model: model.deployment,
temperature: 0,
max_tokens: 3000,
};
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import { makeTextToDriverEval } from "../../../TextToDriverEval";
import { loadTextToDriverBraintrustEvalCases } from "../../../loadBraintrustDatasets";
import {
ReasonableOutput,
SuccessfulExecution,
} from "../../../evaluationMetrics";
import { annotatedDbSchemas } from "../../../generateDriverCode/annotatedDbSchemas";
import { createOpenAI } from "@ai-sdk/openai";
import { wrapAISDKModel } from "mongodb-rag-core/braintrust";
import {
BRAINTRUST_API_KEY,
DATASET_NAME,
PROJECT_NAME,
MONGODB_TEXT_TO_DRIVER_CONNECTION_URI,
makeLlmOptions,
MAX_CONCURRENT_EXPERIMENTS,
MODELS,
EXPERIMENT_BASE_NAME,
Experiment,
} from "./config";
import PromisePool from "@supercharge/promise-pool";
import { makeGenerateMongoshCodePromptCompletionTask } from "../../../generateDriverCode/generateMongoshCodePromptCompletion";
import { getOpenAiEndpointAndApiKey } from "mongodb-rag-core/models";
import { makeExperimentName } from "../../../makeExperimentName";

async function main() {
await PromisePool.for(MODELS)
.withConcurrency(MAX_CONCURRENT_EXPERIMENTS)
.handleError((error, model) => {
console.error(
`Error running experiment for model ${model.label}:`,
JSON.stringify(error)
);
})
.process(async (model) => {
const llmOptions = makeLlmOptions(model);
const experiment: Experiment = {
model,
schemaStrategy: "annotated",
systemPromptStrategy: "default",
type: "promptCompletion",
};
const experimentName = makeExperimentName({
baseName: EXPERIMENT_BASE_NAME,
experimentType: experiment.type,
model: model.label,
systemPromptStrategy: experiment.systemPromptStrategy,
schemaStrategy: experiment.schemaStrategy,
});
console.log(`Running experiment: ${experimentName}`);

await makeTextToDriverEval({
apiKey: BRAINTRUST_API_KEY,
projectName: PROJECT_NAME,
experimentName,
data: loadTextToDriverBraintrustEvalCases({
apiKey: BRAINTRUST_API_KEY,
projectName: PROJECT_NAME,
datasetName: DATASET_NAME,
}),
maxConcurrency: model.maxConcurrency,

task: makeGenerateMongoshCodePromptCompletionTask({
uri: MONGODB_TEXT_TO_DRIVER_CONNECTION_URI,
databaseInfos: annotatedDbSchemas,
llmOptions,
systemPromptStrategy: experiment.systemPromptStrategy,
openai: wrapAISDKModel(
createOpenAI({
...(await getOpenAiEndpointAndApiKey(model)),
}).chat(llmOptions.model, {
structuredOutputs: true,
})
),
schemaStrategy: experiment.schemaStrategy,
}),
metadata: {
llmOptions,
...experiment,
},
scores: [SuccessfulExecution, ReasonableOutput],
});
});
}

main();
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,12 @@ export async function loadTextToDriverBraintrustEvalCases({
apiKey,
dataset: datasetName,
}).fetchedData()
).map((d) => TextToDriverEvalCaseSchema.parse(d));
).map((d) => {
if (d.tags === null) {
d.tags = undefined;
}
return TextToDriverEvalCaseSchema.parse(d);
});

return dataset;
}
Loading