diff --git a/src/vanna/chromadb/chromadb_vector.py b/src/vanna/chromadb/chromadb_vector.py index 5c44b650..032cb0cc 100644 --- a/src/vanna/chromadb/chromadb_vector.py +++ b/src/vanna/chromadb/chromadb_vector.py @@ -1,7 +1,6 @@ import json -from typing import List import uuid -from abc import abstractmethod +from typing import List import chromadb import pandas as pd @@ -20,13 +19,26 @@ def __init__(self, config=None): if config is not None: path = config.get("path", ".") self.embedding_function = config.get("embedding_function", default_ef) + curr_client = config.get("client", "persistent") else: path = "." self.embedding_function = default_ef + curr_client = "persistent" # defaults to persistent storage + + if curr_client == "persistent": + self.chroma_client = chromadb.PersistentClient( + path=path, settings=Settings(anonymized_telemetry=False) + ) + elif curr_client == "in-memory": + self.chroma_client = chromadb.EphemeralClient( + settings=Settings(anonymized_telemetry=False) + ) + elif isinstance(curr_client, chromadb.api.client.Client): + # allow providing client directly + self.chroma_client = curr_client + else: + raise ValueError(f"Unsupported client was set in config: {curr_client}") - self.chroma_client = chromadb.PersistentClient( - path=path, settings=Settings(anonymized_telemetry=False) - ) self.documentation_collection = self.chroma_client.get_or_create_collection( name="documentation", embedding_function=self.embedding_function ) @@ -196,7 +208,8 @@ def _extract_documents(query_results) -> list: query_results (pd.DataFrame): The dataframe to use. Returns: - List[str] or None: The extracted documents, or an empty list or single document if an error occurred. + List[str] or None: The extracted documents, or an empty list or + single document if an error occurred. """ if query_results is None: return []