Skip to content

Commit

Permalink
duo: v5
Browse files Browse the repository at this point in the history
  • Loading branch information
okuvshynov committed May 26, 2024
1 parent 7c8699a commit de26d49
Showing 1 changed file with 21 additions and 20 deletions.
41 changes: 21 additions & 20 deletions examples/duo/duo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ static int decode(llama_context * ctx, iter_t from, iter_t to, int offset, bool
}
batch.logits[batch.n_tokens - 1] = true;
int res = 0;
if (llama_decode(ctx, batch) != 0) {
if (llama_decode(ctx, batch) != 0)
{
fprintf(stderr, "llama_decode() failed\n");
res = 1;
}
Expand Down Expand Up @@ -126,7 +127,8 @@ static int speculation(
}

auto next_tokens = greedy_tokens(model, ctx, logit_idx, logit_idx + 1);
if (next_tokens.size() != 1) {
if (next_tokens.size() != 1)
{
fprintf(stderr, "invalid next tokens\n");
return 1;
}
Expand Down Expand Up @@ -157,9 +159,7 @@ static int speculation(
}
}


decode(ctx, local.begin() + match_len, local.end(), match_len, false, batch);

logit_idx = local.size() - match_len - 1;
}

Expand All @@ -179,9 +179,8 @@ static int target(
llama_batch batch = llama_batch_init(512, 0, 1);
decode(ctx, input.begin(), input.end(), 0, false, batch);

// TODO: rename to n_accepted
size_t n_cur = input.size();
size_t n_decode = 0;
size_t n_accepted = input.size();
size_t n_decoded = 0;

const auto t_main_start = ggml_time_us();

Expand All @@ -192,7 +191,7 @@ static int target(
llama_tokens input_seq, next_tokens;
input_seq.push_back(input.back());

while (n_decode <= n_predict)
while (n_decoded < n_predict)
{
next_tokens = greedy_tokens(model, ctx, logits_from, logits_to);
if (next_tokens.size() != input_seq.size())
Expand All @@ -201,16 +200,16 @@ static int target(
return 1;
}

size_t next_tokens_pos = n_cur;
size_t next_tokens_pos = n_accepted;
// we always accept at least one new token
n_cur += 1;
n_decode += 1;
n_accepted += 1;
n_decoded += 1;
for (size_t i = 0; i + 1 < input_seq.size(); i++)
{
if (next_tokens[i] == input_seq[i + 1])
{
n_cur += 1;
n_decode += 1;
n_accepted += 1;
n_decoded += 1;
}
else
{
Expand All @@ -222,7 +221,7 @@ static int target(

// empty the non-matching portion of kv cache.
// n_cur is incremented at least once and will be > 0
llama_kv_cache_seq_rm(ctx, 0, n_cur - 1, -1);
llama_kv_cache_seq_rm(ctx, 0, n_accepted - 1, -1);

bool done = false;
for (size_t i = 0; i < next_tokens.size(); i++)
Expand Down Expand Up @@ -263,14 +262,14 @@ static int target(
spec.push_back(tok);
}
}
input_seq.assign(spec.begin() + n_cur - 1, spec.end());
input_seq.assign(spec.begin() + n_accepted - 1, spec.end());
}
if (n_decode >= n_predict || done)
if (n_decoded >= n_predict || done)
{
break;
}

decode(ctx, input_seq.begin(), input_seq.end(), n_cur - 1, true, batch);
decode(ctx, input_seq.begin(), input_seq.end(), n_accepted - 1, true, batch);

logits_from = 0;
logits_to = input_seq.size();
Expand All @@ -279,7 +278,7 @@ static int target(
const auto t_main_end = ggml_time_us();

fprintf(stderr, "decoded %zu tokens in %.2f s, speed: %.2f t/s\n",
n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
n_decoded, (t_main_end - t_main_start) / 1000000.0f, n_decoded / ((t_main_end - t_main_start) / 1000000.0f));

llama_print_timings(ctx);
fprintf(stderr, "\n");
Expand All @@ -295,11 +294,13 @@ static int target(
int main(int argc, char ** argv) {
gpt_params params;

if (gpt_params_parse(argc, argv, params) == false) {
if (gpt_params_parse(argc, argv, params) == false)
{
return 1;
}

if (params.seed == LLAMA_DEFAULT_SEED) {
if (params.seed == LLAMA_DEFAULT_SEED)
{
params.seed = time(NULL);
}

Expand Down

0 comments on commit de26d49

Please sign in to comment.