Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Server security with API key and SSL #238

Merged
merged 7 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions Editor/LLMEditor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,47 @@ protected override Type[] GetPropertyTypes()
return new Type[] { typeof(LLM) };
}

public void AddSecuritySettings(SerializedObject llmScriptSO, LLM llmScript)
{
void AddSSLLoad(string type, Callback<string> setterCallback)
{
if (GUILayout.Button("Load SSL " + type, GUILayout.Width(buttonWidth)))
{
EditorApplication.delayCall += () =>
{
string path = EditorUtility.OpenFilePanel("Select a SSL " + type + " file", "", "");
if (!string.IsNullOrEmpty(path)) setterCallback(path);
};
}
}

void AddSSLInfo(string propertyName, string type, Callback<string> setterCallback)
{
string path = llmScriptSO.FindProperty(propertyName).stringValue;
if (path != "")
{
EditorGUILayout.BeginHorizontal();
EditorGUILayout.LabelField("SSL " + type + " path", path);
if (GUILayout.Button(trashIcon, GUILayout.Height(actionColumnWidth), GUILayout.Width(actionColumnWidth))) setterCallback("");
EditorGUILayout.EndHorizontal();
}
}

EditorGUILayout.LabelField("Server Security Settings", EditorStyles.boldLabel);
EditorGUILayout.PropertyField(llmScriptSO.FindProperty("APIKey"));

if (llmScriptSO.FindProperty("advancedOptions").boolValue)
{
EditorGUILayout.BeginHorizontal();
AddSSLLoad("certificate", llmScript.SetSSLCert);
AddSSLLoad("key", llmScript.SetSSLKey);
EditorGUILayout.EndHorizontal();
AddSSLInfo("SSLCertPath", "certificate", llmScript.SetSSLCert);
AddSSLInfo("SSLKeyPath", "key", llmScript.SetSSLKey);
}
Space();
}

public void AddModelLoadersSettings(SerializedObject llmScriptSO, LLM llmScript)
{
EditorGUILayout.LabelField("Model Settings", EditorStyles.boldLabel);
Expand Down Expand Up @@ -422,6 +463,7 @@ public override void OnInspectorGUI()

AddOptionsToggles(llmScriptSO);
AddSetupSettings(llmScriptSO);
if (llmScriptSO.FindProperty("remote").boolValue) AddSecuritySettings(llmScriptSO, llmScript);
AddModelLoadersSettings(llmScriptSO, llmScript);
AddChatSettings(llmScriptSO);

Expand Down
68 changes: 66 additions & 2 deletions Runtime/LLM.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,17 @@ public class LLM : MonoBehaviour
/// <summary> enable use of flash attention </summary>
[ModelExtras] public bool flashAttention = false;

/// <summary> API key to use for the server (optional) </summary>
public string APIKey;
// SSL certificate
[SerializeField]
private string SSLCert = "";
public string SSLCertPath = "";
// SSL key
[SerializeField]
private string SSLKey = "";
public string SSLKeyPath = "";

/// \cond HIDE

IntPtr LLMObject = IntPtr.Zero;
Expand Down Expand Up @@ -302,6 +313,41 @@ public void SetTemplate(string templateName, bool setDirty = true)
#endif
}

/// \cond HIDE

string ReadFileContents(string path)
{
if (String.IsNullOrEmpty(path)) return "";
else if (!File.Exists(path))
{
LLMUnitySetup.LogError($"File {path} not found!");
return "";
}
return File.ReadAllText(path);
}

/// \endcond

/// <summary>
/// Use a SSL certificate for the LLM server.
/// </summary>
/// <param name="templateName">the SSL certificate path </param>
public void SetSSLCert(string path)
{
SSLCertPath = path;
SSLCert = ReadFileContents(path);
}

/// <summary>
/// Use a SSL key for the LLM server.
/// </summary>
/// <param name="templateName">the SSL key path </param>
public void SetSSLKey(string path)
{
SSLKeyPath = path;
SSLKey = ReadFileContents(path);
}

