Skip to content

Commit

Permalink
add summarize app
Browse files Browse the repository at this point in the history
  • Loading branch information
MichaelClifford committed Dec 21, 2023
1 parent 7ad05b8 commit a2beb1e
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 11 deletions.
43 changes: 32 additions & 11 deletions src/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,33 @@

class Chat:

n_ctx = 2048

def __init__(self) -> None:
def __init__(self, n_ctx=2048) -> None:
self.chat_history = [
{"role": "system", "content": """You are a helpful assistant that is comfortable speaking
with C level executives in a professional setting."""},
]
self.llm = Llama(model_path=os.getenv("MODEL_FILE",
"llama-2-7b-chat.Q5_K_S.gguf"),
n_ctx=Chat.n_ctx,
n_ctx=n_ctx,
n_gpu_layers=-1,
n_batch=Chat.n_ctx,
n_batch=n_ctx,
f16_kv=True,
stream=True,)

self.n_ctx = n_ctx


def reset_system_prompt(self, prompt=None):
if not prompt:
self.chat_history = []
self.chat_history[0] = {"role":"system", "content":""}
else:
self.chat_history = [{"role":"system",
"content": prompt}]
print(self.chat_history)
self.chat_history[0] = {"role":"system",
"content": prompt}
print(self.chat_history[0])


def clear_history(self):
self.chat_history = [self.chat_history[0]]


def count_tokens(self, messages):
num_extra_tokens = len(self.chat_history) * 6 # accounts for tokens outside of "content"
Expand All @@ -34,7 +39,7 @@ def count_tokens(self, messages):


def clip_history(self, prompt):
context_length = Chat.n_ctx
context_length = self.n_ctx
prompt_length = len(self.llm.tokenize(bytes(prompt["content"], "utf-8")))
history_length = self.count_tokens(self.chat_history)
input_length = prompt_length + history_length
Expand All @@ -60,3 +65,19 @@ def ask(self, prompt, history):
reply += token["content"]
yield reply
self.chat_history.append({"role":"assistant","content":reply})

def summarize(self, prompt, history):
self.reset_system_prompt("""You are a summarizing agent.
You only respond in bullet points.
Your only job is to summarize your inputs and provide the most concise possible output.
Do not add any information that does not come directly from the user prompt.
Limit your response to a maximum of 5 bullet points.
It's fine to have less than 5 bullet points"""
)

prompt = {"role":"user","content": prompt}
self.chat_history.append(prompt)
chat_response = self.llm.create_chat_completion(self.chat_history)
self.clear_history()
return chat_response["choices"][0]["message"]["content"]

8 changes: 8 additions & 0 deletions src/summary_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import gradio as gr
from chat import Chat

if __name__ == "__main__":

chat = Chat(n_ctx=4096)
demo = gr.ChatInterface(chat.summarize)
demo.launch(server_name="0.0.0.0")
48 changes: 48 additions & 0 deletions summarizer/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Text Summarizer Application

### Download model(s)

This example assumes that the developer already has a copy of the model that they would like to use downloaded onto their host machine.

The two models that we have tested and recommend for this example are Llama2 and Mistral. Please download any of the GGUF variants you'd like to use.

* Llama2 - https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/tree/main
* Mistral - https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-GGUF/tree/main

_For a full list of supported model variants, please see the "Supported models" section of the [llama.cpp repository](https://github.com/ggerganov/llama.cpp?tab=readme-ov-file#description)._

```bash
wget https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q5_K_S.gguf
```

### Build the image

```bash
podman build -t summarizer . -f summarizer/arm/Containerfile --build-arg=MODEL_FILE=llama-2-7b-chat.Q5_K_S.gguf
```
### Run the image
```bash
podman run -it -p 7860:7860 summarizer
```
### Interact with the app

```python
from gradio_client import Client
client = Client("http://0.0.0.0:7860")
result = client.predict("""
It's Hackathon day.
All the developers are excited to work on interesting problems.
There are six teams total, but only one can take home the grand prize.
The first team to solve Artificial General Intelligence wins!"""
)
print(result)
```

```bash
Sure, here is a summary of the input in bullet points:
• Hackathon day
• Developers excited to work on interesting problems
• Six teams participating
• Grand prize for the first team to solve Artificial General Intelligence
• Excitement and competition among the teams
```
11 changes: 11 additions & 0 deletions summarizer/arm/Containerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
FROM registry.access.redhat.com/ubi9/python-39:1-158
WORKDIR /locallm
COPY requirements.txt /locallm/requirements.txt
RUN pip install --upgrade pip
RUN pip install --no-cache-dir --upgrade -r /locallm/requirements.txt
ARG MODEL_FILE=llama-2-7b-chat.Q5_K_S.gguf
ENV MODEL_FILE=${MODEL_FILE}
COPY ${MODEL_FILE} /locallm/
COPY src/ /locallm
RUN printenv | grep MODEL_FILE
ENTRYPOINT [ "python", "summary_app.py" ]

0 comments on commit a2beb1e

Please sign in to comment.