Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use AI-Mask within Worker #19

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions app/lib/chat_models/ai_mask.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import {
SimpleChatModel,
type BaseChatModelParams,
} from "@langchain/core/language_models/chat_models";
import type { BaseLanguageModelCallOptions } from "@langchain/core/language_models/base";
import { BaseMessage, AIMessageChunk } from "@langchain/core/messages";
import { AIMaskClient, ChatCompletionMessageParam } from "@ai-mask/sdk";
import { ChatGenerationChunk } from "@langchain/core/outputs";

export interface AIMaskInputs extends BaseChatModelParams {
modelId: string;
temperature?: number;
aiMaskClient?: AIMaskClient;
appName?: string;
}

export interface AIMaskCallOptions extends BaseLanguageModelCallOptions {}

function convertMessages(
messages: BaseMessage[],
): ChatCompletionMessageParam[] {
return messages.map((message) => {
let role: ChatCompletionMessageParam["role"],
content: ChatCompletionMessageParam["content"];
if (message._getType() === "human") {
role = "user";
} else if (message._getType() === "ai") {
role = "assistant";
} else if (message._getType() === "system") {
role = "system";
} else {
throw new Error(
`Unsupported message type for AIMask: ${message._getType()}`,
);
}
if (typeof message.content === "string") {
content = message.content;
} else {
throw new Error("unsupported content type");
}
return { role, content };
});
}

/**
* @example
* ```typescript
* // Initialize the ChatAIMask model with the path to the model binary file.
* const model = new ChatAIMask({
* modelId: "Mistral-7B-Instruct-v0.2-q4f16_1",
* });
*
* // Call the model with a message and await the response.
* const response = await model.call([
* new HumanMessage({ content: "My name is John." }),
* ]);
*
* // Log the response to the console.
* console.log({ response });
*
* ```
*/
export class ChatAIMask extends SimpleChatModel<AIMaskCallOptions> {
static inputs: AIMaskInputs;

protected _aiMaskClient: AIMaskClient;

modelId: string;
temperature?: number;

static lc_name() {
return "ChatAIMask";
}

constructor(inputs: AIMaskInputs) {
super(inputs);

this._aiMaskClient =
inputs?.aiMaskClient ?? new AIMaskClient({ name: inputs?.appName });

this.modelId = inputs.modelId;
this.temperature = inputs.temperature;
}

_llmType() {
return "ai-mask";
}

async *_streamResponseChunks(
messages: BaseMessage[],
): AsyncGenerator<ChatGenerationChunk> {
const stream = await this._aiMaskClient.chat(
{
messages: convertMessages(messages),
temperature: this.temperature,
},
{
modelId: this.modelId,
stream: true,
},
);

for await (const chunk of stream) {
const text = chunk;
yield new ChatGenerationChunk({
text,
message: new AIMessageChunk({
content: text,
}),
});
}
return stream;
}

async _call(messages: BaseMessage[]): Promise<string> {
const completion = await this._aiMaskClient.chat(
{
messages: convertMessages(messages),
temperature: this.temperature,
},
{
modelId: this.modelId,
},
);
return completion;
}
}
108 changes: 108 additions & 0 deletions app/lib/embeddings/ai_mask.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import { Pipeline, pipeline } from "@xenova/transformers";
import { Embeddings, type EmbeddingsParams } from "@langchain/core/embeddings";
import { chunkArray } from "@langchain/core/utils/chunk_array";
import { AIMaskClient } from "@ai-mask/sdk";

export interface AIMaskEmbeddingsParams extends EmbeddingsParams {
/** Model name to use */
modelName: string;

/**
* Timeout to use when making requests to OpenAI.
*/
timeout?: number;

/**
* The maximum number of documents to embed in a single request.
*/
batchSize?: number;

/**
* Whether to strip new lines from the input text. This is recommended by
* OpenAI, but may not be suitable for all use cases.
*/
stripNewLines?: boolean;
aiMaskClient?: AIMaskClient;
appName?: string;
}

/**
* @example
* ```typescript
* const model = new HuggingFaceTransformersEmbeddings({
* modelName: "Xenova/all-MiniLM-L6-v2",
* });
*
* // Embed a single query
* const res = await model.embedQuery(
* "What would be a good company name for a company that makes colorful socks?"
* );
* console.log({ res });
*
* // Embed multiple documents
* const documentRes = await model.embedDocuments(["Hello world", "Bye bye"]);
* console.log({ documentRes });
* ```
*/
export class AIMaskEmbeddings
extends Embeddings
implements AIMaskEmbeddingsParams
{
modelName = "Xenova/all-MiniLM-L6-v2";

batchSize = 512;

stripNewLines = true;

timeout?: number;

protected _aiMaskClient: AIMaskClient;

constructor(fields?: Partial<AIMaskEmbeddingsParams>) {
super(fields ?? {});

this.modelName = fields?.modelName ?? this.modelName;
this.stripNewLines = fields?.stripNewLines ?? this.stripNewLines;
this.timeout = fields?.timeout;

this._aiMaskClient =
fields?.aiMaskClient ?? new AIMaskClient({ name: fields?.appName });
}

async embedDocuments(texts: string[]): Promise<number[][]> {
const batches = chunkArray(
this.stripNewLines ? texts.map((t) => t.replace(/\n/g, " ")) : texts,
this.batchSize,
);

const batchRequests = batches.map((batch) => this.runEmbedding(batch));
const batchResponses = await Promise.all(batchRequests);
const embeddings: number[][] = [];

for (let i = 0; i < batchResponses.length; i += 1) {
const batchResponse = batchResponses[i];
for (let j = 0; j < batchResponse.length; j += 1) {
embeddings.push(batchResponse[j]);
}
}

return embeddings;
}

async embedQuery(text: string): Promise<number[]> {
const data = await this.runEmbedding([
this.stripNewLines ? text.replace(/\n/g, " ") : text,
]);
return data[0];
}

private async runEmbedding(texts: string[]) {
return this.caller.call(async () => {
const output = await this._aiMaskClient.featureExtraction(
{ texts, pooling: "mean", normalize: true },
{ modelId: this.modelName },
);
return output;
});
}
}
85 changes: 67 additions & 18 deletions app/worker.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { ChatWindowMessage } from "@/schema/ChatWindowMessage";

import { Voy as VoyClient } from "voy-search";
import { AIMaskClient } from "@ai-mask/sdk";

import { createRetrievalChain } from "langchain/chains/retrieval";
import { createStuffDocumentsChain } from "langchain/chains/combine_documents";
Expand All @@ -9,6 +10,7 @@ import { createHistoryAwareRetriever } from "langchain/chains/history_aware_retr
import { WebPDFLoader } from "langchain/document_loaders/web/pdf";