/// <summary>
/// Returns the chat template of the LLM.
/// </summary>
Expand All @@ -314,6 +360,12 @@ public string GetTemplate()
protected virtual string GetLlamaccpArguments()
{
// Start the LLM server in a cross-platform way
if ((SSLCert != "" && SSLKey == "") || (SSLCert == "" && SSLKey != ""))
{
LLMUnitySetup.LogError($"Both SSL certificate and key need to be provided!");
return null;
}

if (model == "")
{
LLMUnitySetup.LogError("No model file provided!");
Expand Down Expand Up @@ -344,7 +396,11 @@ protected virtual string GetLlamaccpArguments()

int slots = GetNumClients();
string arguments = $"-m \"{modelPath}\" -c {contextSize} -b {batchSize} --log-disable -np {slots}";
if (remote) arguments += $" --port {port} --host 0.0.0.0";
if (remote)
{
arguments += $" --port {port} --host 0.0.0.0";
if (!String.IsNullOrEmpty(APIKey)) arguments += $" --api-key {APIKey}";
}
if (numThreadsToUse > 0) arguments += $" -t {numThreadsToUse}";
arguments += loraArgument;
arguments += $" -ngl {numGPULayers}";
Expand Down Expand Up @@ -427,7 +483,15 @@ private void InitService(string arguments)
if (debug) CallWithLock(SetupLogging);
CallWithLock(() => { LLMObject = llmlib.LLM_Construct(arguments); });
CallWithLock(() => llmlib.LLM_SetTemplate(LLMObject, chatTemplate));
if (remote) CallWithLock(() => llmlib.LLM_StartServer(LLMObject));
if (remote)
{
if (SSLCert != "" && SSLKey != "")
{
LLMUnitySetup.Log("Using SSL");
CallWithLock(() => llmlib.LLM_SetSSL(LLMObject, SSLCert, SSLKey));
}
CallWithLock(() => llmlib.LLM_StartServer(LLMObject));
}
CallWithLock(() => CheckLLMStatus(false));
}

Expand Down
37 changes: 32 additions & 5 deletions Runtime/LLMCharacter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ public class LLMCharacter : MonoBehaviour
[Remote] public int port = 13333;
/// <summary> number of retries to use for the LLM server requests (-1 = infinite) </summary>
[Remote] public int numRetries = -1;
/// <summary> allows to use a server with API key </summary>
[Remote] public string APIKey;
/// <summary> file to save the chat history.
/// The file is saved only for Chat calls with addToHistory set to true.
/// The file will be saved within the persistentDataPath directory (see https://docs.unity3d.com/ScriptReference/Application-persistentDataPath.html). </summary>
Expand All @@ -38,6 +40,8 @@ public class LLMCharacter : MonoBehaviour
[LLM] public bool saveCache = false;
/// <summary> select to log the constructed prompt the Unity Editor. </summary>
[LLM] public bool debugPrompt = false;
/// <summary> allows to bypass certificate checks (note: only for development!). </summary>
[LLMAdvanced] public bool bypassCertificate = false;
/// <summary> option to receive the reply from the model as it is produced (recommended!).
/// If it is not selected, the full reply from the model is received in one go </summary>
[Model] public bool stream = true;
Expand Down Expand Up @@ -125,8 +129,9 @@ public class LLMCharacter : MonoBehaviour
private string chatTemplate;
private ChatTemplate template = null;
public string grammarString;
private List<(string, string)> requestHeaders = new List<(string, string)> { ("Content-Type", "application/json") };
private List<(string, string)> requestHeaders;
private List<UnityWebRequest> WIPRequests = new List<UnityWebRequest>();
bool prebypassCertificate = false;
/// \endcond

/// <summary>
Expand All @@ -142,6 +147,8 @@ public void Awake()
{
// Start the LLM server in a cross-platform way
if (!enabled) return;

requestHeaders = new List<(string, string)> { ("Content-Type", "application/json") };
if (!remote)
{
AssignLLM();
Expand All @@ -153,6 +160,10 @@ public void Awake()
int slotFromServer = llm.Register(this);
if (slot == -1) slot = slotFromServer;
}
else
{
if (!String.IsNullOrEmpty(APIKey)) requestHeaders.Add(("Authorization", "Bearer " + APIKey));
}

InitGrammar();
InitHistory();
Expand All @@ -162,6 +173,7 @@ void OnValidate()
{
AssignLLM();
if (llm != null && llm.parallelPrompts > -1 && (slot < -1 || slot >= llm.parallelPrompts)) LLMUnitySetup.LogError($"The slot needs to be between 0 and {llm.parallelPrompts-1}, or -1 to be automatically set");
if (prebypassCertificate != bypassCertificate && bypassCertificate) LLMUnitySetup.LogWarning("All of the certificates will be accepted. Use only for development!");
}

void Reset()
Expand Down Expand Up @@ -286,14 +298,17 @@ private bool CheckTemplate()
return true;
}

private async Task InitNKeep()
private async Task<bool> InitNKeep()
{
if (setNKeepToPrompt && nKeep == -1)
{
if (!CheckTemplate()) return;
if (!CheckTemplate()) return false;
string systemPrompt = template.ComputePrompt(new List<ChatMessage>(){chat[0]}, playerName, "", false);
await Tokenize(systemPrompt, SetNKeep);
List<int> tokens = await Tokenize(systemPrompt);
if (tokens == null) return false;
SetNKeep(tokens);
}
return true;
}

private void InitGrammar()
Expand Down Expand Up @@ -485,7 +500,7 @@ public async Task<string> Chat(string query, Callback<string> callback = null, E
// call the completionCallback function when the answer is fully received
await LoadTemplate();
if (!CheckTemplate()) return null;
await InitNKeep();
if (!await InitNKeep()) return null;

string json;
await chatLock.WaitAsync();
Expand Down Expand Up @@ -781,6 +796,7 @@ protected async Task<Ret> PostRequestRemote<Res, Ret>(string json, string endpoi
WIPRequests.Add(request);

request.method = "POST";
if (bypassCertificate) request.certificateHandler = new BypassCertificateHandler();
if (requestHeaders != null)
{
for (int i = 0; i < requestHeaders.Count; i++)
Expand Down Expand Up @@ -814,6 +830,7 @@ protected async Task<Ret> PostRequestRemote<Res, Ret>(string json, string endpoi
{
result = default;
error = request.error;
if (request.responseCode == (int)System.Net.HttpStatusCode.Unauthorized) break;
}
}
tryNr--;
Expand All @@ -837,5 +854,15 @@ public class ChatListWrapper
{
public List<ChatMessage> chat;
}

// Custom certificate handler that bypasses validation
class BypassCertificateHandler : CertificateHandler
{
protected override bool ValidateCertificate(byte[] certificateData)
{
// Always accept the certificate
return true;
}
}
/// \endcond
}
3 changes: 3 additions & 0 deletions Runtime/LLMLib.cs
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ public LLMLib(string arch)
LLM_Started = LibraryLoader.GetSymbolDelegate<LLM_StartedDelegate>(libraryHandle, "LLM_Started");
LLM_Stop = LibraryLoader.GetSymbolDelegate<LLM_StopDelegate>(libraryHandle, "LLM_Stop");
LLM_SetTemplate = LibraryLoader.GetSymbolDelegate<LLM_SetTemplateDelegate>(libraryHandle, "LLM_SetTemplate");
LLM_SetSSL = LibraryLoader.GetSymbolDelegate<LLM_SetSSLDelegate>(libraryHandle, "LLM_SetSSL");
LLM_Tokenize = LibraryLoader.GetSymbolDelegate<LLM_TokenizeDelegate>(libraryHandle, "LLM_Tokenize");
LLM_Detokenize = LibraryLoader.GetSymbolDelegate<LLM_DetokenizeDelegate>(libraryHandle, "LLM_Detokenize");
LLM_Embeddings = LibraryLoader.GetSymbolDelegate<LLM_EmbeddingsDelegate>(libraryHandle, "LLM_Embeddings");
Expand Down Expand Up @@ -479,6 +480,7 @@ public string GetStringWrapperResult(IntPtr stringWrapper)
public delegate bool LLM_StartedDelegate(IntPtr LLMObject);
public delegate void LLM_StopDelegate(IntPtr LLMObject);
public delegate void LLM_SetTemplateDelegate(IntPtr LLMObject, string chatTemplate);
public delegate void LLM_SetSSLDelegate(IntPtr LLMObject, string SSLCert, string SSLKey);
public delegate void LLM_TokenizeDelegate(IntPtr LLMObject, string jsonData, IntPtr stringWrapper);
public delegate void LLM_DetokenizeDelegate(IntPtr LLMObject, string jsonData, IntPtr stringWrapper);
public delegate void LLM_EmbeddingsDelegate(IntPtr LLMObject, string jsonData, IntPtr stringWrapper);
Expand All @@ -503,6 +505,7 @@ public string GetStringWrapperResult(IntPtr stringWrapper)
public LLM_StartedDelegate LLM_Started;
public LLM_StopDelegate LLM_Stop;
public LLM_SetTemplateDelegate LLM_SetTemplate;
public LLM_SetSSLDelegate LLM_SetSSL;
public LLM_TokenizeDelegate LLM_Tokenize;
public LLM_DetokenizeDelegate LLM_Detokenize;
public LLM_CompletionDelegate LLM_Completion;
Expand Down
2 changes: 1 addition & 1 deletion Runtime/LLMUnitySetup.cs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ public class LLMUnitySetup
/// <summary> LLM for Unity version </summary>
public static string Version = "v2.2.2";
/// <summary> LlamaLib version </summary>
public static string LlamaLibVersion = "v1.1.10";
public static string LlamaLibVersion = "v1.1.12";
/// <summary> LlamaLib release url </summary>
public static string LlamaLibReleaseURL = $"https://github.com/undreamai/LlamaLib/releases/download/{LlamaLibVersion}";
/// <summary> LlamaLib url </summary>
Expand Down
Loading