diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index e01a40cc..9b2b1761 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -251,7 +251,7 @@ protected virtual void TryReuseMatchingPrefix() /// /// /// - protected abstract Task PreprocessInputs(string text, InferStateArgs args); + protected abstract Task PreprocessInputs(string? text, InferStateArgs args); /// /// Do some post processing after the inference. @@ -296,11 +296,11 @@ protected virtual void TryReuseMatchingPrefix() /// /// Execute the inference. /// - /// + /// The prompt. If null, generation will continue where it left off previously. /// /// /// - public virtual async IAsyncEnumerable InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + public virtual async IAsyncEnumerable InferAsync(string? text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { cancellationToken.ThrowIfCancellationRequested(); inferenceParams ??= new InferenceParams(); diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index 057c44c0..ec41aa7f 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -116,30 +116,38 @@ protected override Task GetLoopCondition(InferStateArgs args) } /// - protected override Task PreprocessInputs(string text, InferStateArgs args) + protected override Task PreprocessInputs(string? text, InferStateArgs args) { args.Antiprompts ??= [ ]; - args.Antiprompts.Add(_instructionPrefix); + if (!args.Antiprompts.Contains(_instructionPrefix)) + args.Antiprompts.Add(_instructionPrefix); + if (_is_prompt_run) { // When running the first input (prompt) in inteactive mode, we should specially process it. + if (text == null) throw new ArgumentException("Prompt cannot be null to trigger continuation if a prompt has not been provided previously."); _embed_inps = Context.Tokenize(text, true, true).ToList(); } else { - if (!text.EndsWith("\n")) - { - text += "\n"; - } _consumedTokensCount = _embed_inps.Count; - _embed_inps.AddRange(_inp_pfx); - var line_inp = Context.Tokenize(text, false, true); - _embed_inps.AddRange(line_inp); + // Don't append the template tokens if continuation is requested (by providing a null prompt) + if (text != null) + { + if (!text.EndsWith("\n")) + { + text += "\n"; + } + _embed_inps.AddRange(_inp_pfx); + + var line_inp = Context.Tokenize(text, false, true); + _embed_inps.AddRange(line_inp); - _embed_inps.AddRange(_inp_sfx); + _embed_inps.AddRange(_inp_sfx); - args.RemainedTokens -= line_inp.Length; + args.RemainedTokens -= line_inp.Length; + } } return Task.CompletedTask; diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index 068cf129..f97a2b63 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -111,11 +111,12 @@ protected override Task GetLoopCondition(InferStateArgs args) } /// - protected override Task PreprocessInputs(string text, InferStateArgs args) + protected override Task PreprocessInputs(string? text, InferStateArgs args) { if (_is_prompt_run) { // When running the first input (prompt) in interactive mode, we should specially process it. + if (text == null) throw new ArgumentException("Prompt cannot be null to trigger continuation if a prompt has not been provided previously."); if (!this.IsMultiModal) { _embed_inps = Context.Tokenize(text, true, true).ToList(); @@ -127,20 +128,24 @@ protected override Task PreprocessInputs(string text, InferStateArgs args) } else { - if (!text.EndsWith("\n")) + // Don't add any tokens if continuation is requested (by providing a null prompt) + if (text != null) { - text += "\n"; - } + if (!text.EndsWith("\n")) + { + text += "\n"; + } - if (!this.IsMultiModal) - { - var line_inp = Context.Tokenize(text, false, true); - _embed_inps.AddRange(line_inp); - args.RemainedTokens -= line_inp.Length; - } - else - { - PreprocessLlava(text, args, false); + if (!this.IsMultiModal) + { + var line_inp = Context.Tokenize(text, false, true); + _embed_inps.AddRange(line_inp); + args.RemainedTokens -= line_inp.Length; + } + else + { + PreprocessLlava(text, args, false); + } } }