Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Strided perplexity #2714

Merged
merged 2 commits into from
Aug 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.antiprompt.push_back(argv[i]);
} else if (arg == "--perplexity") {
params.perplexity = true;
} else if (arg == "--ppl-stride") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.ppl_stride = std::stoi(argv[i]);
} else if (arg == "--ppl-output-type") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.ppl_output_type = std::stoi(argv[i]);
Comment on lines +421 to +432
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two parameters do not appear in the --help. I assume this is a simple oversight?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not add it to the help on purpose. As you can see from the table above, the current implementation is very inefficient. Hence, I decided that for now it is better to have this option kind of hidden from general usage and available only to those who pay attention to the commits. Later, when a better handling of RoPE becomes available (I discussed this with @ggerganov and this is on his radar), the implementation can be improved to be (almost) as efficient as the original llama.cpp implementation. My plan was to add the option to the help at that point.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, as we discussed - we need more efficient KV cache reuse.
Will track this in the existing issue on the roadmap: #2060

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this point, now that the KV changes got merged it should be possible to update the new perplexity calculation to perform about the same as the original. Right?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

} else if (arg == "--hellaswag") {
params.hellaswag = true;
} else if (arg == "--hellaswag-tasks") {
Expand Down
4 changes: 4 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ struct gpt_params {
std::string lora_adapter = ""; // lora adapter path
std::string lora_base = ""; // base model path for the lora adapter

int ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used.
int ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line
// (which is more convenient to use for plotting)
//
bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt
size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score

Expand Down
126 changes: 125 additions & 1 deletion examples/perplexity/perplexity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,121 @@ std::vector<float> softmax(const std::vector<float>& logits) {
return probs;
}

void perplexity_v2(llama_context * ctx, const gpt_params & params) {

// Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
// Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
// Output: `perplexity: 13.5106 [114/114]`
// BOS tokens will be added for each chunk before eval

if (params.ppl_stride <= 0) {
fprintf(stderr, "%s: stride is %d but must be greater than zero!\n",__func__,params.ppl_stride);
return;
}
auto tokens = ::llama_tokenize(ctx, params.prompt, true);

const int calc_chunk = params.n_ctx;

fprintf(stderr, "%s: have %zu tokens. Calculation chunk = %d\n", __func__, tokens.size(), calc_chunk);

if (int(tokens.size()) <= calc_chunk) {
fprintf(stderr, "%s: there are only %zu tokens, this is not enough for a context size of %d and stride %d\n",__func__,
tokens.size(), params.n_ctx, params.ppl_stride);
return;
}

const int n_chunk_max = (tokens.size() - calc_chunk + params.ppl_stride - 1) / params.ppl_stride;

const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
const int n_vocab = llama_n_vocab(ctx);
const int n_batch = params.n_batch;

int count = 0;
double nll = 0.0;

fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);

for (int i = 0; i < n_chunk; ++i) {
const int start = i * params.ppl_stride;
const int end = start + calc_chunk;

const int num_batches = (calc_chunk + n_batch - 1) / n_batch;
//fprintf(stderr, "%s: evaluating %d...%d using %d batches\n", __func__, start, end, num_batches);

std::vector<float> logits;

const auto t_start = std::chrono::high_resolution_clock::now();

for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;
const int batch_size = std::min(end - batch_start, n_batch);

//fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) {
//fprintf(stderr, "%s : failed to eval\n", __func__);
return;
}

// save original token and restore it after eval
const auto token_org = tokens[batch_start];

// add BOS token for the first batch of each chunk
if (j == 0) {
tokens[batch_start] = llama_token_bos(ctx);
}

const auto batch_logits = llama_get_logits(ctx);
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);

if (j == 0) {
tokens[batch_start] = token_org;
}
}

const auto t_end = std::chrono::high_resolution_clock::now();

if (i == 0) {
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
int total_seconds = (int)(t_total * n_chunk);
if (total_seconds >= 60*60) {
fprintf(stderr, "%d hours ", total_seconds / (60*60));
total_seconds = total_seconds % (60*60);
}
fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0);
}

//fprintf(stderr, "%s: using tokens %d...%d\n",__func__,params.n_ctx - params.ppl_stride + start, params.n_ctx + start);
for (int j = params.n_ctx - params.ppl_stride - 1; j < params.n_ctx - 1; ++j) {

// Calculate probability of next token, given the previous ones.
const std::vector<float> tok_logits(
logits.begin() + (j + 0) * n_vocab,
logits.begin() + (j + 1) * n_vocab);

const float prob = softmax(tok_logits)[tokens[start + j + 1]];

nll += -std::log(prob);
++count;
}
// perplexity is e^(average negative log-likelihood)
if (params.ppl_output_type == 0) {
printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
} else {
printf("%8d %.4lf\n", i*params.ppl_stride, std::exp(nll / count));
}
fflush(stdout);
}
printf("\n");
}

void perplexity(llama_context * ctx, const gpt_params & params) {

if (params.ppl_stride > 0) {
perplexity_v2(ctx, params);
return;
}

// Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
// Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
// Output: `perplexity: 13.5106 [114/114]`
Expand Down Expand Up @@ -116,7 +230,11 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
++count;
}
// perplexity is e^(average negative log-likelihood)
printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
if (params.ppl_output_type == 0) {
printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
} else {
printf("%8d %.4lf\n", i*params.n_ctx, std::exp(nll / count));
}
fflush(stdout);
}
printf("\n");
Expand Down Expand Up @@ -369,6 +487,12 @@ int main(int argc, char ** argv) {
params.perplexity = true;
params.n_batch = std::min(params.n_batch, params.n_ctx);

if (params.ppl_stride > 0) {
fprintf(stderr, "Will perform strided perplexity calculation -> adjusting context size from %d to %d\n",
params.n_ctx, params.n_ctx + params.ppl_stride/2);
params.n_ctx += params.ppl_stride/2;
}

if (params.n_ctx > 2048) {
fprintf(stderr, "%s: warning: model might not support context sizes greater than 2048 tokens (%d specified);"
"expect poor results\n", __func__, params.n_ctx);
Expand Down