Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Few Shot Chat Prompt #8038

Merged
merged 20 commits into from
Jul 28, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion libs/langchain/langchain/prompts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
NGramOverlapExampleSelector,
SemanticSimilarityExampleSelector,
)
from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts.few_shot import (
FewShotChatMessagePromptTemplate,
FewShotPromptTemplate,
)
from langchain.prompts.few_shot_with_templates import FewShotPromptWithTemplates
from langchain.prompts.loading import load_prompt
from langchain.prompts.pipeline import PipelinePromptTemplate
Expand All @@ -42,4 +45,5 @@
"StringPromptTemplate",
"SystemMessagePromptTemplate",
"load_prompt",
"FewShotChatMessagePromptTemplate",
]
303 changes: 268 additions & 35 deletions libs/langchain/langchain/prompts/few_shot.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
"""Prompt template that contains few shot examples."""
from typing import Any, Dict, List, Optional
from __future__ import annotations

from pydantic import Extra, root_validator
from typing import Any, Dict, List, Optional, Union

from pydantic import BaseModel, Extra, root_validator

from langchain.prompts.base import (
DEFAULT_FORMATTER_MAPPING,
StringPromptTemplate,
check_valid_template,
)
from langchain.prompts.chat import BaseChatPromptTemplate, BaseMessagePromptTemplate
from langchain.prompts.example_selector.base import BaseExampleSelector
from langchain.prompts.prompt import PromptTemplate
from langchain.schema.messages import BaseMessage, get_buffer_string
from langchain.schema.prompt_template import BasePromptTemplate


class FewShotPromptTemplate(StringPromptTemplate):
class _FewShotPromptTemplateMixin(BaseModel):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does mixin need to be BaseModel (seems like everything that uses it already is)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was giving some weird metaclass resolution things without but I'll double check that still is the case

"""Prompt template that contains few shot examples."""

@property
def lc_serializable(self) -> bool:
return False

examples: Optional[List[dict]] = None
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not super convinced this abstraction is useful for static examples fwiw

"""Examples to format into the prompt.
Either this or example_selector should be provided."""
Expand All @@ -27,26 +28,11 @@ def lc_serializable(self) -> bool:
"""ExampleSelector to choose the examples to format into the prompt.
Either this or examples should be provided."""

example_prompt: PromptTemplate
"""PromptTemplate used to format an individual example."""

suffix: str
"""A prompt template string to put after the examples."""

input_variables: List[str]
"""A list of the names of the variables the prompt template expects."""

example_separator: str = "\n\n"
"""String separator used to join the prefix, the examples, and suffix."""

prefix: str = ""
"""A prompt template string to put before the examples."""

template_format: str = "f-string"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
class Config:
"""Configuration for this pydantic object."""

validate_template: bool = True
"""Whether or not to try validating the template."""
extra = Extra.forbid
arbitrary_types_allowed = True

@root_validator(pre=True)
def check_examples_and_selector(cls, values: Dict) -> Dict:
Expand All @@ -65,6 +51,58 @@ def check_examples_and_selector(cls, values: Dict) -> Dict:

return values

def _get_examples(self, **kwargs: Any) -> List[dict]:
"""Get the examples to use for formatting the prompt.

Args:
**kwargs: Keyword arguments to be passed to the example selector.

Returns:
List of examples.
"""
if self.examples is not None:
return self.examples
elif self.example_selector is not None:
return self.example_selector.select_examples(kwargs)
else:
raise ValueError(
"One of 'examples' and 'example_selector' should be provided"
)


class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate):
"""Prompt template that contains few shot examples."""

@property
def lc_serializable(self) -> bool:
"""Return whether the prompt template is lc_serializable.

Returns:
Boolean indicating whether the prompt template is lc_serializable.
"""
return False

validate_template: bool = True
"""Whether or not to try validating the template."""

input_variables: List[str]
"""A list of the names of the variables the prompt template expects."""

example_prompt: PromptTemplate
"""PromptTemplate used to format an individual example."""

suffix: str
"""A prompt template string to put after the examples."""

example_separator: str = "\n\n"
"""String separator used to join the prefix, the examples, and suffix."""

prefix: str = ""
"""A prompt template string to put before the examples."""

template_format: str = "f-string"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""

@root_validator()
def template_is_valid(cls, values: Dict) -> Dict:
"""Check that prefix, suffix, and input variables are consistent."""
Expand All @@ -82,19 +120,11 @@ class Config:
extra = Extra.forbid
arbitrary_types_allowed = True

def _get_examples(self, **kwargs: Any) -> List[dict]:
if self.examples is not None:
return self.examples
elif self.example_selector is not None:
return self.example_selector.select_examples(kwargs)
else:
raise ValueError

def format(self, **kwargs: Any) -> str:
"""Format the prompt with the inputs.

Args:
kwargs: Any arguments to be passed to the prompt template.
**kwargs: Any arguments to be passed to the prompt template.

Returns:
A formatted string.
Expand Down Expand Up @@ -132,3 +162,206 @@ def dict(self, **kwargs: Any) -> Dict:
if self.example_selector:
raise ValueError("Saving an example selector is not currently supported")
return super().dict(**kwargs)


class FewShotChatMessagePromptTemplate(
BaseChatPromptTemplate, _FewShotPromptTemplateMixin
):
"""Chat prompt template that supports few-shot examples.

