diff --git a/docs/api/language_model_clients/Snowflake.md b/docs/api/language_model_clients/Snowflake.md index 7ffc49af3..ac3002821 100644 --- a/docs/api/language_model_clients/Snowflake.md +++ b/docs/api/language_model_clients/Snowflake.md @@ -9,6 +9,7 @@ sidebar_position: 12 ```python import dspy import os +from snowflake.snowpark import Session connection_parameters = { @@ -20,25 +21,31 @@ connection_parameters = { "database": os.getenv('SNOWFLAKE_DATABASE'), "schema": os.getenv('SNOWFLAKE_SCHEMA')} -lm = dspy.Snowflake(model="mixtral-8x7b",credentials=connection_parameters) +# Establish connection to Snowflake +snowpark = Session.builder.configs(connection_parameters).create() + +# Initialize Snowflake Cortex LM +lm = dspy.Snowflake(session=snowpark,model="mixtral-8x7b") ``` ### Constructor -The constructor inherits from the base class `LM` and verifies the `credentials` for using Snowflake API. +The constructor inherits from the base class `LM` and verifies the `session` for using Snowflake API. ```python class Snowflake(LM): def __init__( - self, + self, + session, model, - credentials, **kwargs): ``` **Parameters:** + +- `session` (_object_): Snowflake connection object enabled by the [snowflake snowpark session](https://docs.snowflake.com/en/developer-guide/snowpark/reference/python/latest/snowpark/api/snowflake.snowpark.Session) - `model` (_str_): model hosted by [Snowflake Cortex](https://docs.snowflake.com/en/user-guide/snowflake-cortex/llm-functions#availability). -- `credentials` (_dict_): connection parameters required to initialize a [snowflake snowpark session](https://docs.snowflake.com/en/developer-guide/snowpark/reference/python/latest/api/snowflake.snowpark.Session) + snowpark/reference/python/latest/api/snowflake.snowpark.Session) ### Methods diff --git a/docs/api/retrieval_model_clients/SnowflakeRM.md b/docs/api/retrieval_model_clients/SnowflakeRM.md index c0f4d89bc..ac502d728 100644 --- a/docs/api/retrieval_model_clients/SnowflakeRM.md +++ b/docs/api/retrieval_model_clients/SnowflakeRM.md @@ -6,53 +6,52 @@ sidebar_position: 9 ### Constructor -Initialize an instance of the `SnowflakeRM` class, with the option to use `e5-base-v2` or `snowflake-arctic-embed-m` embeddings or any other Snowflake Cortex supported embeddings model. +Initialize an instance of the `SnowflakeRM` class, which enables user to leverage the Cortex Search service for hybrid retrieval. Before using this, ensure the Cortex Search service is configured as outlined in the documentation [here](https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-search/cortex-search-overview#overview) ```python SnowflakeRM( - snowflake_table_name: str, - snowflake_credentials: dict, + snowflake_session: object, + cortex_search_service: str, + snowflake_database: str, + snowflake_schema: dict, + retrieval_columns: list, + search_filter: dict = None, k: int = 3, - embeddings_field: str, - embeddings_text_field:str, - embeddings_model: str = "e5-base-v2", ) ``` **Parameters:** -- `snowflake_table_name (str)`: The name of the Snowflake table containing embeddings. -- `snowflake_credentials (dict)`: The connection parameters needed to initialize a Snowflake Snowpark Session. +- `snowflake_session (str)`: Snowflake Snowpark session for connecting to Snowflake. +- `cortex_search_service (str)`: The name of the Cortex Search service to be used. +- `snowflake_database (str)`: The name of the Snowflake database to be used with the Cortex Search service. +- `snowflake_schema (str)`: The name of the Snowflake schema to be used with the Cortex Search service. +- `retrieval_columns (str)`: A list of columns to return for each relevant result in the response. +- `search_filter (dict, optional)`: Optional filter object used for filtering results based on data in the ATTRIBUTES columns. See [Filter syntax](https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-search/query-cortex-search-service#filter-syntax) - `k (int, optional)`: The number of top passages to retrieve. Defaults to 3. -- `embeddings_field (str)`: The name of the column in the Snowflake table containing the embeddings. -- `embeddings_text_field (str)`: The name of the column in the Snowflake table containing the passages. -- `embeddings_model (str)`: The model to be used to convert text to embeddings ### Methods -#### `forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = None) -> dspy.Prediction` +#### `def forward(self,query_or_queries: Union[str, list[str]],response_columns:list[str],filters:dict = None, k: Optional[int] = None)-> dspy.Prediction:` -Search the Snowflake table for the top `k` passages matching the given query or queries, using embeddings generated via the default `e5-base-v2` model or the specified `embedding_model`. +Query the Cortex Search service to retrieve the top k relevant results given a query. **Parameters:** - `query_or_queries` (_Union[str, List[str]]_): The query or list of queries to search for. -- `k` (_Optional[int]_, _optional_): The number of results to retrieve. If not specified, defaults to the value set during initialization. +- `k` (_Optional[int]_): The number of results to retrieve. If not specified, defaults to the value set during initialization. **Returns:** -- `dspy.Prediction`: Contains the retrieved passages, each represented as a `dotdict` with schema `[{"id": str, "score": float, "long_text": str, "metadatas": dict }]` +- `dspy.Prediction`: Contains the retrieved passages, each represented as a `dotdict` with schema `[{"long_text": str}]` ### Quickstart -To support passage retrieval, it assumes that a Snowflake table has been created and populated with the passages in a column `embeddings_text_field` and the embeddings in another column `embeddings_field` - -SnowflakeRM uses `e5-base-v2` embeddings model by default or any Snowflake Cortex supported embeddings model. - -#### Default OpenAI Embeddings +To support passage retrieval from a Snowflake table with this integration, a Cortex Search endpoint must be configured. ```python from dspy.retrieve.snowflake_rm import SnowflakeRM +from snowflake.snowpark import Session import os connection_parameters = { @@ -65,14 +64,16 @@ connection_parameters = { "database": os.getenv('SNOWFLAKE_DATABASE'), "schema": os.getenv('SNOWFLAKE_SCHEMA')} -retriever_model = SnowflakeRM( - snowflake_table_name="", - snowflake_credentials=connection_parameters, - embeddings_field="", - embeddings_text_field= "" - ) +# Establish connection to Snowflake +snowpark = Session.builder.configs(connection_parameters).create() + +snowflake_retriever = SnowflakeRM(snowflake_session=snowpark, + snowflake_database="", + snowflake_schema="", + cortex_search_service="", + k = 5) -results = retriever_model("Explore the meaning of life", k=5) +results = snowflake_retriever("Explore the meaning of life",response_columns=[""]) for result in results: print("Document:", result.long_text, "\n") diff --git a/dsp/modules/snowflake.py b/dsp/modules/snowflake.py index 51b6a2312..cd01a76ff 100644 --- a/dsp/modules/snowflake.py +++ b/dsp/modules/snowflake.py @@ -3,15 +3,12 @@ from typing import Any import backoff -from pydantic_core import PydanticCustomError from dsp.modules.lm import LM from dsp.utils.settings import settings try: - from snowflake.snowpark import Session from snowflake.snowpark import functions as snow_func - except ImportError: pass @@ -35,53 +32,34 @@ def giveup_hdlr(details) -> bool: class Snowflake(LM): """Wrapper around Snowflake's CortexAPI. - Currently supported models include 'snowflake-arctic','mistral-large','reka-flash','mixtral-8x7b', + Supported models include 'llama3.1-70b','llama3.1-405b','snowflake-arctic','mistral-large','reka-flash','mixtral-8x7b', 'llama2-70b-chat','mistral-7b','gemma-7b','llama3-8b','llama3-70b','reka-core'. """ - def __init__(self, model: str = "mixtral-8x7b", credentials=None, **kwargs): + def __init__(self, session: object, model: str = "mixtral-8x7b", **kwargs): """Parameters ---------- + session: + Snowflake Snowpark session for accessing Snowflake Cortex service. + Full list of requirements can be found here: https://docs.snowflake.com/en/developer-guide/snowpark/reference/python/latest/api/snowflake.snowpark.Session model : str - Which pre-trained model from Snowflake to use? + Which pre-trained model from Snowflake to use. Choices are 'snowflake-arctic','mistral-large','reka-flash','mixtral-8x7b','llama2-70b-chat','mistral-7b','gemma-7b' Full list of supported models is available here: https://docs.snowflake.com/en/user-guide/snowflake-cortex/llm-functions#complete - credentials: dict - Snowflake credentials required to initialize the session. - Full list of requirements can be found here: https://docs.snowflake.com/en/developer-guide/snowpark/reference/python/latest/api/snowflake.snowpark.Session **kwargs: dict Additional arguments to pass to the API provider. """ super().__init__(model) + self.client = self._init_cortex(snowflake_session=session) self.model = model - cortex_models = [ - "llama3-8b", - "llama3-70b", - "reka-core", - "snowflake-arctic", - "mistral-large", - "reka-flash", - "mixtral-8x7b", - "llama2-70b-chat", - "mistral-7b", - "gemma-7b", - ] - - if model in cortex_models: - self.available_args = { - "max_tokens", - "temperature", - "top_p", - } - else: - raise PydanticCustomError( - "model", - 'model name is not valid, got "{model_name}"', - ) + self.available_args = { + "max_tokens", + "temperature", + "top_p", + } - self.client = self._init_cortex(credentials=credentials) self.provider = "Snowflake" self.history: list[dict[str, Any]] = [] self.kwargs = { @@ -94,11 +72,11 @@ def __init__(self, model: str = "mixtral-8x7b", credentials=None, **kwargs): } @classmethod - def _init_cortex(cls, credentials: dict) -> None: - session = Session.builder.configs(credentials).create() - session.query_tag = {"origin": "sf_sit", "name": "dspy", "version": {"major": 1, "minor": 0}} + def _init_cortex(cls, snowflake_session) -> None: + # session = Session.builder.configs(credentials).create() + snowflake_session.query_tag = {"origin": "sf_sit", "name": "dspy", "version": {"major": 1, "minor": 0}} - return session + return snowflake_session def _prepare_params( self, diff --git a/dspy/retrieve/snowflake_rm.py b/dspy/retrieve/snowflake_rm.py index 40aac2b59..bef3f7f03 100644 --- a/dspy/retrieve/snowflake_rm.py +++ b/dspy/retrieve/snowflake_rm.py @@ -1,14 +1,11 @@ +import json from typing import Optional, Union import dspy from dsp.utils import dotdict try: - from snowflake.snowpark import Session - from snowflake.snowpark import functions as snow_fn - from snowflake.snowpark.functions import col, function, lit - from snowflake.snowpark.types import VectorType - + from snowflake.core import Root except ImportError: raise ImportError( "The snowflake-snowpark-python library is required to use SnowflakeRM. Install it with dspy-ai[snowflake]", @@ -16,39 +13,43 @@ class SnowflakeRM(dspy.Retrieve): - """A retrieval module that uses Weaviate to return the top passages for a given query. + """A retrieval module that uses Snowlfake's Cortex Search service to return the top relevant passages for a given query. + + Assumes that a Snowflake Cortex Search endpoint has been configured by the use. - Assumes that a Snowflake table has been created and populated with the following payload: - - content: The text of the passage + For more information on configuring the Cortex Search service, visit: https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-search/cortex-search-overview Args: - snowflake_credentials: connection parameters for initializing Snowflake client. - snowflake_table_name (str): The name of the Snowflake table containing document embeddings. - embeddings_field (str): The field in the Snowflake table with the content embeddings - embeddings_text_field (str): The field in the Snowflake table with the content. + snowflake_sesssion (object): Snowflake Snowpark session for accessing the service. + cortex_search_service(str): Name of the Cortex Search service to be used. + snowflake_database (str): The name of the Snowflake table containing document embeddings. + snowflake_schema (str): The name of the Snowflake table containing document embeddings. + search_columns (list): A comma-separated list of columns to return for each relevant result in the response. These columns must be included in the source query for the service. + search_filter (dict): A filter object for filtering results based on data in the ATTRIBUTES columns. See Filter syntax. k (int, optional): The default number of top passages to retrieve. Defaults to 3. """ def __init__( self, - snowflake_table_name: str, - snowflake_credentials: dict, + snowflake_session: object, + cortex_search_service: str, + snowflake_database: str, + snowflake_schema: str, + retrieval_columns: list, + search_filter=None, k: int = 3, - embeddings_field: str = "chunk_vec", - embeddings_text_field: str = "chunk", - embeddings_model: str = "e5-base-v2", ): - self.snowflake_table_name = snowflake_table_name - self.embeddings_field = embeddings_field - self.embeddings_text_field = embeddings_text_field - self.embeddings_model = embeddings_model - self.client = self._init_cortex(credentials=snowflake_credentials) - super().__init__(k=k) + self.k = k + self.cortex_search_service_name = cortex_search_service + self.retrieval_columns = retrieval_columns + self.search_filter = search_filter + self.client = self._fetch_cortex_service( + snowflake_session, snowflake_database, snowflake_schema, cortex_search_service + ) def forward(self, query_or_queries: Union[str, list[str]], k: Optional[int] = None) -> dspy.Prediction: - """Search Snowflake document embeddings table for self.k top passages for query. - + """Query Cortex Search endpoint for top relevant passages. Args: query_or_queries (Union[str, List[str]]): The query or queries to search for. k (Optional[int]): The number of top passages to retrieve. Defaults to self.k. @@ -56,61 +57,57 @@ def forward(self, query_or_queries: Union[str, list[str]], k: Optional[int] = No Returns: dspy.Prediction: An object containing the retrieved passages. """ - k = k if k is not None else self.k + k = self.k if k is None else k queries = [query_or_queries] if isinstance(query_or_queries, str) else query_or_queries queries = [q for q in queries if q] passages = [] - for query in queries: - query_embeddings = self._get_embeddings(query) - top_k_chunks = self._top_k_similar_chunks(query_embeddings, k) + for cortex_query in queries: + response_chunks = self._query_cortex_search( + cortex_search_service=self.client, + query=cortex_query, + columns=self.retrieval_columns, + filter=self.search_filter, + k=k, + ) - passages.extend(dotdict({"long_text": passage[0]}) for passage in top_k_chunks) + if len(self.retrieval_columns) == 1: + passages.extend( + dotdict({"long_text": passage[self.retrieval_columns[0]]}) for passage in response_chunks["results"] + ) + else: + passages.extend(dotdict({"long_text": str(passage)}) for passage in response_chunks["results"]) return passages - def _top_k_similar_chunks(self, query_embeddings, k): - """Search Snowflake table for self.k top passages for query. + def _fetch_cortex_service(self, snowpark_session, snowflake_database, snowflake_schema, cortex_search_service_name): + """Fetch the Cortex Search service to be used""" + snowpark_session.query_tag = {"origin": "sf_sit", "name": "dspy", "version": {"major": 1, "minor": 0}} + root = Root(snowpark_session) + + # fetch service + search_service = ( + root.databases[snowflake_database] + .schemas[snowflake_schema] + .cortex_search_services[cortex_search_service_name] + ) + + return search_service + + def _query_cortex_search(self, cortex_search_service, query, columns, filter, k): + """Query Cortex Search endpoint for top relevant passages . Args: - query_embeddings(List[float]]): the embeddings for the query of interest - doc_table + cortex_search_service (object): cortex search service for querying + query (str): The query or queries to search for. + repsonse_columns: A comma-separated list of columns to return for each relevant result in the response. These columns must be included in the source query for the service. + filters: A filter object for filtering results based on data in the ATTRIBUTES columns. See Filter syntax. k (Optional[int]): The number of top passages to retrieve. Defaults to self.k. Returns: dspy.Prediction: An object containing the retrieved passages. """ - doc_table_value = self.embeddings_field - doc_table_key = self.embeddings_text_field - - doc_embeddings = self.client.table(self.snowflake_table_name) - cosine_similarity = function("vector_cosine_similarity") - - top_k = ( - doc_embeddings.select( - doc_table_value, - doc_table_key, - cosine_similarity( - doc_embeddings.col(doc_table_value), - lit(query_embeddings).cast(VectorType(float, len(query_embeddings))), - ).as_("dist"), - ) - .sort("dist", ascending=False) - .limit(k) - ) - - return top_k.select(doc_table_key).to_pandas().values - - @classmethod - def _init_cortex(cls, credentials: dict) -> None: - session = Session.builder.configs(credentials).create() - session.query_tag = {"origin": "sf_sit", "name": "dspy", "version": {"major": 1, "minor": 0}} - - return session - - def _get_embeddings(self, query: str) -> list[float]: - # create embeddings for the query - embed = snow_fn.builtin("snowflake.cortex.embed_text_768") - cortex_embed_args = embed(snow_fn.lit(self.embeddings_model), snow_fn.lit(query)) + # query service + resp = cortex_search_service.search(query=query, columns=columns, filter=filter, limit=k) - return self.client.range(1).withColumn("complete_cal", cortex_embed_args).collect()[0].COMPLETE_CAL + return json.loads(resp.to_json())