diff --git a/Scripts/LLM/ChatGPT/ChatGPTService.cs b/Scripts/LLM/ChatGPT/ChatGPTService.cs index 851761a..0613e4a 100644 --- a/Scripts/LLM/ChatGPT/ChatGPTService.cs +++ b/Scripts/LLM/ChatGPT/ChatGPTService.cs @@ -13,6 +13,8 @@ namespace ChatdollKit.LLM.ChatGPT public class ChatGPTService : LLMServiceBase { public string HistoryKey = "ChatGPTHistories"; + public string CustomParameterKey = "ChatGPTParameters"; + public string CustomHeaderKey = "ChatGPTHeaders"; [Header("API configuration")] public string ApiKey; @@ -21,6 +23,11 @@ public class ChatGPTService : LLMServiceBase public bool IsAzure; public int MaxTokens = 0; public float Temperature = 0.5f; + public float FrequencyPenalty = 0.0f; + public bool Logprobs = false; // Not available on gpt-4v + public int TopLogprobs = 0; // Set true to Logprobs to use TopLogprobs + public float PresencePenalty = 0.0f; + public List Stop; [Header("Network configuration")] [SerializeField] @@ -120,11 +127,18 @@ public override async UniTask> MakePromptAsync(string userId, public override async UniTask GenerateContentAsync(List messages, Dictionary payloads, bool useFunctions = true, int retryCounter = 1, CancellationToken token = default) { + // Custom parameters and headers + var stateData = (Dictionary)payloads["StateData"]; + var customParameters = stateData.ContainsKey(CustomParameterKey) ? (Dictionary)stateData[CustomParameterKey] : new Dictionary(); + var customHeaders = stateData.ContainsKey(CustomHeaderKey) ? (Dictionary)stateData[CustomHeaderKey] : new Dictionary(); + + // Start streaming session var chatGPTSession = new ChatGPTSession(); chatGPTSession.Contexts = messages; - chatGPTSession.StreamingTask = StartStreamingAsync(chatGPTSession, useFunctions, token); + chatGPTSession.StreamingTask = StartStreamingAsync(chatGPTSession, customParameters, customHeaders, useFunctions, token); chatGPTSession.FunctionName = await WaitForFunctionName(chatGPTSession, token); + // Retry if (chatGPTSession.ResponseType == ResponseType.Timeout) { if (retryCounter > 0) @@ -143,7 +157,7 @@ public override async UniTask GenerateContentAsync(List customParameters, Dictionary customHeaders, bool useFunctions = true, CancellationToken token = default) { // Make request data var data = new Dictionary() @@ -151,8 +165,11 @@ public virtual async UniTask StartStreamingAsync(ChatGPTSession chatGPTSession, { "model", Model }, { "temperature", Temperature }, { "messages", chatGPTSession.Contexts }, + { "frequency_penalty", FrequencyPenalty }, + { "presence_penalty", PresencePenalty }, { "stream", true }, }; + if (MaxTokens > 0) { data.Add("max_tokens", MaxTokens); @@ -161,6 +178,19 @@ public virtual async UniTask StartStreamingAsync(ChatGPTSession chatGPTSession, { data.Add("functions", llmTools); } + if (Logprobs == true) + { + data.Add("logprobs", true); + data.Add("top_logprobs", TopLogprobs); + } + if (Stop != null && Stop.Count > 0) + { + data.Add("stop", Stop); + } + foreach (var p in customParameters) + { + data[p.Key] = p.Value; + } // Prepare API request using var streamRequest = new UnityWebRequest( @@ -178,6 +208,11 @@ public virtual async UniTask StartStreamingAsync(ChatGPTSession chatGPTSession, streamRequest.SetRequestHeader("Authorization", "Bearer " + ApiKey); } streamRequest.SetRequestHeader("Content-Type", "application/json"); + foreach (var h in customHeaders) + { + streamRequest.SetRequestHeader(h.Key, h.Value); + } + if (DebugMode) { Debug.Log($"Request to ChatGPT: {JsonConvert.SerializeObject(data)}"); diff --git a/Scripts/LLM/ChatGPT/ChatGPTServiceWebGL.cs b/Scripts/LLM/ChatGPT/ChatGPTServiceWebGL.cs index a79fa1c..6810a65 100644 --- a/Scripts/LLM/ChatGPT/ChatGPTServiceWebGL.cs +++ b/Scripts/LLM/ChatGPT/ChatGPTServiceWebGL.cs @@ -30,7 +30,7 @@ public override bool IsEnabled protected bool isChatCompletionJSDone { get; set; } = false; protected Dictionary sessions { get; set; } = new Dictionary(); - public override async UniTask StartStreamingAsync(ChatGPTSession chatGPTSession, bool useFunctions = true, CancellationToken token = default) + public override async UniTask StartStreamingAsync(ChatGPTSession chatGPTSession, Dictionary customParameters, Dictionary customHeaders, bool useFunctions = true, CancellationToken token = default) { // Add session for callback var sessionId = Guid.NewGuid().ToString(); @@ -42,6 +42,8 @@ public override async UniTask StartStreamingAsync(ChatGPTSession chatGPTSession, { "model", Model }, { "temperature", Temperature }, { "messages", chatGPTSession.Contexts }, + { "frequency_penalty", FrequencyPenalty }, + { "presence_penalty", PresencePenalty }, { "stream", true }, }; if (MaxTokens > 0) @@ -52,6 +54,25 @@ public override async UniTask StartStreamingAsync(ChatGPTSession chatGPTSession, { data.Add("functions", llmTools); } + if (Logprobs == true) + { + data.Add("logprobs", true); + data.Add("top_logprobs", TopLogprobs); + } + if (Stop != null && Stop.Count > 0) + { + data.Add("stop", Stop); + } + foreach (var p in customParameters) + { + data[p.Key] = p.Value; + } + + // TODO: Support custom headers later... + if (customHeaders.Count >= 0) + { + Debug.LogWarning("Custom headers for ChatGPT on WebGL is not supported for now."); + } // Start API stream isChatCompletionJSDone = false; diff --git a/Scripts/LLM/Claude/ClaudeService.cs b/Scripts/LLM/Claude/ClaudeService.cs index ae35ebe..f9be764 100644 --- a/Scripts/LLM/Claude/ClaudeService.cs +++ b/Scripts/LLM/Claude/ClaudeService.cs @@ -13,6 +13,8 @@ namespace ChatdollKit.LLM.Claude public class ClaudeService : LLMServiceBase { public string HistoryKey = "ClaudeHistories"; + public string CustomParameterKey = "ClaudeParameters"; + public string CustomHeaderKey = "ClaudeHeaders"; [Header("API configuration")] public string ApiKey; @@ -81,11 +83,18 @@ public override async UniTask> MakePromptAsync(string userId, public override async UniTask GenerateContentAsync(List messages, Dictionary payloads, bool useFunctions = true, int retryCounter = 1, CancellationToken token = default) { + // Custom parameters and headers + var stateData = (Dictionary)payloads["StateData"]; + var customParameters = stateData.ContainsKey(CustomParameterKey) ? (Dictionary)stateData[CustomParameterKey] : new Dictionary(); + var customHeaders = stateData.ContainsKey(CustomHeaderKey) ? (Dictionary)stateData[CustomHeaderKey] : new Dictionary(); + + // Start streaming session var claudeSession = new ClaudeSession(); claudeSession.Contexts = messages; - claudeSession.StreamingTask = StartStreamingAsync(claudeSession, useFunctions, token); + claudeSession.StreamingTask = StartStreamingAsync(claudeSession, customParameters, customHeaders, useFunctions, token); await WaitForResponseType(claudeSession, token); + // Retry if (claudeSession.ResponseType == ResponseType.Timeout) { if (retryCounter > 0) @@ -104,7 +113,7 @@ public override async UniTask GenerateContentAsync(List customParameters, Dictionary customHeaders, bool useFunctions = true, CancellationToken token = default) { // Make request data var data = new Dictionary() @@ -121,6 +130,10 @@ public virtual async UniTask StartStreamingAsync(ClaudeSession claudeSession, bo { data.Add("top_k", TopK); } + foreach (var p in customParameters) + { + data[p.Key] = p.Value; + } // Prepare API request using var streamRequest = new UnityWebRequest( @@ -132,6 +145,11 @@ public virtual async UniTask StartStreamingAsync(ClaudeSession claudeSession, bo streamRequest.SetRequestHeader("anthropic-beta", "messages-2023-12-15"); streamRequest.SetRequestHeader("Content-Type", "application/json"); streamRequest.SetRequestHeader("x-api-key", ApiKey); + foreach (var h in customHeaders) + { + streamRequest.SetRequestHeader(h.Key, h.Value); + } + if (DebugMode) { Debug.Log($"Request to Claude: {JsonConvert.SerializeObject(data)}"); diff --git a/Scripts/LLM/Claude/ClaudeServiceWebGL.cs b/Scripts/LLM/Claude/ClaudeServiceWebGL.cs index 9775659..5b27e97 100644 --- a/Scripts/LLM/Claude/ClaudeServiceWebGL.cs +++ b/Scripts/LLM/Claude/ClaudeServiceWebGL.cs @@ -30,7 +30,7 @@ public override bool IsEnabled protected bool isChatCompletionJSDone { get; set; } = false; protected Dictionary sessions { get; set; } = new Dictionary(); - public override async UniTask StartStreamingAsync(ClaudeSession claudeSession, bool useFunctions = true, CancellationToken token = default) + public override async UniTask StartStreamingAsync(ClaudeSession claudeSession, Dictionary customParameters, Dictionary customHeaders, bool useFunctions = true, CancellationToken token = default) { // Add session for callback var sessionId = Guid.NewGuid().ToString(); @@ -51,6 +51,16 @@ public override async UniTask StartStreamingAsync(ClaudeSession claudeSession, b { data.Add("top_k", TopK); } + foreach (var p in customParameters) + { + data[p.Key] = p.Value; + } + + // TODO: Support custom headers later... + if (customHeaders.Count >= 0) + { + Debug.LogWarning("Custom headers for Claude on WebGL is not supported for now."); + } // Start API stream isChatCompletionJSDone = false; diff --git a/Scripts/LLM/Gemini/GeminiService.cs b/Scripts/LLM/Gemini/GeminiService.cs index 416d79a..b4e0ea0 100644 --- a/Scripts/LLM/Gemini/GeminiService.cs +++ b/Scripts/LLM/Gemini/GeminiService.cs @@ -13,6 +13,8 @@ namespace ChatdollKit.LLM.Gemini public class GeminiService : LLMServiceBase { public string HistoryKey = "GeminiHistories"; + public string CustomParameterKey = "GeminiParameters"; + public string CustomHeaderKey = "GeminiHeaders"; [Header("API configuration")] public string ApiKey; @@ -117,9 +119,15 @@ public override async UniTask> MakePromptAsync(string userId, public override async UniTask GenerateContentAsync(List messages, Dictionary payloads, bool useFunctions = true, int retryCounter = 1, CancellationToken token = default) { + // Custom parameters and headers + var stateData = (Dictionary)payloads["StateData"]; + var customParameters = stateData.ContainsKey(CustomParameterKey) ? (Dictionary)stateData[CustomParameterKey] : new Dictionary(); + var customHeaders = stateData.ContainsKey(CustomHeaderKey) ? (Dictionary)stateData[CustomHeaderKey] : new Dictionary(); + + // Start streaming session var geminiSession = new GeminiSession(); geminiSession.Contexts = messages; - geminiSession.StreamingTask = StartStreamingAsync(geminiSession, useFunctions, token); + geminiSession.StreamingTask = StartStreamingAsync(geminiSession, customParameters, customHeaders, useFunctions, token); await WaitForResponseType(geminiSession, token); if (geminiSession.ResponseType == ResponseType.Timeout) @@ -140,7 +148,7 @@ public override async UniTask GenerateContentAsync(List customParameters, Dictionary customHeaders, bool useFunctions = true, CancellationToken token = default) { // GenerationConfig var generationConfig = new GeminiGenerationConfig() @@ -158,6 +166,10 @@ public virtual async UniTask StartStreamingAsync(GeminiSession geminiSession, bo { "contents", geminiSession.Contexts }, { "generationConfig", generationConfig } }; + foreach (var p in customParameters) + { + data[p.Key] = p.Value; + } // Set tools. Multimodal model doesn't support function calling for now (2023.12.29) if (useFunctions && llmTools.Count > 0 && !Model.ToLower().Contains("vision")) @@ -176,6 +188,11 @@ public virtual async UniTask StartStreamingAsync(GeminiSession geminiSession, bo ); streamRequest.timeout = responseTimeoutSec; streamRequest.SetRequestHeader("Content-Type", "application/json"); + foreach (var h in customHeaders) + { + streamRequest.SetRequestHeader(h.Key, h.Value); + } + if (DebugMode) { Debug.Log($"Request to Gemini: {JsonConvert.SerializeObject(data)}"); diff --git a/Scripts/LLM/Gemini/GeminiServiceWebGL.cs b/Scripts/LLM/Gemini/GeminiServiceWebGL.cs index 3f0c8ba..c8c22b9 100644 --- a/Scripts/LLM/Gemini/GeminiServiceWebGL.cs +++ b/Scripts/LLM/Gemini/GeminiServiceWebGL.cs @@ -30,7 +30,7 @@ public override bool IsEnabled protected bool isChatCompletionJSDone { get; set; } = false; protected Dictionary sessions { get; set; } = new Dictionary(); - public override async UniTask StartStreamingAsync(GeminiSession geminiSession, bool useFunctions = true, CancellationToken token = default) + public override async UniTask StartStreamingAsync(GeminiSession geminiSession, Dictionary customParameters, Dictionary customHeaders, bool useFunctions = true, CancellationToken token = default) { // Add session for callback var sessionId = Guid.NewGuid().ToString(); @@ -52,6 +52,16 @@ public override async UniTask StartStreamingAsync(GeminiSession geminiSession, b { "contents", geminiSession.Contexts }, { "generationConfig", generationConfig } }; + foreach (var p in customParameters) + { + data[p.Key] = p.Value; + } + + // TODO: Support custom headers later... + if (customHeaders.Count >= 0) + { + Debug.LogWarning("Custom headers for Gemini on WebGL is not supported for now."); + } // Set tools. Multimodal model doesn't support function calling for now (2023.12.29) if (useFunctions && llmTools.Count > 0 && !Model.ToLower().Contains("vision"))