Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support custom request parameters and headers for LLM APIs #272

Merged
merged 3 commits into from
Dec 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 37 additions & 2 deletions Scripts/LLM/ChatGPT/ChatGPTService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ namespace ChatdollKit.LLM.ChatGPT
public class ChatGPTService : LLMServiceBase
{
public string HistoryKey = "ChatGPTHistories";
public string CustomParameterKey = "ChatGPTParameters";
public string CustomHeaderKey = "ChatGPTHeaders";

[Header("API configuration")]
public string ApiKey;
Expand All @@ -21,6 +23,11 @@ public class ChatGPTService : LLMServiceBase
public bool IsAzure;
public int MaxTokens = 0;
public float Temperature = 0.5f;
public float FrequencyPenalty = 0.0f;
public bool Logprobs = false; // Not available on gpt-4v
public int TopLogprobs = 0; // Set true to Logprobs to use TopLogprobs
public float PresencePenalty = 0.0f;
public List<string> Stop;

[Header("Network configuration")]
[SerializeField]
Expand Down Expand Up @@ -120,11 +127,18 @@ public override async UniTask<List<ILLMMessage>> MakePromptAsync(string userId,

public override async UniTask<ILLMSession> GenerateContentAsync(List<ILLMMessage> messages, Dictionary<string, object> payloads, bool useFunctions = true, int retryCounter = 1, CancellationToken token = default)
{
// Custom parameters and headers
var stateData = (Dictionary<string, object>)payloads["StateData"];
var customParameters = stateData.ContainsKey(CustomParameterKey) ? (Dictionary<string, string>)stateData[CustomParameterKey] : new Dictionary<string, string>();
var customHeaders = stateData.ContainsKey(CustomHeaderKey) ? (Dictionary<string, string>)stateData[CustomHeaderKey] : new Dictionary<string, string>();

// Start streaming session
var chatGPTSession = new ChatGPTSession();
chatGPTSession.Contexts = messages;
chatGPTSession.StreamingTask = StartStreamingAsync(chatGPTSession, useFunctions, token);
chatGPTSession.StreamingTask = StartStreamingAsync(chatGPTSession, customParameters, customHeaders, useFunctions, token);
chatGPTSession.FunctionName = await WaitForFunctionName(chatGPTSession, token);

// Retry
if (chatGPTSession.ResponseType == ResponseType.Timeout)
{
if (retryCounter > 0)
Expand All @@ -143,16 +157,19 @@ public override async UniTask<ILLMSession> GenerateContentAsync(List<ILLMMessage
return chatGPTSession;
}

public virtual async UniTask StartStreamingAsync(ChatGPTSession chatGPTSession, bool useFunctions = true, CancellationToken token = default)
public virtual async UniTask StartStreamingAsync(ChatGPTSession chatGPTSession, Dictionary<string, string> customParameters, Dictionary<string, string> customHeaders, bool useFunctions = true, CancellationToken token = default)
{
// Make request data
var data = new Dictionary<string, object>()
{
{ "model", Model },
{ "temperature", Temperature },
{ "messages", chatGPTSession.Contexts },
{ "frequency_penalty", FrequencyPenalty },
{ "presence_penalty", PresencePenalty },
{ "stream", true },
};

if (MaxTokens > 0)
{
data.Add("max_tokens", MaxTokens);
Expand All @@ -161,6 +178,19 @@ public virtual async UniTask StartStreamingAsync(ChatGPTSession chatGPTSession,
{
data.Add("functions", llmTools);
}
if (Logprobs == true)
{
data.Add("logprobs", true);
data.Add("top_logprobs", TopLogprobs);
}
if (Stop != null && Stop.Count > 0)
{
data.Add("stop", Stop);
}
foreach (var p in customParameters)
{
data[p.Key] = p.Value;
}

// Prepare API request
using var streamRequest = new UnityWebRequest(
Expand All @@ -178,6 +208,11 @@ public virtual async UniTask StartStreamingAsync(ChatGPTSession chatGPTSession,
streamRequest.SetRequestHeader("Authorization", "Bearer " + ApiKey);
}
streamRequest.SetRequestHeader("Content-Type", "application/json");
foreach (var h in customHeaders)
{
streamRequest.SetRequestHeader(h.Key, h.Value);
}

if (DebugMode)
{
Debug.Log($"Request to ChatGPT: {JsonConvert.SerializeObject(data)}");
Expand Down
23 changes: 22 additions & 1 deletion Scripts/LLM/ChatGPT/ChatGPTServiceWebGL.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public override bool IsEnabled
protected bool isChatCompletionJSDone { get; set; } = false;
protected Dictionary<string, ChatGPTSession> sessions { get; set; } = new Dictionary<string, ChatGPTSession>();

public override async UniTask StartStreamingAsync(ChatGPTSession chatGPTSession, bool useFunctions = true, CancellationToken token = default)
public override async UniTask StartStreamingAsync(ChatGPTSession chatGPTSession, Dictionary<string, string> customParameters, Dictionary<string, string> customHeaders, bool useFunctions = true, CancellationToken token = default)
{
// Add session for callback
var sessionId = Guid.NewGuid().ToString();
Expand All @@ -42,6 +42,8 @@ public override async UniTask StartStreamingAsync(ChatGPTSession chatGPTSession,
{ "model", Model },
{ "temperature", Temperature },
{ "messages", chatGPTSession.Contexts },
{ "frequency_penalty", FrequencyPenalty },
{ "presence_penalty", PresencePenalty },
{ "stream", true },
};
if (MaxTokens > 0)
Expand All @@ -52,6 +54,25 @@ public override async UniTask StartStreamingAsync(ChatGPTSession chatGPTSession,
{
data.Add("functions", llmTools);
}
if (Logprobs == true)
{
data.Add("logprobs", true);
data.Add("top_logprobs", TopLogprobs);
}
if (Stop != null && Stop.Count > 0)
{
data.Add("stop", Stop);
}
foreach (var p in customParameters)
{
data[p.Key] = p.Value;
}

// TODO: Support custom headers later...
if (customHeaders.Count >= 0)
{
Debug.LogWarning("Custom headers for ChatGPT on WebGL is not supported for now.");
}

// Start API stream
isChatCompletionJSDone = false;
Expand Down
22 changes: 20 additions & 2 deletions Scripts/LLM/Claude/ClaudeService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ namespace ChatdollKit.LLM.Claude
public class ClaudeService : LLMServiceBase
{
public string HistoryKey = "ClaudeHistories";
public string CustomParameterKey = "ClaudeParameters";
public string CustomHeaderKey = "ClaudeHeaders";

[Header("API configuration")]
public string ApiKey;
Expand Down Expand Up @@ -81,11 +83,18 @@ public override async UniTask<List<ILLMMessage>> MakePromptAsync(string userId,

public override async UniTask<ILLMSession> GenerateContentAsync(List<ILLMMessage> messages, Dictionary<string, object> payloads, bool useFunctions = true, int retryCounter = 1, CancellationToken token = default)
{
// Custom parameters and headers
var stateData = (Dictionary<string, object>)payloads["StateData"];
var customParameters = stateData.ContainsKey(CustomParameterKey) ? (Dictionary<string, string>)stateData[CustomParameterKey] : new Dictionary<string, string>();
var customHeaders = stateData.ContainsKey(CustomHeaderKey) ? (Dictionary<string, string>)stateData[CustomHeaderKey] : new Dictionary<string, string>();

// Start streaming session
var claudeSession = new ClaudeSession();
claudeSession.Contexts = messages;
claudeSession.StreamingTask = StartStreamingAsync(claudeSession, useFunctions, token);
claudeSession.StreamingTask = StartStreamingAsync(claudeSession, customParameters, customHeaders, useFunctions, token);
await WaitForResponseType(claudeSession, token);

// Retry
if (claudeSession.ResponseType == ResponseType.Timeout)
{
if (retryCounter > 0)
Expand All @@ -104,7 +113,7 @@ public override async UniTask<ILLMSession> GenerateContentAsync(List<ILLMMessage
return claudeSession;
}

public virtual async UniTask StartStreamingAsync(ClaudeSession claudeSession, bool useFunctions = true, CancellationToken token = default)
public virtual async UniTask StartStreamingAsync(ClaudeSession claudeSession, Dictionary<string, string> customParameters, Dictionary<string, string> customHeaders, bool useFunctions = true, CancellationToken token = default)
{
// Make request data
var data = new Dictionary<string, object>()
Expand All @@ -121,6 +130,10 @@ public virtual async UniTask StartStreamingAsync(ClaudeSession claudeSession, bo
{
data.Add("top_k", TopK);
}
foreach (var p in customParameters)
{
data[p.Key] = p.Value;
}

// Prepare API request
using var streamRequest = new UnityWebRequest(
Expand All @@ -132,6 +145,11 @@ public virtual async UniTask StartStreamingAsync(ClaudeSession claudeSession, bo
streamRequest.SetRequestHeader("anthropic-beta", "messages-2023-12-15");
streamRequest.SetRequestHeader("Content-Type", "application/json");
streamRequest.SetRequestHeader("x-api-key", ApiKey);
foreach (var h in customHeaders)
{
streamRequest.SetRequestHeader(h.Key, h.Value);
}

if (DebugMode)
{
Debug.Log($"Request to Claude: {JsonConvert.SerializeObject(data)}");
Expand Down
12 changes: 11 additions & 1 deletion Scripts/LLM/Claude/ClaudeServiceWebGL.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public override bool IsEnabled
protected bool isChatCompletionJSDone { get; set; } = false;
protected Dictionary<string, ClaudeSession> sessions { get; set; } = new Dictionary<string, ClaudeSession>();

public override async UniTask StartStreamingAsync(ClaudeSession claudeSession, bool useFunctions = true, CancellationToken token = default)
public override async UniTask StartStreamingAsync(ClaudeSession claudeSession, Dictionary<string, string> customParameters, Dictionary<string, string> customHeaders, bool useFunctions = true, CancellationToken token = default)
{
// Add session for callback
var sessionId = Guid.NewGuid().ToString();
Expand All @@ -51,6 +51,16 @@ public override async UniTask StartStreamingAsync(ClaudeSession claudeSession, b
{
data.Add("top_k", TopK);
}
foreach (var p in customParameters)
{
data[p.Key] = p.Value;
}

// TODO: Support custom headers later...
if (customHeaders.Count >= 0)
{
Debug.LogWarning("Custom headers for Claude on WebGL is not supported for now.");
}

// Start API stream
isChatCompletionJSDone = false;
Expand Down
21 changes: 19 additions & 2 deletions Scripts/LLM/Gemini/GeminiService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ namespace ChatdollKit.LLM.Gemini
public class GeminiService : LLMServiceBase
{
public string HistoryKey = "GeminiHistories";
public string CustomParameterKey = "GeminiParameters";
public string CustomHeaderKey = "GeminiHeaders";

[Header("API configuration")]
public string ApiKey;
Expand Down Expand Up @@ -117,9 +119,15 @@ public override async UniTask<List<ILLMMessage>> MakePromptAsync(string userId,

public override async UniTask<ILLMSession> GenerateContentAsync(List<ILLMMessage> messages, Dictionary<string, object> payloads, bool useFunctions = true, int retryCounter = 1, CancellationToken token = default)
{
// Custom parameters and headers
var stateData = (Dictionary<string, object>)payloads["StateData"];
var customParameters = stateData.ContainsKey(CustomParameterKey) ? (Dictionary<string, string>)stateData[CustomParameterKey] : new Dictionary<string, string>();
var customHeaders = stateData.ContainsKey(CustomHeaderKey) ? (Dictionary<string, string>)stateData[CustomHeaderKey] : new Dictionary<string, string>();

// Start streaming session
var geminiSession = new GeminiSession();
geminiSession.Contexts = messages;
geminiSession.StreamingTask = StartStreamingAsync(geminiSession, useFunctions, token);
geminiSession.StreamingTask = StartStreamingAsync(geminiSession, customParameters, customHeaders, useFunctions, token);
await WaitForResponseType(geminiSession, token);

if (geminiSession.ResponseType == ResponseType.Timeout)
Expand All @@ -140,7 +148,7 @@ public override async UniTask<ILLMSession> GenerateContentAsync(List<ILLMMessage
return geminiSession;
}

public virtual async UniTask StartStreamingAsync(GeminiSession geminiSession, bool useFunctions = true, CancellationToken token = default)
public virtual async UniTask StartStreamingAsync(GeminiSession geminiSession, Dictionary<string, string> customParameters, Dictionary<string, string> customHeaders, bool useFunctions = true, CancellationToken token = default)
{
// GenerationConfig
var generationConfig = new GeminiGenerationConfig()
Expand All @@ -158,6 +166,10 @@ public virtual async UniTask StartStreamingAsync(GeminiSession geminiSession, bo
{ "contents", geminiSession.Contexts },
{ "generationConfig", generationConfig }
};
foreach (var p in customParameters)
{
data[p.Key] = p.Value;
}

// Set tools. Multimodal model doesn't support function calling for now (2023.12.29)
if (useFunctions && llmTools.Count > 0 && !Model.ToLower().Contains("vision"))
Expand All @@ -176,6 +188,11 @@ public virtual async UniTask StartStreamingAsync(GeminiSession geminiSession, bo
);
streamRequest.timeout = responseTimeoutSec;
streamRequest.SetRequestHeader("Content-Type", "application/json");
foreach (var h in customHeaders)
{
streamRequest.SetRequestHeader(h.Key, h.Value);
}

if (DebugMode)
{
Debug.Log($"Request to Gemini: {JsonConvert.SerializeObject(data)}");
Expand Down
12 changes: 11 additions & 1 deletion Scripts/LLM/Gemini/GeminiServiceWebGL.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public override bool IsEnabled
protected bool isChatCompletionJSDone { get; set; } = false;
protected Dictionary<string, GeminiSession> sessions { get; set; } = new Dictionary<string, GeminiSession>();

public override async UniTask StartStreamingAsync(GeminiSession geminiSession, bool useFunctions = true, CancellationToken token = default)
public override async UniTask StartStreamingAsync(GeminiSession geminiSession, Dictionary<string, string> customParameters, Dictionary<string, string> customHeaders, bool useFunctions = true, CancellationToken token = default)
{
// Add session for callback
var sessionId = Guid.NewGuid().ToString();
Expand All @@ -52,6 +52,16 @@ public override async UniTask StartStreamingAsync(GeminiSession geminiSession, b
{ "contents", geminiSession.Contexts },
{ "generationConfig", generationConfig }
};
foreach (var p in customParameters)
{
data[p.Key] = p.Value;
}

// TODO: Support custom headers later...
if (customHeaders.Count >= 0)
{
Debug.LogWarning("Custom headers for Gemini on WebGL is not supported for now.");
}

// Set tools. Multimodal model doesn't support function calling for now (2023.12.29)
if (useFunctions && llmTools.Count > 0 && !Model.ToLower().Contains("vision"))
Expand Down