Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/lambdahook bedrock #11

Merged
merged 6 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

## [0.1.8] - 2023-10-05
## [0.1.8] - 2023-10-10
### Added
- Added Mistral 7b Instruct LLM - PR #10

Expand Down
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,11 @@ Optionally configure QnAbot to prompt the LLM directly by configuring the LLM Pl

When your Plugin CloudFormation stack status is CREATE_COMPLETE, choose the **Outputs** tab. Look for the outputs `QnAItemLambdaHookFunctionName` and `QnAItemLambdaHookArgs`. Use these values in the LambdaHook section of your no_hits item. You can change the value of "Prefix', or use "None" if you don't want to prefix the LLM answer.

The default behavior is to relay the user's query to the LLM as the prompt. If LLM_QUERY_GENERATION is enabled, the generated (disambiguated) query will be used, otherwise the user's utterance is used. You can override this behavior by supplying an explicit `"Prompt"` key in the `QnAItemLambdaHookArgs` value. For example setting `QnAItemLambdaHookArgs` to `{"Prefix": "LLM Answer:", "Model_params": {"modelId": "anthropic.claude-instant-v1", "temperature": 0}, "Prompt":"Why is the sky blue?"}` will ignore the user's input and simply use the configured prompt instead. Prompts supplied in this manner do not (yet) support variable substitution (eg to substitute user attributes, session attributes, etc. into the prompt). If you feel that would be a useful feature, please create a feature request issue in the repo, or, better yet, implement it, and submit a Pull Request!

Currently the Lambda hook option has been implemented only in the Bedrock and AI21 plugins.


<img src="./images/qnaitem_lambdahook.png" alt="LambdaHook" width="600">

<img src="./images/qnaitem_lambdahook_example.png" alt="LambdaHook" width="600">
Expand Down
5 changes: 3 additions & 2 deletions lambdas/ai21-llm/src/lambdahook.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,10 @@ def format_response(event, llm_response, prefix):

def lambda_handler(event, context):
print("Received event: %s" % json.dumps(event))
prompt = event["req"]["question"]
# args = {"Prefix:"<Prefix|None>", "Model_params":{"max_tokens":256}}
# args = {"Prefix:"<Prefix|None>", "Model_params":{"max_tokens":256}, "Prompt":"<prompt>"}
args = get_args_from_lambdahook_args(event)
# prompt set from args, or from req.question if not specified in args.
prompt = args.get("Prompt", event["req"]["question"])
model_params = args.get("Model_params",{})
llm_response = get_llm_response(model_params, prompt)
prefix = args.get("Prefix","LLM Answer:")
Expand Down
137 changes: 137 additions & 0 deletions lambdas/bedrock-embeddings-and-llm/src/lambdahook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import boto3
import json
import os

# Defaults
DEFAULT_MODEL_ID = os.environ.get("DEFAULT_MODEL_ID","anthropic.claude-instant-v1")
AWS_REGION = os.environ["AWS_REGION_OVERRIDE"] if "AWS_REGION_OVERRIDE" in os.environ else os.environ["AWS_REGION"]
ENDPOINT_URL = os.environ.get("ENDPOINT_URL", f'https://bedrock-runtime.{AWS_REGION}.amazonaws.com')
DEFAULT_MAX_TOKENS = 256

# global variables - avoid creating a new client for every request
client = None

def get_client():
print("Connecting to Bedrock Service: ", ENDPOINT_URL)
client = boto3.client(service_name='bedrock-runtime', region_name=AWS_REGION, endpoint_url=ENDPOINT_URL)
return client

def get_request_body(modelId, parameters, prompt):
provider = modelId.split(".")[0]
request_body = None
if provider == "anthropic":
request_body = {
"prompt": prompt,
"max_tokens_to_sample": DEFAULT_MAX_TOKENS
}
request_body.update(parameters)
elif provider == "ai21":
request_body = {
"prompt": prompt,
"maxTokens": DEFAULT_MAX_TOKENS
}
request_body.update(parameters)
elif provider == "amazon":
textGenerationConfig = {
"maxTokenCount": DEFAULT_MAX_TOKENS
}
textGenerationConfig.update(parameters)
request_body = {
"inputText": prompt,
"textGenerationConfig": textGenerationConfig
}
else:
raise Exception("Unsupported provider: ", provider)
return request_body

