From de26d49fbe63a88d8c30de1f429945cb818f5884 Mon Sep 17 00:00:00 2001 From: Oleksandr Kuvshynov <661042+okuvshynov@users.noreply.github.com> Date: Sat, 25 May 2024 22:19:23 -0400 Subject: [PATCH] duo: v5 --- examples/duo/duo.cpp | 41 +++++++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/examples/duo/duo.cpp b/examples/duo/duo.cpp index ff65df739bbd6..46329cbc70c38 100644 --- a/examples/duo/duo.cpp +++ b/examples/duo/duo.cpp @@ -82,7 +82,8 @@ static int decode(llama_context * ctx, iter_t from, iter_t to, int offset, bool } batch.logits[batch.n_tokens - 1] = true; int res = 0; - if (llama_decode(ctx, batch) != 0) { + if (llama_decode(ctx, batch) != 0) + { fprintf(stderr, "llama_decode() failed\n"); res = 1; } @@ -126,7 +127,8 @@ static int speculation( } auto next_tokens = greedy_tokens(model, ctx, logit_idx, logit_idx + 1); - if (next_tokens.size() != 1) { + if (next_tokens.size() != 1) + { fprintf(stderr, "invalid next tokens\n"); return 1; } @@ -157,9 +159,7 @@ static int speculation( } } - decode(ctx, local.begin() + match_len, local.end(), match_len, false, batch); - logit_idx = local.size() - match_len - 1; } @@ -179,9 +179,8 @@ static int target( llama_batch batch = llama_batch_init(512, 0, 1); decode(ctx, input.begin(), input.end(), 0, false, batch); - // TODO: rename to n_accepted - size_t n_cur = input.size(); - size_t n_decode = 0; + size_t n_accepted = input.size(); + size_t n_decoded = 0; const auto t_main_start = ggml_time_us(); @@ -192,7 +191,7 @@ static int target( llama_tokens input_seq, next_tokens; input_seq.push_back(input.back()); - while (n_decode <= n_predict) + while (n_decoded < n_predict) { next_tokens = greedy_tokens(model, ctx, logits_from, logits_to); if (next_tokens.size() != input_seq.size()) @@ -201,16 +200,16 @@ static int target( return 1; } - size_t next_tokens_pos = n_cur; + size_t next_tokens_pos = n_accepted; // we always accept at least one new token - n_cur += 1; - n_decode += 1; + n_accepted += 1; + n_decoded += 1; for (size_t i = 0; i + 1 < input_seq.size(); i++) { if (next_tokens[i] == input_seq[i + 1]) { - n_cur += 1; - n_decode += 1; + n_accepted += 1; + n_decoded += 1; } else { @@ -222,7 +221,7 @@ static int target( // empty the non-matching portion of kv cache. // n_cur is incremented at least once and will be > 0 - llama_kv_cache_seq_rm(ctx, 0, n_cur - 1, -1); + llama_kv_cache_seq_rm(ctx, 0, n_accepted - 1, -1); bool done = false; for (size_t i = 0; i < next_tokens.size(); i++) @@ -263,14 +262,14 @@ static int target( spec.push_back(tok); } } - input_seq.assign(spec.begin() + n_cur - 1, spec.end()); + input_seq.assign(spec.begin() + n_accepted - 1, spec.end()); } - if (n_decode >= n_predict || done) + if (n_decoded >= n_predict || done) { break; } - decode(ctx, input_seq.begin(), input_seq.end(), n_cur - 1, true, batch); + decode(ctx, input_seq.begin(), input_seq.end(), n_accepted - 1, true, batch); logits_from = 0; logits_to = input_seq.size(); @@ -279,7 +278,7 @@ static int target( const auto t_main_end = ggml_time_us(); fprintf(stderr, "decoded %zu tokens in %.2f s, speed: %.2f t/s\n", - n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); + n_decoded, (t_main_end - t_main_start) / 1000000.0f, n_decoded / ((t_main_end - t_main_start) / 1000000.0f)); llama_print_timings(ctx); fprintf(stderr, "\n"); @@ -295,11 +294,13 @@ static int target( int main(int argc, char ** argv) { gpt_params params; - if (gpt_params_parse(argc, argv, params) == false) { + if (gpt_params_parse(argc, argv, params) == false) + { return 1; } - if (params.seed == LLAMA_DEFAULT_SEED) { + if (params.seed == LLAMA_DEFAULT_SEED) + { params.seed = time(NULL); }