Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix todo: avoid relying on logits_all == true in perplexity_v2 #9102

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2666,7 +2666,7 @@ void llama_batch_add(
for (size_t i = 0; i < seq_ids.size(); ++i) {
batch.seq_id[batch.n_tokens][i] = seq_ids[i];
}
batch.logits [batch.n_tokens] = logits;
batch.output [batch.n_tokens] = logits;

batch.n_tokens++;
}
Expand Down
2 changes: 1 addition & 1 deletion common/log.h
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,7 @@ inline std::string LOG_BATCH_TOSTR_PRETTY(const C & ctx, const B & batch)
<< ":pos " << std::to_string(batch.pos[i])
<< ":n_seq_id " << std::to_string(batch.n_seq_id[i])
<< ":seq_id " << std::to_string(batch.seq_id[i][0])
<< ":logits " << std::to_string(batch.logits[i]);
<< ":logits " << std::to_string(batch.output[i]);
}
buf << " ]";

Expand Down
4 changes: 2 additions & 2 deletions examples/batched-bench/batched-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ int main(int argc, char ** argv) {
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
batch.output + i,
0, 0, 0, // unused
};

Expand Down Expand Up @@ -149,7 +149,7 @@ int main(int argc, char ** argv) {
llama_batch_add(batch, 0, i, { j }, false);
}
}
batch.logits[batch.n_tokens - 1] = true;
batch.output[batch.n_tokens - 1] = true;

const auto t_pp_start = ggml_time_us();

Expand Down
6 changes: 3 additions & 3 deletions examples/batched.swift/Sources/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,11 @@ for (i, token) in tokens.enumerated() {
if let seq_id = batch.seq_id[i] {
seq_id[0] = 0
}
batch.logits[i] = 0
batch.output[i] = 0
}

// llama_decode will output logits only for the last token of the prompt
batch.logits[Int(batch.n_tokens) - 1] = 1
batch.output[Int(batch.n_tokens) - 1] = 1

if llama_decode(context, batch) != 0 {
print("llama_decode() failed")
Expand Down Expand Up @@ -178,7 +178,7 @@ while n_cur <= n_len {
if let seq_id = batch.seq_id[Int(batch.n_tokens)] {
seq_id[0] = Int32(i)
}
batch.logits[Int(batch.n_tokens)] = 1
batch.output[Int(batch.n_tokens)] = 1

i_batch[i] = batch.n_tokens

Expand Down
2 changes: 1 addition & 1 deletion examples/batched/batched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ int main(int argc, char ** argv) {
}

// llama_decode will output logits only for the last token of the prompt
batch.logits[batch.n_tokens - 1] = true;
batch.output[batch.n_tokens - 1] = true;

if (llama_decode(ctx, batch) != 0) {
LOG_TEE("%s: llama_decode() failed\n", __func__);
Expand Down
2 changes: 1 addition & 1 deletion examples/embedding/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
}

for (int i = 0; i < batch.n_tokens; i++) {
if (!batch.logits[i]) {
if (!batch.output[i]) {
continue;
}

Expand Down
12 changes: 6 additions & 6 deletions examples/gritlm/gritlm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,21 +102,21 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo
llama_set_embeddings(ctx, false);
llama_set_causal_attn(ctx, true);

llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);
llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1);

std::vector<llama_token> inputs = llama_tokenize(mdl, prompt, false, true);
int32_t i_current_token = 0;

while (true) {
llama_batch_clear(bat);
llama_batch_clear(batch);
auto n_inputs = (int32_t)inputs.size();
for (int32_t i = 0; i < n_inputs; i++) {
llama_batch_add(bat, inputs[i], i_current_token++, { 0 }, i == n_inputs - 1);
llama_batch_add(batch, inputs[i], i_current_token++, { 0 }, i == n_inputs - 1);
}
inputs.clear();

llama_decode(ctx, bat);
auto logits = llama_get_logits_ith(ctx, bat.n_tokens - 1);
llama_decode(ctx, batch);
auto logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);

auto candidates = std::vector<llama_token_data>(llama_n_vocab(mdl));
auto n_candidates = (int32_t)candidates.size();
Expand Down Expand Up @@ -145,7 +145,7 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo
std::printf("\n");
}

llama_batch_free(bat);
llama_batch_free(batch);

return result;
}
Expand Down
2 changes: 1 addition & 1 deletion examples/imatrix/imatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
}

// TODO: use batch.logits to save computations instead of relying on logits_all == true
// TODO: use batch.output to save computations instead of relying on logits_all == true
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;
Expand Down
6 changes: 3 additions & 3 deletions examples/llama.android/llama/src/main/cpp/llama-android.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
llama_batch_add(*batch, 0, i, { 0 }, false);
}

