Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow continuation in Instruct and Interact executors; fix a minor leak #852

Merged
merged 2 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions LLama/LLamaExecutorBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@
/// </summary>
/// <param name="text"></param>
/// <param name="args"></param>
protected abstract Task PreprocessInputs(string text, InferStateArgs args);
protected abstract Task PreprocessInputs(string? text, InferStateArgs args);

/// <summary>
/// Do some post processing after the inference.
Expand Down Expand Up @@ -296,11 +296,11 @@
/// <summary>
/// Execute the inference.
/// </summary>
/// <param name="text"></param>
/// <param name="text">The prompt. If null, generation will continue where it left off previously.</param>
/// <param name="inferenceParams"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
public virtual async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
public virtual async IAsyncEnumerable<string> InferAsync(string? text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
cancellationToken.ThrowIfCancellationRequested();
inferenceParams ??= new InferenceParams();
Expand Down Expand Up @@ -419,16 +419,16 @@
public string? SessionFilePath { get; set; }

[JsonPropertyName("embd")]
public LLamaToken[] Embeds { get; set; }

Check warning on line 422 in LLama/LLamaExecutorBase.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

Non-nullable property 'Embeds' must contain a non-null value when exiting constructor. Consider declaring the property as nullable.

Check warning on line 422 in LLama/LLamaExecutorBase.cs

View workflow job for this annotation

GitHub Actions / Test (osx-release)

Non-nullable property 'Embeds' must contain a non-null value when exiting constructor. Consider declaring the property as nullable.

[JsonPropertyName("embd_inps")]
public LLamaToken[] EmbedInps { get; set; }

Check warning on line 425 in LLama/LLamaExecutorBase.cs

View workflow job for this annotation

GitHub Actions / Test (osx-release)

Non-nullable property 'EmbedInps' must contain a non-null value when exiting constructor. Consider declaring the property as nullable.

[JsonPropertyName("session_tokens")]
public LLamaToken[] SessionTokens { get; set; }

Check warning on line 428 in LLama/LLamaExecutorBase.cs

View workflow job for this annotation

GitHub Actions / Test (osx-release)

Non-nullable property 'SessionTokens' must contain a non-null value when exiting constructor. Consider declaring the property as nullable.

[JsonPropertyName("last_n_tokens")]
public LLamaToken[] LastTokens { get; set; }

Check warning on line 431 in LLama/LLamaExecutorBase.cs

View workflow job for this annotation

GitHub Actions / Test (osx-release)

Non-nullable property 'LastTokens' must contain a non-null value when exiting constructor. Consider declaring the property as nullable.

[JsonPropertyName("last_tokens_maximum_count")]
public int LastTokensCapacity { get; set; }
Expand Down
28 changes: 17 additions & 11 deletions LLama/LLamaInstructExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@
using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read))
{
var state = await JsonSerializer.DeserializeAsync<InstructExecutorState>(fs);
await LoadState(state);

Check warning on line 109 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

Possible null reference argument for parameter 'data' in 'Task InstructExecutor.LoadState(ExecutorBaseState data)'.

Check warning on line 109 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

Possible null reference argument for parameter 'data' in 'Task InstructExecutor.LoadState(ExecutorBaseState data)'.

Check warning on line 109 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (osx-release)

Possible null reference argument for parameter 'data' in 'Task InstructExecutor.LoadState(ExecutorBaseState data)'.
}
}

Expand All @@ -117,41 +117,47 @@
}

/// <inheritdoc />
protected override Task PreprocessInputs(string text, InferStateArgs args)
protected override Task PreprocessInputs(string? text, InferStateArgs args)
{
args.Antiprompts ??= new List<string>();
args.Antiprompts.Add(_instructionPrefix);
if (!args.Antiprompts.Contains(_instructionPrefix)) args.Antiprompts.Add(_instructionPrefix);
if (_is_prompt_run)
{
// When running the first input (prompt) in inteactive mode, we should specially process it.
if (text == null) throw new ArgumentException("Prompt cannot be null to trigger continuation if a prompt has not been provided previously.");
_embed_inps = Context.Tokenize(text, true, true).ToList();
}
else
{
if (!text.EndsWith("\n"))
{
text += "\n";
}
_consumedTokensCount = _embed_inps.Count;
_embed_inps.AddRange(_inp_pfx);

var line_inp = Context.Tokenize(text, false, true);
_embed_inps.AddRange(line_inp);
// Don't append the template tokens if continuation is requested (by providing a null prompt)
if (text != null)
{
if (!text.EndsWith("\n"))
{
text += "\n";
}
_embed_inps.AddRange(_inp_pfx);

var line_inp = Context.Tokenize(text, false, true);
_embed_inps.AddRange(line_inp);

_embed_inps.AddRange(_inp_sfx);
_embed_inps.AddRange(_inp_sfx);

args.RemainedTokens -= line_inp.Length;
args.RemainedTokens -= line_inp.Length;
}
}

return Task.CompletedTask;
}

/// <inheritdoc />
protected override async Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args)

Check warning on line 156 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

This async method lacks 'await' operators and will run synchronously. Consider using the 'await' operator to await non-blocking API calls, or 'await Task.Run(...)' to do CPU-bound work on a background thread.

Check warning on line 156 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