def get_generate_text(modelId, response):
provider = modelId.split(".")[0]
generated_text = None
if provider == "anthropic":
response_body = json.loads(response.get("body").read().decode())
generated_text = response_body.get("completion")
elif provider == "ai21":
response_body = json.loads(response.get("body").read())
generated_text = response_body.get("completions")[0].get("data").get("text")
elif provider == "amazon":
response_body = json.loads(response.get("body").read())
generated_text = response_body.get("results")[0].get("outputText")
else:
raise Exception("Unsupported provider: ", provider)
return generated_text


def format_prompt(modelId, prompt):
# TODO - replace prompt template placeholders - eg query, input, chatHistory, session attributes, user info
provider = modelId.split(".")[0]
if provider == "anthropic":
print("Model provider is Anthropic. Checking prompt format.")
if not prompt.startswith("\n\nHuman:"):
prompt = "\n\nHuman: " + prompt
print("Prepended '\\n\\nHuman:'")
if not prompt.endswith("\n\nAssistant:"):
prompt = prompt + "\n\nAssistant:"
print("Appended '\\n\\nHuman:'")
print(f"Prompt: {json.dumps(prompt)}")
return prompt

def get_llm_response(parameters, prompt):
global client
modelId = parameters.pop("modelId", DEFAULT_MODEL_ID)
prompt = format_prompt(modelId, prompt)
body = get_request_body(modelId, parameters, prompt)
print("ModelId", modelId, "- Body: ", body)
if (client is None):
client = get_client()
response = client.invoke_model(body=json.dumps(body), modelId=modelId, accept='application/json', contentType='application/json')
generated_text = get_generate_text(modelId, response)
return generated_text

def get_args_from_lambdahook_args(event):
parameters = {}
lambdahook_args_list = event["res"]["result"].get("args",[])
print("LambdaHook args: ", lambdahook_args_list)
if len(lambdahook_args_list):
try:
parameters = json.loads(lambdahook_args_list[0])
except Exception as e:
print(f"Failed to parse JSON:", lambdahook_args_list[0], e)
print("..continuing")
return parameters

def format_response(event, llm_response, prefix):
# set plaintext, markdown, & ssml response
if prefix in ["None", "N/A", "Empty"]:
prefix = None
plainttext = llm_response
markdown = llm_response
ssml = llm_response
if prefix:
plainttext = f"{prefix}\n\n{plainttext}"
markdown = f"**{prefix}**\n\n{markdown}"
# add plaintext, markdown, and ssml fields to event.res
event["res"]["message"] = plainttext
event["res"]["session"]["appContext"] = {
"altMessages": {
"markdown": markdown,
"ssml": ssml
}
}
#TODO - can we determine when LLM has a good answer or not?
#For now, always assume it's a good answer.
#QnAbot sets session attribute qnabot_gotanswer True when got_hits > 0
event["res"]["got_hits"] = 1
return event

def lambda_handler(event, context):
print("Received event: %s" % json.dumps(event))
# args = {"Prefix:"<Prefix|None>", "Model_params":{"modelId":"anthropic.claude-instant-v1", "max_tokens":256}, "Prompt":"<prompt>"}
args = get_args_from_lambdahook_args(event)
# prompt set from args, or from req.question if not specified in args.
prompt = args.get("Prompt", event["req"]["question"])
model_params = args.get("Model_params",{})
llm_response = get_llm_response(model_params, prompt)
prefix = args.get("Prefix","LLM Answer:")
event = format_response(event, llm_response, prefix)
print("Returning response: %s" % json.dumps(event))
return event
4 changes: 2 additions & 2 deletions lambdas/bedrock-embeddings-and-llm/src/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os