batch->logits[batch->n_tokens - 1] = true;
batch->output[batch->n_tokens - 1] = true;
llama_kv_cache_clear(context);

const auto t_pp_start = ggml_time_us();
Expand Down Expand Up @@ -306,7 +306,7 @@ Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens,
for (int i = 0; i < n_tokens; ++i) {
batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
}
batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
batch->output = (int8_t *) malloc(sizeof(int8_t) * n_tokens);

return reinterpret_cast<jlong>(batch);
}
Expand Down Expand Up @@ -363,7 +363,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
}

// llama_decode will output logits only for the last token of the prompt
batch->logits[batch->n_tokens - 1] = true;
batch->output[batch->n_tokens - 1] = true;

if (llama_decode(context, *batch) != 0) {
LOGe("llama_decode() failed");
Expand Down
6 changes: 3 additions & 3 deletions examples/llama.swiftui/llama.cpp.swift/LibLlama.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama
for i in 0..<seq_ids.count {
batch.seq_id[Int(batch.n_tokens)]![Int(i)] = seq_ids[i]
}
batch.logits [Int(batch.n_tokens)] = logits ? 1 : 0
batch.output [Int(batch.n_tokens)] = logits ? 1 : 0

batch.n_tokens += 1
}
Expand Down Expand Up @@ -132,7 +132,7 @@ actor LlamaContext {
let i = Int(i1)
llama_batch_add(&batch, tokens_list[i], Int32(i), [0], false)
}
batch.logits[Int(batch.n_tokens) - 1] = 1 // true
batch.output[Int(batch.n_tokens) - 1] = 1 // true

