Skip to content

Commit

Permalink
cleaned up convo feature - breaking change on wizard.ask() -> wizard.…
Browse files Browse the repository at this point in the history
…chat() and .messages attribute
  • Loading branch information
grantbuster committed Oct 20, 2023
1 parent 7525dd0 commit 23eae3f
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 48 deletions.
26 changes: 20 additions & 6 deletions elm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,26 @@ def __init__(self, model=None):
"""
self.model = model or self.DEFAULT_MODEL
self.api_queue = None
self.chat_messages = []
self.messages = []
self.clear()

@property
def all_messages_txt(self):
"""Get a string printout of the full conversation with the LLM
Returns
-------
str
"""
messages = [f"{msg['role'].upper()}: {msg['content']}"
for msg in self.messages]
messages = '\n\n'.join(messages)
return messages

def clear(self):
"""Clear chat"""
self.chat_messages = [{"role": "system", "content": self.MODEL_ROLE}]
"""Clear chat history and reduce messages to just the initial model
role message."""
self.messages = [{"role": "system", "content": self.MODEL_ROLE}]

@staticmethod
async def call_api(url, headers, request_json):
Expand Down Expand Up @@ -166,18 +180,18 @@ def chat(self, query, temperature=0):
Model response
"""

self.chat_messages.append({"role": "user", "content": query})
self.messages.append({"role": "user", "content": query})

kwargs = dict(model=self.model,
messages=self.chat_messages,
messages=self.messages,
temperature=temperature,
stream=False)
if 'azure' in str(openai.api_type).lower():
kwargs['engine'] = self.model

response = openai.ChatCompletion.create(**kwargs)
response = response["choices"][0]["message"]["content"]
self.chat_messages.append({'role': 'assistant', 'content': response})
self.messages.append({'role': 'assistant', 'content': response})

return response

Expand Down
7 changes: 2 additions & 5 deletions elm/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def messages(self):
-------
list
"""
return self.api.chat_messages
return self.api.messages

@property
def all_messages_txt(self):
Expand All @@ -90,10 +90,7 @@ def all_messages_txt(self):
-------
str
"""
messages = [f"{msg['role'].upper()}: {msg['content']}"
for msg in self.messages]
messages = '\n\n'.join(messages)
return messages
return self.api.all_messages_txt

@property
def history(self):
Expand Down
59 changes: 29 additions & 30 deletions elm/wizard.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,8 @@ def rank_strings(self, query, top_n=100):

return strings, scores, best

def _make_convo_query(self):
query = [f"{msg['role'].upper()}: {msg['content']}"
for msg in self.chat_messages]
query = '\n\n'.join(query)
return query

def engineer_query(self, query, token_budget=None, new_info_threshold=0.7,
conversational=False):
convo=False):
"""Engineer a query for GPT using the corpus of information
Parameters
Expand All @@ -146,10 +140,10 @@ def engineer_query(self, query, token_budget=None, new_info_threshold=0.7,
New text added to the engineered query must contain at least this
much new information. This helps prevent (for example) the table of
contents being added multiple times.
conversational : bool
Flag to ask query with conversation history. Call
EnergyWizard.clear() to reset the chat.
convo : bool
Flag to perform semantic search with full conversation history
(True) or just the single query (False). Call EnergyWizard.clear()
to reset the chat history.
Returns
-------
message : str
Expand All @@ -160,8 +154,13 @@ def engineer_query(self, query, token_budget=None, new_info_threshold=0.7,
returned here
"""

if conversational:
query = self._make_convo_query()
self.messages.append({"role": "user", "content": query})

if convo:
# [1:] to not include the system role in the semantic search
query = [f"{msg['role'].upper()}: {msg['content']}"
for msg in self.messages[1:]]
query = '\n\n'.join(query)

token_budget = token_budget or self.token_budget

Expand Down Expand Up @@ -212,16 +211,16 @@ def make_ref_list(self, idx):

return ref_list

def ask(self, query,
debug=True,
stream=True,
temperature=0,
conversational=False,
token_budget=None,
new_info_threshold=0.7,
print_references=False):
"""Answers a query using GPT and a dataframe of relevant texts and
embeddings.
def chat(self, query,
debug=True,
stream=True,
temperature=0,
convo=False,
token_budget=None,
new_info_threshold=0.7,
print_references=False):
"""Answers a query by doing a semantic search of relevant text with
embeddings and then sending engineered query to the LLM.
Parameters
----------
Expand All @@ -233,9 +232,10 @@ def ask(self, query,
GPT model temperature, a measure of response entropy from 0 to 1. 0
is more reliable and nearly deterministic; 1 will give the model
more creative freedom and may not return as factual of results.
conversational : bool
Flag to ask query with conversation history. Call
EnergyWizard.clear() to reset the chat.
convo : bool
Flag to perform semantic search with full conversation history
(True) or just the single query (False). Call EnergyWizard.clear()
to reset the chat history.
token_budget : int
Option to override the class init token budget.
new_info_threshold : float
Expand All @@ -258,10 +258,9 @@ def ask(self, query,
engineered prompt is returned here
"""

self.chat_messages.append({"role": "user", "content": query})
out = self.engineer_query(query, token_budget=token_budget,
new_info_threshold=new_info_threshold,
conversational=conversational)
convo=convo)
query, references = out

messages = [{"role": "system", "content": self.MODEL_ROLE},
Expand Down Expand Up @@ -292,8 +291,8 @@ def ask(self, query,
'support its answer:')
print(' - ' + '\n - '.join(references))

self.chat_messages.append({'role': 'assistant',
'content': response_message})
self.messages.append({'role': 'assistant',
'content': response_message})

if debug:
return response_message, query, references
Expand Down
56 changes: 49 additions & 7 deletions tests/test_wizard.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,8 @@ def create(*args, **kwargs): # pylint: disable=unused-argument
return response


def test_chunk_and_embed(mocker):
"""Simple text to embedding test
Note that embedding api is mocked here and not actually tested.
"""

def make_corpus(mocker):
"""Make a text corpus with embeddings for the wizard."""
mocker.patch.object(elm.embed.ChunkAndEmbed, "call_api", MockClass.call)
mocker.patch.object(elm.wizard.EnergyWizard, "get_embedding",
MockClass.get_embedding)
Expand All @@ -58,14 +54,60 @@ def test_chunk_and_embed(mocker):
for i, emb in enumerate(embeddings):
corpus.append({'text': ce0.text_chunks[i], 'embedding': emb,
'ref': 'source0'})
return corpus


def test_chunk_and_embed(mocker):
"""Simple text to embedding test
Note that embedding api is mocked here and not actually tested.
"""

corpus = make_corpus(mocker)
wizard = EnergyWizard(pd.DataFrame(corpus), token_budget=1000,
ref_col='ref')

question = 'What time is it?'
out = wizard.ask(question, debug=True, stream=False, print_references=True)
out = wizard.chat(question, debug=True, stream=False,
print_references=True)
msg, query, ref = out

assert msg == 'hello!'
assert query.startswith(EnergyWizard.MODEL_INSTRUCTION)
assert query.endswith(question)
assert 'source0' in ref


def test_convo_query(mocker):
"""Query with multiple messages
Note that embedding api is mocked here and not actually tested.
"""

corpus = make_corpus(mocker)
wizard = EnergyWizard(pd.DataFrame(corpus), token_budget=1000,
ref_col='ref')

question1 = 'What time is it?'
question2 = 'How about now?'

query = wizard.chat(question1, debug=True, stream=False, convo=True,
print_references=True)[1]
assert question1 in query
assert question2 not in query
assert len(wizard.messages) == 3

query = wizard.chat(question2, debug=True, stream=False, convo=True,
print_references=True)[1]
assert question1 in query
assert question2 in query
assert len(wizard.messages) == 5

wizard.clear()
assert len(wizard.messages) == 1

query = wizard.chat(question2, debug=True, stream=False, convo=True,
print_references=True)[1]
assert question1 not in query
assert question2 in query
assert len(wizard.messages) == 3

0 comments on commit 23eae3f

Please sign in to comment.