From 24c759600f080881a4f238d66f6c0506ebdddbee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Flc=E3=82=9B?= Date: Tue, 23 Jul 2024 09:58:42 +0800 Subject: [PATCH 1/4] feat(callbacks): Add a `StackHandler` to support multiple Callbacks --- callbacks/stack.go | 126 +++++++++++++++++++++++++++++++++++++ callbacks/stack_test.go | 136 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 262 insertions(+) create mode 100644 callbacks/stack.go create mode 100644 callbacks/stack_test.go diff --git a/callbacks/stack.go b/callbacks/stack.go new file mode 100644 index 000000000..a85fb55eb --- /dev/null +++ b/callbacks/stack.go @@ -0,0 +1,126 @@ +package callbacks + +import ( + "context" + + "github.com/tmc/langchaingo/llms" + "github.com/tmc/langchaingo/schema" +) + +type StackHandler struct { + handlers []Handler +} + +var _ Handler = (*StackHandler)(nil) + +// NewStackHandler creates a new stack handler with the given handlers. +// The handlers will be called in the order they are provided. +// +// Example: +// +// h := NewStackHandler( +// &SimpleHandler{}, +// &LogHandler{}, +// &MyCustomHandler{}, +// ) +// +// h.HandleText(ctx, "Hello, world!") +func NewStackHandler(handlers ...Handler) StackHandler { + return StackHandler{handlers: handlers} +} + +func (s *StackHandler) HandleText(ctx context.Context, text string) { + for _, h := range s.handlers { + h.HandleText(ctx, text) + } +} + +func (s *StackHandler) HandleLLMStart(ctx context.Context, prompts []string) { + for _, h := range s.handlers { + h.HandleLLMStart(ctx, prompts) + } +} + +func (s *StackHandler) HandleLLMGenerateContentStart(ctx context.Context, ms []llms.MessageContent) { + for _, h := range s.handlers { + h.HandleLLMGenerateContentStart(ctx, ms) + } +} + +func (s *StackHandler) HandleLLMGenerateContentEnd(ctx context.Context, res *llms.ContentResponse) { + for _, h := range s.handlers { + h.HandleLLMGenerateContentEnd(ctx, res) + } +} + +func (s *StackHandler) HandleLLMError(ctx context.Context, err error) { + for _, h := range s.handlers { + h.HandleLLMError(ctx, err) + } +} + +func (s *StackHandler) HandleChainStart(ctx context.Context, inputs map[string]any) { + for _, h := range s.handlers { + h.HandleChainStart(ctx, inputs) + } +} + +func (s *StackHandler) HandleChainEnd(ctx context.Context, outputs map[string]any) { + for _, h := range s.handlers { + h.HandleChainEnd(ctx, outputs) + } +} + +func (s *StackHandler) HandleChainError(ctx context.Context, err error) { + for _, h := range s.handlers { + h.HandleChainError(ctx, err) + } +} + +func (s *StackHandler) HandleToolStart(ctx context.Context, input string) { + for _, h := range s.handlers { + h.HandleToolStart(ctx, input) + } +} + +func (s *StackHandler) HandleToolEnd(ctx context.Context, output string) { + for _, h := range s.handlers { + h.HandleToolEnd(ctx, output) + } +} + +func (s *StackHandler) HandleToolError(ctx context.Context, err error) { + for _, h := range s.handlers { + h.HandleToolError(ctx, err) + } +} + +func (s *StackHandler) HandleAgentAction(ctx context.Context, action schema.AgentAction) { + for _, h := range s.handlers { + h.HandleAgentAction(ctx, action) + } +} + +func (s *StackHandler) HandleAgentFinish(ctx context.Context, finish schema.AgentFinish) { + for _, h := range s.handlers { + h.HandleAgentFinish(ctx, finish) + } +} + +func (s *StackHandler) HandleRetrieverStart(ctx context.Context, query string) { + for _, h := range s.handlers { + h.HandleRetrieverStart(ctx, query) + } +} + +func (s *StackHandler) HandleRetrieverEnd(ctx context.Context, query string, documents []schema.Document) { + for _, h := range s.handlers { + h.HandleRetrieverEnd(ctx, query, documents) + } +} + +func (s *StackHandler) HandleStreamingFunc(ctx context.Context, chunk []byte) { + for _, h := range s.handlers { + h.HandleStreamingFunc(ctx, chunk) + } +} diff --git a/callbacks/stack_test.go b/callbacks/stack_test.go new file mode 100644 index 000000000..18a433a7d --- /dev/null +++ b/callbacks/stack_test.go @@ -0,0 +1,136 @@ +package callbacks + +import ( + "context" + "fmt" + "testing" + + "github.com/tmc/langchaingo/llms" + "github.com/tmc/langchaingo/schema" +) + +type MyCustomHandler struct { + name string + ch chan string +} + +var _ Handler = (*MyCustomHandler)(nil) + +func NewMyCustomHandler(name string, ch chan string) *MyCustomHandler { + return &MyCustomHandler{ + name: name, + ch: ch, + } +} + +func (m *MyCustomHandler) HandleText(context.Context, string) { + m.ch <- fmt.Sprintf("[HandleText] %s", m.name) +} + +func (m *MyCustomHandler) HandleLLMStart(context.Context, []string) { + m.ch <- fmt.Sprintf("[HandleLLMStart] %s", m.name) +} + +func (m *MyCustomHandler) HandleLLMGenerateContentStart(context.Context, []llms.MessageContent) { + m.ch <- fmt.Sprintf("[HandleLLMGenerateContentStart] %s", m.name) +} + +func (m *MyCustomHandler) HandleLLMGenerateContentEnd(ctx context.Context, res *llms.ContentResponse) { + m.ch <- fmt.Sprintf("[HandleLLMGenerateContentEnd] %s", m.name) +} + +func (m *MyCustomHandler) HandleLLMError(context.Context, error) { + m.ch <- fmt.Sprintf("[HandleLLMError] %s", m.name) +} + +func (m *MyCustomHandler) HandleChainStart(context.Context, map[string]any) { + m.ch <- fmt.Sprintf("[HandleChainStart] %s", m.name) +} + +func (m *MyCustomHandler) HandleChainEnd(context.Context, map[string]any) { + m.ch <- fmt.Sprintf("[HandleChainEnd] %s", m.name) +} + +func (m *MyCustomHandler) HandleChainError(context.Context, error) { + m.ch <- fmt.Sprintf("[HandleChainError] %s", m.name) +} + +func (m *MyCustomHandler) HandleToolStart(context.Context, string) { + m.ch <- fmt.Sprintf("[HandleToolStart] %s", m.name) +} + +func (m *MyCustomHandler) HandleToolEnd(context.Context, string) { + m.ch <- fmt.Sprintf("[HandleToolEnd] %s", m.name) +} + +func (m *MyCustomHandler) HandleToolError(context.Context, error) { + m.ch <- fmt.Sprintf("[HandleToolError] %s", m.name) +} + +func (m *MyCustomHandler) HandleAgentAction(context.Context, schema.AgentAction) { + m.ch <- fmt.Sprintf("[HandleAgentAction] %s", m.name) +} + +func (m *MyCustomHandler) HandleAgentFinish(ctx context.Context, finish schema.AgentFinish) { + m.ch <- fmt.Sprintf("[HandleAgentFinish] %s", m.name) +} + +func (m *MyCustomHandler) HandleRetrieverStart(ctx context.Context, query string) { + m.ch <- fmt.Sprintf("[HandleRetrieverStart] %s", m.name) +} + +func (m *MyCustomHandler) HandleRetrieverEnd(context.Context, string, []schema.Document) { + m.ch <- fmt.Sprintf("[HandleRetrieverEnd] %s", m.name) +} + +func (m *MyCustomHandler) HandleStreamingFunc(context.Context, []byte) { + m.ch <- fmt.Sprintf("[HandleStreamingFunc] %s", m.name) +} + +func TestStackHandler(t *testing.T) { + ch := make(chan string, 2) + defer close(ch) + + h := NewStackHandler( + &SimpleHandler{}, + NewMyCustomHandler("my-custom-handler-1", ch), + NewMyCustomHandler("my-custom-handler-2", ch), + ) + + tests := []struct { + name string + fn func() + }{ + {"HandleText", func() { h.HandleText(context.Background(), "text") }}, + {"HandleLLMStart", func() { h.HandleLLMStart(context.Background(), []string{"prompt"}) }}, + {"HandleLLMGenerateContentStart", func() { h.HandleLLMGenerateContentStart(context.Background(), nil) }}, + {"HandleLLMGenerateContentEnd", func() { h.HandleLLMGenerateContentEnd(context.Background(), nil) }}, + {"HandleLLMError", func() { h.HandleLLMError(context.Background(), fmt.Errorf("error")) }}, + {"HandleChainStart", func() { h.HandleChainStart(context.Background(), map[string]any{"input": "value"}) }}, + {"HandleChainEnd", func() { h.HandleChainEnd(context.Background(), map[string]any{"output": "value"}) }}, + {"HandleChainError", func() { h.HandleChainError(context.Background(), fmt.Errorf("error")) }}, + {"HandleToolStart", func() { h.HandleToolStart(context.Background(), "input") }}, + {"HandleToolEnd", func() { h.HandleToolEnd(context.Background(), "output") }}, + {"HandleToolError", func() { h.HandleToolError(context.Background(), fmt.Errorf("error")) }}, + {"HandleAgentAction", func() { h.HandleAgentAction(context.Background(), schema.AgentAction{}) }}, + {"HandleAgentFinish", func() { h.HandleAgentFinish(context.Background(), schema.AgentFinish{}) }}, + {"HandleRetrieverStart", func() { h.HandleRetrieverStart(context.Background(), "query") }}, + {"HandleRetrieverEnd", func() { h.HandleRetrieverEnd(context.Background(), "query", nil) }}, + {"HandleStreamingFunc", func() { h.HandleStreamingFunc(context.Background(), nil) }}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.fn() + + for _, name := range []string{ + "my-custom-handler-1", + "my-custom-handler-2", + } { + if got := <-ch; got != fmt.Sprintf("[%s] %s", tt.name, name) { + t.Errorf("unexpected value: got %q", got) + } + } + }) + } +} From 5946324c14383b1a8082c25d9467d67062502b4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Flc=E3=82=9B?= Date: Tue, 23 Jul 2024 10:04:20 +0800 Subject: [PATCH 2/4] feat(callbacks): Add a `StackHandler` to support multiple Callbacks --- callbacks/stack_test.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/callbacks/stack_test.go b/callbacks/stack_test.go index 18a433a7d..9905336b9 100644 --- a/callbacks/stack_test.go +++ b/callbacks/stack_test.go @@ -35,7 +35,7 @@ func (m *MyCustomHandler) HandleLLMGenerateContentStart(context.Context, []llms. m.ch <- fmt.Sprintf("[HandleLLMGenerateContentStart] %s", m.name) } -func (m *MyCustomHandler) HandleLLMGenerateContentEnd(ctx context.Context, res *llms.ContentResponse) { +func (m *MyCustomHandler) HandleLLMGenerateContentEnd(context.Context, *llms.ContentResponse) { m.ch <- fmt.Sprintf("[HandleLLMGenerateContentEnd] %s", m.name) } @@ -71,11 +71,11 @@ func (m *MyCustomHandler) HandleAgentAction(context.Context, schema.AgentAction) m.ch <- fmt.Sprintf("[HandleAgentAction] %s", m.name) } -func (m *MyCustomHandler) HandleAgentFinish(ctx context.Context, finish schema.AgentFinish) { +func (m *MyCustomHandler) HandleAgentFinish(context.Context, schema.AgentFinish) { m.ch <- fmt.Sprintf("[HandleAgentFinish] %s", m.name) } -func (m *MyCustomHandler) HandleRetrieverStart(ctx context.Context, query string) { +func (m *MyCustomHandler) HandleRetrieverStart(context.Context, string) { m.ch <- fmt.Sprintf("[HandleRetrieverStart] %s", m.name) } @@ -87,7 +87,7 @@ func (m *MyCustomHandler) HandleStreamingFunc(context.Context, []byte) { m.ch <- fmt.Sprintf("[HandleStreamingFunc] %s", m.name) } -func TestStackHandler(t *testing.T) { +func TestStackHandler(t *testing.T) { //nolint:paralleltest ch := make(chan string, 2) defer close(ch) @@ -119,6 +119,7 @@ func TestStackHandler(t *testing.T) { {"HandleStreamingFunc", func() { h.HandleStreamingFunc(context.Background(), nil) }}, } + //nolint:paralleltest for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { tt.fn() From 16f44b6e2bfc7dbce0cc4fe262b135a0755c07a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Flc=E3=82=9B?= Date: Tue, 23 Jul 2024 10:06:05 +0800 Subject: [PATCH 3/4] feat(callbacks): Add a `StackHandler` to support multiple Callbacks --- callbacks/stack.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/callbacks/stack.go b/callbacks/stack.go index a85fb55eb..4928842cb 100644 --- a/callbacks/stack.go +++ b/callbacks/stack.go @@ -20,8 +20,8 @@ var _ Handler = (*StackHandler)(nil) // // h := NewStackHandler( // &SimpleHandler{}, -// &LogHandler{}, -// &MyCustomHandler{}, +// &LogHandler{}, +// &MyCustomHandler{}, // ) // // h.HandleText(ctx, "Hello, world!") From 4923fbaf8445f52509aa2444e0d7bcd826b2c987 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Flc=E3=82=9B?= Date: Tue, 23 Jul 2024 13:47:43 +0800 Subject: [PATCH 4/4] callbacks: supplement single test coverage --- callbacks/combining.go | 4 + .../{stack_test.go => combining_test.go} | 4 +- callbacks/stack.go | 126 ------------------ 3 files changed, 6 insertions(+), 128 deletions(-) rename callbacks/{stack_test.go => combining_test.go} (98%) delete mode 100644 callbacks/stack.go diff --git a/callbacks/combining.go b/callbacks/combining.go index 2e95e80aa..747d34689 100644 --- a/callbacks/combining.go +++ b/callbacks/combining.go @@ -14,6 +14,10 @@ type CombiningHandler struct { var _ Handler = CombiningHandler{} +func NewCombiningHandler(callbacks ...Handler) *CombiningHandler { + return &CombiningHandler{Callbacks: callbacks} +} + func (l CombiningHandler) HandleText(ctx context.Context, text string) { for _, handle := range l.Callbacks { handle.HandleText(ctx, text) diff --git a/callbacks/stack_test.go b/callbacks/combining_test.go similarity index 98% rename from callbacks/stack_test.go rename to callbacks/combining_test.go index 9905336b9..e3b615685 100644 --- a/callbacks/stack_test.go +++ b/callbacks/combining_test.go @@ -87,11 +87,11 @@ func (m *MyCustomHandler) HandleStreamingFunc(context.Context, []byte) { m.ch <- fmt.Sprintf("[HandleStreamingFunc] %s", m.name) } -func TestStackHandler(t *testing.T) { //nolint:paralleltest +func TestCombiningHandler(t *testing.T) { //nolint:paralleltest ch := make(chan string, 2) defer close(ch) - h := NewStackHandler( + h := NewCombiningHandler( &SimpleHandler{}, NewMyCustomHandler("my-custom-handler-1", ch), NewMyCustomHandler("my-custom-handler-2", ch), diff --git a/callbacks/stack.go b/callbacks/stack.go deleted file mode 100644 index 4928842cb..000000000 --- a/callbacks/stack.go +++ /dev/null @@ -1,126 +0,0 @@ -package callbacks - -import ( - "context" - - "github.com/tmc/langchaingo/llms" - "github.com/tmc/langchaingo/schema" -) - -type StackHandler struct { - handlers []Handler -} - -var _ Handler = (*StackHandler)(nil) - -// NewStackHandler creates a new stack handler with the given handlers. -// The handlers will be called in the order they are provided. -// -// Example: -// -// h := NewStackHandler( -// &SimpleHandler{}, -// &LogHandler{}, -// &MyCustomHandler{}, -// ) -// -// h.HandleText(ctx, "Hello, world!") -func NewStackHandler(handlers ...Handler) StackHandler { - return StackHandler{handlers: handlers} -} - -func (s *StackHandler) HandleText(ctx context.Context, text string) { - for _, h := range s.handlers { - h.HandleText(ctx, text) - } -} - -func (s *StackHandler) HandleLLMStart(ctx context.Context, prompts []string) { - for _, h := range s.handlers { - h.HandleLLMStart(ctx, prompts) - } -} - -func (s *StackHandler) HandleLLMGenerateContentStart(ctx context.Context, ms []llms.MessageContent) { - for _, h := range s.handlers { - h.HandleLLMGenerateContentStart(ctx, ms) - } -} - -func (s *StackHandler) HandleLLMGenerateContentEnd(ctx context.Context, res *llms.ContentResponse) { - for _, h := range s.handlers { - h.HandleLLMGenerateContentEnd(ctx, res) - } -} - -func (s *StackHandler) HandleLLMError(ctx context.Context, err error) { - for _, h := range s.handlers { - h.HandleLLMError(ctx, err) - } -} - -func (s *StackHandler) HandleChainStart(ctx context.Context, inputs map[string]any) { - for _, h := range s.handlers { - h.HandleChainStart(ctx, inputs) - } -} - -func (s *StackHandler) HandleChainEnd(ctx context.Context, outputs map[string]any) { - for _, h := range s.handlers { - h.HandleChainEnd(ctx, outputs) - } -} - -func (s *StackHandler) HandleChainError(ctx context.Context, err error) { - for _, h := range s.handlers { - h.HandleChainError(ctx, err) - } -} - -func (s *StackHandler) HandleToolStart(ctx context.Context, input string) { - for _, h := range s.handlers { - h.HandleToolStart(ctx, input) - } -} - -func (s *StackHandler) HandleToolEnd(ctx context.Context, output string) { - for _, h := range s.handlers { - h.HandleToolEnd(ctx, output) - } -} - -func (s *StackHandler) HandleToolError(ctx context.Context, err error) { - for _, h := range s.handlers { - h.HandleToolError(ctx, err) - } -} - -func (s *StackHandler) HandleAgentAction(ctx context.Context, action schema.AgentAction) { - for _, h := range s.handlers { - h.HandleAgentAction(ctx, action) - } -} - -func (s *StackHandler) HandleAgentFinish(ctx context.Context, finish schema.AgentFinish) { - for _, h := range s.handlers { - h.HandleAgentFinish(ctx, finish) - } -} - -func (s *StackHandler) HandleRetrieverStart(ctx context.Context, query string) { - for _, h := range s.handlers { - h.HandleRetrieverStart(ctx, query) - } -} - -func (s *StackHandler) HandleRetrieverEnd(ctx context.Context, query string, documents []schema.Document) { - for _, h := range s.handlers { - h.HandleRetrieverEnd(ctx, query, documents) - } -} - -func (s *StackHandler) HandleStreamingFunc(ctx context.Context, chunk []byte) { - for _, h := range s.handlers { - h.HandleStreamingFunc(ctx, chunk) - } -}