From cbb5dd7b12faa6871fdfe6c3fe4e6f2acbbb51b4 Mon Sep 17 00:00:00 2001 From: Jia Liu Date: Fri, 16 Aug 2024 16:42:44 +0800 Subject: [PATCH 1/5] change batch.logits to batch.output --- common/common.cpp | 2 +- common/log.h | 2 +- examples/batched-bench/batched-bench.cpp | 4 +-- examples/batched.swift/Sources/main.swift | 6 ++-- examples/batched/batched.cpp | 2 +- examples/embedding/embedding.cpp | 2 +- examples/gritlm/gritlm.cpp | 12 ++++---- examples/imatrix/imatrix.cpp | 2 +- .../llama/src/main/cpp/llama-android.cpp | 6 ++-- .../llama.cpp.swift/LibLlama.swift | 6 ++-- examples/parallel/parallel.cpp | 4 +-- examples/passkey/passkey.cpp | 4 +-- examples/perplexity/perplexity.cpp | 18 ++++++------ examples/retrieval/retrieval.cpp | 2 +- examples/server/server.cpp | 6 ++-- examples/simple/simple.cpp | 2 +- include/llama.h | 12 ++++---- src/llama.cpp | 28 +++++++++---------- 18 files changed, 60 insertions(+), 60 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 382d585a5e6f9..1a0a3aa6b6e5d 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -2666,7 +2666,7 @@ void llama_batch_add( for (size_t i = 0; i < seq_ids.size(); ++i) { batch.seq_id[batch.n_tokens][i] = seq_ids[i]; } - batch.logits [batch.n_tokens] = logits; + batch.output [batch.n_tokens] = logits; batch.n_tokens++; } diff --git a/common/log.h b/common/log.h index 1bc5328ce3e11..2a7dadd7d1c3f 100644 --- a/common/log.h +++ b/common/log.h @@ -686,7 +686,7 @@ inline std::string LOG_BATCH_TOSTR_PRETTY(const C & ctx, const B & batch) << ":pos " << std::to_string(batch.pos[i]) << ":n_seq_id " << std::to_string(batch.n_seq_id[i]) << ":seq_id " << std::to_string(batch.seq_id[i][0]) - << ":logits " << std::to_string(batch.logits[i]); + << ":logits " << std::to_string(batch.output[i]); } buf << " ]"; diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index 25e7c775a0095..94589b7fa4518 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -94,7 +94,7 @@ int main(int argc, char ** argv) { batch.pos + i, batch.n_seq_id + i, batch.seq_id + i, - batch.logits + i, + batch.output + i, 0, 0, 0, // unused }; @@ -149,7 +149,7 @@ int main(int argc, char ** argv) { llama_batch_add(batch, 0, i, { j }, false); } } - batch.logits[batch.n_tokens - 1] = true; + batch.output[batch.n_tokens - 1] = true; const auto t_pp_start = ggml_time_us(); diff --git a/examples/batched.swift/Sources/main.swift b/examples/batched.swift/Sources/main.swift index 616494d2d841d..273c24350a8a5 100644 --- a/examples/batched.swift/Sources/main.swift +++ b/examples/batched.swift/Sources/main.swift @@ -86,11 +86,11 @@ for (i, token) in tokens.enumerated() { if let seq_id = batch.seq_id[i] { seq_id[0] = 0 } - batch.logits[i] = 0 + batch.output[i] = 0 } // llama_decode will output logits only for the last token of the prompt -batch.logits[Int(batch.n_tokens) - 1] = 1 +batch.output[Int(batch.n_tokens) - 1] = 1 if llama_decode(context, batch) != 0 { print("llama_decode() failed") @@ -178,7 +178,7 @@ while n_cur <= n_len { if let seq_id = batch.seq_id[Int(batch.n_tokens)] { seq_id[0] = Int32(i) } - batch.logits[Int(batch.n_tokens)] = 1 + batch.output[Int(batch.n_tokens)] = 1 i_batch[i] = batch.n_tokens diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 53fbfb0a8cf2a..265433fedcfa5 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -122,7 +122,7 @@ int main(int argc, char ** argv) { } // llama_decode will output logits only for the last token of the prompt - batch.logits[batch.n_tokens - 1] = true; + batch.output[batch.n_tokens - 1] = true; if (llama_decode(ctx, batch) != 0) { LOG_TEE("%s: llama_decode() failed\n", __func__); diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index b05aa006e7da5..f0cbf8b168cd9 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -52,7 +52,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu } for (int i = 0; i < batch.n_tokens; i++) { - if (!batch.logits[i]) { + if (!batch.output[i]) { continue; } diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index 2c61c2e1eb3bc..69bb96d7f51c0 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -102,21 +102,21 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo llama_set_embeddings(ctx, false); llama_set_causal_attn(ctx, true); - llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1); + llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1); std::vector inputs = llama_tokenize(mdl, prompt, false, true); int32_t i_current_token = 0; while (true) { - llama_batch_clear(bat); + llama_batch_clear(batch); auto n_inputs = (int32_t)inputs.size(); for (int32_t i = 0; i < n_inputs; i++) { - llama_batch_add(bat, inputs[i], i_current_token++, { 0 }, i == n_inputs - 1); + llama_batch_add(batch, inputs[i], i_current_token++, { 0 }, i == n_inputs - 1); } inputs.clear(); - llama_decode(ctx, bat); - auto logits = llama_get_logits_ith(ctx, bat.n_tokens - 1); + llama_decode(ctx, batch); + auto logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); auto candidates = std::vector(llama_n_vocab(mdl)); auto n_candidates = (int32_t)candidates.size(); @@ -145,7 +145,7 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo std::printf("\n"); } - llama_batch_free(bat); + llama_batch_free(batch); return result; } diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index 83b85d72b043a..f988e141982ca 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -513,7 +513,7 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) { tokens[batch_start] = llama_token_bos(llama_get_model(ctx)); } - // TODO: use batch.logits to save computations instead of relying on logits_all == true + // TODO: use batch.output to save computations instead of relying on logits_all == true if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) { fprintf(stderr, "%s : failed to eval\n", __func__); return false; diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp index 2aafe23167557..816a37e705800 100644 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -193,7 +193,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( llama_batch_add(*batch, 0, i, { 0 }, false); } - batch->logits[batch->n_tokens - 1] = true; + batch->output[batch->n_tokens - 1] = true; llama_kv_cache_clear(context); const auto t_pp_start = ggml_time_us(); @@ -306,7 +306,7 @@ Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, for (int i = 0; i < n_tokens; ++i) { batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max); } - batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); + batch->output = (int8_t *) malloc(sizeof(int8_t) * n_tokens); return reinterpret_cast(batch); } @@ -363,7 +363,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init( } // llama_decode will output logits only for the last token of the prompt - batch->logits[batch->n_tokens - 1] = true; + batch->output[batch->n_tokens - 1] = true; if (llama_decode(context, *batch) != 0) { LOGe("llama_decode() failed"); diff --git a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift index 58c32ca533bb1..e3eabdd2412ff 100644 --- a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift +++ b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift @@ -16,7 +16,7 @@ func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama for i in 0.. 0) { - batch.logits[batch.n_tokens - 1] = true; + batch.output[batch.n_tokens - 1] = true; } client.n_prompt = tokens_prompt.size(); @@ -308,7 +308,7 @@ int main(int argc, char ** argv) { batch.pos + i, batch.n_seq_id + i, batch.seq_id + i, - batch.logits + i, + batch.output + i, 0, 0, 0, // unused }; diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index d03215cd1e0a9..7fe7001eba454 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -140,7 +140,7 @@ int main(int argc, char ** argv) { } if (i + n_batch >= n_tokens_all) { - batch.logits[batch.n_tokens - 1] = true; + batch.output[batch.n_tokens - 1] = true; } if (llama_decode(ctx, batch) != 0) { @@ -174,7 +174,7 @@ int main(int argc, char ** argv) { } if (i + n_batch >= n_tokens_all) { - batch.logits[batch.n_tokens - 1] = true; + batch.output[batch.n_tokens - 1] = true; } if (llama_decode(ctx, batch) != 0) { diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 484dd589109c7..c7a98ccfd6a21 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -407,7 +407,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & const int batch_size = std::min(end - batch_start, n_batch); //fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch); - // TODO: use llama_batch.logits instead of relying on logits_all == true + // TODO: use llama_batch.output instead of relying on logits_all == true if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) { //fprintf(stderr, "%s : failed to eval\n", __func__); return {tokens, -1, logit_history, prob_history}; @@ -601,9 +601,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par batch.pos [idx] = j*n_batch + k; batch.n_seq_id[idx] = 1; batch.seq_id [idx][0] = seq; - batch.logits [idx] = batch.pos[idx] >= first ? 1 : 0; + batch.output [idx] = batch.pos[idx] >= first ? 1 : 0; - n_outputs += batch.logits[idx] != 0; + n_outputs += batch.output[idx] != 0; } batch.n_tokens += batch_size; @@ -697,7 +697,7 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector< batch.pos + i, batch.n_seq_id + i, batch.seq_id + i, - batch.logits + i, + batch.output + i, 0, 0, 0, // unused }; @@ -709,7 +709,7 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector< int n_outputs = 0; for (int i = 0; i < n_tokens; ++i) { - n_outputs += batch_view.logits[i] != 0; + n_outputs += batch_view.output[i] != 0; } memcpy(batch_logits.data() + prev_outputs*n_vocab, llama_get_logits(ctx), n_outputs*n_vocab*sizeof(float)); @@ -917,7 +917,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { for (size_t i = 0; i < hs_cur.common_prefix; ++i) { llama_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false); } - batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix + batch.output[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix n_logits += 1; for (int s = 0; s < 4; ++s) { @@ -1196,7 +1196,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) { for (size_t i = 0; i < data[i1].common_prefix; ++i) { llama_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false); } - batch.logits[batch.n_tokens - 1] = true; + batch.output[batch.n_tokens - 1] = true; n_logits += 1; for (int s = 0; s < 2; ++s) { @@ -1565,7 +1565,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params //llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false); llama_batch_add(batch, cur_task.seq_tokens[0][i], i, batch_indeces, false); } - batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix + batch.output[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix n_logits += 1; for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) { @@ -1794,7 +1794,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) { tokens[batch_start] = llama_token_bos(llama_get_model(ctx)); } - // TODO: use llama_batch.logits instead of relying on logits_all == true + // TODO: use llama_batch.output instead of relying on logits_all == true if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) { fprintf(stderr, "%s : failed to eval\n", __func__); return; diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index aab9d81058af9..09ec539316354 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -91,7 +91,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu } for (int i = 0; i < batch.n_tokens; i++) { - if (!batch.logits[i]) { + if (!batch.output[i]) { continue; } diff --git a/examples/server/server.cpp b/examples/server/server.cpp index e073f5813d459..322bbfa759bb4 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1480,7 +1480,7 @@ struct server_context { std::vector embd_res(n_embd, 0.0f); for (int i = 0; i < batch.n_tokens; ++i) { - if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) { + if (!batch.output[i] || batch.seq_id[i][0] != slot.id + 1) { continue; } @@ -2269,7 +2269,7 @@ struct server_context { GGML_ASSERT(batch.n_tokens > 0); // extract the logits only for the last token - batch.logits[batch.n_tokens - 1] = true; + batch.output[batch.n_tokens - 1] = true; slot.n_decoded = 0; slot.i_batch = batch.n_tokens - 1; @@ -2341,7 +2341,7 @@ struct server_context { batch.pos + i, batch.n_seq_id + i, batch.seq_id + i, - batch.logits + i, + batch.output + i, 0, 0, 0, // unused }; diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index 69a92cf7dc0c0..e8a3253a25d6e 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -93,7 +93,7 @@ int main(int argc, char ** argv) { } // llama_decode will output logits only for the last token of the prompt - batch.logits[batch.n_tokens - 1] = true; + batch.output[batch.n_tokens - 1] = true; if (llama_decode(ctx, batch) != 0) { LOG_TEE("%s: llama_decode() failed\n", __func__); diff --git a/include/llama.h b/include/llama.h index 188ae76f8001e..532ab42b52137 100644 --- a/include/llama.h +++ b/include/llama.h @@ -220,7 +220,7 @@ extern "C" { // - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL) // - pos : the positions of the respective token in the sequence // - seq_id : the sequence to which the respective token belongs - // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output + // - output : if zero, the logits (and/or the embeddings) for the respective token will not be output // typedef struct llama_batch { int32_t n_tokens; @@ -230,7 +230,7 @@ extern "C" { llama_pos * pos; int32_t * n_seq_id; llama_seq_id ** seq_id; - int8_t * logits; // TODO: rename this to "output" + int8_t * output; // Previously named 'logits', renamed to 'output' now. // NOTE: helpers for smooth API transition - can be deprecated in the future // for future-proof code, use the above fields instead and ignore everything below @@ -328,7 +328,7 @@ extern "C" { enum ggml_type type_v; // data type for V cache [EXPERIMENTAL] // Keep the booleans together to avoid misalignment during copy-by-value. - bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead) + bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.output instead) bool embeddings; // if true, extract embeddings (together with logits) bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU bool flash_attn; // whether to use flash attention [EXPERIMENTAL] @@ -859,9 +859,9 @@ extern "C" { LLAMA_API void llama_synchronize(struct llama_context * ctx); // Token logits obtained from the last call to llama_decode() - // The logits for which llama_batch.logits[i] != 0 are stored contiguously + // The logits for which llama_batch.output[i] != 0 are stored contiguously // in the order they have appeared in the batch. - // Rows: number of tokens for which llama_batch.logits[i] != 0 + // Rows: number of tokens for which llama_batch.output[i] != 0 // Cols: n_vocab LLAMA_API float * llama_get_logits(struct llama_context * ctx); @@ -873,7 +873,7 @@ extern "C" { // Get all output token embeddings. // when pooling_type == LLAMA_POOLING_TYPE_NONE or when using a generative model, - // the embeddings for which llama_batch.logits[i] != 0 are stored contiguously + // the embeddings for which llama_batch.output[i] != 0 are stored contiguously // in the order they have appeared in the batch. // shape: [n_outputs*n_embd] // Otherwise, returns NULL. diff --git a/src/llama.cpp b/src/llama.cpp index 7e9149eb98302..3c36d9495f96b 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -14500,10 +14500,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { for (int i = 0; i < n_tokens; ++i) { data[i] = i; } - } else if (batch.logits) { + } else if (batch.output) { int32_t n_outputs = 0; for (int i = 0; i < n_tokens; ++i) { - if (batch.logits[i]) { + if (batch.output[i]) { data[n_outputs++] = i; } } @@ -14972,13 +14972,13 @@ static int llama_decode_internal( std::vector seq_id_arr; std::vector> seq_id; - // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens + // this indicates we are doing pooled embedding, so we ignore batch.output and output all tokens const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE; // count outputs - if (batch_all.logits && !embd_pooled) { + if (batch_all.output && !embd_pooled) { for (uint32_t i = 0; i < n_tokens_all; ++i) { - n_outputs += batch_all.logits[i] != 0; + n_outputs += batch_all.output[i] != 0; } } else if (lctx.logits_all || embd_pooled) { n_outputs = n_tokens_all; @@ -14994,10 +14994,10 @@ static int llama_decode_internal( }; // set output mappings - if (batch_all.logits) { + if (batch_all.output) { int32_t i_logits = 0; for (uint32_t i = 0; i < n_tokens_all; ++i) { - if (batch_all.logits[i]) { + if (batch_all.output[i]) { lctx.output_ids[i] = i_logits++; } } @@ -15016,7 +15016,7 @@ static int llama_decode_internal( /* .pos = */ batch_all.pos ? batch_all.pos + cur_token : nullptr, /* .n_seq_id = */ batch_all.n_seq_id ? batch_all.n_seq_id + cur_token : nullptr, /* .seq_id = */ batch_all.seq_id ? batch_all.seq_id + cur_token : nullptr, - /* .logits = */ batch_all.logits ? batch_all.logits + cur_token : nullptr, + /* .logits = */ batch_all.output ? batch_all.output + cur_token : nullptr, /* .all_pos_0 = */ batch_all.all_pos_0 + (llama_pos) cur_token*batch_all.all_pos_1, /* .all_pos_1 = */ batch_all.all_pos_1, /* .all_seq_id = */ batch_all.all_seq_id, @@ -15026,9 +15026,9 @@ static int llama_decode_internal( { int32_t n_outputs_new = 0; - if (u_batch.logits && !embd_pooled) { + if (u_batch.output && !embd_pooled) { for (uint32_t i = 0; i < n_tokens; i++) { - n_outputs_new += u_batch.logits[i] != 0; + n_outputs_new += u_batch.output[i] != 0; } } else if (n_outputs == n_tokens_all) { n_outputs_new = n_tokens; @@ -18881,7 +18881,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_ } batch.seq_id[n_tokens_alloc] = nullptr; - batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc); + batch.output = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc); return batch; } @@ -18897,7 +18897,7 @@ void llama_batch_free(struct llama_batch batch) { } free(batch.seq_id); } - if (batch.logits) free(batch.logits); + if (batch.output) free(batch.output); } int32_t llama_encode( @@ -18975,7 +18975,7 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) { } if (j < 0) { - throw std::runtime_error(format("batch.logits[%d] != true", i)); + throw std::runtime_error(format("batch.output[%d] != true", i)); } if (j >= ctx->n_outputs) { // This should not happen @@ -19020,7 +19020,7 @@ float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) { } if (j < 0) { - throw std::runtime_error(format("batch.logits[%d] != true", i)); + throw std::runtime_error(format("batch.output[%d] != true", i)); } if (j >= ctx->n_outputs) { // This should not happen From b0c6ad778d62fcf7dfda6ad88b1c24c1bb6b519f Mon Sep 17 00:00:00 2001 From: Jia Liu Date: Tue, 20 Aug 2024 17:16:27 +0800 Subject: [PATCH 2/5] avoid relying on 'logits_all == true' in perplexity_v2 --- examples/perplexity/perplexity.cpp | 32 ++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index c7a98ccfd6a21..517f93e073bcf 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -367,17 +367,15 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & return {tokens, -1, logit_history, prob_history}; } - const int calc_chunk = n_ctx; + fprintf(stderr, "%s: have %zu tokens. Calculation chunk = %d\n", __func__, tokens.size(), n_ctx); - fprintf(stderr, "%s: have %zu tokens. Calculation chunk = %d\n", __func__, tokens.size(), calc_chunk); - - if (int(tokens.size()) <= calc_chunk) { + if (int(tokens.size()) <= n_ctx) { fprintf(stderr, "%s: there are only %zu tokens, this is not enough for a context size of %d and stride %d\n",__func__, tokens.size(), n_ctx, params.ppl_stride); return {tokens, -1, logit_history, prob_history}; } - const int n_chunk_max = (tokens.size() - calc_chunk + params.ppl_stride - 1) / params.ppl_stride; + const int n_chunk_max = (tokens.size() - n_ctx + params.ppl_stride - 1) / params.ppl_stride; const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max); const int n_vocab = llama_n_vocab(llama_get_model(ctx)); @@ -386,13 +384,13 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & int count = 0; double nll = 0.0; + const int num_batches = (n_ctx + n_batch - 1) / n_batch; + fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch); for (int i = 0; i < n_chunk; ++i) { const int start = i * params.ppl_stride; - const int end = start + calc_chunk; - - const int num_batches = (calc_chunk + n_batch - 1) / n_batch; + const int end = start + n_ctx; //fprintf(stderr, "%s: evaluating %d...%d using %d batches\n", __func__, start, end, num_batches); std::vector logits; @@ -406,13 +404,27 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & const int batch_start = start + j * n_batch; const int batch_size = std::min(end - batch_start, n_batch); + llama_batch batch = llama_batch_init(batch_size, 0, 1); + for (int k = 0; k < batch_size; ++k) { + const int idx = batch_start + k; + batch.token [k] = tokens[idx]; + batch.output [k] = 1; + } + batch.n_tokens = batch_size; + batch.pos = nullptr; + batch.n_seq_id = nullptr; + batch.seq_id = nullptr; + batch.all_pos_0 = j*n_batch; + batch.all_pos_1 = 1; + batch.all_seq_id = 0; + //fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch); - // TODO: use llama_batch.output instead of relying on logits_all == true - if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) { + if (llama_decode(ctx, batch)) { //fprintf(stderr, "%s : failed to eval\n", __func__); return {tokens, -1, logit_history, prob_history}; } + llama_batch_free(batch); // save original token and restore it after eval const auto token_org = tokens[batch_start]; From 27ecb076cf2fd1809f523b7f19699ea97b7ec53a Mon Sep 17 00:00:00 2001 From: Jia Liu Date: Thu, 22 Aug 2024 11:28:16 +0800 Subject: [PATCH 3/5] simplify the code --- examples/perplexity/perplexity.cpp | 11 +---------- src/llama.cpp | 6 +++--- 2 files changed, 4 insertions(+), 13 deletions(-) diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 517f93e073bcf..9fec608785929 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -399,7 +399,6 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & // clear the KV cache llama_kv_cache_clear(ctx); - for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; const int batch_size = std::min(end - batch_start, n_batch); @@ -407,16 +406,8 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & llama_batch batch = llama_batch_init(batch_size, 0, 1); for (int k = 0; k < batch_size; ++k) { const int idx = batch_start + k; - batch.token [k] = tokens[idx]; - batch.output [k] = 1; + llama_batch_add(batch, tokens[idx], j*n_batch + k, {0}, true); } - batch.n_tokens = batch_size; - batch.pos = nullptr; - batch.n_seq_id = nullptr; - batch.seq_id = nullptr; - batch.all_pos_0 = j*n_batch; - batch.all_pos_1 = 1; - batch.all_seq_id = 0; //fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch); if (llama_decode(ctx, batch)) { diff --git a/src/llama.cpp b/src/llama.cpp index 7e798a07ea47f..11fbd9313ef46 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2874,17 +2874,17 @@ struct llama_sbatch { ubatch.output[ubatch.n_tokens + i] = 1; out_ids.push_back(ids[seq.offset + i]); } - } else if (batch->logits) { + } else if (batch->output) { if (ubatch.equal_seqs) { for (size_t i = 0; i < length; ++i) { size_t id = ids[seq.offset + i]; - int8_t is_output = batch->logits[id]; + int8_t is_output = batch->output[id]; ubatch.output[ubatch.n_tokens + i] = is_output; if (is_output) { out_ids.push_back(id); } } } else { // simple split - ubatch.output = batch->logits + seq.offset; + ubatch.output = batch->output + seq.offset; for (size_t i = 0; i < length; ++i) { if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); } } From 0451b1f9ef8a2f95f0360fe881054d249a3f5a3b Mon Sep 17 00:00:00 2001 From: Jia Liu Date: Thu, 22 Aug 2024 14:33:01 +0800 Subject: [PATCH 4/5] re-use same llama_batch --- examples/perplexity/perplexity.cpp | 8 ++++++-- include/llama.h | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 9fec608785929..ae21bfbaf8db4 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -385,6 +385,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & double nll = 0.0; const int num_batches = (n_ctx + n_batch - 1) / n_batch; + llama_batch batch = llama_batch_init(n_batch, 0, 1); fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch); @@ -403,7 +404,8 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & const int batch_start = start + j * n_batch; const int batch_size = std::min(end - batch_start, n_batch); - llama_batch batch = llama_batch_init(batch_size, 0, 1); + llama_batch_clear(batch); + for (int k = 0; k < batch_size; ++k) { const int idx = batch_start + k; llama_batch_add(batch, tokens[idx], j*n_batch + k, {0}, true); @@ -415,7 +417,6 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & return {tokens, -1, logit_history, prob_history}; } - llama_batch_free(batch); // save original token and restore it after eval const auto token_org = tokens[batch_start]; @@ -468,6 +469,9 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & } fflush(stdout); } + + llama_batch_free(batch); + printf("\n"); return {tokens, std::exp(nll / count), logit_history, prob_history}; diff --git a/include/llama.h b/include/llama.h index d1b2238c334d2..e942c9d3f9353 100644 --- a/include/llama.h +++ b/include/llama.h @@ -230,7 +230,7 @@ extern "C" { llama_pos * pos; int32_t * n_seq_id; llama_seq_id ** seq_id; - int8_t * output; // Previously named 'logits', renamed to 'output' now. + int8_t * output; // Previously named 'logits', renamed to 'output' now. // NOTE: helpers for smooth API transition - can be deprecated in the future // for future-proof code, use the above fields instead and ignore everything below From 395ae48cb01831508c689bcc5f678ac31b7534b4 Mon Sep 17 00:00:00 2001 From: Jia Liu Date: Thu, 22 Aug 2024 14:39:43 +0800 Subject: [PATCH 5/5] Format modification --- examples/perplexity/perplexity.cpp | 1 + include/llama.h | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index ae21bfbaf8db4..e1ea57224cfe0 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -400,6 +400,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & // clear the KV cache llama_kv_cache_clear(ctx); + for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; const int batch_size = std::min(end - batch_start, n_batch); diff --git a/include/llama.h b/include/llama.h index e942c9d3f9353..ed5f9bd420511 100644 --- a/include/llama.h +++ b/include/llama.h @@ -230,7 +230,7 @@ extern "C" { llama_pos * pos; int32_t * n_seq_id; llama_seq_id ** seq_id; - int8_t * output; // Previously named 'logits', renamed to 'output' now. + int8_t * output; // Previously named "logits", renamed to "output" now. // NOTE: helpers for smooth API transition - can be deprecated in the future // for future-proof code, use the above fields instead and ignore everything below