# Defaults
DEFAULT_MODEL_ID = os.environ.get("DEFAULT_MODEL_ID","amazon.titan-text-express-v1")
DEFAULT_MODEL_ID = os.environ.get("DEFAULT_MODEL_ID","anthropic.claude-instant-v1")
AWS_REGION = os.environ["AWS_REGION_OVERRIDE"] if "AWS_REGION_OVERRIDE" in os.environ else os.environ["AWS_REGION"]
ENDPOINT_URL = os.environ.get("ENDPOINT_URL", f'https://bedrock-runtime.{AWS_REGION}.amazonaws.com')
DEFAULT_MAX_TOKENS = 256
Expand Down Expand Up @@ -75,7 +75,7 @@ def call_llm(parameters, prompt):
"""
Example Test Event:
{
"prompt": "Human:Why is the sky blue?\nAssistant:",
"prompt": "\n\nHuman:Why is the sky blue?\n\nAssistant:",
"parameters": {
"modelId": "anthropic.claude-v1",
"temperature": 0
Expand Down
4 changes: 3 additions & 1 deletion lambdas/bedrock-embeddings-and-llm/src/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@ def getModelSettings(modelId):
"modelId": modelId,
"temperature": 0
}
lambdahook_args = {"Prefix":"LLM Answer:", "Model_params": params}
settings = {
'LLM_GENERATE_QUERY_MODEL_PARAMS': json.dumps(params),
'LLM_QA_MODEL_PARAMS': json.dumps(params)
'LLM_QA_MODEL_PARAMS': json.dumps(params),
'QNAITEM_LAMBDAHOOK_ARGS': json.dumps(lambdahook_args)
}
provider = modelId.split(".")[0]
if provider == "anthropic":
Expand Down
35 changes: 32 additions & 3 deletions lambdas/bedrock-embeddings-and-llm/template.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Parameters:

LLMModelId:
Type: String
Default: amazon.titan-text-express-v1
Default: anthropic.claude-instant-v1
AllowedValues:
- amazon.titan-text-express-v1
- ai21.j2-ultra-v1
Expand Down Expand Up @@ -230,6 +230,27 @@ Resources:
- id: W92
reason: No requirements to set reserved concurrencies, function will not be invoked simultaneously.

QnaItemLambdaHookFunction:
Type: AWS::Lambda::Function
Properties:
# LambdaHook name must start with 'QNA-' to match QnAbot invoke policy
FunctionName: !Sub "QNA-LAMBDAHOOK-${AWS::StackName}"
Handler: "lambdahook.lambda_handler"
Role: !GetAtt 'LambdaFunctionRole.Arn'
Runtime: python3.11
Timeout: 60
MemorySize: 128
Layers:
- !Ref BedrockBoto3Layer
Code: ./src
Metadata:
cfn_nag:
rules_to_suppress:
- id: W89
reason: Lambda function is not communicating with any VPC resources.
- id: W92
reason: No requirements to set reserved concurrencies, function will not be invoked simultaneously.

OutputSettingsFunctionRole:
Type: AWS::IAM::Role
Properties:
Expand Down Expand Up @@ -266,7 +287,7 @@ Resources:
ServiceToken: !GetAtt OutputSettingsFunction.Arn
EmbeddingsModelId: !Ref EmbeddingsModelId
LLMModelId: !Ref LLMModelId
LastUpdate: '10/02/2023'
LastUpdate: '10/25/2023'

TestBedrockModelFunction:
Type: AWS::Lambda::Function
Expand Down Expand Up @@ -338,4 +359,12 @@ Outputs:

QnABotSettingQAModelParams:
Description: QnABot Designer Setting "LLM_QA_MODEL_PARAMS"
Value: !GetAtt OutputSettings.LLM_QA_MODEL_PARAMS
Value: !GetAtt OutputSettings.LLM_QA_MODEL_PARAMS

QnAItemLambdaHookFunctionName:
Description: QnA Item Lambda Hook Function Name (use with no_hits item for optional ask-the-LLM fallback)
Value: !Ref QnaItemLambdaHookFunction

QnAItemLambdaHookArgs:
Description: QnA Item Lambda Hook Args (use with no_hits item for optional ask-the-LLM fallback)
Value: !GetAtt OutputSettings.QNAITEM_LAMBDAHOOK_ARGS