Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Semantic Kernel Cleanup #878

Merged
merged 1 commit into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 17 additions & 27 deletions LLama.SemanticKernel/ChatCompletion/ChatRequestSettings.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel;
using System.Text.Json;
using System.Text.Json.Serialization;

Expand All @@ -12,30 +12,30 @@ public class ChatRequestSettings : PromptExecutionSettings
/// The higher the temperature, the more random the completion.
/// </summary>
[JsonPropertyName("temperature")]
public double Temperature { get; set; } = 0;
public double Temperature { get; set; }

/// <summary>
/// TopP controls the diversity of the completion.
/// The higher the TopP, the more diverse the completion.
/// </summary>
[JsonPropertyName("top_p")]
public double TopP { get; set; } = 0;
public double TopP { get; set; }

/// <summary>
/// Number between -2.0 and 2.0. Positive values penalize new tokens
/// based on whether they appear in the text so far, increasing the
/// model's likelihood to talk about new topics.
/// </summary>
[JsonPropertyName("presence_penalty")]
public double PresencePenalty { get; set; } = 0;
public double PresencePenalty { get; set; }

/// <summary>
/// Number between -2.0 and 2.0. Positive values penalize new tokens
/// based on their existing frequency in the text so far, decreasing
/// the model's likelihood to repeat the same line verbatim.
/// </summary>
[JsonPropertyName("frequency_penalty")]
public double FrequencyPenalty { get; set; } = 0;
public double FrequencyPenalty { get; set; }

/// <summary>
/// Sequences where the completion will stop generating further tokens.
Expand Down Expand Up @@ -71,21 +71,18 @@ public class ChatRequestSettings : PromptExecutionSettings
/// <returns>An instance of OpenAIRequestSettings</returns>
public static ChatRequestSettings FromRequestSettings(PromptExecutionSettings? requestSettings, int? defaultMaxTokens = null)
{
if (requestSettings is null)
requestSettings ??= new ChatRequestSettings
{
return new ChatRequestSettings()
{
MaxTokens = defaultMaxTokens
};
}
MaxTokens = defaultMaxTokens
};

if (requestSettings is ChatRequestSettings requestSettingsChatRequestSettings)
{
return requestSettingsChatRequestSettings;
}

var json = JsonSerializer.Serialize(requestSettings);
var chatRequestSettings = JsonSerializer.Deserialize<ChatRequestSettings>(json, s_options);
var chatRequestSettings = JsonSerializer.Deserialize<ChatRequestSettings>(json, SerializerOptions);

if (chatRequestSettings is not null)
{
Expand All @@ -95,20 +92,13 @@ public static ChatRequestSettings FromRequestSettings(PromptExecutionSettings? r
throw new ArgumentException($"Invalid request settings, cannot convert to {nameof(ChatRequestSettings)}", nameof(requestSettings));
}

private static readonly JsonSerializerOptions s_options = CreateOptions();

private static JsonSerializerOptions CreateOptions()
private static readonly JsonSerializerOptions SerializerOptions = new()
{
JsonSerializerOptions options = new()
{
WriteIndented = true,
MaxDepth = 20,
AllowTrailingCommas = true,
PropertyNameCaseInsensitive = true,
ReadCommentHandling = JsonCommentHandling.Skip,
Converters = { new ChatRequestSettingsConverter() }
};

return options;
}
WriteIndented = true,
MaxDepth = 20,
AllowTrailingCommas = true,
PropertyNameCaseInsensitive = true,
ReadCommentHandling = JsonCommentHandling.Skip,
Converters = { new ChatRequestSettingsConverter() }
};
}
Original file line number Diff line number Diff line change
@@ -1,32 +1,25 @@
using System;
using System.Collections.Generic;
using System.Text.Json;
using System.Text.Json.Serialization;

namespace LLamaSharp.SemanticKernel.ChatCompletion;

