Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llms: add watsonx #577

Merged
merged 23 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,11 @@
sidebar_label: Local
---

# Local
import CodeBlock from "@theme/CodeBlock";
import LocalExample from "@examples/local-llm-example/local_llm_example.go";

# Local

## Example

<CodeBlock language="go">{LocalExample}</CodeBlock>
15 changes: 9 additions & 6 deletions docs/docs/modules/model_io/models/llms/Integrations/openai.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@ This example goes over how to use LangChain to interact with OpenAI models.

There are two options to set the the OpenAI key.

1. We can do this by setting the environment variable OPENAI_API_KEY to the api key.
1. We can do this by setting the environment variable `OPENAI_API_KEY` to the API key.

2. Or we can do it when initializing the wrapper along with other arguments.
```go
model, err := openai.New(openai.WithToken(apiToken))
```
2. Or we can do it when initializing the wrapper along with other arguments:

<CodeBlock language="go">{ExampleOpenAI}</CodeBlock>
```go
model, err := openai.New(openai.WithToken(apiToken))
```

## Example

<CodeBlock language="go">{ExampleOpenAI}</CodeBlock>
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ import ExampleVertexAICompletion from "@examples/vertex-completion-example/verte

# Vertex AI
To use the Vertex AI LLM you need to set the google project ID.
You can do this by setting the GOOGLE_CLOUD_PROJECT environment variable or giving it as a variaic option when creating the wrapper.
You can do this by setting the `GOOGLE_CLOUD_PROJECT` environment variable or giving it as a variaic option when creating the wrapper.

<CodeBlock language="go">{ExampleVertexAICompletion}</CodeBlock>
35 changes: 35 additions & 0 deletions docs/docs/modules/model_io/models/llms/Integrations/watsonx.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
---
sidebar_label: watsonx
---

import CodeBlock from "@theme/CodeBlock";
import WatsonxExample from "@examples/watsonx-llm-example/watsonx_example.go";

# watsonx

Integration support for [IBM watsonx](https://www.ibm.com/watsonx) foundation models.

## Setup

You will need to set the following environment variables for using the WatsonX AI API.

- `IBMCLOUD_API_KEY`: generate from your [IBM Cloud account](https://cloud.ibm.com/iam/apikeys).
- `WATSONX_PROJECT_ID`: copy from your [watsonx project settings](https://dataplatform.cloud.ibm.com/projects/?context=wx).

Alternatively, these can be passed into the model on creation:

```go
import (
wx "github.com/h0rv/go-watsonx/models"
"github.com/tmc/langchaingo/llms/watsonx"
)
...
llm, _ := watsonx.New(
wx.WithIBMCloudAPIKey("YOUR IBM CLOUD API KEY"),
wx.WithWatsonxProjectID("YOUR WATSONX PROJECT ID"),
)
```

## Example

<CodeBlock language="go">{WatsonxExample}</CodeBlock>
16 changes: 16 additions & 0 deletions examples/watsonx-llm-example/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
module github.com/tmc/langchaingo/examples/watsonx-llm-example

go 1.22.0

toolchain go1.22.2

require (
github.com/h0rv/go-watsonx v0.2.1
github.com/tmc/langchaingo v0.1.9-0.20240403145928-594021b91d0d
)

require (
github.com/dlclark/regexp2 v1.10.0 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/pkoukk/tiktoken-go v0.1.6 // indirect
)
18 changes: 18 additions & 0 deletions examples/watsonx-llm-example/go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/h0rv/go-watsonx v0.2.1 h1:m3NSenpQP3txjLMzFX32WeNS6MSTAz4vigob47rUCs4=
github.com/h0rv/go-watsonx v0.2.1/go.mod h1:QHED4UARKVpcbkzZWqfeskcfzkOqkRYepdnIYaHWxZw=
github.com/pkoukk/tiktoken-go v0.1.6 h1:JF0TlJzhTbrI30wCvFuiw6FzP2+/bR+FIxUdgEAcUsw=
github.com/pkoukk/tiktoken-go v0.1.6/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tmc/langchaingo v0.1.9-0.20240403145928-594021b91d0d h1:XgRR1R7zSaBvXmfZWEOKEZPg9dcrXEC6xCr3OzYzcHg=
github.com/tmc/langchaingo v0.1.9-0.20240403145928-594021b91d0d/go.mod h1:uKbZdnjKcNNmve7BD8OO441YhS4nQwhKtb0ze28DAdI=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
40 changes: 40 additions & 0 deletions examples/watsonx-llm-example/watsonx_example.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package main

import (
"context"
"fmt"
"log"

wx "github.com/h0rv/go-watsonx/models"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/llms/watsonx"
)

func main() {
llm, err := watsonx.New(
// Optional parameters:
// wx.WithIBMCloudAPIKey("YOUR IBM CLOUD API KEY"),
// wx.WithWatsonxProjectID("YOUR WATSONX PROJECT ID"),
wx.WithModel(wx.LLAMA_2_70B_CHAT),
)
if err != nil {
log.Fatal(err)
}
ctx := context.Background()

// Or override default model to another one
prompt := "What would be a good company name be for name a company that makes colorful socks?"
completion, err := llms.GenerateFromSinglePrompt(
ctx,
llm,
prompt,
llms.WithTopK(10),
llms.WithTopP(0.95),
llms.WithSeed(13),
)
// Check for errors
if err != nil {
log.Fatal(err)
}
fmt.Println(completion)
}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ require (
github.com/gocolly/colly v1.2.0
github.com/google/generative-ai-go v0.11.0
github.com/google/go-cmp v0.6.0
github.com/h0rv/go-watsonx v0.2.1
github.com/jackc/pgx/v5 v5.5.5
github.com/ledongthuc/pdf v0.0.0-20220302134840-0c2507a12d80
github.com/mattn/go-sqlite3 v1.14.17
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,8 @@ github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 h1:+9834+KizmvFV7pXQGSXQTsaW
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0/go.mod h1:z0ButlSOZa5vEBq9m2m2hlwIgKw+rp3sdCBRoJY+30Y=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.18.0 h1:RtRsiaGvWxcwd8y3BiRZxsylPT8hLWZ5SPcfI+3IDNk=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.18.0/go.mod h1:TzP6duP4Py2pHLVPPQp42aoYI92+PCrVotyR5e8Vqlk=
github.com/h0rv/go-watsonx v0.2.1 h1:m3NSenpQP3txjLMzFX32WeNS6MSTAz4vigob47rUCs4=
github.com/h0rv/go-watsonx v0.2.1/go.mod h1:QHED4UARKVpcbkzZWqfeskcfzkOqkRYepdnIYaHWxZw=
github.com/hashicorp/go-version v1.2.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA=
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
Expand Down
141 changes: 141 additions & 0 deletions llms/watsonx/watsonxllm.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
package watsonx

import (
"context"
"errors"

wx "github.com/h0rv/go-watsonx/models"
"github.com/tmc/langchaingo/callbacks"
"github.com/tmc/langchaingo/llms"
)

var (
ErrInvalidPrompt = errors.New("invalid prompt")
ErrEmptyResponse = errors.New("no response")
)

type LLM struct {
CallbacksHandler callbacks.Handler
client *wx.Model
}

var _ llms.Model = (*LLM)(nil)

// Call implements the LLM interface.
func (o *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOption) (string, error) {
return llms.GenerateFromSinglePrompt(ctx, o, prompt, options...)
}

// GenerateContent implements the Model interface.
func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { //nolint: lll, cyclop, whitespace

if o.CallbacksHandler != nil {
o.CallbacksHandler.HandleLLMGenerateContentStart(ctx, messages)
}

prompt, err := getPrompt(messages)
if err != nil {
return nil, err
}

result, err := o.client.GenerateText(
prompt,
toWatsonxOptions(&options)...,
)
if err != nil {
if o.CallbacksHandler != nil {
o.CallbacksHandler.HandleLLMError(ctx, err)
}
return nil, err
}

if result.Text == "" {
return nil, ErrEmptyResponse
}

resp := &llms.ContentResponse{
Choices: []*llms.ContentChoice{
{
Content: result.Text,
},
},
}
return resp, nil
}

func New(opts ...wx.ModelOption) (*LLM, error) {
c, err := wx.NewModel(opts...)
if err != nil {
return nil, err
}

return &LLM{
client: c,
}, nil
}

func getPrompt(messages []llms.MessageContent) (string, error) {
// Assume we get a single text message
msg0 := messages[0]
part := msg0.Parts[0]
prompt, ok := part.(llms.TextContent)
if !ok {
return "", ErrInvalidPrompt
}

return prompt.Text, nil
}

func getDefaultCallOptions() *llms.CallOptions {
return &llms.CallOptions{
TopP: -1,
TopK: -1,
Temperature: -1,
Seed: -1,
RepetitionPenalty: -1,
MaxTokens: -1,
}
}

func toWatsonxOptions(options *[]llms.CallOption) []wx.GenerateOption {
opts := getDefaultCallOptions()
for _, opt := range *options {
opt(opts)
}

o := []wx.GenerateOption{}
if opts.TopP != -1 {
o = append(o, wx.WithTopP(opts.TopP))
}
if opts.TopK != -1 {
o = append(o, wx.WithTopK(uint(opts.TopK)))
}
if opts.Temperature != -1 {
o = append(o, wx.WithTemperature(opts.Temperature))
}
if opts.Seed != -1 {
o = append(o, wx.WithRandomSeed(uint(opts.Seed)))
}
if opts.RepetitionPenalty != -1 {
o = append(o, wx.WithRepetitionPenalty(opts.RepetitionPenalty))
}
if opts.MaxTokens != -1 {
o = append(o, wx.WithMaxNewTokens(uint(opts.MaxTokens)))
}
if len(opts.StopWords) > 0 {
o = append(o, wx.WithStopSequences(opts.StopWords))
}

/*
watsonx options not supported:

wx.WithMinNewTokens(minNewTokens)
wx.WithDecodingMethod(decodingMethod)
wx.WithLengthPenalty(decayFactor, startIndex)
wx.WithTimeLimit(timeLimit)
wx.WithTruncateInputTokens(truncateInputTokens)
wx.WithReturnOptions(inputText, generatedTokens, inputTokens, tokenLogProbs, tokenRanks, topNTokens)
*/

return o
}
Loading