Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix todo: avoid relying on logits_all == true in perplexity_v2 #9102

Open
wants to merge 6 commits into
base: master
Choose a base branch
from

Conversation

Septa2112
Copy link
Contributor

@Septa2112 Septa2112 commented Aug 20, 2024


Changes

From #9037 (comment) and some related comments in perplexity.cpp, I noticed that logits_all seems to be deprecated?

So I made the following changes:

  1. By setting llama_batch.logits, avoid relying on logits_all == true when running in function perplexity_v2 .
  2. Completed a todo along the way: rename llama_batch.logits to llama_batch.output.

Test Platform

Linux 6.5.0-41-generic #41~22.04.2-Ubuntu SMP PREEMPT_DYNAMIC Mon Jun 3 11:32:55 UTC 2 x86_64 x86_64 x86_64 GNU/Linux

  • CPU: i7-11700k

Test Command

./build/bin/llama-perplexity -m ../gguf_models/llama-2-7b.Q2_K.gguf -f ../prompt/article.txt -t 16 --ppl-stride 220

Test Results

Before

perplexity_v2: tokenizing the input ..
perplexity_v2: have 4394 tokens. Calculation chunk = 2176
perplexity_v2: calculating perplexity over 11 chunks, batch_size=2048
perplexity_v2: 68.55 seconds per pass - ETA 12.57 minutes
[1]1.0052,[2]1.0334,[3]1.0239,[4]1.0191,[5]1.0160,[6]1.3106,[7]1.5908,[8]1.5034,[9]1.4402,[10]1.3894,[11]1.3493,
 
llama_print_timings:        load time =     469.82 ms
llama_print_timings:      sample time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_print_timings: prompt eval time =  807898.38 ms / 23936 tokens (   33.75 ms per token,    29.63 tokens per second)
llama_print_timings:        eval time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_print_timings:       total time =  810532.89 ms / 23937 tokens

After

perplexity_v2: tokenizing the input ..
perplexity_v2: have 4394 tokens. Calculation chunk = 2176
perplexity_v2: calculating perplexity over 11 chunks, batch_size=2048
perplexity_v2: 68.83 seconds per pass - ETA 12.62 minutes
[1]1.0052,[2]1.0334,[3]1.0239,[4]1.0191,[5]1.0160,[6]1.3106,[7]1.5908,[8]1.5034,[9]1.4402,[10]1.3894,[11]1.3493,
 
llama_print_timings:        load time =     465.04 ms
llama_print_timings:      sample time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_print_timings: prompt eval time =  795179.21 ms / 23936 tokens (   33.22 ms per token,    30.10 tokens per second)
llama_print_timings:        eval time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_print_timings:       total time =  797794.73 ms / 23937 tokens

Others

If these changes are acceptable to the community, I'd like to modify other places that depend on logits_all == true, such as

// TODO: use llama_batch.logits instead of relying on logits_all == true
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return;
}

// TODO: use batch.logits to save computations instead of relying on logits_all == true
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;
}

@github-actions github-actions bot added android Issues specific to Android examples server labels Aug 20, 2024
examples/perplexity/perplexity.cpp Outdated Show resolved Hide resolved
examples/perplexity/perplexity.cpp Outdated Show resolved Hide resolved
@compilade
Copy link
Collaborator

@Septa2112

If these changes are acceptable to the community, I'd like to modify other places that depend on logits_all == true, such as

// TODO: use batch.logits to save computations instead of relying on logits_all == true
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;
}

Hi, I wrote that comment in #6122, but after studying imatrix.cpp more I think for imatrix specifically that always using all logits is fine, because otherwise the hidden state going through the FFN of the last layer would get trimmed, which would not be good for the importance matrix calculation.

Though it's true that it could avoid using logits_all, while still specifying that all logits should be output, this is already what is internally done when using logits_all (still, removing it could be worthwhile).

For the strided perplexity calculation (aka perplexity_v2) which is modified here, the main reason this was not updated to use an allocated llama_batch is because I was not sure which output should be kept. For the perplexity() function, it's simply the second half of the logits, while for this one, I'm not sure.

I see in b0c6ad7 that you've simply made it ask for all outputs, same as using logits_all. This is a good start because using an allocated llama_batch is necessary for more control over which outputs to keep.

If you do figure out which outputs to keep, know that the logits in the buffer returned by llama_get_logits are all contiguous (no holes). This means if outputs are excluded, then the offsets in that buffer will need to be corrected. llama_get_logits_ith abstracts this away and uses the original indices used in the llama_batch, but can only be expected to return a pointer to the logits of a single token at a time (although technically, this is all the same buffer, so it's also possible to use it simply to get the offset of a contiguous range).

Hopefully this helps!

@Septa2112
Copy link
Contributor Author

Septa2112 commented Aug 22, 2024

@compilade Thanks a lot for your suggestion! I will carefully consider your advice and propose changes related to imatrix.cpp in another separate PR.

In the current PR, I only plan to modify the perplexity_v2 section. The remaining two parts, I still need time to consider how to modify them. :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
android Issues specific to Android examples server
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants