Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

callbacks refactor #49

Merged
merged 7 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion examples/LangChain.Samples.Prompts/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@
HumanMessagePromptTemplate.FromTemplate("{text}")
});

var chainB = new LlmChain(new LlmChainInput(chat, chatPrompt));
var chainB = new LlmChain(new LlmChainInput(chat, chatPrompt)
{
Verbose = true
});

var resultB = await chainB.CallAsync(new ChainValues(new Dictionary<string, object>(3)
{
Expand Down
17 changes: 11 additions & 6 deletions examples/LangChain.Samples.SequentialChain/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

var chainOne = new LlmChain(new LlmChainInput(llm, firstPrompt)
{
Verbose = true,
OutputKey = "company_name"
});

Expand All @@ -20,16 +21,20 @@

var chainTwo = new LlmChain(new LlmChainInput(llm, secondPrompt));

var overallChain = new SequentialChain(new SequentialChainInput(new []
{
chainOne,
chainTwo
}, new []{"product"}));
var overallChain = new SequentialChain(new SequentialChainInput(
new[]
{
chainOne,
chainTwo
},
new[] { "product" },
new[] { "company_name", "text" }
));

var result = await overallChain.CallAsync(new ChainValues(new Dictionary<string, object>(1)
{
{ "product", "colourful socks" }
}));

Console.WriteLine(result.Value["text"]);
Console.WriteLine("Test");
Console.WriteLine("SequentialChain sample finished.");
128 changes: 74 additions & 54 deletions src/libs/LangChain.Core/Base/BaseCallbackHandler.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
using LangChain.Abstractions.Chains.Base;
using LangChain.Docstore;
using LangChain.LLMS;
using LangChain.Providers;
using LangChain.Retrievers;
using LangChain.Schema;

namespace LangChain.Base;
Expand All @@ -7,11 +11,36 @@ namespace LangChain.Base;
public abstract class BaseCallbackHandler : IBaseCallbackHandler
{
/// <inheritdoc />
public string Name { get; protected set; }
public abstract string Name { get; }

public bool IgnoreLlm { get; set; }
public bool IgnoreRetry { get; set; }
public bool IgnoreChain { get; set; }
public bool IgnoreAgent { get; set; }
public bool IgnoreRetriever { get; set; }
public bool IgnoreChatModel { get; set; }

/// <summary>
///
/// </summary>
/// <param name="input"></param>
protected BaseCallbackHandler(IBaseCallbackHandlerInput input)
{
input = input ?? throw new ArgumentNullException(nameof(input));

IgnoreLlm = input.IgnoreLlm;
IgnoreRetry = input.IgnoreRetry;
IgnoreChain = input.IgnoreChain;
IgnoreAgent = input.IgnoreAgent;
IgnoreRetriever = input.IgnoreRetriever;
IgnoreChatModel = input.IgnoreChatModel;
}

/// <inheritdoc />
public abstract Task HandleLlmStartAsync(BaseLlm llm, string[] prompts, string runId, string? parentRunId = null,
Dictionary<string, object>? extraParams = null);
public abstract Task HandleLlmStartAsync(
BaseLlm llm, string[] prompts, string runId, string? parentRunId = null,
List<string>? tags = null, Dictionary<string, object>? metadata = null,
string name = null, Dictionary<string, object>? extraParams = null);

/// <inheritdoc />
public abstract Task HandleLlmNewTokenAsync(string token, string runId, string? parentRunId = null);
Expand All @@ -23,20 +52,42 @@ public abstract Task HandleLlmStartAsync(BaseLlm llm, string[] prompts, string r
public abstract Task HandleLlmEndAsync(LlmResult output, string runId, string? parentRunId = null);

/// <inheritdoc />
public abstract Task HandleChatModelStartAsync(Dictionary<string, object> llm, List<List<object>> messages, string runId, string? parentRunId = null,
public abstract Task HandleChatModelStartAsync(BaseLlm llm, List<List<Message>> messages, string runId,
string? parentRunId = null,
Dictionary<string, object>? extraParams = null);

/// <inheritdoc />
public abstract Task HandleChainStartAsync(Dictionary<string, object> chain, Dictionary<string, object> inputs, string runId, string? parentRunId = null);
public abstract Task HandleChainStartAsync(IChain chain, Dictionary<string, object> inputs,
string runId, string? parentRunId = null,
List<string>? tags = null,
Dictionary<string, object>? metadata = null,
string runType = null,
string name = null,
Dictionary<string, object>? extraParams = null);

/// <inheritdoc />
public abstract Task HandleChainErrorAsync(Exception err, string runId, string? parentRunId = null);
public abstract Task HandleChainErrorAsync(
Exception err, string runId,
Dictionary<string, object>? inputs = null,
string? parentRunId = null);

/// <inheritdoc />
public abstract Task HandleChainEndAsync(Dictionary<string, object> outputs, string runId, string? parentRunId = null);
public abstract Task HandleChainEndAsync(
Dictionary<string, object>? inputs,
Dictionary<string, object> outputs,
string runId,
string? parentRunId = null);

/// <inheritdoc />
public abstract Task HandleToolStartAsync(Dictionary<string, object> tool, string input, string runId, string? parentRunId = null);
public abstract Task HandleToolStartAsync(
Dictionary<string, object> tool,
string input, string runId,
string? parentRunId = null,
List<string>? tags = null,
Dictionary<string, object>? metadata = null,
string runType = null,
string name = null,
Dictionary<string, object>? extraParams = null);

/// <inheritdoc />
public abstract Task HandleToolErrorAsync(Exception err, string runId, string? parentRunId = null);
Expand All @@ -54,55 +105,24 @@ public abstract Task HandleChatModelStartAsync(Dictionary<string, object> llm, L
public abstract Task HandleAgentEndAsync(Dictionary<string, object> action, string runId, string? parentRunId = null);

/// <inheritdoc />
public abstract Task HandleRetrieverStartAsync(string query, string runId, string? parentRunId);
public abstract Task HandleRetrieverStartAsync(
BaseRetriever retriever,
string query,
string runId,
string? parentRunId,
List<string>? tags = null,
Dictionary<string, object>? metadata = null,
string? runType = null,
string? name = null,
Dictionary<string, object>? extraParams = null);

/// <inheritdoc />
public abstract Task HandleRetrieverEndAsync(string query, string runId, string? parentRunId);
public abstract Task HandleRetrieverEndAsync(
string query,
List<Document> documents,
string runId,
string? parentRunId);

/// <inheritdoc />
public abstract Task HandleRetrieverErrorAsync(Exception error, string query, string runId, string? parentRunId);

/// <summary>
///
/// </summary>
public bool IgnoreLlm { get; set; }

/// <summary>
///
/// </summary>
public bool IgnoreChain { get; set; }

/// <summary>
///
/// </summary>
public bool IgnoreAgent { get; set; }

public bool IgnoreRetriever { get; set; }

/// <summary>
///
/// </summary>
protected BaseCallbackHandler()
{
Name = Guid.NewGuid().ToString();
}

/// <summary>
///
/// </summary>
/// <param name="input"></param>
protected BaseCallbackHandler(IBaseCallbackHandlerInput input) : this()
{
input = input ?? throw new ArgumentNullException(nameof(input));

IgnoreLlm = input.IgnoreLlm;
IgnoreChain = input.IgnoreChain;
IgnoreAgent = input.IgnoreAgent;
}

/// <summary>
///
/// </summary>
/// <returns></returns>
public abstract IBaseCallbackHandler Copy();
}
48 changes: 45 additions & 3 deletions src/libs/LangChain.Core/Base/BaseChain.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using LangChain.Abstractions.Chains.Base;
using LangChain.Abstractions.Schema;
using LangChain.Callback;
using LangChain.Chains;
using LangChain.Schema;

Expand All @@ -9,7 +10,7 @@ namespace LangChain.Base;
using LoadValues = Dictionary<string, object>;

/// <inheritdoc />
public abstract class BaseChain : IChain
public abstract class BaseChain(IChainInputs fields) : IChain
{
const string RunKey = "__run";

Expand Down Expand Up @@ -57,7 +58,7 @@ public abstract class BaseChain : IChain

throw new Exception("Return values have multiple keys, 'run' only supported when one key currently");
}

/// <summary>
/// Run the chain using a simple input/output.
/// </summary>
Expand All @@ -83,8 +84,49 @@ public virtual async Task<string> Run(Dictionary<string, object> input)
/// Execute the chain, using the values provided.
/// </summary>
/// <param name="values">The <see cref="ChainValues"/> to use.</param>
/// <param name="callbacks"></param>
/// <param name="tags"></param>
/// <param name="metadata"></param>
/// <returns></returns>
public async Task<IChainValues> CallAsync(
IChainValues values,
ICallbacks? callbacks = null,
List<string>? tags = null,
Dictionary<string, object>? metadata = null)
{
var callbackManager = await CallbackManager.Configure(
callbacks,
fields.Callbacks,
fields.Verbose,
tags,
fields.Tags,
metadata,
fields.Metadata);

var runManager = await callbackManager.HandleChainStart(this, values);

try
{
var result = await CallAsync(values, runManager);

await runManager.HandleChainEndAsync(values, result);

return result;
}
catch (Exception e)
{
await runManager.HandleChainErrorAsync(e, values);
throw;
}
}

/// <summary>
/// Execute the chain, using the values provided.
/// </summary>
/// <param name="values">The <see cref="ChainValues"/> to use.</param>
/// <param name="runManager"></param>
/// <returns></returns>
public abstract Task<IChainValues> CallAsync(IChainValues values);
protected abstract Task<IChainValues> CallAsync(IChainValues values, CallbackManagerForChainRun? runManager);

/// <summary>
///
Expand Down
37 changes: 37 additions & 0 deletions src/libs/LangChain.Core/Base/BaseChainInput.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
using LangChain.Callback;

namespace LangChain.Base;

public interface IBaseChainInput
{
/// <summary>
/// Optional list of callback handlers (or callback manager). Defaults to None.
/// Callback handlers are called throughout the lifecycle of a call to a chain,
/// starting with on_chain_start, ending with on_chain_end or on_chain_error.
/// Each custom chain can optionally call additional callback methods, see Callback docs
/// for full details.
/// </summary>
public ICallbacks? Callbacks { get; set; }

/// <summary>
/// Whether or not run in verbose mode. In verbose mode, some intermediate logs
/// will be printed to the console.
/// </summary>
public bool Verbose { get; set; }

/// <summary>
/// Optional list of tags associated with the chain. Defaults to None.
/// These tags will be associated with each call to this chain,
/// and passed as arguments to the handlers defined in `callbacks`.
/// You can use these to eg identify a specific instance of a chain with its use case.
/// </summary>
public List<string> Tags { get; set; }

/// <summary>
/// Optional metadata associated with the chain. Defaults to None.
/// This metadata will be associated with each call to this chain,
/// and passed as arguments to the handlers defined in `callbacks`.
/// You can use these to eg identify a specific instance of a chain with its use case.
/// </summary>
public Dictionary<string, object> Metadata { get; set; }
}
6 changes: 2 additions & 4 deletions src/libs/LangChain.Core/Base/BaseLangChain.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@ namespace LangChain.Base;
/// <inheritdoc />
public abstract class BaseLangChain : IBaseLangChainParams
{
private const bool DefaultVerbosity = false;

/// <summary>
///
/// </summary>
public bool? Verbose { get; set; }
public bool Verbose { get; set; }

/// <summary>
///
Expand All @@ -18,6 +16,6 @@ protected BaseLangChain(IBaseLangChainParams parameters)
{
parameters = parameters ?? throw new ArgumentNullException(nameof(parameters));

Verbose = parameters.Verbose ?? DefaultVerbosity;
Verbose = parameters.Verbose;
}
}
9 changes: 4 additions & 5 deletions src/libs/LangChain.Core/Base/ChainInputs.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@ namespace LangChain.Base;
/// <inheritdoc />
public class ChainInputs : IChainInputs
{
/// <inheritdoc />
public CallbackManager? CallbackManager { get; set; }

/// <inheritdoc />
public bool? Verbose { get; set; }
public ICallbacks? Callbacks { get; set; }
public List<string> Tags { get; set; }
public Dictionary<string, object> Metadata { get; set; }
public bool Verbose { get; set; }
}
4 changes: 1 addition & 3 deletions src/libs/LangChain.Core/Base/Handler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@ namespace LangChain.Base;
/// <inheritdoc />
public abstract class Handler : BaseCallbackHandler
{
/// <inheritdoc />
public override IBaseCallbackHandler Copy()
protected Handler(IBaseCallbackHandlerInput input) : base(input)
{
throw new NotImplementedException();
}
}
Loading
Loading