diff --git a/LLama.SemanticKernel/ChatCompletion/ChatRequestSettings.cs b/LLama.SemanticKernel/ChatCompletion/ChatRequestSettings.cs index 683f8c452..0974fb691 100644 --- a/LLama.SemanticKernel/ChatCompletion/ChatRequestSettings.cs +++ b/LLama.SemanticKernel/ChatCompletion/ChatRequestSettings.cs @@ -1,4 +1,4 @@ -using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel; using System.Text.Json; using System.Text.Json.Serialization; @@ -12,14 +12,14 @@ public class ChatRequestSettings : PromptExecutionSettings /// The higher the temperature, the more random the completion. /// [JsonPropertyName("temperature")] - public double Temperature { get; set; } = 0; + public double Temperature { get; set; } /// /// TopP controls the diversity of the completion. /// The higher the TopP, the more diverse the completion. /// [JsonPropertyName("top_p")] - public double TopP { get; set; } = 0; + public double TopP { get; set; } /// /// Number between -2.0 and 2.0. Positive values penalize new tokens @@ -27,7 +27,7 @@ public class ChatRequestSettings : PromptExecutionSettings /// model's likelihood to talk about new topics. /// [JsonPropertyName("presence_penalty")] - public double PresencePenalty { get; set; } = 0; + public double PresencePenalty { get; set; } /// /// Number between -2.0 and 2.0. Positive values penalize new tokens @@ -35,7 +35,7 @@ public class ChatRequestSettings : PromptExecutionSettings /// the model's likelihood to repeat the same line verbatim. /// [JsonPropertyName("frequency_penalty")] - public double FrequencyPenalty { get; set; } = 0; + public double FrequencyPenalty { get; set; } /// /// Sequences where the completion will stop generating further tokens. @@ -71,13 +71,10 @@ public class ChatRequestSettings : PromptExecutionSettings /// An instance of OpenAIRequestSettings public static ChatRequestSettings FromRequestSettings(PromptExecutionSettings? requestSettings, int? defaultMaxTokens = null) { - if (requestSettings is null) + requestSettings ??= new ChatRequestSettings { - return new ChatRequestSettings() - { - MaxTokens = defaultMaxTokens - }; - } + MaxTokens = defaultMaxTokens + }; if (requestSettings is ChatRequestSettings requestSettingsChatRequestSettings) { @@ -85,7 +82,7 @@ public static ChatRequestSettings FromRequestSettings(PromptExecutionSettings? r } var json = JsonSerializer.Serialize(requestSettings); - var chatRequestSettings = JsonSerializer.Deserialize(json, s_options); + var chatRequestSettings = JsonSerializer.Deserialize(json, SerializerOptions); if (chatRequestSettings is not null) { @@ -95,20 +92,13 @@ public static ChatRequestSettings FromRequestSettings(PromptExecutionSettings? r throw new ArgumentException($"Invalid request settings, cannot convert to {nameof(ChatRequestSettings)}", nameof(requestSettings)); } - private static readonly JsonSerializerOptions s_options = CreateOptions(); - - private static JsonSerializerOptions CreateOptions() + private static readonly JsonSerializerOptions SerializerOptions = new() { - JsonSerializerOptions options = new() - { - WriteIndented = true, - MaxDepth = 20, - AllowTrailingCommas = true, - PropertyNameCaseInsensitive = true, - ReadCommentHandling = JsonCommentHandling.Skip, - Converters = { new ChatRequestSettingsConverter() } - }; - - return options; - } + WriteIndented = true, + MaxDepth = 20, + AllowTrailingCommas = true, + PropertyNameCaseInsensitive = true, + ReadCommentHandling = JsonCommentHandling.Skip, + Converters = { new ChatRequestSettingsConverter() } + }; } diff --git a/LLama.SemanticKernel/ChatCompletion/ChatRequestSettingsConverter.cs b/LLama.SemanticKernel/ChatCompletion/ChatRequestSettingsConverter.cs index 15bc45cd4..ca14e278e 100644 --- a/LLama.SemanticKernel/ChatCompletion/ChatRequestSettingsConverter.cs +++ b/LLama.SemanticKernel/ChatCompletion/ChatRequestSettingsConverter.cs @@ -1,18 +1,17 @@ -using System; -using System.Collections.Generic; using System.Text.Json; using System.Text.Json.Serialization; namespace LLamaSharp.SemanticKernel.ChatCompletion; /// -/// JSON converter for +/// JSON converter for /// [Obsolete("Use LLamaSharpPromptExecutionSettingsConverter instead")] -public class ChatRequestSettingsConverter : JsonConverter +public class ChatRequestSettingsConverter + : JsonConverter { /// - public override ChatRequestSettings? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + public override ChatRequestSettings Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) { var requestSettings = new ChatRequestSettings(); @@ -20,13 +19,7 @@ public class ChatRequestSettingsConverter : JsonConverter { if (reader.TokenType == JsonTokenType.PropertyName) { - string? propertyName = reader.GetString(); - - if (propertyName is not null) - { - // normalise property name to uppercase - propertyName = propertyName.ToUpperInvariant(); - } + var propertyName = reader.GetString()?.ToUpperInvariant(); reader.Read(); diff --git a/LLama.SemanticKernel/ChatCompletion/HistoryTransform.cs b/LLama.SemanticKernel/ChatCompletion/HistoryTransform.cs index f1a0ebcb6..08815819c 100644 --- a/LLama.SemanticKernel/ChatCompletion/HistoryTransform.cs +++ b/LLama.SemanticKernel/ChatCompletion/HistoryTransform.cs @@ -1,5 +1,4 @@ -using LLama.Common; -using System.Text; +using LLama.Common; using static LLama.LLamaTransforms; namespace LLamaSharp.SemanticKernel.ChatCompletion; @@ -10,7 +9,7 @@ namespace LLamaSharp.SemanticKernel.ChatCompletion; public class HistoryTransform : DefaultHistoryTransform { /// - public override string HistoryToText(global::LLama.Common.ChatHistory history) + public override string HistoryToText(ChatHistory history) { return base.HistoryToText(history) + $"{AuthorRole.Assistant}: "; } diff --git a/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs b/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs index 01a061db1..413c9ed35 100644 --- a/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs +++ b/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs @@ -1,10 +1,7 @@ -using LLama; +using LLama; using LLama.Abstractions; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.ChatCompletion; -using Microsoft.SemanticKernel.Services; -using System; -using System.IO; using System.Runtime.CompilerServices; using System.Text; using static LLama.InteractiveExecutor; @@ -18,16 +15,16 @@ namespace LLamaSharp.SemanticKernel.ChatCompletion; public sealed class LLamaSharpChatCompletion : IChatCompletionService { private readonly ILLamaExecutor _model; - private LLamaSharpPromptExecutionSettings defaultRequestSettings; - private readonly IHistoryTransform historyTransform; - private readonly ITextStreamTransform outputTransform; + private readonly LLamaSharpPromptExecutionSettings _defaultRequestSettings; + private readonly IHistoryTransform _historyTransform; + private readonly ITextStreamTransform _outputTransform; private readonly Dictionary _attributes = new(); private readonly bool _isStatefulExecutor; - public IReadOnlyDictionary Attributes => this._attributes; + public IReadOnlyDictionary Attributes => _attributes; - static LLamaSharpPromptExecutionSettings GetDefaultSettings() + private static LLamaSharpPromptExecutionSettings GetDefaultSettings() { return new LLamaSharpPromptExecutionSettings { @@ -43,11 +40,11 @@ public LLamaSharpChatCompletion(ILLamaExecutor model, IHistoryTransform? historyTransform = null, ITextStreamTransform? outputTransform = null) { - this._model = model; - this._isStatefulExecutor = this._model is StatefulExecutorBase; - this.defaultRequestSettings = defaultRequestSettings ?? GetDefaultSettings(); - this.historyTransform = historyTransform ?? new HistoryTransform(); - this.outputTransform = outputTransform ?? new KeywordTextOutputStreamTransform(new[] { $"{LLama.Common.AuthorRole.User}:", + _model = model; + _isStatefulExecutor = _model is StatefulExecutorBase; + _defaultRequestSettings = defaultRequestSettings ?? GetDefaultSettings(); + _historyTransform = historyTransform ?? new HistoryTransform(); + _outputTransform = outputTransform ?? new KeywordTextOutputStreamTransform(new[] { $"{LLama.Common.AuthorRole.User}:", $"{LLama.Common.AuthorRole.Assistant}:", $"{LLama.Common.AuthorRole.System}:"}); } @@ -69,12 +66,12 @@ public async Task> GetChatMessageContentsAsync { var settings = executionSettings != null ? LLamaSharpPromptExecutionSettings.FromRequestSettings(executionSettings) - : defaultRequestSettings; + : _defaultRequestSettings; - string prompt = this._getFormattedPrompt(chatHistory); + var prompt = _getFormattedPrompt(chatHistory); var result = _model.InferAsync(prompt, settings.ToLLamaSharpInferenceParams(), cancellationToken); - var output = outputTransform.TransformAsync(result); + var output = _outputTransform.TransformAsync(result); var sb = new StringBuilder(); await foreach (var token in output) @@ -90,12 +87,12 @@ public async IAsyncEnumerable GetStreamingChatMessa { var settings = executionSettings != null ? LLamaSharpPromptExecutionSettings.FromRequestSettings(executionSettings) - : defaultRequestSettings; + : _defaultRequestSettings; - string prompt = this._getFormattedPrompt(chatHistory); + var prompt = _getFormattedPrompt(chatHistory); var result = _model.InferAsync(prompt, settings.ToLLamaSharpInferenceParams(), cancellationToken); - var output = outputTransform.TransformAsync(result); + var output = _outputTransform.TransformAsync(result); await foreach (var token in output) { @@ -109,24 +106,26 @@ public async IAsyncEnumerable GetStreamingChatMessa /// /// /// The formatted prompt - private string _getFormattedPrompt(ChatHistory chatHistory){ + private string _getFormattedPrompt(ChatHistory chatHistory) + { string prompt; - if (this._isStatefulExecutor){ - InteractiveExecutorState state = (InteractiveExecutorState)((StatefulExecutorBase)this._model).GetStateData(); + if (_isStatefulExecutor) + { + var state = (InteractiveExecutorState)((StatefulExecutorBase)_model).GetStateData(); if (state.IsPromptRun) { - prompt = historyTransform.HistoryToText(chatHistory.ToLLamaSharpChatHistory()); + prompt = _historyTransform.HistoryToText(chatHistory.ToLLamaSharpChatHistory()); } else { - ChatHistory temp_history = new(); - temp_history.AddUserMessage(chatHistory.Last().Content); - prompt = historyTransform.HistoryToText(temp_history.ToLLamaSharpChatHistory()); + ChatHistory tempHistory = new(); + tempHistory.AddUserMessage(chatHistory.Last().Content ?? ""); + prompt = _historyTransform.HistoryToText(tempHistory.ToLLamaSharpChatHistory()); } } else { - prompt = historyTransform.HistoryToText(chatHistory.ToLLamaSharpChatHistory()); + prompt = _historyTransform.HistoryToText(chatHistory.ToLLamaSharpChatHistory()); } return prompt; diff --git a/LLama.SemanticKernel/ExtensionMethods.cs b/LLama.SemanticKernel/ExtensionMethods.cs index 086999aa3..c63ee42b8 100644 --- a/LLama.SemanticKernel/ExtensionMethods.cs +++ b/LLama.SemanticKernel/ExtensionMethods.cs @@ -1,21 +1,25 @@ -using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.ChatCompletion; +using AuthorRole = LLama.Common.AuthorRole; + namespace LLamaSharp.SemanticKernel; public static class ExtensionMethods { - public static global::LLama.Common.ChatHistory ToLLamaSharpChatHistory(this ChatHistory chatHistory, bool ignoreCase = true) + public static LLama.Common.ChatHistory ToLLamaSharpChatHistory(this ChatHistory chatHistory, bool ignoreCase = true) { if (chatHistory is null) { throw new ArgumentNullException(nameof(chatHistory)); } - var history = new global::LLama.Common.ChatHistory(); + var history = new LLama.Common.ChatHistory(); foreach (var chat in chatHistory) { - var role = Enum.TryParse(chat.Role.Label, ignoreCase, out var _role) ? _role : global::LLama.Common.AuthorRole.Unknown; - history.AddMessage(role, chat.Content); + if (!Enum.TryParse(chat.Role.Label, ignoreCase, out var role)) + role = AuthorRole.Unknown; + + history.AddMessage(role, chat.Content ?? ""); } return history; @@ -26,7 +30,7 @@ public static class ExtensionMethods /// /// /// - internal static global::LLama.Common.InferenceParams ToLLamaSharpInferenceParams(this LLamaSharpPromptExecutionSettings requestSettings) + internal static LLama.Common.InferenceParams ToLLamaSharpInferenceParams(this LLamaSharpPromptExecutionSettings requestSettings) { if (requestSettings is null) { @@ -34,10 +38,12 @@ public static class ExtensionMethods } var antiPrompts = new List(requestSettings.StopSequences) - { LLama.Common.AuthorRole.User.ToString() + ":" , - LLama.Common.AuthorRole.Assistant.ToString() + ":", - LLama.Common.AuthorRole.System.ToString() + ":"}; - return new global::LLama.Common.InferenceParams + { + $"{AuthorRole.User}:", + $"{AuthorRole.Assistant}:", + $"{AuthorRole.System}:" + }; + return new LLama.Common.InferenceParams { Temperature = (float)requestSettings.Temperature, TopP = (float)requestSettings.TopP, diff --git a/LLama.SemanticKernel/LLamaSharpPromptExecutionSettings.cs b/LLama.SemanticKernel/LLamaSharpPromptExecutionSettings.cs index 5e8a66697..77fe9a75c 100644 --- a/LLama.SemanticKernel/LLamaSharpPromptExecutionSettings.cs +++ b/LLama.SemanticKernel/LLamaSharpPromptExecutionSettings.cs @@ -1,15 +1,3 @@ - -/* Unmerged change from project 'LLamaSharp.SemanticKernel (netstandard2.0)' -Before: -using Microsoft.SemanticKernel; -After: -using LLamaSharp; -using LLamaSharp.SemanticKernel; -using LLamaSharp.SemanticKernel; -using LLamaSharp.SemanticKernel.ChatCompletion; -using Microsoft.SemanticKernel; -*/ -using LLamaSharp.SemanticKernel.ChatCompletion; using Microsoft.SemanticKernel; using System.Text.Json; using System.Text.Json.Serialization; @@ -23,14 +11,14 @@ public class LLamaSharpPromptExecutionSettings : PromptExecutionSettings /// The higher the temperature, the more random the completion. /// [JsonPropertyName("temperature")] - public double Temperature { get; set; } = 0; + public double Temperature { get; set; } /// /// TopP controls the diversity of the completion. /// The higher the TopP, the more diverse the completion. /// [JsonPropertyName("top_p")] - public double TopP { get; set; } = 0; + public double TopP { get; set; } /// /// Number between -2.0 and 2.0. Positive values penalize new tokens @@ -38,7 +26,7 @@ public class LLamaSharpPromptExecutionSettings : PromptExecutionSettings /// model's likelihood to talk about new topics. /// [JsonPropertyName("presence_penalty")] - public double PresencePenalty { get; set; } = 0; + public double PresencePenalty { get; set; } /// /// Number between -2.0 and 2.0. Positive values penalize new tokens @@ -46,7 +34,7 @@ public class LLamaSharpPromptExecutionSettings : PromptExecutionSettings /// the model's likelihood to repeat the same line verbatim. /// [JsonPropertyName("frequency_penalty")] - public double FrequencyPenalty { get; set; } = 0; + public double FrequencyPenalty { get; set; } /// /// Sequences where the completion will stop generating further tokens. @@ -88,13 +76,10 @@ public class LLamaSharpPromptExecutionSettings : PromptExecutionSettings /// An instance of OpenAIRequestSettings public static LLamaSharpPromptExecutionSettings FromRequestSettings(PromptExecutionSettings? requestSettings, int? defaultMaxTokens = null) { - if (requestSettings is null) + requestSettings ??= new LLamaSharpPromptExecutionSettings { - return new LLamaSharpPromptExecutionSettings() - { - MaxTokens = defaultMaxTokens - }; - } + MaxTokens = defaultMaxTokens + }; if (requestSettings is LLamaSharpPromptExecutionSettings requestSettingsChatRequestSettings) { @@ -102,7 +87,7 @@ public static LLamaSharpPromptExecutionSettings FromRequestSettings(PromptExecut } var json = JsonSerializer.Serialize(requestSettings); - var chatRequestSettings = JsonSerializer.Deserialize(json, s_options); + var chatRequestSettings = JsonSerializer.Deserialize(json, SerializerOptions); if (chatRequestSettings is not null) { @@ -112,20 +97,13 @@ public static LLamaSharpPromptExecutionSettings FromRequestSettings(PromptExecut throw new ArgumentException($"Invalid request settings, cannot convert to {nameof(LLamaSharpPromptExecutionSettings)}", nameof(requestSettings)); } - private static readonly JsonSerializerOptions s_options = CreateOptions(); - - private static JsonSerializerOptions CreateOptions() + private static readonly JsonSerializerOptions SerializerOptions = new() { - JsonSerializerOptions options = new() - { - WriteIndented = true, - MaxDepth = 20, - AllowTrailingCommas = true, - PropertyNameCaseInsensitive = true, - ReadCommentHandling = JsonCommentHandling.Skip, - Converters = { new LLamaSharpPromptExecutionSettingsConverter() } - }; - - return options; - } + WriteIndented = true, + MaxDepth = 20, + AllowTrailingCommas = true, + PropertyNameCaseInsensitive = true, + ReadCommentHandling = JsonCommentHandling.Skip, + Converters = { new LLamaSharpPromptExecutionSettingsConverter() } + }; } diff --git a/LLama.SemanticKernel/LLamaSharpPromptExecutionSettingsConverter.cs b/LLama.SemanticKernel/LLamaSharpPromptExecutionSettingsConverter.cs index 36ca9c6cf..b7c3bdeb0 100644 --- a/LLama.SemanticKernel/LLamaSharpPromptExecutionSettingsConverter.cs +++ b/LLama.SemanticKernel/LLamaSharpPromptExecutionSettingsConverter.cs @@ -1,17 +1,16 @@ -using System; -using System.Collections.Generic; using System.Text.Json; using System.Text.Json.Serialization; namespace LLamaSharp.SemanticKernel; /// -/// JSON converter for +/// JSON converter for /// -public class LLamaSharpPromptExecutionSettingsConverter : JsonConverter +public class LLamaSharpPromptExecutionSettingsConverter + : JsonConverter { /// - public override LLamaSharpPromptExecutionSettings? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + public override LLamaSharpPromptExecutionSettings Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) { var requestSettings = new LLamaSharpPromptExecutionSettings(); @@ -19,13 +18,7 @@ public class LLamaSharpPromptExecutionSettingsConverter : JsonConverter _attributes = new(); - public IReadOnlyDictionary Attributes => this._attributes; + public IReadOnlyDictionary Attributes => _attributes; public LLamaSharpTextCompletion(ILLamaExecutor executor) { - this.executor = executor; + _executor = executor; } /// public async Task> GetTextContentsAsync(string prompt, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default) { var settings = LLamaSharpPromptExecutionSettings.FromRequestSettings(executionSettings); - var result = executor.InferAsync(prompt, settings?.ToLLamaSharpInferenceParams(), cancellationToken); + var result = _executor.InferAsync(prompt, settings.ToLLamaSharpInferenceParams(), cancellationToken); var sb = new StringBuilder(); await foreach (var token in result) { @@ -37,7 +36,7 @@ public async Task> GetTextContentsAsync(string prompt public async IAsyncEnumerable GetStreamingTextContentsAsync(string prompt, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { var settings = LLamaSharpPromptExecutionSettings.FromRequestSettings(executionSettings); - var result = executor.InferAsync(prompt, settings?.ToLLamaSharpInferenceParams(), cancellationToken); + var result = _executor.InferAsync(prompt, settings.ToLLamaSharpInferenceParams(), cancellationToken); await foreach (var token in result) { yield return new StreamingTextContent(token); diff --git a/LLama.SemanticKernel/TextEmbedding/LLamaSharpEmbeddingGeneration.cs b/LLama.SemanticKernel/TextEmbedding/LLamaSharpEmbeddingGeneration.cs index 6889ba6a7..9514e1711 100644 --- a/LLama.SemanticKernel/TextEmbedding/LLamaSharpEmbeddingGeneration.cs +++ b/LLama.SemanticKernel/TextEmbedding/LLamaSharpEmbeddingGeneration.cs @@ -1,4 +1,4 @@ -using LLama; +using LLama; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.Embeddings; @@ -10,7 +10,7 @@ public sealed class LLamaSharpEmbeddingGeneration : ITextEmbeddingGenerationServ private readonly Dictionary _attributes = new(); - public IReadOnlyDictionary Attributes => this._attributes; + public IReadOnlyDictionary Attributes => _attributes; public LLamaSharpEmbeddingGeneration(LLamaEmbedder embedder) {