From 59fd84fd2cb9c8ef416407e9090ff989b2fd7b46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Pedersen?= Date: Tue, 17 Sep 2024 11:53:35 +0200 Subject: [PATCH 1/9] Drafted support for PGVector --- src/vanna/pgvector/__init__.py | 1 + src/vanna/pgvector/pgvector.py | 250 +++++++++++++++++++++++++++++++++ 2 files changed, 251 insertions(+) create mode 100644 src/vanna/pgvector/__init__.py create mode 100644 src/vanna/pgvector/pgvector.py diff --git a/src/vanna/pgvector/__init__.py b/src/vanna/pgvector/__init__.py new file mode 100644 index 00000000..dd152a30 --- /dev/null +++ b/src/vanna/pgvector/__init__.py @@ -0,0 +1 @@ +from .pgvector import PG_VectorStore diff --git a/src/vanna/pgvector/pgvector.py b/src/vanna/pgvector/pgvector.py new file mode 100644 index 00000000..672a2294 --- /dev/null +++ b/src/vanna/pgvector/pgvector.py @@ -0,0 +1,250 @@ +import ast +import json +import logging +import uuid + +import pandas as pd +from langchain_core.documents import Document +from langchain_postgres.vectorstores import PGVector +from sqlalchemy import create_engine, text + +from vanna import ValidationError +from vanna.base.base import VannaBase +from vanna.types import TrainingPlan, TrainingPlanItem + + +class PG_VectorStore(VannaBase): + def __init__(self, config=None): + if not config or "connection_string" not in config: + raise ValueError( + "A valid 'config' dictionary with a 'connection_string' is required.") + + VannaBase.__init__(self, config=config) + + if config and "connection_string" in config: + self.connection_string = config.get("connection_string") + self.n_results = config.get("n_results", 10) + + if config and "embedding_function" in config: + self.embedding_function = config.get("embedding_function") + else: + from langchain_core.embeddings import HuggingFaceEmbeddings + self.embedding_function = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-l6-v2") + + self.sql_vectorstore = PGVector( + embedding_function=self.embedding_function, + collection_name="sql", + connection_string=self.connection_string, + ) + self.ddl_vectorstore = PGVector( + embedding_function=self.embedding_function, + collection_name="ddl", + connection_string=self.connection_string, + ) + self.documentation_vectorstore = PGVector( + embedding_function=self.embedding_function, + collection_name="documentation", + connection_string=self.connection_string, + ) + + def add_question_sql(self, question: str, sql: str, **kwargs) -> str: + question_sql_json = json.dumps( + { + "question": question, + "sql": sql, + }, + ensure_ascii=False, + ) + id = str(uuid.uuid4()) + "-sql" + createdat = kwargs.get("createdat") + doc = Document( + page_content=question_sql_json, + metadata={"id": id, "createdat": createdat}, + ) + self.sql_collection.add_documents([doc], ids=[doc.metadata["id"]]) + + return id + + def add_ddl(self, ddl: str, **kwargs) -> str: + _id = str(uuid.uuid4()) + "-ddl" + doc = Document( + page_content=ddl, + metadata={"id": _id}, + ) + self.ddl_collection.add_documents([doc], ids=[doc.metadata["id"]]) + return _id + + def add_documentation(self, documentation: str, **kwargs) -> str: + _id = str(uuid.uuid4()) + "-doc" + doc = Document( + page_content=documentation, + metadata={"id": _id}, + ) + self.documentation_collection.add_documents([doc], ids=[doc.metadata["id"]]) + return _id + + def get_collection(self, collection_name): + match collection_name: + case "sql": + return self.sql_collection + case "ddl": + return self.ddl_collection + case "documentation": + return self.documentation_collection + case _: + raise ValueError("Specified collection does not exist.") + + async def get_similar_question_sql(self, question: str) -> list: + documents = self.sql_collection.similarity_search(query=question, k=self.n_results) + return [ast.literal_eval(document.page_content) for document in documents] + + async def get_related_ddl(self, question: str, **kwargs) -> list: + documents = await self.ddl_collection.similarity_search(query=question, k=self.n_results) + return [document.page_content for document in documents] + + async def get_related_documentation(self, question: str, **kwargs) -> list: + documents = await self.documentation_collection.similarity_search(query=question, k=self.n_results) + return [document.page_content for document in documents] + + def train( + self, + question: str | None = None, + sql: str | None = None, + ddl: str | None = None, + documentation: str | None = None, + plan: TrainingPlan | None = None, + createdat: str | None = None, + ): + if question and not sql: + raise ValidationError("Please provide a SQL query.") + + if documentation: + logging.info(f"Adding documentation: {documentation}") + return self.add_documentation(documentation) + + if sql and question: + return self.add_question_sql(question=question, sql=sql, createdat=createdat) + + if ddl: + logging.info(f"Adding ddl: {ddl}") + return self.add_ddl(ddl) + + if plan: + for item in plan._plan: + if item.item_type == TrainingPlanItem.ITEM_TYPE_DDL: + self.add_ddl(item.item_value) + elif item.item_type == TrainingPlanItem.ITEM_TYPE_IS: + self.add_documentation(item.item_value) + elif item.item_type == TrainingPlanItem.ITEM_TYPE_SQL and item.item_name: + self.add_question_sql(question=item.item_name, sql=item.item_value) + + def get_training_data(self, **kwargs) -> pd.DataFrame: + # Establishing the connection + engine = create_engine(self.connection_string) + + # Querying the 'langchain_pg_embedding' table + query_embedding = "SELECT cmetadata, document FROM langchain_pg_embedding" + df_embedding = pd.read_sql(query_embedding, engine) + + # List to accumulate the processed rows + processed_rows = [] + + # Process each row in the DataFrame + for _, row in df_embedding.iterrows(): + custom_id = row["cmetadata"]["id"] + document = row["document"] + training_data_type = "documentation" if custom_id[-3:] == "doc" else custom_id[-3:] + + if training_data_type == "sql": + # Convert the document string to a dictionary + try: + doc_dict = ast.literal_eval(document) + question = doc_dict.get("question") + content = doc_dict.get("sql") + except (ValueError, SyntaxError): + logging.info(f"Skipping row with custom_id {custom_id} due to parsing error.") + continue + elif training_data_type in ["documentation", "ddl"]: + question = None # Default value for question + content = document + else: + # If the suffix is not recognized, skip this row + logging.info(f"Skipping row with custom_id {custom_id} due to unrecognized training data type.") + continue + + # Append the processed data to the list + processed_rows.append( + {"id": custom_id, "question": question, "content": content, "training_data_type": training_data_type} + ) + + # Create a DataFrame from the list of processed rows + df_processed = pd.DataFrame(processed_rows) + + return df_processed + + def remove_training_data(self, id: str, **kwargs) -> bool: + # Create the database engine + engine = create_engine(self.connection_string) + + # SQL DELETE statement + delete_statement = text( + """ + DELETE FROM langchain_pg_embedding + WHERE cmetadata ->> 'id' = :id + """ + ) + + # Connect to the database and execute the delete statement + with engine.connect() as connection: + # Start a transaction + with connection.begin() as transaction: + try: + result = connection.execute(delete_statement, {"id": id}) + # Commit the transaction if the delete was successful + transaction.commit() + # Check if any row was deleted and return True or False accordingly + return result.rowcount > 0 + except Exception as e: + # Rollback the transaction in case of error + logging.error(f"An error occurred: {e}") + transaction.rollback() + return False + + def remove_collection(self, collection_name: str) -> bool: + engine = create_engine(self.connection_string) + + # Determine the suffix to look for based on the collection name + suffix_map = {"ddl": "ddl", "sql": "sql", "documentation": "doc"} + suffix = suffix_map.get(collection_name) + + if not suffix: + logging.info("Invalid collection name. Choose from 'ddl', 'sql', or 'documentation'.") + return False + + # SQL query to delete rows based on the condition + query = text( + f""" + DELETE FROM langchain_pg_embedding + WHERE cmetadata->>'id' LIKE '%{suffix}' + """ + ) + + # Execute the deletion within a transaction block + with engine.connect() as connection: + with connection.begin() as transaction: + try: + result = connection.execute(query) + transaction.commit() # Explicitly commit the transaction + if result.rowcount > 0: + logging.info( + f"Deleted {result.rowcount} rows from " + f"langchain_pg_embedding where collection is {collection_name}." + ) + return True + else: + logging.info(f"No rows deleted for collection {collection_name}.") + return False + except Exception as e: + logging.error(f"An error occurred: {e}") + transaction.rollback() # Rollback in case of error + return False From f1b8872cb34545e04162da0ca88c1173a9f33b4b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Pedersen?= Date: Tue, 17 Sep 2024 11:53:46 +0200 Subject: [PATCH 2/9] Included PGVector deps to pyproject --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 772a20f3..e21d8efc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ bigquery = ["google-cloud-bigquery"] snowflake = ["snowflake-connector-python"] duckdb = ["duckdb"] google = ["google-generativeai", "google-cloud-aiplatform"] -all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "qianfan", "mistralai>=1.0.0", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers", "pinecone-client", "pymilvus[model]","weaviate-client", "azure-search-documents", "azure-identity", "azure-common"] +all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "qianfan", "mistralai>=1.0.0", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers", "pinecone-client", "pymilvus[model]","weaviate-client", "azure-search-documents", "azure-identity", "azure-common", "langchain-postgres"] test = ["tox"] chromadb = ["chromadb"] openai = ["openai"] @@ -53,3 +53,4 @@ milvus = ["pymilvus[model]"] bedrock = ["boto3", "botocore"] weaviate = ["weaviate-client"] azuresearch = ["azure-search-documents", "azure-identity", "azure-common", "fastembed"] +pgvector = ["langchain-postgres>=0.0.12"] From a72f9c197c7007222cc64de6452b3a84a6553245 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Pedersen?= Date: Wed, 18 Sep 2024 12:50:39 +0200 Subject: [PATCH 3/9] Fixed PGVector construction to work with latest version; added missing methods to class --- src/vanna/pgvector/pgvector.py | 37 ++++++++++++++++++++++++++-------- 1 file changed, 29 insertions(+), 8 deletions(-) diff --git a/src/vanna/pgvector/pgvector.py b/src/vanna/pgvector/pgvector.py index 672a2294..fb6e0bbd 100644 --- a/src/vanna/pgvector/pgvector.py +++ b/src/vanna/pgvector/pgvector.py @@ -28,23 +28,23 @@ def __init__(self, config=None): if config and "embedding_function" in config: self.embedding_function = config.get("embedding_function") else: - from langchain_core.embeddings import HuggingFaceEmbeddings - self.embedding_function = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-l6-v2") + from sentence_transformers import SentenceTransformer + self.embedding_function = SentenceTransformer("sentence-transformers/all-MiniLM-l6-v2") self.sql_vectorstore = PGVector( - embedding_function=self.embedding_function, + embeddings=self.embedding_function, collection_name="sql", - connection_string=self.connection_string, + connection=self.connection_string, ) self.ddl_vectorstore = PGVector( - embedding_function=self.embedding_function, + embeddings=self.embedding_function, collection_name="ddl", - connection_string=self.connection_string, + connection=self.connection_string, ) self.documentation_vectorstore = PGVector( - embedding_function=self.embedding_function, + embeddings=self.embedding_function, collection_name="documentation", - connection_string=self.connection_string, + connection=self.connection_string, ) def add_question_sql(self, question: str, sql: str, **kwargs) -> str: @@ -248,3 +248,24 @@ def remove_collection(self, collection_name: str) -> bool: logging.error(f"An error occurred: {e}") transaction.rollback() # Rollback in case of error return False + + + def assistant_message(self, *args, **kwargs): + # Implement the method + pass + + def generate_embedding(self, *args, **kwargs): + # Implement the method + pass + + def submit_prompt(self, *args, **kwargs): + # Implement the method + pass + + def system_message(self, *args, **kwargs): + # Implement the method + pass + + def user_message(self, *args, **kwargs): + # Implement the method + pass From 07b481356df003f6e6f9f9f1ca32dba8f8516890 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Pedersen?= Date: Wed, 18 Sep 2024 12:50:55 +0200 Subject: [PATCH 4/9] Updated import test --- tests/test_imports.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_imports.py b/tests/test_imports.py index 0efd180f..dc3509be 100644 --- a/tests/test_imports.py +++ b/tests/test_imports.py @@ -2,6 +2,7 @@ def test_regular_imports(): from vanna.anthropic.anthropic_chat import Anthropic_Chat + from vanna.azuresearch.azuresearch_vector import AzureAISearch_VectorStore from vanna.base.base import VannaBase from vanna.chromadb.chromadb_vector import ChromaDB_VectorStore from vanna.hf.hf import Hf @@ -13,16 +14,17 @@ def test_regular_imports(): from vanna.openai.openai_chat import OpenAI_Chat from vanna.openai.openai_embeddings import OpenAI_Embeddings from vanna.opensearch.opensearch_vector import OpenSearch_VectorStore + from vanna.pgvector.pgvector import PG_VectorStore from vanna.pinecone.pinecone_vector import PineconeDB_VectorStore from vanna.remote import VannaDefault from vanna.vannadb.vannadb_vector import VannaDB_VectorStore from vanna.weaviate.weaviate_vector import WeaviateDatabase from vanna.ZhipuAI.ZhipuAI_Chat import ZhipuAI_Chat from vanna.ZhipuAI.ZhipuAI_embeddings import ZhipuAI_Embeddings - from vanna.azuresearch.azuresearch_vector import AzureAISearch_VectorStore def test_shortcut_imports(): from vanna.anthropic import Anthropic_Chat + from vanna.azuresearch import AzureAISearch_VectorStore from vanna.base import VannaBase from vanna.chromadb import ChromaDB_VectorStore from vanna.hf import Hf @@ -32,9 +34,9 @@ def test_shortcut_imports(): from vanna.ollama import Ollama from vanna.openai import OpenAI_Chat, OpenAI_Embeddings from vanna.opensearch import OpenSearch_VectorStore + from vanna.pgvector import PG_VectorStore from vanna.pinecone import PineconeDB_VectorStore from vanna.vannadb import VannaDB_VectorStore from vanna.vllm import Vllm from vanna.weaviate import WeaviateDatabase from vanna.ZhipuAI import ZhipuAI_Chat, ZhipuAI_Embeddings - from vanna.azuresearch import AzureAISearch_VectorStore \ No newline at end of file From 9efb82969eb2150fc4d21081754cbcc8b1339734 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Pedersen?= Date: Wed, 18 Sep 2024 12:51:08 +0200 Subject: [PATCH 5/9] Added simple pgvector test script --- tests/test_pgvector.py | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 tests/test_pgvector.py diff --git a/tests/test_pgvector.py b/tests/test_pgvector.py new file mode 100644 index 00000000..2be6cea5 --- /dev/null +++ b/tests/test_pgvector.py @@ -0,0 +1,6 @@ +from vanna.pgvector import PG_VectorStore + + +def test_pgvector(): + pgclient = PG_VectorStore(config={"connection_string": ""}) + assert pgclient is not None From 3aafae16a011894ee7dc8917853c3a2ad9ce2537 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Pedersen?= Date: Wed, 18 Sep 2024 15:35:02 +0200 Subject: [PATCH 6/9] Updated test script --- tests/test_pgvector.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/tests/test_pgvector.py b/tests/test_pgvector.py index 2be6cea5..46d6d257 100644 --- a/tests/test_pgvector.py +++ b/tests/test_pgvector.py @@ -1,6 +1,26 @@ +import os + +from dotenv import load_dotenv + from vanna.pgvector import PG_VectorStore +load_dotenv() + + +def get_vanna_connection_string(): + server = os.environ.get("PG_SERVER") + driver = "psycopg" + port = 5432 + database = os.environ.get("PG_DATABASE") + username = os.environ.get("PG_USERNAME") + password = os.environ.get("PG_PASSWORD") + + return f"postgresql+psycopg://{username}:{password}@{server}:{port}/{database}" + def test_pgvector(): - pgclient = PG_VectorStore(config={"connection_string": ""}) + connection_string = get_vanna_connection_string() + print("Connection string:") + print(connection_string) + pgclient = PG_VectorStore(config={"connection_string": connection_string}) assert pgclient is not None From a0b955ccb3ef90873a4001466b9f79f93b08948f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Pedersen?= Date: Wed, 18 Sep 2024 16:03:04 +0200 Subject: [PATCH 7/9] pgvector class contruction work; can connect to PG DB --- tests/test_pgvector.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_pgvector.py b/tests/test_pgvector.py index 46d6d257..4bc1dea9 100644 --- a/tests/test_pgvector.py +++ b/tests/test_pgvector.py @@ -10,7 +10,7 @@ def get_vanna_connection_string(): server = os.environ.get("PG_SERVER") driver = "psycopg" - port = 5432 + port = 5434 database = os.environ.get("PG_DATABASE") username = os.environ.get("PG_USERNAME") password = os.environ.get("PG_PASSWORD") @@ -20,7 +20,5 @@ def get_vanna_connection_string(): def test_pgvector(): connection_string = get_vanna_connection_string() - print("Connection string:") - print(connection_string) pgclient = PG_VectorStore(config={"connection_string": connection_string}) assert pgclient is not None From 8dfeb8e3ddc79154d020539685c77aba449bb6a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Pedersen?= Date: Wed, 18 Sep 2024 16:04:22 +0200 Subject: [PATCH 8/9] Implemented message methods --- src/vanna/pgvector/pgvector.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/src/vanna/pgvector/pgvector.py b/src/vanna/pgvector/pgvector.py index fb6e0bbd..b9cc2940 100644 --- a/src/vanna/pgvector/pgvector.py +++ b/src/vanna/pgvector/pgvector.py @@ -249,11 +249,6 @@ def remove_collection(self, collection_name: str) -> bool: transaction.rollback() # Rollback in case of error return False - - def assistant_message(self, *args, **kwargs): - # Implement the method - pass - def generate_embedding(self, *args, **kwargs): # Implement the method pass @@ -262,10 +257,11 @@ def submit_prompt(self, *args, **kwargs): # Implement the method pass - def system_message(self, *args, **kwargs): - # Implement the method - pass + def system_message(self, message: str) -> any: + return {"role": "system", "content": message} - def user_message(self, *args, **kwargs): - # Implement the method - pass + def user_message(self, message: str) -> any: + return {"role": "user", "content": message} + + def assistant_message(self, message: str) -> any: + return {"role": "assistant", "content": message} From 374f6bd8aa9da2642cdb7e373653de073eb5cf88 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Pedersen?= Date: Wed, 18 Sep 2024 16:14:00 +0200 Subject: [PATCH 9/9] Fix imports --- src/vanna/pgvector/pgvector.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/vanna/pgvector/pgvector.py b/src/vanna/pgvector/pgvector.py index b9cc2940..cf0c2a23 100644 --- a/src/vanna/pgvector/pgvector.py +++ b/src/vanna/pgvector/pgvector.py @@ -8,9 +8,9 @@ from langchain_postgres.vectorstores import PGVector from sqlalchemy import create_engine, text -from vanna import ValidationError -from vanna.base.base import VannaBase -from vanna.types import TrainingPlan, TrainingPlanItem +from .. import ValidationError +from ..base import VannaBase +from ..types import TrainingPlan, TrainingPlanItem class PG_VectorStore(VannaBase): @@ -250,11 +250,9 @@ def remove_collection(self, collection_name: str) -> bool: return False def generate_embedding(self, *args, **kwargs): - # Implement the method pass def submit_prompt(self, *args, **kwargs): - # Implement the method pass def system_message(self, message: str) -> any: