-
Notifications
You must be signed in to change notification settings - Fork 14.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Few Shot Chat Prompt #8038
Merged
Merged
Few Shot Chat Prompt #8038
Changes from 13 commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
5c427ea
tmp
hinthornw c212f80
Add few shot chat prompt
hinthornw 959f691
typing
hinthornw 47ee39f
config
hinthornw 916472f
Merge branch 'master' into wfh/few_shot_prompt
hinthornw e1b2a2c
Merge branch 'master' into wfh/few_shot_prompt
hinthornw 17a15ea
Merge branch 'master' into wfh/few_shot_prompt
hinthornw 30a4429
add surrounding
hinthornw e8aad37
format and add example
hinthornw bb4db9f
Merge branch 'master' into wfh/few_shot_prompt
hinthornw 8759679
lint
hinthornw 3d6527a
x
eyurtsev 5a92aa1
Update libs/langchain/langchain/prompts/few_shot.py
hinthornw 8eee924
merge
hinthornw aa39b59
merge
hinthornw 06d726f
Merge branch 'master' into wfh/few_shot_prompt
hinthornw afd70e3
ergonomics
hinthornw a93ecf9
test
hinthornw 60a3262
Merge branch 'master' into wfh/few_shot_prompt
hinthornw d5d7af2
update
hinthornw File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,24 +1,25 @@ | ||
"""Prompt template that contains few shot examples.""" | ||
from typing import Any, Dict, List, Optional | ||
from __future__ import annotations | ||
|
||
from pydantic import Extra, root_validator | ||
from typing import Any, Dict, List, Optional, Union | ||
|
||
from pydantic import BaseModel, Extra, root_validator | ||
|
||
from langchain.prompts.base import ( | ||
DEFAULT_FORMATTER_MAPPING, | ||
StringPromptTemplate, | ||
check_valid_template, | ||
) | ||
from langchain.prompts.chat import BaseChatPromptTemplate, BaseMessagePromptTemplate | ||
from langchain.prompts.example_selector.base import BaseExampleSelector | ||
from langchain.prompts.prompt import PromptTemplate | ||
from langchain.schema.messages import BaseMessage, get_buffer_string | ||
from langchain.schema.prompt_template import BasePromptTemplate | ||
|
||
|
||
class FewShotPromptTemplate(StringPromptTemplate): | ||
class _FewShotPromptTemplateMixin(BaseModel): | ||
"""Prompt template that contains few shot examples.""" | ||
|
||
@property | ||
def lc_serializable(self) -> bool: | ||
return False | ||
|
||
examples: Optional[List[dict]] = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not super convinced this abstraction is useful for static examples fwiw |
||
"""Examples to format into the prompt. | ||
Either this or example_selector should be provided.""" | ||
|
@@ -27,26 +28,11 @@ def lc_serializable(self) -> bool: | |
"""ExampleSelector to choose the examples to format into the prompt. | ||
Either this or examples should be provided.""" | ||
|
||
example_prompt: PromptTemplate | ||
"""PromptTemplate used to format an individual example.""" | ||
|
||
suffix: str | ||
"""A prompt template string to put after the examples.""" | ||
|
||
input_variables: List[str] | ||
"""A list of the names of the variables the prompt template expects.""" | ||
|
||
example_separator: str = "\n\n" | ||
"""String separator used to join the prefix, the examples, and suffix.""" | ||
|
||
prefix: str = "" | ||
"""A prompt template string to put before the examples.""" | ||
|
||
template_format: str = "f-string" | ||
"""The format of the prompt template. Options are: 'f-string', 'jinja2'.""" | ||
class Config: | ||
"""Configuration for this pydantic object.""" | ||
|
||
validate_template: bool = True | ||
"""Whether or not to try validating the template.""" | ||
extra = Extra.forbid | ||
arbitrary_types_allowed = True | ||
|
||
@root_validator(pre=True) | ||
def check_examples_and_selector(cls, values: Dict) -> Dict: | ||
|
@@ -65,6 +51,58 @@ def check_examples_and_selector(cls, values: Dict) -> Dict: | |
|
||
return values | ||
|
||
def _get_examples(self, **kwargs: Any) -> List[dict]: | ||
"""Get the examples to use for formatting the prompt. | ||
|
||
Args: | ||
**kwargs: Keyword arguments to be passed to the example selector. | ||
|
||
Returns: | ||
List of examples. | ||
""" | ||
if self.examples is not None: | ||
return self.examples | ||
elif self.example_selector is not None: | ||
return self.example_selector.select_examples(kwargs) | ||
else: | ||
raise ValueError( | ||
"One of 'examples' and 'example_selector' should be provided" | ||
) | ||
|
||
|
||
class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate): | ||
"""Prompt template that contains few shot examples.""" | ||
|
||
@property | ||
def lc_serializable(self) -> bool: | ||
"""Return whether the prompt template is lc_serializable. | ||
|
||
Returns: | ||
Boolean indicating whether the prompt template is lc_serializable. | ||
""" | ||
return False | ||
|
||
validate_template: bool = True | ||
"""Whether or not to try validating the template.""" | ||
|
||
input_variables: List[str] | ||
"""A list of the names of the variables the prompt template expects.""" | ||
|
||
example_prompt: PromptTemplate | ||
"""PromptTemplate used to format an individual example.""" | ||
|
||
suffix: str | ||
"""A prompt template string to put after the examples.""" | ||
|
||
example_separator: str = "\n\n" | ||
"""String separator used to join the prefix, the examples, and suffix.""" | ||
|
||
prefix: str = "" | ||
"""A prompt template string to put before the examples.""" | ||
|
||
template_format: str = "f-string" | ||
"""The format of the prompt template. Options are: 'f-string', 'jinja2'.""" | ||
|
||
@root_validator() | ||
def template_is_valid(cls, values: Dict) -> Dict: | ||
"""Check that prefix, suffix, and input variables are consistent.""" | ||
|
@@ -82,19 +120,11 @@ class Config: | |
extra = Extra.forbid | ||
arbitrary_types_allowed = True | ||
|
||
def _get_examples(self, **kwargs: Any) -> List[dict]: | ||
if self.examples is not None: | ||
return self.examples | ||
elif self.example_selector is not None: | ||
return self.example_selector.select_examples(kwargs) | ||
else: | ||
raise ValueError | ||
|
||
def format(self, **kwargs: Any) -> str: | ||
"""Format the prompt with the inputs. | ||
|
||
Args: | ||
kwargs: Any arguments to be passed to the prompt template. | ||
**kwargs: Any arguments to be passed to the prompt template. | ||
|
||
Returns: | ||
A formatted string. | ||
|
@@ -132,3 +162,206 @@ def dict(self, **kwargs: Any) -> Dict: | |
if self.example_selector: | ||
raise ValueError("Saving an example selector is not currently supported") | ||
return super().dict(**kwargs) | ||
|
||
|
||
class FewShotChatMessagePromptTemplate( | ||
BaseChatPromptTemplate, _FewShotPromptTemplateMixin | ||
): | ||
"""Chat prompt template that supports few-shot examples. | ||
|
||
The high level structure of produced by this prompt template is a list of messages | ||
consisting of prefix message(s), example message(s), and suffix message(s). | ||
|
||
This structure enables creating a conversation with intermediate examples like: | ||
|
||
System: You are a helpful AI Assistant | ||
Human: What is 2+2? | ||
AI: 4 | ||
Human: What is 2+3? | ||
AI: 5 | ||
Human: What is 4+4? | ||
|
||
This prompt template can be used to generate a fixed list of examples or else | ||
to dynamically select examples based on the input. | ||
|
||
Examples: | ||
|
||
Prompt template with a fixed list of examples (matching the sample | ||
conversation above): | ||
|
||
.. code-block:: python | ||
|
||
from langchain.prompts import ( | ||
ChatPromptTemplate, HumanMessagePromptTemplate, AIMessagePromptTemplate | ||
) | ||
from langchain.prompts.few_shot import FewShotChatMessagePromptTemplate | ||
|
||
examples = [ | ||
{"input": "2+2", "output": "4"}, | ||
{"input": "2+3", "output": "5"}, | ||
] | ||
|
||
# This is a prompt template used to format each individual example. | ||
example_prompt = ChatPromptTemplate.from_messages( | ||
[ | ||
HumanMessagePromptTemplate.from_template("{input}"), | ||
AIMessagePromptTemplate.from_template("{output}"), | ||
] | ||
) | ||
|
||
few_shot_prompt = FewShotChatMessagePromptTemplate( | ||
input_variables=["input"], | ||
prefix=[SystemMessage(content="You are a helpful AI Assistant")], | ||
example_prompt=example_prompt, | ||
examples=examples, | ||
suffix=[HumanMessagePromptTemplate.from_template("{input}")], | ||
) | ||
|
||
few_shot_prompt.format(input="What is 4+4?") | ||
|
||
Prompt template with dynamically selected examples: | ||
|
||
.. code-block:: python | ||
|
||
from langchain.prompts import SemanticSimilarityExampleSelector | ||
from langchain.embeddings import OpenAIEmbeddings | ||
from langchain.vectorstores import Chroma | ||
|
||
examples = [ | ||
{"input": "2+2", "output": "4"}, | ||
{"input": "2+3", "output": "5"}, | ||
{"input": "2+4", "output": "6"}, | ||
# ... | ||
] | ||
|
||
to_vectorize = [ | ||
" ".join(example.values()) | ||
for example in examples | ||
] | ||
embeddings = OpenAIEmbeddings() | ||
vectorstore = Chroma.from_texts( | ||
to_vectorize, embeddings, metadatas=examples | ||
) | ||
example_selector = SemanticSimilarityExampleSelector( | ||
vectorstore=vectorstore | ||
) | ||
|
||
from langchain.schema import SystemMessage | ||
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate | ||
from langchain.prompts.few_shot import FewShotChatMessagePromptTemplate | ||
|
||
# Define how each example will be formatted. | ||
# In this case, each example will become 2 messages: | ||
# 1 human, and 1 AI | ||
example_prompt = ChatPromptTemplate.from_messages( | ||
[ | ||
HumanMessagePromptTemplate.from_template("{input}"), | ||
AIMessagePromptTemplate.from_template("{output}"), | ||
] | ||
) | ||
|
||
# Define the overall prompt. | ||
few_shot_prompt = FewShotChatMessagePromptTemplate( | ||
input_variables=["input"], | ||
prefix = [SystemMessage(content="You are a helpful AI Assistant")], | ||
example_selector=example_selector, | ||
example_prompt=example_prompt, | ||
suffix = [HumanMessagePromptTemplate.from_template("{input}")], | ||
) | ||
# Show the prompt | ||
print(few_shot_prompt.format_messages(input="What's 3+3?")) | ||
|
||
# Use within an LLMChain | ||
from langchain.chains import LLMChain | ||
from langchain.chat_models import ChatOpenAI | ||
chain = LLMChain( | ||
llm=ChatOpenAI(), | ||
prompt=few_shot_prompt, | ||
) | ||
chain({"input": "What's 3+3?"}) | ||
""" | ||
|
||
@property | ||
def lc_serializable(self) -> bool: | ||
"""Return whether the prompt template is lc_serializable. | ||
|
||
Returns: | ||
Boolean indicating whether the prompt template is lc_serializable. | ||
""" | ||
return False | ||
|
||
prefix: List[ | ||
Union[BaseMessagePromptTemplate, BaseChatPromptTemplate, BaseMessage] | ||
] = [] | ||
"""The class to format the prefix.""" | ||
example_prompt: Union[BaseMessagePromptTemplate, BaseChatPromptTemplate] | ||
"""The class to format each example.""" | ||
suffix: List[ | ||
Union[BaseMessagePromptTemplate, BaseChatPromptTemplate, BaseMessage] | ||
] = [] | ||
"""The class to format the suffix.""" | ||
|
||
class Config: | ||
"""Configuration for this pydantic object.""" | ||
|
||
extra = Extra.forbid | ||
arbitrary_types_allowed = True | ||
|
||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]: | ||
"""Format kwargs into a list of messages. | ||
|
||
Args: | ||
**kwargs: keyword arguments to use for filling in templates in messages. | ||
|
||
Returns: | ||
A list of formatted messages with all template variables filled in. | ||
""" | ||
# Get the examples to use. | ||
examples = self._get_examples(**kwargs) | ||
examples = [ | ||
{k: e[k] for k in self.example_prompt.input_variables} for e in examples | ||
] | ||
# Format prefix examples | ||
prefix_messages = [ | ||
message | ||
for template in self.prefix | ||
for message in ( | ||
template.format_messages(**kwargs) | ||
if isinstance(template, (BasePromptTemplate, BaseMessagePromptTemplate)) | ||
else [template] | ||
) | ||
] | ||
# Format the examples. | ||
messages = [ | ||
message | ||
for example in examples | ||
for message in self.example_prompt.format_messages(**example) | ||
] | ||
# Format suffix examples | ||
suffix_messages = [ | ||
message | ||
for template in self.suffix | ||
for message in ( | ||
template.format_messages(**kwargs) | ||
if isinstance(template, (BasePromptTemplate, BaseMessagePromptTemplate)) | ||
else [template] | ||
) | ||
] | ||
return prefix_messages + messages + suffix_messages | ||
|
||
def format(self, **kwargs: Any) -> str: | ||
"""Format the prompt with inputs generating a string. | ||
|
||
Use this method to generate a string representation of a prompt consisting | ||
of chat messages. | ||
|
||
Useful for feeding into a string based completion language model or debugging. | ||
|
||
Args: | ||
**kwargs: keyword arguments to use for formatting. | ||
|
||
Returns: | ||
A string representation of the prompt | ||
""" | ||
messages = self.format_messages(**kwargs) | ||
return get_buffer_string(messages) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does mixin need to be BaseModel (seems like everything that uses it already is)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was giving some weird metaclass resolution things without but I'll double check that still is the case