Skip to content

Commit

Permalink
Merge pull request #3 from xica/magellan_rag
Browse files Browse the repository at this point in the history
Implement Magellan RAG Interface
  • Loading branch information
mrkn committed Jun 24, 2024
2 parents 67f6a6d + 873c989 commit 8733031
Show file tree
Hide file tree
Showing 14 changed files with 382 additions and 17 deletions.
1 change: 1 addition & 0 deletions Gemfile
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ gem "sidekiq", "~> 7"
gem "sinatra"
gem "sinatra-contrib", require: false
gem "slack-ruby-client"
gem "faraday"

group :development, :test do
# See https://guides.rubyonrails.org/debugging_rails_applications.html#debugging-with-the-debug-gem
Expand Down
1 change: 1 addition & 0 deletions Gemfile.lock
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ DEPENDENCIES
cssbundling-rails
database_rewinder
debug
faraday
jbuilder
jsbundling-rails
pg (~> 1.1)
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

* `ADMIN_USER` and `ADMIN_PASSWORD` is used for Basic Auth of Sidekiq's management console
* `ALLOW_CHANNEL_IDS` for specifying channel IDs where users can communicate with the chatbot
* `MAGELLAN_RAG_CHANNEL_IDS` for specifying channel IDs where users can query about MAGELLAN's past reports
* `MAGELLAN_RAG_ENDPOINT` for specifying the endpoint of the Magellan RAG API in the `schema://host:port` format
* `SLACK_BOT_TOKEN` is for the Slack Bot's access token
* `SLACK_SIGNING_SECRET` is for signing secret to check the request coming from Slack
* `OPENAI_ACCESS_TOKEN` is for OpenAI API Access Token
Expand Down
17 changes: 4 additions & 13 deletions app/jobs/chat_completion_job.rb
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
require "slack_bot/utils"
require "utils"

class ChatCompletionJob < ApplicationJob
class ChatCompletionJob < SlackResponseJob
queue_as :default

