Skip to content

Commit

Permalink
Implementing strided computation of perplexity
Browse files Browse the repository at this point in the history
  • Loading branch information
Kawrakow committed Aug 22, 2023
1 parent 8e4364f commit b791d1f
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 0 deletions.
6 changes: 6 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.antiprompt.push_back(argv[i]);
} else if (arg == "--perplexity") {
params.perplexity = true;
} else if (arg == "--ppl-stride") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.ppl_stride = std::stoi(argv[i]);
} else if (arg == "--hellaswag") {
params.hellaswag = true;
} else if (arg == "--hellaswag-tasks") {
Expand Down
2 changes: 2 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ struct gpt_params {
std::string lora_adapter = ""; // lora adapter path
std::string lora_base = ""; // base model path for the lora adapter

int ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used.
//
bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt
size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score

Expand Down
116 changes: 116 additions & 0 deletions examples/perplexity/perplexity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,117 @@ std::vector<float> softmax(const std::vector<float>& logits) {
return probs;
}

void perplexity_v2(llama_context * ctx, const gpt_params & params) {

// Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
// Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
// Output: `perplexity: 13.5106 [114/114]`
// BOS tokens will be added for each chunk before eval

if (params.ppl_stride <= 0) {
fprintf(stderr, "%s: stride is %d but must be greater than zero!\n",__func__,params.ppl_stride);
return;
}
auto tokens = ::llama_tokenize(ctx, params.prompt, true);

const int calc_chunk = params.n_ctx;

fprintf(stderr, "%s: have %zu tokens. Calculation chunk = %d\n", __func__, tokens.size(), calc_chunk);

if (int(tokens.size()) <= calc_chunk) {
fprintf(stderr, "%s: there are only %zu tokens, this is not enough for a context size of %d and stride %d\n",__func__,
tokens.size(), params.n_ctx, params.ppl_stride);
return;
}

const int n_chunk_max = (tokens.size() - calc_chunk + params.ppl_stride - 1) / params.ppl_stride;

const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
const int n_vocab = llama_n_vocab(ctx);
const int n_batch = params.n_batch;

int count = 0;
double nll = 0.0;

fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);

for (int i = 0; i < n_chunk; ++i) {
const int start = i * params.ppl_stride;
const int end = start + calc_chunk;

const int num_batches = (calc_chunk + n_batch - 1) / n_batch;
//fprintf(stderr, "%s: evaluating %d...%d using %d batches\n", __func__, start, end, num_batches);

std::vector<float> logits;

const auto t_start = std::chrono::high_resolution_clock::now();

for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;
const int batch_size = std::min(end - batch_start, n_batch);

//fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) {
//fprintf(stderr, "%s : failed to eval\n", __func__);
return;
}

// save original token and restore it after eval
const auto token_org = tokens[batch_start];

// add BOS token for the first batch of each chunk
if (j == 0) {
tokens[batch_start] = llama_token_bos(ctx);
}

const auto batch_logits = llama_get_logits(ctx);
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);

if (j == 0) {
tokens[batch_start] = token_org;
}
}

const auto t_end = std::chrono::high_resolution_clock::now();

if (i == 0) {
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
int total_seconds = (int)(t_total * n_chunk);
if (total_seconds >= 60*60) {
fprintf(stderr, "%d hours ", total_seconds / (60*60));
total_seconds = total_seconds % (60*60);
}
fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0);
}

//fprintf(stderr, "%s: using tokens %d...%d\n",__func__,params.n_ctx - params.ppl_stride + start, params.n_ctx + start);
for (int j = params.n_ctx - params.ppl_stride - 1; j < params.n_ctx - 1; ++j) {

// Calculate probability of next token, given the previous ones.
const std::vector<float> tok_logits(
logits.begin() + (j + 0) * n_vocab,
logits.begin() + (j + 1) * n_vocab);

const float prob = softmax(tok_logits)[tokens[start + j + 1]];

nll += -std::log(prob);
++count;
}
// perplexity is e^(average negative log-likelihood)
printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
fflush(stdout);
}
printf("\n");
}

void perplexity(llama_context * ctx, const gpt_params & params) {

if (params.ppl_stride > 0) {
perplexity_v2(ctx, params);
return;
}

// Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
// Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
// Output: `perplexity: 13.5106 [114/114]`
Expand Down Expand Up @@ -369,6 +479,12 @@ int main(int argc, char ** argv) {
params.perplexity = true;
params.n_batch = std::min(params.n_batch, params.n_ctx);

if (params.ppl_stride > 0) {
fprintf(stderr, "Will perform strided perplexity calculation -> adjusting context size from %d to %d\n",
params.n_ctx, params.n_ctx + params.ppl_stride/2);
params.n_ctx += params.ppl_stride/2;
}

if (params.n_ctx > 2048) {
fprintf(stderr, "%s: warning: model might not support context sizes greater than 2048 tokens (%d specified);"
"expect poor results\n", __func__, params.n_ctx);
Expand Down

0 comments on commit b791d1f

Please sign in to comment.