Skip to content

Commit

Permalink
Merge pull request #1451 from sfc-gh-alherrera/dspy-snowflake
Browse files Browse the repository at this point in the history
refactor(dspy): updating snowflake LM and RM implementations
  • Loading branch information
arnavsinghvi11 committed Sep 18, 2024
2 parents 75de77f + 3548c1a commit 0bdbeb0
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 137 deletions.
17 changes: 12 additions & 5 deletions docs/api/language_model_clients/Snowflake.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ sidebar_position: 12
```python
import dspy
import os
from snowflake.snowpark import Session

connection_parameters = {

Expand All @@ -20,25 +21,31 @@ connection_parameters = {
"database": os.getenv('SNOWFLAKE_DATABASE'),
"schema": os.getenv('SNOWFLAKE_SCHEMA')}

lm = dspy.Snowflake(model="mixtral-8x7b",credentials=connection_parameters)
# Establish connection to Snowflake
snowpark = Session.builder.configs(connection_parameters).create()

# Initialize Snowflake Cortex LM
lm = dspy.Snowflake(session=snowpark,model="mixtral-8x7b")
```

### Constructor

The constructor inherits from the base class `LM` and verifies the `credentials` for using Snowflake API.
The constructor inherits from the base class `LM` and verifies the `session` for using Snowflake API.

```python
class Snowflake(LM):
def __init__(
self,
self,
session,
model,
credentials,
**kwargs):
```

**Parameters:**