VALID_MODELS = [
Expand Down Expand Up @@ -221,23 +221,14 @@ def perform(params)
end

private def start_query(message, name=REACTION_SYMBOL)
client = Slack::Web::Client.new
client.reactions_add(channel: message.conversation.slack_id, timestamp: message.slack_ts, name:)
rescue
nil
start_response(message, name)
end

private def finish_query(message, name=REACTION_SYMBOL)
client = Slack::Web::Client.new
client.reactions_remove(channel: message.conversation.slack_id, timestamp: message.slack_ts, name:)
rescue
nil
finish_response(message, name)
end

private def error_query(message, name=ERROR_REACTION_SYMBOL)
client = Slack::Web::Client.new
client.reactions_add(channel: message.conversation.slack_id, timestamp: message.slack_ts, name:)
rescue
nil
error_response(message, name)
end
end
147 changes: 147 additions & 0 deletions app/jobs/magellan_rag_query_job.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
class MagellanRagQueryJob < SlackResponseJob
queue_as :default

VALID_MODELS = [
# OpenAI API
"gpt-4o".freeze,
]

DEFAULT_MODEL = ENV.fetch("DEFAULT_MODEL", VALID_MODELS[0])

class InvalidOptionError < StandardError; end

Options = Struct.new(
:model,
keyword_init: true
) do
def initialize(**kw)
super
self.model ||= DEFAULT_MODEL
end

def validate!
validate_model! unless model.nil?
end

def validate_model!
MagellanRagQueryJob::VALID_MODELS.each do |valid_model|
case valid_model
when Regexp
return if valid_model.match?(model)
else
return if model == valid_model
end
end
raise InvalidOptionError, "Invalid model is specified: #{model}"
end
end

def model_for_message(message)
message.conversation.model || DEFAULT_MODEL
end

def perform(params)
if params["message_id"].blank?
logger.warn "Empty message_id is given"
return
end

message = Message.find(params["message_id"])
if message.blank?
logger.warn "Unable to find Message with id=#{message_id}"
return
end

if message.query.present?
logger.warn "Message with id=#{message.id} already has its query and response"
return
end

options = Options.new(**params.fetch("options", {}))

begin
start_response(message)
process_query(message, options)
ensure
finish_response(message)
end
rescue Exception => error
logger.error "ERROR: #{error.message}"
raise unless Rails.env.production?
end

private def process_query(message, options)
if message.slack_ts != message.slack_thread_ts
# NOTE: This job currently does not support queries in threads.
error_response("スレッドでの問い合わせには対応していません。")
return
end

model = options.model
query = Query.new(
message: message,
text: "[RAG QUERY] #{message.text}",
body: {
parameters: {
model: model
}
}
)

rag_response = Utils::MagellanRAG.generate_answer(message.text)
logger.info "RAG Response:\n" + rag_response.pretty_inspect.each_line.map {|l| "> #{l}" }.join("")

unless rag_response.key? "answer"
Util.post_message(
channel: message.conversation.slack_id,
thread_ts: message.slack_thread_ts,
text: ":#{ERROR_REACTION_SYMBOL}: *ERROR*: No answer key in the response from RAG: #{rag_response.inspect}",
mrkdwn: true
)
error_response(message)
return
end

answer = rag_response["answer"]
logger.info "RAG Answer:\n" + answer.each_line.map {|l| "> #{l}" }.join("")

response = Response.new(
query: query,
text: "[RAG ANSWER] #{answer}",
n_query_tokens: 0,
n_response_tokens: 0,
body: rag_response,
slack_thread_ts: message.slack_thread_ts
)

post_params = format_rag_response(
answer,
user: message.user
)

posted_message = Utils.post_message(
channel: message.conversation.slack_id,
thread_ts: message.slack_thread_ts,
**post_params
)
logger.info posted_message.inspect

unless posted_message.ok
error_response(message)
return
end

response.slack_ts = posted_message.ts
response.slack_thread_ts = message.slack_thread_ts

Query.transaction do
query.save!
response.save!
end
end

private def format_rag_response(answer, user:)
text = "<@#{user.slack_id}> #{answer}"
SlackBot.format_chat_gpt_response(text)
end
end
28 changes: 28 additions & 0 deletions app/jobs/slack_response_job.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
require "slack_bot/utils"

class SlackResponseJob < ApplicationJob
DEFAULT_REACTION_SYMBOL = "hourglass_flowing_sand".freeze
REACTION_SYMBOL = ENV.fetch("SLACK_REACTION_SYMBOL", DEFAULT_REACTION_SYMBOL)
ERROR_REACTION_SYMBOL = "bangbang".freeze

private def start_response(message, name=REACTION_SYMBOL)
client = Slack::Web::Client.new
client.reactions_add(channel: message.conversation.slack_id, timestamp: message.slack_ts, name:)
rescue
nil
end

private def finish_response(message, name=REACTION_SYMBOL)
client = Slack::Web::Client.new
client.reactions_remove(channel: message.conversation.slack_id, timestamp: message.slack_ts, name:)
rescue
nil
end

private def error_response(message, name=ERROR_REACTION_SYMBOL)
client = Slack::Web::Client.new
client.reactions_add(channel: message.conversation.slack_id, timestamp: message.slack_ts, name:)
rescue
nil
end
end
69 changes: 65 additions & 4 deletions lib/slack_bot/app.rb
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class Application < Sinatra::Base
end

ALLOW_CHANNEL_IDS = ENV.fetch("ALLOW_CHANNEL_IDS", "").split(/\s+|,\s*/)
MAGELLAN_RAG_CHANNEL_IDS = ENV.fetch("MAGELLAN_RAG_CHANNEL_IDS", "").split(/\s+|,\s*/)
MAGELLAN_RAG_ENDPOINT = ENV.fetch("MAGELLAN_RAG_ENDPOINT", "localhost:12345")

private def allowed_channel?(channel)
if ALLOW_CHANNEL_IDS.empty?
Expand All @@ -36,6 +38,14 @@ class Application < Sinatra::Base
end
end

private def magellan_rag_channel?(channel)
if MAGELLAN_RAG_CHANNEL_IDS.empty?
false
else
MAGELLAN_RAG_CHANNEL_IDS.include?(channel.slack_id)
end
end

private def thread_allowed_channel?(channel)
true # channel.thread_allowed?
end
Expand All @@ -61,6 +71,14 @@ class Application < Sinatra::Base
status 400
end

def mention?(event)
event["type"] == "app_mention"
end

def direct_message?(event)
event["type"] == "message" && event["channel_type"] == "im"
end

post "/events" do
request_data = JSON.parse(request.body.read)

Expand All @@ -71,7 +89,7 @@ class Application < Sinatra::Base
when "event_callback"
event = request_data["event"]

# Non-thread message:
# Non-thread message: {{{
# {"client_msg_id"=>"58e6675f-c164-451a-8354-7c7085ddba33",
# "type"=>"app_mention",
# "text"=>"<@U04U7QNHCD9> 以下を日本語に翻訳してください:\n" + "\n" + "Brands and...",
Expand All @@ -89,8 +107,9 @@ class Application < Sinatra::Base
# "team"=>"T036WLG7F",
# "channel"=>"C036WLG7Z",
# "event_ts"=>"1679639978.922569"}
# }}}

# Reply in thread:
# Reply in thread: {{{
# {"client_msg_id"=>"56846932-4600-4161-9947-a164084bf559",
# "type"=>"app_mention",
# "text"=>"<@U04U7QNHCD9> スレッド返信のテストです。",
Expand All @@ -109,15 +128,25 @@ class Application < Sinatra::Base
# "parent_user_id"=>"U04U7QNHCD9",
# "channel"=>"C036WLG7Z",
# "event_ts"=>"1679644228.326869"}
# }}}

if event["type"] == "app_mention" || (event["type"] == "message" && event["channel_type"] == "im")
if mention?(event) || direct_message?(event)
channel = ensure_conversation(event["channel"])
user = ensure_user(event["user"], channel)
ts = event["ts"]
thread_ts = event["thread_ts"]
text = event["text"]

if allowed_channel?(channel)
case
when magellan_rag_channel?(channel)
logger.info "Event:\n" + event.pretty_inspect.each_line.map {|l| "> #{l}" }.join("")
logger.info "#{channel.slack_id}: #{text}"
if thread_ts and not thread_allowed_channel?(channel)
notify_do_not_allowed_thread_context(channel, user, ts)
else
process_magellan_rag_message(channel, user, ts, thread_ts, text)
end
when allowed_channel?(channel)
logger.info "Event:\n" + event.pretty_inspect.each_line.map {|l| "> #{l}" }.join("")
logger.info "#{channel.slack_id}: #{text}"

Expand All @@ -137,6 +166,7 @@ class Application < Sinatra::Base
payload = JSON.parse(params["payload"])
case payload["type"]
when "block_actions"
# Example payload: {{{
# {"type"=>"block_actions",
# "user"=>
# {"id"=>"U02M703H8UD",
Expand Down Expand Up @@ -205,6 +235,7 @@ class Application < Sinatra::Base
# "style"=>"primary",
# "type"=>"button",
# "action_ts"=>"1680230940.375437"}]}
# }}}

feedback_value = payload.dig("actions", 0, "value")
response = Response.find_by!(slack_ts: payload["message"]["ts"])
Expand Down Expand Up @@ -318,6 +349,7 @@ class Application < Sinatra::Base
return unless text =~ /^<@#{bot_id}>\s+/

message_body = Regexp.last_match.post_match

options = process_options(message_body)
return if options.nil?

Expand Down Expand Up @@ -372,6 +404,35 @@ class Application < Sinatra::Base
options
end

private def process_magellan_rag_message(channel, user, ts, thread_ts, text)
return unless text =~ /^<@#{bot_id}>\s+/

message_body = Regexp.last_match.post_match
options = process_magellan_rag_options(message_body)
return if options.nil?

begin
options.validate!
rescue MagellanRagQeuryJob::InvalidOptionError => error
reply_as_ephemeral(channel, user, ts, error.message)
return
end

message = Message.create!(
conversation: channel,
user: user,
text: message_body,
slack_ts: ts,
slack_thread_ts: thread_ts || ts
)
MagellanRagQeuryJob.perform_later("message_id" => message.id, "options" => options.to_h)
end

private def process_magellan_rag_options(message_body)
# TODO: implement options
MagellanRagQeuryJob::Options.new
end

private def check_command_permission!(channel, user)
# TODO
end
Expand Down
2 changes: 2 additions & 0 deletions lib/utils.rb
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,5 @@ module Utils
messages
end
end

require_relative "utils/magellan_rag"
Loading

0 comments on commit 8733031

Please sign in to comment.