-
Notifications
You must be signed in to change notification settings - Fork 26.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GPT model generate()
function not correctly skipping the padding tokens indicated by attention_mask
#14521
Comments
generate()
function not correctly skipping the padding tokens indicated by attention_mask
Maybe of interest to @patrickvonplaten @Narsil |
Update: I changed my experiment code from right padding to left padding and the performance is greatly improved. If the |
I just checked, and the attention_mask is correctly sent back to the model Looking at the code, the Then you can check that the attention_mask adds a very large negative number : https://github.com/huggingface/transformers/blob/master/src/transformers/models/gpt_neo/modeling_gpt_neo.py#L200 I am not familiar enough with the internals to know if that's enough, but it definitely seems to be doing what it should. I even tried a much smaller example: input_str_1 = "This is a test of<|endoftext|>"
input_str_2 = "This is a test<|endoftext|> of" Now I checked that the ids are actually correct ( which is not necessarily the case with extra spaces etc..)
And then both generate exactly the same thing. Is there a possibility that the issue comes from slightly twisted |
Hi @Narsil, thanks a lot for the reply! Yeah, I can see those code as well and it seems to be doing the correct thing but the results I am getting suggests otherwise. It is possible, however, related to how GPT-NEO handles those positional ids internally. With the smaller example here, though the generated sequences are the same, the logits are actually different, which is why it exhibits the incorrect behavior in longer sequences. Here is the code to reproduce:
The output I got is:
The output scores only differs in a very small amount since the sequence is short and the position of the padding token is only off-by-one, but it's still different. |
Tagging @patil-suraj, if you have more information on how the Just for reference, I also checked outputs, and indeed there's variance (even more than in you post, I get:
|
Jumping in the conversation here to maybe solve some problems. One thing to remember is that This means one should never look at the output of the padding token, i.e. in @Narsil example:
this means that the last row of the first logits and the previous to last row of the second logits are useless (they correspond to padding tokens). What we should instead compare here is the previous to last row of the first logits to the last row of the second logits (both corresponding to the output logits of Now as a conclusion for padded inputs to GPT-like models one should always use
|
@patrickvonplaten thanks a lot for the clarification! It confirms what I found in the experiments -- right padding for the GPT-like model is incorrect and leads to performance degradation. However, I do think the problem for not correctly skipping the padding tokens still exists in general. if sampling from the padding token will lead to incorrect results, then in the following examples, the logits for the generated tokens should be the same since the last token is not padding token anymore:
However, the output I've been getting is:
Notice that they look the same, but when doing subtraction and summation, we can see they are of different values. In principle, if the padding tokens are correctly skipped everywhere, then it would not matter even if I have input like this:
Or am I understanding it incorrectly? The full code snippet I used to generate the output is pasted below:
|
Hey @niansong1996, I think your understanding is very much correct here. If I understand your example
you are seeing (very) small differences in the output logits that shouldn't be there.
Now taking this into account for your example:
It means the following for I think you're reasoning is 100% correct and think those small differences on what values are used for padding could be the explanation - you could maybe try to replace all |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
I found this issue extremely helpful for my experiment. I was wondering why pretrained decoder-only LM's are failing to generate anything with |
Why samplimg from |
According to #7552, the padding tokens will be skipped when calculating the
postional_id
duringgenerate()
, if the corresponding positions are masked out inattention_mask
. If I understand this correctly, this would mean that the appearance of padding tokens does not matter as long as they are not attended to. However, I found that it is not exactly the case, do I miss something here?Check the following code for reproduction:
The text was updated successfully, but these errors were encountered: