Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added system prompt for openai #2145

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions bertopic/representation/_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@
Sample texts from this topic:
[DOCUMENTS]
Keywords: [KEYWORDS]
Topic name:"""
Provide the topic name directly without any explanation."""
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you make this change? In my experience, by providing the model with a pre-fix, there is no need for mentioning that it should provide the topic name without any explanation.


DEFAULT_SYSTEM_PROMPT = "You are designated as an assistant that identify and extract high-level topics from texts."


class Cohere(BaseRepresentation):
Expand All @@ -51,6 +53,8 @@ class Cohere(BaseRepresentation):
NOTE: Use `"[KEYWORDS]"` and `"[DOCUMENTS]"` in the prompt
to decide where the keywords and documents need to be
inserted.
system_prompt: The system prompt to be used in the model. If no system prompt is given,
`self.default_system_prompt_` is used instead.
delay_in_seconds: The delay in seconds between consecutive prompts
in order to prevent RateLimitErrors.
nr_docs: The number of documents to pass to OpenAI if a prompt
Expand Down Expand Up @@ -107,8 +111,9 @@ class Cohere(BaseRepresentation):
def __init__(
self,
client,
model: str = "xlarge",
model: str = "command-r",
prompt: str = None,
system_prompt: str = None,
delay_in_seconds: float = None,
nr_docs: int = 4,
diversity: float = None,
Expand All @@ -118,7 +123,9 @@ def __init__(
self.client = client
self.model = model
self.prompt = prompt if prompt is not None else DEFAULT_PROMPT
self.system_prompt = system_prompt if system_prompt is not None else DEFAULT_SYSTEM_PROMPT
self.default_prompt_ = DEFAULT_PROMPT
self.default_system_prompt_ = DEFAULT_SYSTEM_PROMPT
self.delay_in_seconds = delay_in_seconds
self.nr_docs = nr_docs
self.diversity = diversity
Expand Down Expand Up @@ -160,14 +167,14 @@ def extract_topics(
if self.delay_in_seconds:
time.sleep(self.delay_in_seconds)

request = self.client.generate(
request = self.client.chat(
model=self.model,
prompt=prompt,
preamble=self.system_prompt,
message=prompt,
max_tokens=50,
num_generations=1,
stop_sequences=["\n"],
)
label = request.generations[0].text.strip()
label = request.text.strip().replace("Topic name: ", "")
updated_topics[topic] = [(label, 1)] + [("", 0) for _ in range(9)]

return updated_topics
Expand Down
13 changes: 12 additions & 1 deletion bertopic/representation/_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
topic: <topic label>
"""

DEFAULT_SYSTEM_PROMPT = "You are designated as an assistant that identify and extract high-level topics from texts."


class OpenAI(BaseRepresentation):
r"""Using the OpenAI API to generate topic labels based
Expand All @@ -73,6 +75,8 @@ class OpenAI(BaseRepresentation):
NOTE: Use `"[KEYWORDS]"` and `"[DOCUMENTS]"` in the prompt
to decide where the keywords and documents need to be
inserted.
system_prompt: The system prompt to be used in the model. If no system prompt is given,
`self.default_system_prompt_` is used instead.
delay_in_seconds: The delay in seconds between consecutive prompts
in order to prevent RateLimitErrors.
exponential_backoff: Retry requests with a random exponential backoff.
Expand Down Expand Up @@ -144,6 +148,7 @@ def __init__(
client,
model: str = "text-embedding-3-small",
prompt: str = None,
system_prompt: str = None,
generator_kwargs: Mapping[str, Any] = {},
delay_in_seconds: float = None,
exponential_backoff: bool = False,
Expand All @@ -161,7 +166,13 @@ def __init__(
else:
self.prompt = prompt

if chat and system_prompt is None:
self.system_prompt = DEFAULT_SYSTEM_PROMPT
else:
self.system_prompt = system_prompt

self.default_prompt_ = DEFAULT_CHAT_PROMPT if chat else DEFAULT_PROMPT
self.default_system_prompt_ = DEFAULT_SYSTEM_PROMPT
self.delay_in_seconds = delay_in_seconds
self.exponential_backoff = exponential_backoff
self.chat = chat
Expand Down Expand Up @@ -216,7 +227,7 @@ def extract_topics(

if self.chat:
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": prompt},
]
kwargs = {
Expand Down