Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sqlite function #32

Merged
merged 2 commits into from
Nov 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,108 @@ When using [ytube_music_player](https://github.com/KoljaWindeler/ytube_music_pla

<img width="300" alt="스크린샷 2023-11-02 오후 8 40 36" src="https://github.com/jekalmin/extended_openai_conversation/assets/2917984/648efef8-40d1-45d2-b3f9-9bac4a36c517">

### 7. sqlite
#### 7-1. Let model generate a query
- Without examples, a query tries to fetch data only from "states" table like below
```
Question: When did bedroom light turn on?
Query(generated by gpt-3.5): SELECT * FROM states WHERE entity_id = 'input_boolean.livingroom_light_2' AND state = 'on' ORDER BY last_changed DESC LIMIT 1
```
- Since "entity_id" is stored in "states_meta" table, we need to give examples of question and query.
- Not secured, but flexible way

```yaml
- spec:
name: query_histories_from_db
description: >-
Use this function to query histories from Home Assistant SQLite database.
Example:
Question: When did bedroom light turn on?
Answer: SELECT datetime(s.last_updated_ts, 'unixepoch', 'localtime') last_updated_ts FROM states s INNER JOIN states_meta sm ON s.metadata_id = sm.metadata_id INNER JOIN states old ON s.old_state_id = old.state_id WHERE sm.entity_id = 'light.bedroom' AND s.state = 'on' AND s.state != old.state ORDER BY s.last_updated_ts DESC LIMIT 1
Question: Was livingroom light on at 9 am?
Answer: SELECT datetime(s.last_updated_ts, 'unixepoch', 'localtime') last_updated, s.state FROM states s INNER JOIN states_meta sm ON s.metadata_id = sm.metadata_id INNER JOIN states old ON s.old_state_id = old.state_id WHERE sm.entity_id = 'switch.livingroom' AND s.state != old.state AND datetime(s.last_updated_ts, 'unixepoch', 'localtime') < '2023-11-17 08:00:00' ORDER BY s.last_updated_ts DESC LIMIT 1
parameters:
type: object
properties:
query:
type: string
description: A fully formed SQL query.
function:
type: sqlite
```

Get last changed date time of state | Get state at specific time
--|--
<img width="300" alt="스크린샷 2023-11-19 오후 5 32 56" src="https://github.com/jekalmin/extended_openai_conversation/assets/2917984/5a25db59-f66c-4dfd-9e7b-ae6982ed3cd2"> |<img width="300" alt="스크린샷 2023-11-19 오후 5 32 30" src="https://github.com/jekalmin/extended_openai_conversation/assets/2917984/51faaa26-3294-4f96-b115-c71b268b708e">


**FAQ**
1. Can gpt modify or delete data?
> No, since connection is created in a read only mode, data are only used for fetching.
2. Can gpt query data that are not exposed in database?
> Yes, it is hard to validate whether a query is only using exposed entities.
3. Query uses UTC time. Is there any way to adjust timezone?
> Yes. Set "TZ" environment variable to your [region](https://en.wikipedia.org/wiki/List_of_tz_database_time_zones) (eg. `Asia/Seoul`). <br/>
Or use plus/minus hours to adjust instead of 'localtime' (eg. `datetime(s.last_updated_ts, 'unixepoch', '+9 hours')`).


#### 7-2. Let model generate a query (with minimum validation)
- If need to check at least "entity_id" of exposed entities is present in a query, use "is_exposed_entity_in_query" in combination with "raise".
- Not secured enough, but flexible way
```yaml
- spec:
name: query_histories_from_db
description: >-
Use this function to query histories from Home Assistant SQLite database.
Example:
Question: When did bedroom light turn on?
Answer: SELECT datetime(s.last_updated_ts, 'unixepoch', 'localtime') last_updated_ts FROM states s INNER JOIN states_meta sm ON s.metadata_id = sm.metadata_id INNER JOIN states old ON s.old_state_id = old.state_id WHERE sm.entity_id = 'light.bedroom' AND s.state = 'on' AND s.state != old.state ORDER BY s.last_updated_ts DESC LIMIT 1
Question: Was livingroom light on at 9 am?
Answer: SELECT datetime(s.last_updated_ts, 'unixepoch', 'localtime') last_updated, s.state FROM states s INNER JOIN states_meta sm ON s.metadata_id = sm.metadata_id INNER JOIN states old ON s.old_state_id = old.state_id WHERE sm.entity_id = 'switch.livingroom' AND s.state != old.state AND datetime(s.last_updated_ts, 'unixepoch', 'localtime') < '2023-11-17 08:00:00' ORDER BY s.last_updated_ts DESC LIMIT 1
parameters:
type: object
properties:
query:
type: string
description: A fully formed SQL query.
function:
type: sqlite
query: >-
{%- if is_exposed_entity_in_query(query) -%}
{{ query }}
{%- else -%}
{{ raise("entity_id should be exposed.") }}
{%- endif -%}
```

#### 7-3. User defines a query, and model passes entity_id
- Use a user defined query, which is verified. And model passes a requested entity to get data from database.
- Secured, but less flexible way
```yaml
- spec:
name: get_last_updated_time_of_entity
description: >
Use this function to get last updated time of entity
parameters:
type: object
properties:
entity_id:
type: string
description: The target entity
function:
type: sqlite
query: >-
{%- if is_exposed(entity_id) -%}
SELECT datetime(s.last_updated_ts, 'unixepoch', 'localtime') as last_updated_ts
FROM states s
INNER JOIN states_meta sm ON s.metadata_id = sm.metadata_id
INNER JOIN states old ON s.old_state_id = old.state_id
WHERE sm.entity_id = '{{entity_id}}' AND s.state != old.state ORDER BY s.last_updated_ts DESC LIMIT 1
{%- else -%}
{{ raise("entity_id should be exposed.") }}
{%- endif -%}
```

## Practical Usage
See more practical [examples](https://github.com/jekalmin/extended_openai_conversation/tree/main/examples).

Expand Down
97 changes: 85 additions & 12 deletions custom_components/extended_openai_conversation/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import os
import yaml
import time
import sqlite3
from bs4 import BeautifulSoup
from typing import Any
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from openai.error import AuthenticationError
from urllib import parse

from homeassistant.components import automation, rest, scrape
from homeassistant.components.automation.config import _async_validate_config_item
Expand All @@ -22,7 +24,7 @@
CONF_ATTRIBUTE,
)
from homeassistant.config import AUTOMATION_CONFIG_PATH
from homeassistant.components import conversation
from homeassistant.components import conversation, recorder
from homeassistant.core import HomeAssistant
from homeassistant.helpers.template import Template
from homeassistant.helpers.script import (
Expand Down Expand Up @@ -128,7 +130,7 @@ async def execute(
arguments,
user_input: conversation.ConversationInput,
exposed_entities,
) -> str:
):
"""execute function"""


Expand All @@ -143,7 +145,7 @@ async def execute(
arguments,
user_input: conversation.ConversationInput,
exposed_entities,
) -> str:
):
name = function["name"]
if name == "execute_service":
return await self.execute_service(
Expand All @@ -163,7 +165,7 @@ async def execute_service(
arguments,
user_input: conversation.ConversationInput,
exposed_entities,
) -> str:
):
result = []
for service_argument in arguments.get("list", []):
domain = service_argument["domain"]
Expand Down Expand Up @@ -198,7 +200,7 @@ async def execute_service(
_LOGGER.error(e)
result.append(False)

return str(result)
return result

async def add_automation(
self,
Expand All @@ -207,7 +209,7 @@ async def add_automation(
arguments,
user_input: conversation.ConversationInput,
exposed_entities,
) -> str:
):
automation_config = yaml.safe_load(arguments["automation_config"])
config = {"id": str(round(time.time() * 1000))}
if isinstance(automation_config, list):
Expand Down Expand Up @@ -252,7 +254,7 @@ async def execute(
arguments,
user_input: conversation.ConversationInput,
exposed_entities,
) -> str:
):
script = Script(
hass,
function["sequence"],
Expand All @@ -279,7 +281,7 @@ async def execute(
arguments,
user_input: conversation.ConversationInput,
exposed_entities,
) -> str:
):
return Template(function["value_template"], hass).async_render(
arguments,
parse_result=False,
Expand All @@ -297,7 +299,7 @@ async def execute(
arguments,
user_input: conversation.ConversationInput,
exposed_entities,
) -> str:
):
config = function
rest_data = _get_rest_data(hass, config, arguments)

Expand All @@ -324,7 +326,7 @@ async def execute(
arguments,
user_input: conversation.ConversationInput,
exposed_entities,
) -> str:
):
config = function
rest_data = _get_rest_data(hass, config, arguments)
coordinator = scrape.coordinator.ScrapeCoordinator(
Expand Down Expand Up @@ -408,7 +410,7 @@ async def execute(
arguments,
user_input: conversation.ConversationInput,
exposed_entities,
) -> str:
):
config = function
sequence = config["sequence"]

