Skip to content

Commit

Permalink
Merge pull request #272 from uezo/support-llm-custom
Browse files Browse the repository at this point in the history
Support custom request parameters and headers for LLM APIs
  • Loading branch information
uezo committed Dec 31, 2023
2 parents 9c355a8 + 43ac59c commit 1d8c956
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 9 deletions.
39 changes: 37 additions & 2 deletions Scripts/LLM/ChatGPT/ChatGPTService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ namespace ChatdollKit.LLM.ChatGPT
public class ChatGPTService : LLMServiceBase
{
public string HistoryKey = "ChatGPTHistories";
public string CustomParameterKey = "ChatGPTParameters";
public string CustomHeaderKey = "ChatGPTHeaders";

[Header("API configuration")]
public string ApiKey;
Expand All @@ -21,6 +23,11 @@ public class ChatGPTService : LLMServiceBase
public bool IsAzure;
public int MaxTokens = 0;
public float Temperature = 0.5f;
public float FrequencyPenalty = 0.0f;
public bool Logprobs = false; // Not available on gpt-4v
public int TopLogprobs = 0; // Set true to Logprobs to use TopLogprobs
public float PresencePenalty = 0.0f;
public List<string> Stop;

[Header("Network configuration")]
[SerializeField]
Expand Down Expand Up @@ -120,11 +127,18 @@ public override async UniTask<List<ILLMMessage>> MakePromptAsync(string userId,

public override async UniTask<ILLMSession> GenerateContentAsync(List<ILLMMessage> messages, Dictionary<string, object> payloads, bool useFunctions = true, int retryCounter = 1, CancellationToken token = default)
{
// Custom parameters and headers
var stateData = (Dictionary<string, object>)payloads["StateData"];
var customParameters = stateData.ContainsKey(CustomParameterKey) ? (Dictionary<string, string>)stateData[CustomParameterKey] : new Dictionary<string, string>();
var customHeaders = stateData.ContainsKey(CustomHeaderKey) ? (Dictionary<string, string>)stateData[CustomHeaderKey] : new Dictionary<string, string>();

// Start streaming session
var chatGPTSession = new ChatGPTSession();
chatGPTSession.Contexts = messages;
chatGPTSession.StreamingTask = StartStreamingAsync(chatGPTSession, useFunctions, token);
chatGPTSession.StreamingTask = StartStreamingAsync(chatGPTSession, customParameters, customHeaders, useFunctions, token);
chatGPTSession.FunctionName = await WaitForFunctionName(chatGPTSession, token);

// Retry
if (chatGPTSession.ResponseType == ResponseType.Timeout)
{
if (retryCounter > 0)
Expand All @@ -143,16 +157,19 @@ public override async UniTask<ILLMSession> GenerateContentAsync(List<ILLMMessage
return chatGPTSession;
}

public virtual async UniTask StartStreamingAsync(ChatGPTSession chatGPTSession, bool useFunctions = true, CancellationToken token = default)
public virtual async UniTask StartStreamingAsync(ChatGPTSession chatGPTSession, Dictionary<string, string> customParameters, Dictionary<string, string> customHeaders, bool useFunctions = true, CancellationToken token = default)
{
// Make request data
var data = new Dictionary<string, object>()
{
{ "model", Model },
{ "temperature", Temperature },
{ "messages", chatGPTSession.Contexts },
{ "frequency_penalty", FrequencyPenalty },
{ "presence_penalty", PresencePenalty },
{ "stream", true },
};

if (MaxTokens > 0)
{
data.Add("max_tokens", MaxTokens);
Expand All @@ -161,6 +178,19 @@ public virtual async UniTask StartStreamingAsync(ChatGPTSession chatGPTSession,
{
data.Add("functions", llmTools);
}
if (Logprobs == true)
{
data.Add("logprobs", true);
data.Add("top_logprobs", TopLogprobs);
}
if (Stop != null && Stop.Count > 0)
{
data.Add("stop", Stop);
}
foreach (var p in customParameters)
{
data[p.Key] = p.Value;
}

// Prepare API request
using var streamRequest = new UnityWebRequest(
Expand All @@ -178,6 +208,11 @@ public virtual async UniTask StartStreamingAsync(ChatGPTSession chatGPTSession,
streamRequest.SetRequestHeader("Authorization", "Bearer " + ApiKey);
}
streamRequest.SetRequestHeader("Content-Type", "application/json");
foreach (var h in customHeaders)
{
streamRequest.SetRequestHeader(h.Key, h.Value);
}

if (DebugMode)
{
Debug.Log($"Request to ChatGPT: {JsonConvert.SerializeObject(data)}");
Expand Down
23 changes: 22 additions & 1 deletion Scripts/LLM/ChatGPT/ChatGPTServiceWebGL.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public override bool IsEnabled
protected bool isChatCompletionJSDone { get; set; } = false;
protected Dictionary<string, ChatGPTSession> sessions { get; set; } = new Dictionary<string, ChatGPTSession>();

public override async UniTask StartStreamingAsync(ChatGPTSession chatGPTSession, bool useFunctions = true, CancellationToken token = default)
public override async UniTask StartStreamingAsync(ChatGPTSession chatGPTSession, Dictionary<string, string> customParameters, Dictionary<string, string> customHeaders, bool useFunctions = true, CancellationToken token = default)
{
// Add session for callback
var sessionId = Guid.NewGuid().ToString();
Expand All @@ -42,6 +42,8 @@ public override async UniTask StartStreamingAsync(ChatGPTSession chatGPTSession,
{ "model", Model },
{ "temperature", Temperature },
{ "messages", chatGPTSession.Contexts },
{ "frequency_penalty", FrequencyPenalty },
{ "presence_penalty", PresencePenalty },
{ "stream", true },
};
if (MaxTokens > 0)
Expand All @@ -52,6 +54,25 @@ public override async UniTask StartStreamingAsync(ChatGPTSession chatGPTSession,
{
data.Add("functions", llmTools);
}
if (Logprobs == true)
{
data.Add("logprobs", true);
data.Add("top_logprobs", TopLogprobs);
}
if (Stop != null && Stop.Count > 0)
{
data.Add("stop", Stop);
}
foreach (var p in customParameters)
{
data[p.Key] = p.Value;
}

// TODO: Support custom headers later...
if (customHeaders.Count >= 0)
{
Debug.LogWarning("Custom headers for ChatGPT on WebGL is not supported for now.");
}

// Start API stream
isChatCompletionJSDone = false;
Expand Down
22 changes: 20 additions & 2 deletions Scripts/LLM/Claude/ClaudeService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ namespace ChatdollKit.LLM.Claude
public class ClaudeService : LLMServiceBase
{
public string HistoryKey = "ClaudeHistories";
public string CustomParameterKey = "ClaudeParameters";
public string CustomHeaderKey = "ClaudeHeaders";

[Header("API configuration")]
public string ApiKey;
Expand Down Expand Up @@ -81,11 +83,18 @@ public override async UniTask<List<ILLMMessage>> MakePromptAsync(string userId,

public override async UniTask<ILLMSession> GenerateContentAsync(List<ILLMMessage> messages, Dictionary<string, object> payloads, bool useFunctions = true, int retryCounter = 1, CancellationToken token = default)
{
// Custom parameters and headers
var stateData = (Dictionary<string, object>)payloads["StateData"];
var customParameters = stateData.ContainsKey(CustomParameterKey) ? (Dictionary<string, string>)stateData[CustomParameterKey] : new Dictionary<string, string>();
var customHeaders = stateData.ContainsKey(CustomHeaderKey) ? (Dictionary<string, string>)stateData[CustomHeaderKey] : new Dictionary<string, string>();

// Start streaming session
var claudeSession = new ClaudeSession();
claudeSession.Contexts = messages;
claudeSession.StreamingTask = StartStreamingAsync(claudeSession, useFunctions, token);
claudeSession.StreamingTask = StartStreamingAsync(claudeSession, customParameters, customHeaders, useFunctions, token);
await WaitForResponseType(claudeSession, token);

// Retry
if (claudeSession.ResponseType == ResponseType.Timeout)
{
if (retryCounter > 0)
Expand All @@ -104,7 +113,7 @@ public override async UniTask<ILLMSession> GenerateContentAsync(List<ILLMMessage
return claudeSession;
}

public virtual async UniTask StartStreamingAsync(ClaudeSession claudeSession, bool useFunctions = true, CancellationToken token = default)
public virtual async UniTask StartStreamingAsync(ClaudeSession claudeSession, Dictionary<string, string> customParameters, Dictionary<string, string> customHeaders, bool useFunctions = true, CancellationToken token = default)
{
// Make request data
var data = new Dictionary<string, object>()
Expand All @@ -121,6 +130,10 @@ public virtual async UniTask StartStreamingAsync(ClaudeSession claudeSession, bo
{
data.Add("top_k", TopK);
}
foreach (var p in customParameters)
{
data[p.Key] = p.Value;
}

// Prepare API request
using var streamRequest = new UnityWebRequest(
Expand All @@ -132,6 +145,11 @@ public virtual async UniTask StartStreamingAsync(ClaudeSession claudeSession, bo
streamRequest.SetRequestHeader("anthropic-beta", "messages-2023-12-15");
streamRequest.SetRequestHeader("Content-Type", "application/json");
streamRequest.SetRequestHeader("x-api-key", ApiKey);
foreach (var h in customHeaders)
{
streamRequest.SetRequestHeader(h.Key, h.Value);
}

if (DebugMode)
{
Debug.Log($"Request to Claude: {JsonConvert.SerializeObject(data)}");
Expand Down
12 changes: 11 additions & 1 deletion Scripts/LLM/Claude/ClaudeServiceWebGL.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public override bool IsEnabled
protected bool isChatCompletionJSDone { get; set; } = false;
protected Dictionary<string, ClaudeSession> sessions { get; set; } = new Dictionary<string, ClaudeSession>();

public override async UniTask StartStreamingAsync(ClaudeSession claudeSession, bool useFunctions = true, CancellationToken token = default)
public override async UniTask StartStreamingAsync(ClaudeSession claudeSession, Dictionary<string, string> customParameters, Dictionary<string, string> customHeaders, bool useFunctions = true, CancellationToken token = default)
{
// Add session for callback
var sessionId = Guid.NewGuid().ToString();
Expand All @@ -51,6 +51,16 @@ public override async UniTask StartStreamingAsync(ClaudeSession claudeSession, b
{
data.Add("top_k", TopK);
}
foreach (var p in customParameters)
{
data[p.Key] = p.Value;
}

// TODO: Support custom headers later...
if (customHeaders.Count >= 0)
{
Debug.LogWarning("Custom headers for Claude on WebGL is not supported for now.");
}

// Start API stream
isChatCompletionJSDone = false;
Expand Down
21 changes: 19 additions & 2 deletions Scripts/LLM/Gemini/GeminiService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ namespace ChatdollKit.LLM.Gemini
public class GeminiService : LLMServiceBase
{
public string HistoryKey = "GeminiHistories";
public string CustomParameterKey = "GeminiParameters";
public string CustomHeaderKey = "GeminiHeaders";

[Header("API configuration")]
public string ApiKey;
Expand Down Expand Up @@ -117,9 +119,15 @@ public override async UniTask<List<ILLMMessage>> MakePromptAsync(string userId,

public override async UniTask<ILLMSession> GenerateContentAsync(List<ILLMMessage> messages, Dictionary<string, object> payloads, bool useFunctions = true, int retryCounter = 1, CancellationToken token = default)
{
// Custom parameters and headers
var stateData = (Dictionary<string, object>)payloads["StateData"];
var customParameters = stateData.ContainsKey(CustomParameterKey) ? (Dictionary<string, string>)stateData[CustomParameterKey] : new Dictionary<string, string>();
var customHeaders = stateData.ContainsKey(CustomHeaderKey) ? (Dictionary<string, string>)stateData[CustomHeaderKey] : new Dictionary<string, string>();

// Start streaming session
var geminiSession = new GeminiSession();
geminiSession.Contexts = messages;
geminiSession.StreamingTask = StartStreamingAsync(geminiSession, useFunctions, token);
geminiSession.StreamingTask = StartStreamingAsync(geminiSession, customParameters, customHeaders, useFunctions, token);
await WaitForResponseType(geminiSession, token);

if (geminiSession.ResponseType == ResponseType.Timeout)
Expand All @@ -140,7 +148,7 @@ public override async UniTask<ILLMSession> GenerateContentAsync(List<ILLMMessage
return geminiSession;
}

public virtual async UniTask StartStreamingAsync(GeminiSession geminiSession, bool useFunctions = true, CancellationToken token = default)
public virtual async UniTask StartStreamingAsync(GeminiSession geminiSession, Dictionary<string, string> customParameters, Dictionary<string, string> customHeaders, bool useFunctions = true, CancellationToken token = default)
{
// GenerationConfig
var generationConfig = new GeminiGenerationConfig()
Expand All @@ -158,6 +166,10 @@ public virtual async UniTask StartStreamingAsync(GeminiSession geminiSession, bo
{ "contents", geminiSession.Contexts },
{ "generationConfig", generationConfig }
};
foreach (var p in customParameters)
{
data[p.Key] = p.Value;
}

// Set tools. Multimodal model doesn't support function calling for now (2023.12.29)
if (useFunctions && llmTools.Count > 0 && !Model.ToLower().Contains("vision"))
Expand All @@ -176,6 +188,11 @@ public virtual async UniTask StartStreamingAsync(GeminiSession geminiSession, bo
);
streamRequest.timeout = responseTimeoutSec;
streamRequest.SetRequestHeader("Content-Type", "application/json");
foreach (var h in customHeaders)
{
streamRequest.SetRequestHeader(h.Key, h.Value);
}

if (DebugMode)
{
Debug.Log($"Request to Gemini: {JsonConvert.SerializeObject(data)}");
Expand Down
12 changes: 11 additions & 1 deletion Scripts/LLM/Gemini/GeminiServiceWebGL.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public override bool IsEnabled
protected bool isChatCompletionJSDone { get; set; } = false;
protected Dictionary<string, GeminiSession> sessions { get; set; } = new Dictionary<string, GeminiSession>();

public override async UniTask StartStreamingAsync(GeminiSession geminiSession, bool useFunctions = true, CancellationToken token = default)
public override async UniTask StartStreamingAsync(GeminiSession geminiSession, Dictionary<string, string> customParameters, Dictionary<string, string> customHeaders, bool useFunctions = true, CancellationToken token = default)
{
// Add session for callback
var sessionId = Guid.NewGuid().ToString();
Expand All @@ -52,6 +52,16 @@ public override async UniTask StartStreamingAsync(GeminiSession geminiSession, b
{ "contents", geminiSession.Contexts },
{ "generationConfig", generationConfig }
};
foreach (var p in customParameters)
{
data[p.Key] = p.Value;
}

// TODO: Support custom headers later...
if (customHeaders.Count >= 0)
{
Debug.LogWarning("Custom headers for Gemini on WebGL is not supported for now.");
}

// Set tools. Multimodal model doesn't support function calling for now (2023.12.29)
if (useFunctions && llmTools.Count > 0 && !Model.ToLower().Contains("vision"))
Expand Down

0 comments on commit 1d8c956

Please sign in to comment.