- `session` (_object_): Snowflake connection object enabled by the [snowflake snowpark session](https://docs.snowflake.com/en/developer-guide/snowpark/reference/python/latest/snowpark/api/snowflake.snowpark.Session)
- `model` (_str_): model hosted by [Snowflake Cortex](https://docs.snowflake.com/en/user-guide/snowflake-cortex/llm-functions#availability).
- `credentials` (_dict_): connection parameters required to initialize a [snowflake snowpark session](https://docs.snowflake.com/en/developer-guide/snowpark/reference/python/latest/api/snowflake.snowpark.Session)
snowpark/reference/python/latest/api/snowflake.snowpark.Session)

### Methods

Expand Down
55 changes: 28 additions & 27 deletions docs/api/retrieval_model_clients/SnowflakeRM.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,53 +6,52 @@ sidebar_position: 9

### Constructor

Initialize an instance of the `SnowflakeRM` class, with the option to use `e5-base-v2` or `snowflake-arctic-embed-m` embeddings or any other Snowflake Cortex supported embeddings model.
Initialize an instance of the `SnowflakeRM` class, which enables user to leverage the Cortex Search service for hybrid retrieval. Before using this, ensure the Cortex Search service is configured as outlined in the documentation [here](https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-search/cortex-search-overview#overview)

```python
SnowflakeRM(
snowflake_table_name: str,
snowflake_credentials: dict,
snowflake_session: object,
cortex_search_service: str,
snowflake_database: str,
snowflake_schema: dict,
retrieval_columns: list,
search_filter: dict = None,
k: int = 3,
embeddings_field: str,
embeddings_text_field:str,
embeddings_model: str = "e5-base-v2",
)
```

**Parameters:**

- `snowflake_table_name (str)`: The name of the Snowflake table containing embeddings.
- `snowflake_credentials (dict)`: The connection parameters needed to initialize a Snowflake Snowpark Session.
- `snowflake_session (str)`: Snowflake Snowpark session for connecting to Snowflake.
- `cortex_search_service (str)`: The name of the Cortex Search service to be used.
- `snowflake_database (str)`: The name of the Snowflake database to be used with the Cortex Search service.
- `snowflake_schema (str)`: The name of the Snowflake schema to be used with the Cortex Search service.
- `retrieval_columns (str)`: A list of columns to return for each relevant result in the response.
- `search_filter (dict, optional)`: Optional filter object used for filtering results based on data in the ATTRIBUTES columns. See [Filter syntax](https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-search/query-cortex-search-service#filter-syntax)
- `k (int, optional)`: The number of top passages to retrieve. Defaults to 3.
- `embeddings_field (str)`: The name of the column in the Snowflake table containing the embeddings.
- `embeddings_text_field (str)`: The name of the column in the Snowflake table containing the passages.
- `embeddings_model (str)`: The model to be used to convert text to embeddings

### Methods

#### `forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = None) -> dspy.Prediction`
#### `def forward(self,query_or_queries: Union[str, list[str]],response_columns:list[str],filters:dict = None, k: Optional[int] = None)-> dspy.Prediction:`

Search the Snowflake table for the top `k` passages matching the given query or queries, using embeddings generated via the default `e5-base-v2` model or the specified `embedding_model`.
Query the Cortex Search service to retrieve the top k relevant results given a query.

**Parameters:**

- `query_or_queries` (_Union[str, List[str]]_): The query or list of queries to search for.
- `k` (_Optional[int]_, _optional_): The number of results to retrieve. If not specified, defaults to the value set during initialization.
- `k` (_Optional[int]_): The number of results to retrieve. If not specified, defaults to the value set during initialization.

**Returns:**

- `dspy.Prediction`: Contains the retrieved passages, each represented as a `dotdict` with schema `[{"id": str, "score": float, "long_text": str, "metadatas": dict }]`
- `dspy.Prediction`: Contains the retrieved passages, each represented as a `dotdict` with schema `[{"long_text": str}]`

### Quickstart

To support passage retrieval, it assumes that a Snowflake table has been created and populated with the passages in a column `embeddings_text_field` and the embeddings in another column `embeddings_field`

SnowflakeRM uses `e5-base-v2` embeddings model by default or any Snowflake Cortex supported embeddings model.

#### Default OpenAI Embeddings
To support passage retrieval from a Snowflake table with this integration, a Cortex Search endpoint must be configured.

```python
from dspy.retrieve.snowflake_rm import SnowflakeRM
from snowflake.snowpark import Session
import os

connection_parameters = {
Expand All @@ -65,14 +64,16 @@ connection_parameters = {
"database": os.getenv('SNOWFLAKE_DATABASE'),
"schema": os.getenv('SNOWFLAKE_SCHEMA')}

retriever_model = SnowflakeRM(
snowflake_table_name="<YOUR_SNOWFLAKE_TABLE_NAME>",
snowflake_credentials=connection_parameters,
embeddings_field="<YOUR_EMBEDDINGS_COLUMN_NAME>",
embeddings_text_field= "<YOUR_PASSAGE_COLUMN_NAME>"
)
# Establish connection to Snowflake
snowpark = Session.builder.configs(connection_parameters).create()

snowflake_retriever = SnowflakeRM(snowflake_session=snowpark,
snowflake_database="<YOUR_SNOWFLAKE_DATABASE_NAME>",
snowflake_schema="<YOUR_SNOWFLAKE_SCHEMA_NAME>",
cortex_search_service="<YOUR_CORTEX_SERACH_SERVICE_NAME>",
k = 5)

results = retriever_model("Explore the meaning of life", k=5)
results = snowflake_retriever("Explore the meaning of life",response_columns=["<NAME_OF_COLUMN_CONTAINING_TEXT>"])

for result in results:
print("Document:", result.long_text, "\n")
Expand Down
54 changes: 16 additions & 38 deletions dsp/modules/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,12 @@
from typing import Any

import backoff
from pydantic_core import PydanticCustomError

from dsp.modules.lm import LM
from dsp.utils.settings import settings

try:
from snowflake.snowpark import Session
from snowflake.snowpark import functions as snow_func

except ImportError:
pass

Expand All @@ -35,53 +32,34 @@ def giveup_hdlr(details) -> bool:
class Snowflake(LM):
"""Wrapper around Snowflake's CortexAPI.
Currently supported models include 'snowflake-arctic','mistral-large','reka-flash','mixtral-8x7b',
Supported models include 'llama3.1-70b','llama3.1-405b','snowflake-arctic','mistral-large','reka-flash','mixtral-8x7b',
'llama2-70b-chat','mistral-7b','gemma-7b','llama3-8b','llama3-70b','reka-core'.
"""

def __init__(self, model: str = "mixtral-8x7b", credentials=None, **kwargs):
def __init__(self, session: object, model: str = "mixtral-8x7b", **kwargs):
"""Parameters
----------
session:
Snowflake Snowpark session for accessing Snowflake Cortex service.
Full list of requirements can be found here: https://docs.snowflake.com/en/developer-guide/snowpark/reference/python/latest/api/snowflake.snowpark.Session
model : str
Which pre-trained model from Snowflake to use?
Which pre-trained model from Snowflake to use.
Choices are 'snowflake-arctic','mistral-large','reka-flash','mixtral-8x7b','llama2-70b-chat','mistral-7b','gemma-7b'
Full list of supported models is available here: https://docs.snowflake.com/en/user-guide/snowflake-cortex/llm-functions#complete
credentials: dict
Snowflake credentials required to initialize the session.
Full list of requirements can be found here: https://docs.snowflake.com/en/developer-guide/snowpark/reference/python/latest/api/snowflake.snowpark.Session
**kwargs: dict
Additional arguments to pass to the API provider.
"""
super().__init__(model)

self.client = self._init_cortex(snowflake_session=session)
self.model = model
cortex_models = [
"llama3-8b",
"llama3-70b",
"reka-core",
"snowflake-arctic",
"mistral-large",
"reka-flash",
"mixtral-8x7b",
"llama2-70b-chat",
"mistral-7b",
"gemma-7b",
]

if model in cortex_models:
self.available_args = {
"max_tokens",
"temperature",
"top_p",
}
else:
raise PydanticCustomError(
"model",
'model name is not valid, got "{model_name}"',
)
self.available_args = {
"max_tokens",
"temperature",
"top_p",
}

self.client = self._init_cortex(credentials=credentials)
self.provider = "Snowflake"
self.history: list[dict[str, Any]] = []
self.kwargs = {
Expand All @@ -94,11 +72,11 @@ def __init__(self, model: str = "mixtral-8x7b", credentials=None, **kwargs):
}

@classmethod
def _init_cortex(cls, credentials: dict) -> None:
session = Session.builder.configs(credentials).create()
session.query_tag = {"origin": "sf_sit", "name": "dspy", "version": {"major": 1, "minor": 0}}
def _init_cortex(cls, snowflake_session) -> None:
# session = Session.builder.configs(credentials).create()
snowflake_session.query_tag = {"origin": "sf_sit", "name": "dspy", "version": {"major": 1, "minor": 0}}

return session
return snowflake_session

def _prepare_params(
self,
Expand Down
Loading

0 comments on commit 0bdbeb0

Please sign in to comment.