-
Notifications
You must be signed in to change notification settings - Fork 26.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added repetition penalty to PPLM example #2436
Conversation
Codecov Report
@@ Coverage Diff @@
## master #2436 +/- ##
=======================================
Coverage 73.24% 73.24%
=======================================
Files 87 87
Lines 14989 14989
=======================================
Hits 10979 10979
Misses 4010 4010 Continue to review full report at Codecov.
|
what do you think @w4nderlust @mimosavvy? |
[IWillPull here, writing from a personal acc] Do not merge yet. I think it's best to explain in the help text that this was not in the original paper and change the default value to 1.0 so it doesn't influence anything by default. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We anted to add the repetition penalty ourselves, but you were faster :) Added two minor considerations. Anyway if before you were getting awful results it's likely because of sub-optimal parameter choices, as we obtained good results without the need for the repetition penalty. Anyway, actually having it will only make things even better.
Also, I would add a comment clarifying that this parameter was not described in the paper and thus it is 1 by default, just to avoid confusion for people looking at the code.
examples/pplm/run_pplm.py
Outdated
@@ -508,6 +518,13 @@ def generate_text_pplm( | |||
|
|||
pert_logits, past, pert_all_hidden = model(last, past=pert_past) | |||
pert_logits = pert_logits[:, -1, :] / temperature # + SMALL_CONST | |||
|
|||
for j in set(output_so_far[0].tolist()): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would rename j
to token_idx
for readability
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
examples/pplm/run_pplm.py
Outdated
|
||
for j in set(output_so_far[0].tolist()): | ||
if pert_logits[0, j] < 0: | ||
pert_logits[0, j] *= repetition_penalty |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the definition at the end of 4.1 in https://einstein.ai/presentations/ctrl.pdf there is no distinction between logits < 0 and > 0. I understand the reason why logits < 0 should be treated separately: dividing them by a penalty score would actually increase the value and thus increase the probability after softmax is computed, but I'd like a reference for it if possible, as in the < 0 case values may become really small. It is true that we are taking an exponential later so the differences in output between even distant values in the negative side are negligible, but again, if there's a reference for this it would be great to add it as a comment in the code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Of course. I used this idea from transformers library itself as the repetition penalty was broken with GPT-2 model.
Here is the relevant PR with discussion:
#2303
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it, thank you for pointing it out, I replied to that thread too: #2303 (comment)
Thank you for your time reviewing this. May I ask, why does the code quality fail? What did I miss? |
Could you share your optimal parameters? |
Can you run |
The ones we reported on the paper work in most cases, but for some BOWs others may be better because of the size of the BOW and also the specific words contained in it (if they are really common or less common), but in general the reported ones are pretty consistent. |
@julien-c Thank you. I missed reading the guidelines before doing this PR, should I do a new one with proper branching? |
LGTM, thanks! |
@julien-c it didn't look entirely good to me. I explained my argument, that goes beyond repetition penalty for PPLM and is a general argument about repetition penalty (so applies to CTRL too) here: #2303 (comment) |
Aarg I misunderstood your comment then @w4nderlust, I'll ask for more explicit greenlight next time! @IWillPull can you please open a new PR to fix/improve remaining points? Thanks! |
No problem @julien-c ! The repetition penalty as it is implemented in this PR is fine in the sense that it works exactly like the CTRL one and that worked for people so far. |
@julien-c Sure! I will just wait for your (@w4nderlust and others) consensus as to not to make a mess of this. |
The GPT-2 LM itself, and the discriminators are different from what is reported in the paper. I think you need ~1.5 times the step-size/iterations for this version of GPT-2 LM/attribute models and other parameters should work as is. If you are using the GPT-2 LM from the paper (which corresponds to a previous version of the Huggingface GPT-2 LM) and the discriminators from the paper, the listed parameters in the Appendix work quite well. Code/models for what's in the paper --> https://github.com/uber-research/PPLM/tree/master/paper_code Also if repetition is a huge-problem, Table S19 from the paper might be relevant. I think this be an easy to fix help with the "awful" repetitions. Also, repetitions don't seem to be an issue if you're using the discriminator -- so I think a large part of the problem lies with the simple "BoW" loss as opposed to the decoding scheme. |
It was giving awful results, so I added repetition penalty which improved things.