Expand All @@ -420,11 +422,81 @@ async def execute(

response_variable = executor_config.get("response_variable")
if response_variable:
arguments[response_variable] = str(result)
arguments[response_variable] = result

return result


class SqliteFunctionExecutor(FunctionExecutor):
def __init__(self) -> None:
"""initialize sqlite function"""

def is_exposed(self, entity_id, exposed_entities) -> bool:
return any(
exposed_entity["entity_id"] == entity_id
for exposed_entity in exposed_entities
)

def is_exposed_entity_in_query(self, query: str, exposed_entities) -> bool:
exposed_entity_ids = list(
map(lambda e: f"'{e['entity_id']}'", exposed_entities)
)
return any(
exposed_entity_id in query for exposed_entity_id in exposed_entity_ids
)

def raise_error(self, msg="Unexpected error occurred."):
raise HomeAssistantError(msg)

def get_default_db_url(self, hass: HomeAssistant) -> str:
db_file_path = os.path.join(hass.config.config_dir, recorder.DEFAULT_DB_FILE)
return f"file:{db_file_path}?mode=ro"

def set_url_read_only(self, url: str) -> str:
scheme, netloc, path, query_string, fragment = parse.urlsplit(url)
query_params = parse.parse_qs(query_string)

query_params["mode"] = ["ro"]
new_query_string = parse.urlencode(query_params, doseq=True)

return parse.urlunsplit((scheme, netloc, path, new_query_string, fragment))

async def execute(
self,
hass: HomeAssistant,
function,
arguments,
user_input: conversation.ConversationInput,
exposed_entities,
):
db_url = self.set_url_read_only(
function.get("db_url", self.get_default_db_url(hass))
)
query = function.get("query", "{{query}}")

template_arguments = {
"is_exposed": lambda e: self.is_exposed(e, exposed_entities),
"is_exposed_entity_in_query": lambda q: self.is_exposed_entity_in_query(
q, exposed_entities
),
"exposed_entities": exposed_entities,
"raise": self.raise_error,
}
template_arguments.update(arguments)

q = Template(query, hass).async_render(template_arguments)
_LOGGER.info("Rendered query: %s", q)
with sqlite3.connect(db_url, uri=True) as conn:
cursor = conn.execute(q)
names = [description[0] for description in cursor.description]
rows = cursor.fetchall()
result = []
for row in rows:
for name, val in zip(names, row):
result.append({name: val})
return result


FUNCTION_EXECUTORS: dict[str, FunctionExecutor] = {
"predefined": NativeFunctionExecutor(),
"native": NativeFunctionExecutor(),
Expand All @@ -433,4 +505,5 @@ async def execute(
"rest": RestFunctionExecutor(),
"scrape": ScrapeFunctionExecutor(),
"composite": CompositeFunctionExecutor(),
"sqlite": SqliteFunctionExecutor(),
}