Skip to content

Commit

Permalink
Refresh "system" message when messages are truncated.
Browse files Browse the repository at this point in the history
  • Loading branch information
jekalmin authored and jekalmin committed Jan 30, 2024
1 parent fe84681 commit eac3013
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions custom_components/extended_openai_conversation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ def supported_languages(self) -> list[str] | Literal["*"]:
async def async_process(
self, user_input: conversation.ConversationInput
) -> conversation.ConversationResult:
raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
exposed_entities = self.get_exposed_entities()

if user_input.conversation_id in self.history:
Expand All @@ -169,8 +168,8 @@ async def async_process(
conversation_id = ulid.ulid()
user_input.conversation_id = conversation_id
try:
prompt = self._async_generate_prompt(
raw_prompt, exposed_entities, user_input
system_message = self._generate_system_message(
exposed_entities, user_input
)
except TemplateError as err:
_LOGGER.error("Error rendering prompt: %s", err)
Expand All @@ -182,7 +181,7 @@ async def async_process(
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
messages = [{"role": "system", "content": prompt}]
messages = [system_message]
user_message = {"role": "user", "content": user_input.text}
if self.entry.options.get(CONF_ATTACH_USERNAME, DEFAULT_ATTACH_USERNAME):
user = await self.hass.auth.async_get_user(user_input.context.user_id)
Expand Down Expand Up @@ -223,6 +222,13 @@ async def async_process(
response=intent_response, conversation_id=conversation_id
)

def _generate_system_message(
self, exposed_entities, user_input: conversation.ConversationInput
):
raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
prompt = self._async_generate_prompt(raw_prompt, exposed_entities, user_input)
return {"role": "system", "content": prompt}

def _async_generate_prompt(
self,
raw_prompt: str,
Expand Down Expand Up @@ -283,7 +289,9 @@ def get_functions(self):
except:
raise FunctionLoadFailed()

async def truncate_message_history(self, messages):
async def truncate_message_history(
self, messages, exposed_entities, user_input: conversation.ConversationInput
):
"""Truncate message history."""
strategy = self.entry.options.get(
CONF_CONTEXT_TRUNCATE_STRATEGY, DEFAULT_CONTEXT_TRUNCATE_STRATEGY
Expand All @@ -298,6 +306,10 @@ async def truncate_message_history(self, messages):

if last_user_message_index is not None:
del messages[1:last_user_message_index]
# refresh system prompt when all messages are deleted
messages[0] = self._generate_system_message(
exposed_entities, user_input
)

async def query(
self,
Expand Down Expand Up @@ -348,7 +360,7 @@ async def query(
_LOGGER.info("Response %s", response.model_dump(exclude_none=True))

if response.usage.total_tokens > context_threshold:
await self.truncate_message_history(messages)
await self.truncate_message_history(messages, exposed_entities, user_input)

choice: Choice = response.choices[0]
message = choice.message
Expand Down

0 comments on commit eac3013

Please sign in to comment.