This async method lacks 'await' operators and will run synchronously. Consider using the 'await' operator to await non-blocking API calls, or 'await Task.Run(...)' to do CPU-bound work on a background thread.

Check warning on line 156 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (osx-release)

This async method lacks 'await' operators and will run synchronously. Consider using the 'await' operator to await non-blocking API calls, or 'await Task.Run(...)' to do CPU-bound work on a background thread.
{
if (_embed_inps.Count <= _consumedTokensCount)
{
if (_last_n_tokens.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding))

Check warning on line 160 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

'IReadOnlyListExtensions.TokensEndsWithAnyString<TTokens>(TTokens, IList<string>?, SafeLlamaModelHandle, Encoding)' is obsolete: 'Use an Antiprompt processor instead'

Check warning on line 160 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

'IReadOnlyListExtensions.TokensEndsWithAnyString<TTokens>(TTokens, IList<string>?, SafeLlamaModelHandle, Encoding)' is obsolete: 'Use an Antiprompt processor instead'

Check warning on line 160 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (osx-release)

'IReadOnlyListExtensions.TokensEndsWithAnyString<TTokens>(TTokens, IList<string>?, SafeLlamaModelHandle, Encoding)' is obsolete: 'Use an Antiprompt processor instead'
{
args.WaitForInput = true;
return (true, Array.Empty<string>());
Expand Down Expand Up @@ -217,7 +223,7 @@
if (!string.IsNullOrEmpty(_pathSession) && args.NeedToSaveSession)
{
args.NeedToSaveSession = false;
SaveSessionFile(_pathSession);

Check warning on line 226 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

Possible null reference argument for parameter 'filename' in 'void StatefulExecutorBase.SaveSessionFile(string filename)'.

Check warning on line 226 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

Possible null reference argument for parameter 'filename' in 'void StatefulExecutorBase.SaveSessionFile(string filename)'.

Check warning on line 226 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (osx-release)

Possible null reference argument for parameter 'filename' in 'void StatefulExecutorBase.SaveSessionFile(string filename)'.
}

LLamaToken id;
Expand Down Expand Up @@ -277,12 +283,12 @@
/// Instruction prefix tokens.
/// </summary>
[JsonPropertyName("inp_pfx")]
public LLamaToken[] InputPrefixTokens { get; set; }

Check warning on line 286 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

Non-nullable property 'InputPrefixTokens' must contain a non-null value when exiting constructor. Consider declaring the property as nullable.

Check warning on line 286 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

Non-nullable property 'InputPrefixTokens' must contain a non-null value when exiting constructor. Consider declaring the property as nullable.
/// <summary>
/// Instruction suffix tokens.
/// </summary>
[JsonPropertyName("inp_sfx")]
public LLamaToken[] InputSuffixTokens { get; set; }

Check warning on line 291 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

Non-nullable property 'InputSuffixTokens' must contain a non-null value when exiting constructor. Consider declaring the property as nullable.

Check warning on line 291 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

Non-nullable property 'InputSuffixTokens' must contain a non-null value when exiting constructor. Consider declaring the property as nullable.
}
}
}
31 changes: 18 additions & 13 deletions LLama/LLamaInteractExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@
using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read))
{
var state = await JsonSerializer.DeserializeAsync<InteractiveExecutorState>(fs);
await LoadState(state);

Check warning on line 101 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

Possible null reference argument for parameter 'data' in 'Task InteractiveExecutor.LoadState(ExecutorBaseState data)'.

Check warning on line 101 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

Possible null reference argument for parameter 'data' in 'Task InteractiveExecutor.LoadState(ExecutorBaseState data)'.
}
}

Expand All @@ -112,11 +112,12 @@
}

/// <inheritdoc />
protected override Task PreprocessInputs(string text, InferStateArgs args)
protected override Task PreprocessInputs(string? text, InferStateArgs args)
{
if (_is_prompt_run)
{
// When running the first input (prompt) in interactive mode, we should specially process it.
if (text == null) throw new ArgumentException("Prompt cannot be null to trigger continuation if a prompt has not been provided previously.");
if (!this.IsMultiModal)
{
_embed_inps = Context.Tokenize(text, true, true).ToList();
Expand All @@ -128,20 +129,24 @@
}
else
{
if (!text.EndsWith("\n"))
// Don't add any tokens if continuation is requested (by providing a null prompt)
if (text != null)
{
text += "\n";
}
if (!text.EndsWith("\n"))
{
text += "\n";
}

if (!this.IsMultiModal)
{
var line_inp = Context.Tokenize(text, false, true);
_embed_inps.AddRange(line_inp);
args.RemainedTokens -= line_inp.Length;
}
else
{
PreprocessLlava(text, args, false);
if (!this.IsMultiModal)
{
var line_inp = Context.Tokenize(text, false, true);
_embed_inps.AddRange(line_inp);
args.RemainedTokens -= line_inp.Length;
}
else
{
PreprocessLlava(text, args, false);
}
}
}

Expand All @@ -159,7 +164,7 @@
{
foreach (var image in Images)
{
_imageEmbedHandles.Add(SafeLlavaImageEmbedHandle.CreateFromMemory(ClipModel.NativeHandle, Context, image));

Check warning on line 167 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

Dereference of a possibly null reference.
}

int imageIndex = text.IndexOf("<image>");
Expand Down
Loading