diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs
index e01a40cc..9b2b1761 100644
--- a/LLama/LLamaExecutorBase.cs
+++ b/LLama/LLamaExecutorBase.cs
@@ -251,7 +251,7 @@ protected virtual void TryReuseMatchingPrefix()
///
///
///
- protected abstract Task PreprocessInputs(string text, InferStateArgs args);
+ protected abstract Task PreprocessInputs(string? text, InferStateArgs args);
///
/// Do some post processing after the inference.
@@ -296,11 +296,11 @@ protected virtual void TryReuseMatchingPrefix()
///
/// Execute the inference.
///
- ///
+ /// The prompt. If null, generation will continue where it left off previously.
///
///
///
- public virtual async IAsyncEnumerable InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ public virtual async IAsyncEnumerable InferAsync(string? text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
cancellationToken.ThrowIfCancellationRequested();
inferenceParams ??= new InferenceParams();
diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs
index 057c44c0..ec41aa7f 100644
--- a/LLama/LLamaInstructExecutor.cs
+++ b/LLama/LLamaInstructExecutor.cs
@@ -116,30 +116,38 @@ protected override Task GetLoopCondition(InferStateArgs args)
}
///
- protected override Task PreprocessInputs(string text, InferStateArgs args)
+ protected override Task PreprocessInputs(string? text, InferStateArgs args)
{
args.Antiprompts ??= [ ];
- args.Antiprompts.Add(_instructionPrefix);
+ if (!args.Antiprompts.Contains(_instructionPrefix))
+ args.Antiprompts.Add(_instructionPrefix);
+
if (_is_prompt_run)
{
// When running the first input (prompt) in inteactive mode, we should specially process it.
+ if (text == null) throw new ArgumentException("Prompt cannot be null to trigger continuation if a prompt has not been provided previously.");
_embed_inps = Context.Tokenize(text, true, true).ToList();
}
else
{
- if (!text.EndsWith("\n"))
- {
- text += "\n";
- }
_consumedTokensCount = _embed_inps.Count;
- _embed_inps.AddRange(_inp_pfx);
- var line_inp = Context.Tokenize(text, false, true);
- _embed_inps.AddRange(line_inp);
+ // Don't append the template tokens if continuation is requested (by providing a null prompt)
+ if (text != null)
+ {
+ if (!text.EndsWith("\n"))
+ {
+ text += "\n";
+ }
+ _embed_inps.AddRange(_inp_pfx);
+
+ var line_inp = Context.Tokenize(text, false, true);
+ _embed_inps.AddRange(line_inp);
- _embed_inps.AddRange(_inp_sfx);
+ _embed_inps.AddRange(_inp_sfx);
- args.RemainedTokens -= line_inp.Length;
+ args.RemainedTokens -= line_inp.Length;
+ }
}
return Task.CompletedTask;
diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs
index 068cf129..f97a2b63 100644
--- a/LLama/LLamaInteractExecutor.cs
+++ b/LLama/LLamaInteractExecutor.cs
@@ -111,11 +111,12 @@ protected override Task GetLoopCondition(InferStateArgs args)
}
///
- protected override Task PreprocessInputs(string text, InferStateArgs args)
+ protected override Task PreprocessInputs(string? text, InferStateArgs args)
{
if (_is_prompt_run)
{
// When running the first input (prompt) in interactive mode, we should specially process it.
+ if (text == null) throw new ArgumentException("Prompt cannot be null to trigger continuation if a prompt has not been provided previously.");
if (!this.IsMultiModal)
{
_embed_inps = Context.Tokenize(text, true, true).ToList();
@@ -127,20 +128,24 @@ protected override Task PreprocessInputs(string text, InferStateArgs args)
}
else
{
- if (!text.EndsWith("\n"))
+ // Don't add any tokens if continuation is requested (by providing a null prompt)
+ if (text != null)
{
- text += "\n";
- }
+ if (!text.EndsWith("\n"))
+ {
+ text += "\n";
+ }
- if (!this.IsMultiModal)
- {
- var line_inp = Context.Tokenize(text, false, true);
- _embed_inps.AddRange(line_inp);
- args.RemainedTokens -= line_inp.Length;
- }
- else
- {
- PreprocessLlava(text, args, false);
+ if (!this.IsMultiModal)
+ {
+ var line_inp = Context.Tokenize(text, false, true);
+ _embed_inps.AddRange(line_inp);
+ args.RemainedTokens -= line_inp.Length;
+ }
+ else
+ {
+ PreprocessLlava(text, args, false);
+ }
}
}