diff --git a/custom_components/extended_openai_conversation/config_flow.py b/custom_components/extended_openai_conversation/config_flow.py index 059e7f8..6a45bd8 100644 --- a/custom_components/extended_openai_conversation/config_flow.py +++ b/custom_components/extended_openai_conversation/config_flow.py @@ -76,6 +76,12 @@ async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None: """ api_key = data[CONF_API_KEY] base_url = data.get(CONF_BASE_URL) + + if base_url == DEFAULT_CONF_BASE_URL: + # Do not set base_url if using OpenAI for case of OpenAI's base_url change + base_url = None + data.pop(CONF_BASE_URL) + await validate_authentication(hass=hass, api_key=api_key, base_url=base_url) diff --git a/custom_components/extended_openai_conversation/helpers.py b/custom_components/extended_openai_conversation/helpers.py index 3343c58..87cd67b 100644 --- a/custom_components/extended_openai_conversation/helpers.py +++ b/custom_components/extended_openai_conversation/helpers.py @@ -5,8 +5,8 @@ import time from bs4 import BeautifulSoup from typing import Any -from functools import partial -import openai +from homeassistant.helpers.aiohttp_client import async_get_clientsession +from openai.error import AuthenticationError from homeassistant.components import automation, rest, scrape from homeassistant.components.automation.config import _async_validate_config_item @@ -100,22 +100,20 @@ def _get_rest_data(hass, rest_config, arguments): async def validate_authentication( - hass: HomeAssistant, api_key: str, base_url: str or None + hass: HomeAssistant, api_key: str, base_url: str ) -> None: - if base_url and base_url != DEFAULT_CONF_BASE_URL: - await openai.ChatCompletion.acreate( - api_key=api_key, - api_base=base_url, - model="gpt-3.5-turbo", - messages=[{"role": "user", "content": "Hi"}], - ) - return - - await hass.async_add_executor_job( - partial( - openai.Engine.list, api_key=api_key, api_base=base_url, request_timeout=10 - ) + if not base_url: + base_url = DEFAULT_CONF_BASE_URL + session = async_get_clientsession(hass) + response = await session.get( + f"{base_url}/models", + headers={"Authorization": f"Bearer {api_key}"}, + timeout=10, ) + if response.status == 401: + raise AuthenticationError() + + response.raise_for_status() class FunctionExecutor(ABC):