Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ConversationalRetrievalChain #85

Merged
merged 1 commit into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions src/libs/LangChain.Core/Base/BaseChain.cs
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,13 @@ public abstract class BaseChain(IChainInputs fields) : IChain
/// Run the chain using a simple input/output.
/// </summary>
/// <param name="input">The dict input to use to execute the chain.</param>
/// <param name="callbacks">
/// Callbacks to use for this chain run. These will be called in
/// addition to callbacks passed to the chain during construction, but only
/// these runtime callbacks will propagate to calls to other objects.
/// </param>
/// <returns>A text value containing the result of the chain.</returns>
public virtual async Task<string> Run(Dictionary<string, object> input)
public virtual async Task<string> Run(Dictionary<string, object> input, ICallbacks? callbacks = null)
{
var keysLengthDifferent = InputKeys.Length != input.Count;

Expand All @@ -73,7 +78,7 @@ public virtual async Task<string> Run(Dictionary<string, object> input)
throw new ArgumentException($"Chain {ChainType()} expects {InputKeys.Length} but, received {input.Count}");
}

var returnValues = await CallAsync(new ChainValues(input));
var returnValues = await CallAsync(new ChainValues(input), callbacks);

var returnValue = returnValues.Value.FirstOrDefault(kv => kv.Key == OutputKeys[0]).Value;

Expand All @@ -88,7 +93,8 @@ public virtual async Task<string> Run(Dictionary<string, object> input)
/// <param name="tags"></param>
/// <param name="metadata"></param>
/// <returns></returns>
public async Task<IChainValues> CallAsync(IChainValues values,
public async Task<IChainValues> CallAsync(
IChainValues values,
ICallbacks? callbacks = null,
IReadOnlyList<string>? tags = null,
IReadOnlyDictionary<string, object>? metadata = null)
Expand Down
5 changes: 5 additions & 0 deletions src/libs/LangChain.Core/Callback/ICallbacks.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ namespace LangChain.Callback;

public interface ICallbacks;

public static class ManagerCallbacksExtensions
{
public static ManagerCallbacks ToCallbacks(this ParentRunManager source) => new ManagerCallbacks(source.GetChild());
}

public record ManagerCallbacks(CallbackManager Value) : ICallbacks;

public record HandlersCallbacks(List<BaseCallbackHandler> Value) : ICallbacks;
5 changes: 4 additions & 1 deletion src/libs/LangChain.Core/Chains/Base/IChain.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ public interface IChain
{
string[] InputKeys { get; }
string[] OutputKeys { get; }


Task<string?> Run(string input);
Task<string> Run(Dictionary<string, object> input, ICallbacks? callbacks = null);

Task<IChainValues> CallAsync(IChainValues values,
ICallbacks? callbacks = null,
IReadOnlyList<string>? tags = null,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,18 @@ namespace LangChain.Chains.CombineDocuments;
/// </summary>
public class StuffDocumentsChain : BaseCombineDocumentsChain
{
private readonly ILlmChain _llmChain;
public readonly ILlmChain LlmChain;
private readonly BasePromptTemplate _documentPrompt;
private readonly string _documentVariableName;
private readonly string _documentSeparator = "\n\n";
private readonly string _documentSeparator;

public StuffDocumentsChain(StuffDocumentsChainInput input) : base(input)
{
_llmChain = input.LlmChain;
LlmChain = input.LlmChain;
_documentPrompt = input.DocumentPrompt;
_documentSeparator = input.DocumentSeparator;

var llmChainVariables = _llmChain.Prompt.InputVariables;
var llmChainVariables = LlmChain.Prompt.InputVariables;

if (input.DocumentVariableName == null)
{
Expand All @@ -50,7 +50,7 @@ public StuffDocumentsChain(StuffDocumentsChainInput input) : base(input)
}

public override string[] InputKeys =>
base.InputKeys.Concat(_llmChain.InputKeys.Where(k => k != _documentVariableName)).ToArray();
base.InputKeys.Concat(LlmChain.InputKeys.Where(k => k != _documentVariableName)).ToArray();

public override string ChainType() => "stuff_documents_chain";

Expand All @@ -59,17 +59,17 @@ public StuffDocumentsChain(StuffDocumentsChainInput input) : base(input)
IReadOnlyDictionary<string, object> otherKeys)
{
var inputs = await GetInputs(docs, otherKeys);
var predict = await _llmChain.Predict(new ChainValues(inputs.Value));
var predict = await LlmChain.Predict(new ChainValues(inputs.Value));

return (predict.ToString() ?? string.Empty, new Dictionary<string, object>());
}

public override async Task<int?> PromptLength(IReadOnlyList<Document> docs, IReadOnlyDictionary<string, object> otherKeys)
{
if (_llmChain.Llm is ISupportsCountTokens supportsCountTokens)
if (LlmChain.Llm is ISupportsCountTokens supportsCountTokens)
{
var inputs = await GetInputs(docs, otherKeys);
var prompt = await _llmChain.Prompt.FormatPromptValue(inputs);
var prompt = await LlmChain.Prompt.FormatPromptValue(inputs);

return supportsCountTokens.CountTokens(prompt.ToString());
}
Expand All @@ -84,7 +84,7 @@ private async Task<InputValues> GetInputs(IReadOnlyList<Document> docs, IReadOnl
var inputs = new Dictionary<string, object>();
foreach (var kv in otherKeys)
{
if (_llmChain.Prompt.InputVariables.Contains(kv.Key))
if (LlmChain.Prompt.InputVariables.Contains(kv.Key))
{
inputs[kv.Key] = kv.Value;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
using LangChain.Abstractions.Schema;
using LangChain.Base;
using LangChain.Callback;
using LangChain.Common;
using LangChain.Docstore;
using LangChain.Providers;
using LangChain.Schema;

namespace LangChain.Chains.ConversationalRetrieval;

/// <summary>
/// Chain for chatting with an index.
/// </summary>
public abstract class BaseConversationalRetrievalChain(BaseConversationalRetrievalChainInput fields) : BaseChain(fields)
{
/// <summary> Chain input fields </summary>
private readonly BaseConversationalRetrievalChainInput _fields = fields;

public override string[] InputKeys => new[] { "question", "chat_history" };

public override string[] OutputKeys
{
get
{
var outputKeys = new List<string> { _fields.OutputKey };
if (_fields.ReturnSourceDocuments)
{
outputKeys.Add("source_documents");
}

if (_fields.ReturnGeneratedQuestion)
{
outputKeys.Add("generated_question");
}

return outputKeys.ToArray();
}
}

protected override async Task<IChainValues> CallAsync(IChainValues values, CallbackManagerForChainRun? runManager)
{
runManager ??= BaseRunManager.GetNoopManager<CallbackManagerForChainRun>();

var question = values.Value["question"].ToString();

var getChatHistory = _fields.GetChatHistory;
var chatHistoryStr = getChatHistory(values.Value["chat_history"] as List<Message>);

string? newQuestion;
if (chatHistoryStr != null)
{
var callbacks = runManager.GetChild();
newQuestion = await _fields.QuestionGenerator.Run(
new Dictionary<string, object>
{
["question"] = question,
["chat_history"] = chatHistoryStr
},
callbacks: new ManagerCallbacks(callbacks));
}
else
{
newQuestion = question;
}

var docs = await GetDocsAsync(newQuestion, values.Value);
var newInputs = new Dictionary<string, object>
{
["chat_history"] = chatHistoryStr,
["input_documents"] = docs
};

if (_fields.RephraseQuestion)
{
newInputs["question"] = newQuestion;
}

newInputs.TryAddKeyValues(values.Value);

var answer = await _fields.CombineDocsChain.Run(
input: newInputs,
callbacks: new ManagerCallbacks(runManager.GetChild()));

var output = new Dictionary<string, object>
{
[_fields.OutputKey] = answer
};

if (_fields.ReturnSourceDocuments)
{
output["source_documents"] = docs;
}

if (_fields.ReturnGeneratedQuestion)
{
output["generated_question"] = newQuestion;
}

return new ChainValues(output);
}

/// <summary>
/// Get docs.
/// </summary>
/// <param name="question"></param>
/// <param name="inputs"></param>
/// <param name="runManager"></param>
/// <returns></returns>
protected abstract Task<List<Document>> GetDocsAsync(
string question,
Dictionary<string, object> inputs,
CallbackManagerForChainRun? runManager = null);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
using LangChain.Base;
using LangChain.Chains.CombineDocuments;
using LangChain.Chains.LLM;
using LangChain.Providers;

namespace LangChain.Chains.ConversationalRetrieval;

public class BaseConversationalRetrievalChainInput(
BaseCombineDocumentsChain combineDocsChain,
ILlmChain questionGenerator)
: ChainInputs
{
/// <summary>
/// The chain used to combine any retrieved documents.
/// </summary>
public BaseCombineDocumentsChain CombineDocsChain { get; } = combineDocsChain;

/// <summary>
/// The chain used to generate a new question for the sake of retrieval.
///
/// This chain will take in the current question (with variable `question`)
/// and any chat history (with variable `chat_history`) and will produce
/// a new standalone question to be used later on.
/// </summary>
public ILlmChain QuestionGenerator { get; } = questionGenerator;

/// <summary>
/// The output key to return the final answer of this chain in.
/// </summary>
public string OutputKey { get; set; } = "answer";

/// <summary>
/// Whether or not to pass the new generated question to the combine_docs_chain.
///
/// If True, will pass the new generated question along.
/// If False, will only use the new generated question for retrieval and pass the
/// original question along to the <see cref="CombineDocsChain"/>.
/// </summary>
public bool RephraseQuestion { get; set; } = true;

/// <summary>
/// Return the retrieved source documents as part of the final result.
/// </summary>
public bool ReturnSourceDocuments { get; set; }

/// <summary>
/// Return the generated question as part of the final result.
/// </summary>
public bool ReturnGeneratedQuestion { get; set; }

/// <summary>
/// An optional function to get a string of the chat history.
/// If None is provided, will use a default.
/// </summary>
public Func<IReadOnlyList<Message>, string?> GetChatHistory { get; set; } =
ChatTurnTypeHelper.GetChatHistory;

/// <summary>
/// If specified, the chain will return a fixed response if no docs
/// are found for the question.
/// </summary>
public string? ResponseIfNoDocsFound { get; set; }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
using System.Text;
using LangChain.Providers;

namespace LangChain.Chains.ConversationalRetrieval;

public static class ChatTurnTypeHelper
{
public static string GetChatHistory(IReadOnlyList<Message> chatHistory)
{
var buffer = new StringBuilder();

foreach (var message in chatHistory)
{
var rolePrefix = message.Role switch
{
MessageRole.Human => "Human: ",
MessageRole.Ai => "Assistant: ",
_ => $"{message.Role}: "
};

buffer.AppendLine($"{rolePrefix}{message.Content}");

Check warning on line 21 in src/libs/LangChain.Core/Chains/ConversationalRetrieval/ChatTurnTypeHelper.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

The behavior of 'StringBuilder.AppendLine(ref StringBuilder.AppendInterpolatedStringHandler)' could vary based on the current user's locale settings. Replace this call in 'ChatTurnTypeHelper.GetChatHistory(IReadOnlyList<Message>)' with a call to 'StringBuilder.AppendLine(IFormatProvider, ref StringBuilder.AppendInterpolatedStringHandler)'. (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1305)

Check warning on line 21 in src/libs/LangChain.Core/Chains/ConversationalRetrieval/ChatTurnTypeHelper.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

The behavior of 'StringBuilder.AppendLine(ref StringBuilder.AppendInterpolatedStringHandler)' could vary based on the current user's locale settings. Replace this call in 'ChatTurnTypeHelper.GetChatHistory(IReadOnlyList<Message>)' with a call to 'StringBuilder.AppendLine(IFormatProvider, ref StringBuilder.AppendInterpolatedStringHandler)'. (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1305)
}

return buffer.ToString();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
using LangChain.Callback;
using LangChain.Chains.CombineDocuments;
using LangChain.Docstore;
using LangChain.Providers;

namespace LangChain.Chains.ConversationalRetrieval;

/// <summary>
/// Chain for having a conversation based on retrieved documents.
///
/// This chain takes in chat history (a list of messages) and new questions,
/// and then returns an answer to that question.
/// The algorithm for this chain consists of three parts:
///
/// 1. Use the chat history and the new question to create a "standalone question".
/// This is done so that this question can be passed into the retrieval step to fetch
/// relevant documents. If only the new question was passed in, then relevant context
/// may be lacking. If the whole conversation was passed into retrieval, there may
/// be unnecessary information there that would distract from retrieval.
///
/// 2. This new question is passed to the retriever and relevant documents are
/// returned.
///
/// 3. The retrieved documents are passed to an LLM along with either the new question
/// (default behavior) or the original question and chat history to generate a final
/// response.
/// </summary>
public class ConversationalRetrievalChain(ConversationalRetrievalChainInput fields)
: BaseConversationalRetrievalChain(fields)
{
private readonly ConversationalRetrievalChainInput _fields = fields;

public override string ChainType() => "conversational_retrieval";

protected override async Task<List<Document>> GetDocsAsync(
string question,
Dictionary<string, object> inputs,
CallbackManagerForChainRun? runManager = null)
{
var docs = await _fields.Retriever.GetRelevantDocumentsAsync(
question,
callbacks: runManager?.ToCallbacks());

return ReduceTokensBelowLimit(docs);
}

public List<Document> ReduceTokensBelowLimit(IEnumerable<Document> docs)
{
var docsList = docs.ToList();
var numDocs = docsList.Count;

if (_fields.MaxTokensLimit != null &&
_fields.CombineDocsChain is StuffDocumentsChain stuffDocumentsChain &&
stuffDocumentsChain.LlmChain.Llm is ISupportsCountTokens counter)
{
var tokens = docsList.Select(doc => counter.CountTokens(doc.PageContent)).ToArray();
var tokenCount = tokens.Sum();

while (tokenCount > _fields.MaxTokensLimit)
{
numDocs -= 1;
tokenCount -= tokens[numDocs];
}
}

return docsList.Take(numDocs).ToList();
}
}
Loading