The high level structure of produced by this prompt template is a list of messages
consisting of prefix message(s), example message(s), and suffix message(s).

This structure enables creating a conversation with intermediate examples like:

System: You are a helpful AI Assistant
Human: What is 2+2?
AI: 4
Human: What is 2+3?
AI: 5
Human: What is 4+4?

This prompt template can be used to generate a fixed list of examples or else
to dynamically select examples based on the input.

Examples:

Prompt template with a fixed list of examples (matching the sample
conversation above):

.. code-block:: python

from langchain.prompts import (
ChatPromptTemplate, HumanMessagePromptTemplate, AIMessagePromptTemplate
)
from langchain.prompts.few_shot import FewShotChatMessagePromptTemplate

examples = [
{"input": "2+2", "output": "4"},
{"input": "2+3", "output": "5"},
]

# This is a prompt template used to format each individual example.
example_prompt = ChatPromptTemplate.from_messages(
[
HumanMessagePromptTemplate.from_template("{input}"),
AIMessagePromptTemplate.from_template("{output}"),
]
)

few_shot_prompt = FewShotChatMessagePromptTemplate(
input_variables=["input"],
prefix=[SystemMessage(content="You are a helpful AI Assistant")],
example_prompt=example_prompt,
examples=examples,
suffix=[HumanMessagePromptTemplate.from_template("{input}")],
)

few_shot_prompt.format(input="What is 4+4?")

Prompt template with dynamically selected examples:

.. code-block:: python

from langchain.prompts import SemanticSimilarityExampleSelector
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Chroma

examples = [
{"input": "2+2", "output": "4"},
{"input": "2+3", "output": "5"},
{"input": "2+4", "output": "6"},
# ...
]

to_vectorize = [
" ".join(example.values())
for example in examples
]
embeddings = OpenAIEmbeddings()
vectorstore = Chroma.from_texts(
to_vectorize, embeddings, metadatas=examples
)
example_selector = SemanticSimilarityExampleSelector(
vectorstore=vectorstore
)

from langchain.schema import SystemMessage
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain.prompts.few_shot import FewShotChatMessagePromptTemplate

# Define how each example will be formatted.
# In this case, each example will become 2 messages:
# 1 human, and 1 AI
example_prompt = ChatPromptTemplate.from_messages(
[
HumanMessagePromptTemplate.from_template("{input}"),
AIMessagePromptTemplate.from_template("{output}"),
]
)

# Define the overall prompt.
few_shot_prompt = FewShotChatMessagePromptTemplate(
input_variables=["input"],
prefix = [SystemMessage(content="You are a helpful AI Assistant")],
example_selector=example_selector,
example_prompt=example_prompt,
suffix = [HumanMessagePromptTemplate.from_template("{input}")],
)
# Show the prompt
print(few_shot_prompt.format_messages(input="What's 3+3?"))

# Use within an LLMChain
from langchain.chains import LLMChain
from langchain.chat_models import ChatOpenAI
chain = LLMChain(
llm=ChatOpenAI(),
prompt=few_shot_prompt,
)
chain({"input": "What's 3+3?"})
"""

@property
def lc_serializable(self) -> bool:
"""Return whether the prompt template is lc_serializable.

Returns:
Boolean indicating whether the prompt template is lc_serializable.
"""
return False

prefix: List[
Union[BaseMessagePromptTemplate, BaseChatPromptTemplate, BaseMessage]
] = []
"""The class to format the prefix."""
example_prompt: Union[BaseMessagePromptTemplate, BaseChatPromptTemplate]
"""The class to format each example."""
suffix: List[
Union[BaseMessagePromptTemplate, BaseChatPromptTemplate, BaseMessage]
] = []
"""The class to format the suffix."""

class Config:
"""Configuration for this pydantic object."""

extra = Extra.forbid
arbitrary_types_allowed = True

def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format kwargs into a list of messages.

Args:
**kwargs: keyword arguments to use for filling in templates in messages.

Returns:
A list of formatted messages with all template variables filled in.
"""
# Get the examples to use.
examples = self._get_examples(**kwargs)
examples = [
{k: e[k] for k in self.example_prompt.input_variables} for e in examples
]
# Format prefix examples
prefix_messages = [
message
for template in self.prefix
for message in (
template.format_messages(**kwargs)
if isinstance(template, (BasePromptTemplate, BaseMessagePromptTemplate))
else [template]
)
]
# Format the examples.
messages = [
message
for example in examples
for message in self.example_prompt.format_messages(**example)
]
# Format suffix examples
suffix_messages = [
message
for template in self.suffix
for message in (
template.format_messages(**kwargs)
if isinstance(template, (BasePromptTemplate, BaseMessagePromptTemplate))
else [template]
)
]
return prefix_messages + messages + suffix_messages

def format(self, **kwargs: Any) -> str:
"""Format the prompt with inputs generating a string.

Use this method to generate a string representation of a prompt consisting
of chat messages.

Useful for feeding into a string based completion language model or debugging.

Args:
**kwargs: keyword arguments to use for formatting.

Returns:
A string representation of the prompt
"""
messages = self.format_messages(**kwargs)
return get_buffer_string(messages)
Loading