Skip to content

Commit

Permalink
duo: r0
Browse files Browse the repository at this point in the history
  • Loading branch information
okuvshynov committed May 31, 2024
1 parent c29f83d commit a96934d
Showing 1 changed file with 33 additions and 40 deletions.
73 changes: 33 additions & 40 deletions examples/duo/duo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ enum Turn
MAIN = 2
};

struct speculation_context
struct shared_context
{
llama_tokens candidate;
std::mutex mtx;
Expand Down Expand Up @@ -90,13 +90,12 @@ static int decode(llama_context * ctx, iter_t from, iter_t to, int offset, bool

// this becomes more similar to sequential versions
static int speculation(
llama_model * model,
speculation_context * spec_ctx,
llama_context * ctx,
llama_model * model,
llama_context * ctx,
shared_context * sctx,
const llama_tokens & input,
size_t n_draft)
{

llama_batch batch = llama_batch_init(512, 0, 1);
decode(ctx, input.begin(), input.end(), 0, false, batch);

Expand All @@ -107,17 +106,16 @@ static int speculation(
while (true)
{
{
std::unique_lock<std::mutex> lock(spec_ctx->mtx);
spec_ctx->cv.wait(lock, [&spec_ctx] { return spec_ctx->turn == Turn::SPEC || spec_ctx->done; });
if (spec_ctx->done)
std::unique_lock<std::mutex> lock(sctx->mtx);
sctx->cv.wait(lock, [&sctx] { return sctx->turn == Turn::SPEC || sctx->done; });
if (sctx->done)
{
break;
}
shared = spec_ctx->candidate;
spec_ctx->turn = Turn::NONE;
shared = sctx->candidate;
sctx->turn = Turn::NONE;
}

// here we merge shared and local and clean the cache if needed
bool match = true;
match_len = local.size() - 1;
for (size_t i = 0; i < std::min(shared.size(), local.size()); i++)
Expand Down Expand Up @@ -145,10 +143,10 @@ static int speculation(
}

{
std::unique_lock<std::mutex> lock(spec_ctx->mtx);
spec_ctx->candidate = local;
spec_ctx->turn = Turn::MAIN;
spec_ctx->cv.notify_one();
std::unique_lock<std::mutex> lock(sctx->mtx);
sctx->candidate = local;
sctx->turn = Turn::MAIN;
sctx->cv.notify_one();
}
}

Expand All @@ -157,10 +155,10 @@ static int speculation(
}

static int target(
llama_model * model,
speculation_context * spec_ctx,
llama_context * ctx,
const llama_tokens& input,
llama_model * model,
llama_context * ctx,
shared_context * sctx,
const llama_tokens & input,
size_t n_predict)
{
dbg_color(to_string(ctx, input.begin(), input.end()));
Expand Down Expand Up @@ -220,20 +218,15 @@ static int target(
}

{
std::unique_lock<std::mutex> lock(spec_ctx->mtx);
spec_ctx->cv.wait(lock, [&spec_ctx] { return spec_ctx->turn == Turn::MAIN; });
auto & spec = spec_ctx->candidate;
std::unique_lock<std::mutex> lock(sctx->mtx);
sctx->cv.wait(lock, [&sctx] { return sctx->turn == Turn::MAIN; });
auto & spec = sctx->candidate;
size_t n_match = 0;
for (size_t i = 0; i < next_tokens.size() && i + next_tokens_pos < spec.size(); i++)
while (n_match < next_tokens.size()
&& n_match + next_tokens_pos < spec.size()
&& next_tokens[n_match] == spec[n_match + next_tokens_pos])
{
if (next_tokens[i] == spec[i + next_tokens_pos])
{
n_match++;
}
else
{
break;
}
n_match++;
}

dbg_color(to_string(ctx, spec.begin() + next_tokens_pos, spec.begin() + next_tokens_pos + n_match), /* green */ "\033[32m");
Expand All @@ -248,8 +241,8 @@ static int target(
}
}
input_seq.assign(spec.begin() + n_accepted - 1, spec.end());
spec_ctx->turn = Turn::SPEC;
spec_ctx->cv.notify_one();
sctx->turn = Turn::SPEC;
sctx->cv.notify_one();
}

if (n_decoded >= n_predict || done)
Expand All @@ -271,8 +264,8 @@ static int target(
llama_print_timings(ctx);
fprintf(stderr, "\n");
{
std::lock_guard<std::mutex> _lock(spec_ctx->mtx);
spec_ctx->done = true;
std::lock_guard<std::mutex> _lock(sctx->mtx);
sctx->done = true;
}

llama_batch_free(batch);
Expand All @@ -294,16 +287,16 @@ int main(int argc, char ** argv) {

llama_backend_init();
llama_numa_init(params.numa);
speculation_context spec_ctx;
shared_context sctx;

// main model and context
llama_model * model = nullptr;
llama_context * ctx = nullptr;
std::tie(model, ctx) = llama_init_from_gpt_params(params);

llama_tokens input = llama_tokenize(ctx, params.prompt, true);
spec_ctx.candidate = input;
spec_ctx.turn = Turn::SPEC;
sctx.candidate = input;
sctx.turn = Turn::SPEC;

// prepare draft model and contexts.
llama_model * draft_model = nullptr;
Expand All @@ -319,9 +312,9 @@ int main(int argc, char ** argv) {

params.rpc_servers = params.rpc_servers_draft;
std::tie(draft_model, draft_ctx) = llama_init_from_gpt_params(params);
std::thread spec_thread = std::thread(speculation, draft_model, &spec_ctx, draft_ctx, input, params.n_draft);
std::thread spec_thread = std::thread(speculation, draft_model, draft_ctx, &sctx, input, params.n_draft);

target(model, &spec_ctx, ctx, input, params.n_predict);
target(model, ctx, &sctx, input, params.n_predict);

spec_thread.join();

Expand Down

0 comments on commit a96934d

Please sign in to comment.