/// <summary>
/// JSON converter for <see cref="OpenAIRequestSettings"/>
/// JSON converter for <see cref="ChatRequestSettings"/>
/// </summary>
[Obsolete("Use LLamaSharpPromptExecutionSettingsConverter instead")]
public class ChatRequestSettingsConverter : JsonConverter<ChatRequestSettings>
public class ChatRequestSettingsConverter
: JsonConverter<ChatRequestSettings>
{
/// <inheritdoc/>
public override ChatRequestSettings? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
public override ChatRequestSettings Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
var requestSettings = new ChatRequestSettings();

while (reader.Read() && reader.TokenType != JsonTokenType.EndObject)
{
if (reader.TokenType == JsonTokenType.PropertyName)
{
string? propertyName = reader.GetString();

if (propertyName is not null)
{
// normalise property name to uppercase
propertyName = propertyName.ToUpperInvariant();
}
var propertyName = reader.GetString()?.ToUpperInvariant();

reader.Read();

Expand Down
5 changes: 2 additions & 3 deletions LLama.SemanticKernel/ChatCompletion/HistoryTransform.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using LLama.Common;
using System.Text;
using LLama.Common;
using static LLama.LLamaTransforms;

namespace LLamaSharp.SemanticKernel.ChatCompletion;
Expand All @@ -10,7 +9,7 @@ namespace LLamaSharp.SemanticKernel.ChatCompletion;
public class HistoryTransform : DefaultHistoryTransform
{
/// <inheritdoc/>
public override string HistoryToText(global::LLama.Common.ChatHistory history)
public override string HistoryToText(ChatHistory history)
{
return base.HistoryToText(history) + $"{AuthorRole.Assistant}: ";
}
Expand Down
55 changes: 27 additions & 28 deletions LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
using LLama;
using LLama;
using LLama.Abstractions;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Services;
using System;
using System.IO;
using System.Runtime.CompilerServices;
using System.Text;
using static LLama.InteractiveExecutor;
Expand All @@ -18,16 +15,16 @@ namespace LLamaSharp.SemanticKernel.ChatCompletion;
public sealed class LLamaSharpChatCompletion : IChatCompletionService
{
private readonly ILLamaExecutor _model;
private LLamaSharpPromptExecutionSettings defaultRequestSettings;
private readonly IHistoryTransform historyTransform;
private readonly ITextStreamTransform outputTransform;
private readonly LLamaSharpPromptExecutionSettings _defaultRequestSettings;
private readonly IHistoryTransform _historyTransform;
private readonly ITextStreamTransform _outputTransform;

private readonly Dictionary<string, object?> _attributes = new();
private readonly bool _isStatefulExecutor;

public IReadOnlyDictionary<string, object?> Attributes => this._attributes;
public IReadOnlyDictionary<string, object?> Attributes => _attributes;

static LLamaSharpPromptExecutionSettings GetDefaultSettings()
private static LLamaSharpPromptExecutionSettings GetDefaultSettings()
{
return new LLamaSharpPromptExecutionSettings
{
Expand All @@ -43,11 +40,11 @@ public LLamaSharpChatCompletion(ILLamaExecutor model,
IHistoryTransform? historyTransform = null,
ITextStreamTransform? outputTransform = null)
{
this._model = model;
this._isStatefulExecutor = this._model is StatefulExecutorBase;
this.defaultRequestSettings = defaultRequestSettings ?? GetDefaultSettings();
this.historyTransform = historyTransform ?? new HistoryTransform();
this.outputTransform = outputTransform ?? new KeywordTextOutputStreamTransform(new[] { $"{LLama.Common.AuthorRole.User}:",
_model = model;
_isStatefulExecutor = _model is StatefulExecutorBase;
_defaultRequestSettings = defaultRequestSettings ?? GetDefaultSettings();
_historyTransform = historyTransform ?? new HistoryTransform();
_outputTransform = outputTransform ?? new KeywordTextOutputStreamTransform(new[] { $"{LLama.Common.AuthorRole.User}:",
$"{LLama.Common.AuthorRole.Assistant}:",
$"{LLama.Common.AuthorRole.System}:"});
}
Expand All @@ -69,12 +66,12 @@ public async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsync
{
var settings = executionSettings != null
? LLamaSharpPromptExecutionSettings.FromRequestSettings(executionSettings)
: defaultRequestSettings;
: _defaultRequestSettings;

string prompt = this._getFormattedPrompt(chatHistory);
var prompt = _getFormattedPrompt(chatHistory);
var result = _model.InferAsync(prompt, settings.ToLLamaSharpInferenceParams(), cancellationToken);

var output = outputTransform.TransformAsync(result);
var output = _outputTransform.TransformAsync(result);

var sb = new StringBuilder();
await foreach (var token in output)
Expand All @@ -90,12 +87,12 @@ public async IAsyncEnumerable<StreamingChatMessageContent> GetStreamingChatMessa
{
var settings = executionSettings != null
? LLamaSharpPromptExecutionSettings.FromRequestSettings(executionSettings)
: defaultRequestSettings;
: _defaultRequestSettings;

string prompt = this._getFormattedPrompt(chatHistory);
var prompt = _getFormattedPrompt(chatHistory);
var result = _model.InferAsync(prompt, settings.ToLLamaSharpInferenceParams(), cancellationToken);

var output = outputTransform.TransformAsync(result);
var output = _outputTransform.TransformAsync(result);

await foreach (var token in output)
{
Expand All @@ -109,24 +106,26 @@ public async IAsyncEnumerable<StreamingChatMessageContent> GetStreamingChatMessa
/// </summary>
/// <param name="chatHistory"></param>
/// <returns>The formatted prompt</returns>
private string _getFormattedPrompt(ChatHistory chatHistory){
private string _getFormattedPrompt(ChatHistory chatHistory)
{
string prompt;
if (this._isStatefulExecutor){
InteractiveExecutorState state = (InteractiveExecutorState)((StatefulExecutorBase)this._model).GetStateData();
if (_isStatefulExecutor)
{
var state = (InteractiveExecutorState)((StatefulExecutorBase)_model).GetStateData();
if (state.IsPromptRun)
{
prompt = historyTransform.HistoryToText(chatHistory.ToLLamaSharpChatHistory());
prompt = _historyTransform.HistoryToText(chatHistory.ToLLamaSharpChatHistory());
}
else
{
ChatHistory temp_history = new();
temp_history.AddUserMessage(chatHistory.Last().Content);
prompt = historyTransform.HistoryToText(temp_history.ToLLamaSharpChatHistory());
ChatHistory tempHistory = new();
tempHistory.AddUserMessage(chatHistory.Last().Content ?? "");
prompt = _historyTransform.HistoryToText(tempHistory.ToLLamaSharpChatHistory());
}
}
else
{
prompt = historyTransform.HistoryToText(chatHistory.ToLLamaSharpChatHistory());
prompt = _historyTransform.HistoryToText(chatHistory.ToLLamaSharpChatHistory());
}

return prompt;
Expand Down
26 changes: 16 additions & 10 deletions LLama.SemanticKernel/ExtensionMethods.cs
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.ChatCompletion;
using AuthorRole = LLama.Common.AuthorRole;

namespace LLamaSharp.SemanticKernel;

public static class ExtensionMethods
{
public static global::LLama.Common.ChatHistory ToLLamaSharpChatHistory(this ChatHistory chatHistory, bool ignoreCase = true)
public static LLama.Common.ChatHistory ToLLamaSharpChatHistory(this ChatHistory chatHistory, bool ignoreCase = true)
{
if (chatHistory is null)
{
throw new ArgumentNullException(nameof(chatHistory));
}

var history = new global::LLama.Common.ChatHistory();
var history = new LLama.Common.ChatHistory();

foreach (var chat in chatHistory)
{
var role = Enum.TryParse<global::LLama.Common.AuthorRole>(chat.Role.Label, ignoreCase, out var _role) ? _role : global::LLama.Common.AuthorRole.Unknown;
history.AddMessage(role, chat.Content);
if (!Enum.TryParse<AuthorRole>(chat.Role.Label, ignoreCase, out var role))
role = AuthorRole.Unknown;

history.AddMessage(role, chat.Content ?? "");
}

return history;
Expand All @@ -26,18 +30,20 @@ public static class ExtensionMethods
/// </summary>
/// <param name="requestSettings"></param>
/// <returns></returns>
internal static global::LLama.Common.InferenceParams ToLLamaSharpInferenceParams(this LLamaSharpPromptExecutionSettings requestSettings)
internal static LLama.Common.InferenceParams ToLLamaSharpInferenceParams(this LLamaSharpPromptExecutionSettings requestSettings)
{
if (requestSettings is null)
{
throw new ArgumentNullException(nameof(requestSettings));
}

var antiPrompts = new List<string>(requestSettings.StopSequences)
{ LLama.Common.AuthorRole.User.ToString() + ":" ,
LLama.Common.AuthorRole.Assistant.ToString() + ":",
LLama.Common.AuthorRole.System.ToString() + ":"};
return new global::LLama.Common.InferenceParams
{
$"{AuthorRole.User}:",
$"{AuthorRole.Assistant}:",
$"{AuthorRole.System}:"
};
return new LLama.Common.InferenceParams
{
Temperature = (float)requestSettings.Temperature,
TopP = (float)requestSettings.TopP,
Expand Down
Loading
Loading