import { HuggingFaceTransformersEmbeddings } from "@langchain/community/embeddings/hf_transformers";
import { VectorStore } from "@langchain/core/vectorstores";
import { VoyVectorStore } from "@langchain/community/vectorstores/voy";
import {
ChatPromptTemplate,
Expand All @@ -29,16 +31,14 @@ import { LangChainTracer } from "@langchain/core/tracers/tracer_langchain";
import { Client } from "langsmith";

import { ChatOllama } from "@langchain/community/chat_models/ollama";
import { ChatAIMask } from "./lib/chat_models/ai_mask";
import { AIMaskEmbeddings } from "./lib/embeddings/ai_mask";
import { ChatWebLLM } from "./lib/chat_models/webllm";

const embeddings = new HuggingFaceTransformersEmbeddings({
modelName: "Xenova/all-MiniLM-L6-v2",
// Can use "nomic-ai/nomic-embed-text-v1" for more powerful but slower embeddings
// modelName: "nomic-ai/nomic-embed-text-v1",
});
let aiMaskClient: AIMaskClient;

const voyClient = new VoyClient();
const vectorstore = new VoyVectorStore(voyClient, embeddings);
let vectorstore: VectorStore;

const OLLAMA_RESPONSE_SYSTEM_TEMPLATE = `You are an experienced researcher, expert at interpreting and answering questions based on provided sources. Using the provided context, answer the user's question to the best of your ability using the resources provided.
Generate a concise answer for a given question based solely on the provided search results. You must only use information from the provided search results. Use an unbiased and journalistic tone. Combine search results together into a coherent answer. Do not repeat text.
Expand All @@ -54,7 +54,26 @@ const WEBLLM_RESPONSE_SYSTEM_TEMPLATE = `You are an experienced researcher, expe
Generate a concise answer for a given question based solely on the provided search results. You must only use information from the provided search results. Use an unbiased and journalistic tone. Combine search results together into a coherent answer. Do not repeat text, stay focused, and stop generating when you have answered the question.
If there is nothing in the context relevant to the question at hand, just say "Hmm, I'm not sure." Don't try to make up an answer.`;

const embedPDF = async (pdfBlob: Blob) => {
const embedPDF = async (pdfBlob: Blob, modelProvider: string) => {
if (modelProvider === "ai-mask") {
if (!aiMaskClient) {
throw new Error("AIMaskClient has not finished inititializing");
}

const embeddingsAIMask = new AIMaskEmbeddings({
modelName: "Xenova/all-MiniLM-L6-v2",
aiMaskClient,
});
vectorstore = new VoyVectorStore(voyClient, embeddingsAIMask);
} else {
const embeddings = new HuggingFaceTransformersEmbeddings({
modelName: "Xenova/all-MiniLM-L6-v2",
// Can use "nomic-ai/nomic-embed-text-v1" for more powerful but slower embeddings
// modelName: "nomic-ai/nomic-embed-text-v1",
});
vectorstore = new VoyVectorStore(voyClient, embeddings);
}

const pdfLoader = new WebPDFLoader(pdfBlob, { parsedItemSeparator: " " });
const docs = await pdfLoader.load();

Expand Down Expand Up @@ -97,6 +116,9 @@ const queryVectorStore = async (
devModeTracer?: LangChainTracer;
},
) => {
if (!vectorstore) {
throw new Error("Vector store not initialized");
}
const text = messages[messages.length - 1].content;
const chatHistory = await _formatChatHistoryAsMessages(messages.slice(0, -1));

Expand Down Expand Up @@ -218,26 +240,49 @@ self.addEventListener("message", async (event: { data: any }) => {

if (event.data.pdf) {
try {
await embedPDF(event.data.pdf);
await embedPDF(event.data.pdf, event.data.modelProvider);
} catch (e: any) {
self.postMessage({
type: "error",
error: e.message,
});
throw e;
}
} else {
} else if (event.data.messages) {
const modelProvider = event.data.modelProvider;
const modelConfig = event.data.modelConfig;
let chatModel: BaseChatModel | LanguageModelLike =
modelProvider === "ollama"
? new ChatOllama(modelConfig)
: new ChatWebLLM(modelConfig);
if (modelProvider === "webllm") {
await (chatModel as ChatWebLLM).initialize((event) =>
self.postMessage({ type: "init_progress", data: event }),
);
chatModel = chatModel.bind({ stop: ["\nInstruct:", "Instruct:"] });
let chatModel: BaseChatModel | LanguageModelLike;
switch (modelProvider) {
case "ollama":
chatModel = new ChatOllama(modelConfig);
break;
case "web-llm":
chatModel = new ChatWebLLM(modelConfig);
await (chatModel as ChatWebLLM).initialize((event) =>
self.postMessage({ type: "init_progress", data: event }),
);
chatModel = chatModel.bind({ stop: ["\nInstruct:", "Instruct:"] });
break;
case "ai-mask":
if (!aiMaskClient) {
self.postMessage({
type: "error",
error: "AIMaskClient has not finished inititializing",
});
return;
}
chatModel = new ChatAIMask({
...modelConfig,
aiMaskClient,
});
chatModel = chatModel.bind({ stop: ["\nInstruct:", "Instruct:"] });
break;
default:
self.postMessage({
type: "error",
error: "Invalid model provider",
});
throw new Error("Invalid model provider");
}
try {
await queryVectorStore(event.data.messages, {
Expand All @@ -262,3 +307,7 @@ self.addEventListener("message", async (event: { data: any }) => {
data: "OK",
});
});

(async () => {
aiMaskClient = await AIMaskClient.getWorkerClient();
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See below, could this just trigger on some initialization event, then send a initializationComplete once this await returns or it times out?

})();
Loading