if llama_decode(context, batch) != 0 {
print("llama_decode() failed")
Expand Down Expand Up @@ -214,7 +214,7 @@ actor LlamaContext {
for i in 0..<n_tokens {
llama_batch_add(&batch, 0, Int32(i), [0], false)
}
batch.logits[Int(batch.n_tokens) - 1] = 1 // true
batch.output[Int(batch.n_tokens) - 1] = 1 // true

llama_kv_cache_clear(context)

Expand Down
4 changes: 2 additions & 2 deletions examples/parallel/parallel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ int main(int argc, char ** argv) {

// extract the logits only for the last token
if (batch.n_tokens > 0) {
batch.logits[batch.n_tokens - 1] = true;
batch.output[batch.n_tokens - 1] = true;
}

client.n_prompt = tokens_prompt.size();
Expand Down Expand Up @@ -308,7 +308,7 @@ int main(int argc, char ** argv) {
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
batch.output + i,
0, 0, 0, // unused
};

Expand Down
4 changes: 2 additions & 2 deletions examples/passkey/passkey.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ int main(int argc, char ** argv) {
}

if (i + n_batch >= n_tokens_all) {
batch.logits[batch.n_tokens - 1] = true;
batch.output[batch.n_tokens - 1] = true;
}

if (llama_decode(ctx, batch) != 0) {
Expand Down Expand Up @@ -174,7 +174,7 @@ int main(int argc, char ** argv) {
}

if (i + n_batch >= n_tokens_all) {
batch.logits[batch.n_tokens - 1] = true;
batch.output[batch.n_tokens - 1] = true;
}

if (llama_decode(ctx, batch) != 0) {
Expand Down
48 changes: 30 additions & 18 deletions examples/perplexity/perplexity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,17 +367,15 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
return {tokens, -1, logit_history, prob_history};
}

const int calc_chunk = n_ctx;
fprintf(stderr, "%s: have %zu tokens. Calculation chunk = %d\n", __func__, tokens.size(), n_ctx);

fprintf(stderr, "%s: have %zu tokens. Calculation chunk = %d\n", __func__, tokens.size(), calc_chunk);

if (int(tokens.size()) <= calc_chunk) {
if (int(tokens.size()) <= n_ctx) {
fprintf(stderr, "%s: there are only %zu tokens, this is not enough for a context size of %d and stride %d\n",__func__,
tokens.size(), n_ctx, params.ppl_stride);
return {tokens, -1, logit_history, prob_history};
}

const int n_chunk_max = (tokens.size() - calc_chunk + params.ppl_stride - 1) / params.ppl_stride;
const int n_chunk_max = (tokens.size() - n_ctx + params.ppl_stride - 1) / params.ppl_stride;

const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
Expand All @@ -386,13 +384,13 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
int count = 0;
double nll = 0.0;

const int num_batches = (n_ctx + n_batch - 1) / n_batch;

fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);

for (int i = 0; i < n_chunk; ++i) {
const int start = i * params.ppl_stride;
const int end = start + calc_chunk;

const int num_batches = (calc_chunk + n_batch - 1) / n_batch;
const int end = start + n_ctx;
//fprintf(stderr, "%s: evaluating %d...%d using %d batches\n", __func__, start, end, num_batches);

std::vector<float> logits;
Expand All @@ -406,13 +404,27 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
const int batch_start = start + j * n_batch;
const int batch_size = std::min(end - batch_start, n_batch);

llama_batch batch = llama_batch_init(batch_size, 0, 1);
Septa2112 marked this conversation as resolved.
Show resolved Hide resolved
for (int k = 0; k < batch_size; ++k) {
const int idx = batch_start + k;
batch.token [k] = tokens[idx];
batch.output [k] = 1;
}
batch.n_tokens = batch_size;
batch.pos = nullptr;
batch.n_seq_id = nullptr;
batch.seq_id = nullptr;
batch.all_pos_0 = j*n_batch;
batch.all_pos_1 = 1;
batch.all_seq_id = 0;
compilade marked this conversation as resolved.
Show resolved Hide resolved

//fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
// TODO: use llama_batch.logits instead of relying on logits_all == true
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
if (llama_decode(ctx, batch)) {
//fprintf(stderr, "%s : failed to eval\n", __func__);
return {tokens, -1, logit_history, prob_history};
}

llama_batch_free(batch);
// save original token and restore it after eval
const auto token_org = tokens[batch_start];

Expand Down Expand Up @@ -601,9 +613,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
batch.pos [idx] = j*n_batch + k;
batch.n_seq_id[idx] = 1;
batch.seq_id [idx][0] = seq;
batch.logits [idx] = batch.pos[idx] >= first ? 1 : 0;
batch.output [idx] = batch.pos[idx] >= first ? 1 : 0;

n_outputs += batch.logits[idx] != 0;
n_outputs += batch.output[idx] != 0;
}
batch.n_tokens += batch_size;

Expand Down Expand Up @@ -697,7 +709,7 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
batch.output + i,
0, 0, 0, // unused
};

Expand All @@ -709,7 +721,7 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<

int n_outputs = 0;
for (int i = 0; i < n_tokens; ++i) {
n_outputs += batch_view.logits[i] != 0;
n_outputs += batch_view.output[i] != 0;
}

memcpy(batch_logits.data() + prev_outputs*n_vocab, llama_get_logits(ctx), n_outputs*n_vocab*sizeof(float));
Expand Down Expand Up @@ -917,7 +929,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
for (size_t i = 0; i < hs_cur.common_prefix; ++i) {
llama_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false);
}
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
batch.output[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
n_logits += 1;

for (int s = 0; s < 4; ++s) {
Expand Down Expand Up @@ -1196,7 +1208,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
for (size_t i = 0; i < data[i1].common_prefix; ++i) {
llama_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false);
}
batch.logits[batch.n_tokens - 1] = true;
batch.output[batch.n_tokens - 1] = true;
n_logits += 1;

for (int s = 0; s < 2; ++s) {
Expand Down Expand Up @@ -1565,7 +1577,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
//llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false);
llama_batch_add(batch, cur_task.seq_tokens[0][i], i, batch_indeces, false);
}
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
batch.output[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
n_logits += 1;

for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {
Expand Down Expand Up @@ -1794,7 +1806,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
}

// TODO: use llama_batch.logits instead of relying on logits_all == true
// TODO: use llama_batch.output instead of relying on logits_all == true
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return;
Expand Down
2 changes: 1 addition & 1 deletion examples/retrieval/retrieval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
}

for (int i = 0; i < batch.n_tokens; i++) {
if (!batch.logits[i]) {
if (!batch.output[i]) {
continue;
}

Expand Down
6 changes: 3 additions & 3 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1480,7 +1480,7 @@ struct server_context {
std::vector<float> embd_res(n_embd, 0.0f);

for (int i = 0; i < batch.n_tokens; ++i) {
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) {
if (!batch.output[i] || batch.seq_id[i][0] != slot.id + 1) {
continue;
}

Expand Down Expand Up @@ -2269,7 +2269,7 @@ struct server_context {
GGML_ASSERT(batch.n_tokens > 0);

// extract the logits only for the last token
batch.logits[batch.n_tokens - 1] = true;
batch.output[batch.n_tokens - 1] = true;

slot.n_decoded = 0;
slot.i_batch = batch.n_tokens - 1;
Expand Down Expand Up @@ -2341,7 +2341,7 @@ struct server_context {
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
batch.output + i,
0, 0, 0, // unused
};

Expand Down
2 changes: 1 addition & 1 deletion examples/simple/simple.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ int main(int argc, char ** argv) {
}

// llama_decode will output logits only for the last token of the prompt
batch.logits[batch.n_tokens - 1] = true;
batch.output[batch.n_tokens - 1] = true;

if (llama_decode(ctx, batch) != 0) {
LOG_TEE("%s: llama_decode() failed\n", __func__);
Expand Down
Loading
Loading