Skip to content

Add hooked transformer generate stream #908

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: dev
Choose a base branch
from

Conversation

anthonyduong9
Copy link
Contributor

@anthonyduong9 anthonyduong9 commented Apr 11, 2025

Description

Adds a new method HookedTransformer.generate_stream(). We wanted to add this in #847, but hadn't added tests, and also want to complete hijohnnylin/neuronpedia#51. @hijohnnylin said to open a PR, and if we merge this, we can replace a fork of TransformerLens with the latest version of transformer-lens as a dependency in neuronpedia.

Fixes # (issue)

Type of change

Please delete options that are not relevant.

  • New feature (non-breaking change which adds functionality)

Screenshots

Please attach before and after screenshots of the change if applicable.

Checklist:

  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility

@anthonyduong9 anthonyduong9 force-pushed the add-HookedTransformer-generate_stream branch 2 times, most recently from 9fcf45a to b86b035 Compare April 11, 2025 23:15
@anthonyduong9 anthonyduong9 force-pushed the add-HookedTransformer-generate_stream branch from b86b035 to b7bce69 Compare April 11, 2025 23:16
@anthonyduong9 anthonyduong9 marked this pull request as ready for review April 12, 2025 00:29
@hijohnnylin
Copy link

Before we merge this:

  • What was this new test - was it just merged from a different branch? It doesn't seem relevant test_bloom_similarity_with_hf_model_with_kv_cache_activated_stream

  • Make outstanding issues to resolve afterwards - @anthonyduong9 can you create new issue(s) for longer term fixes for this?

  1. Reducing the duplicated code between generate and generate_stream
  2. Adding tests for generate_stream

@bryce13950
Copy link
Collaborator

@hijohnnylin The test is from this commit in this PR b7bce69

@anthonyduong9
Copy link
Contributor Author

anthonyduong9 commented May 6, 2025

Before we merge this:

  • What was this new test - was it just merged from a different branch? It doesn't seem relevant test_bloom_similarity_with_hf_model_with_kv_cache_activated_stream
  • Make outstanding issues to resolve afterwards - @anthonyduong9 can you create new issue(s) for longer term fixes for this?
  1. Reducing the duplicated code between generate and generate_stream
  2. Adding tests for generate_stream

I added the test in this PR. It tests that the last value of generate_stream() matches the output of AutoModelForCausalLM.generate(). It's analogous to the test for generate().

def test_bloom_similarity_with_hf_model_with_kv_cache_activated():
tf_model = HookedTransformer.from_pretrained(
"bigscience/bloom-560m", default_prepend_bos=False, device="cpu"
)
hf_model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m")
hf_tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
output_tf = tf_model.generate(
text, do_sample=False, use_past_kv_cache=True, verbose=False, max_new_tokens=10
)
output_hf_tokens = hf_model.generate(
hf_tokenizer(text, return_tensors="pt").input_ids,
do_sample=False,
max_new_tokens=10,
)
output_hf_str = hf_tokenizer.decode(output_hf_tokens[0], skip_special_tokens=True)
assert output_tf == output_hf_str

So I think that completes 2. You and I talked about 1 shortly after I opened this PR, in person. I'm not sure it's worth the effort - I spent a lot of time trying to dedupe code between the two functions for this PR, and not only did functions stop working as expected, but extracting abstractions was awkward. This is probably because generate() is a lot less similar to generate_stream() than when you first wrote the latter (after #820).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants