Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Paxml patchlist with TE and config patches for improved perf #225

Merged
merged 4 commits into from
Oct 4, 2023

Conversation

ashors1
Copy link
Contributor

@ashors1 ashors1 commented Sep 12, 2023

  • adds Transformer Engine support to Pax
  • updates GPU configs and default XLA flags for improved performance

@terrykong
Copy link
Contributor

terrykong commented Sep 19, 2023

Rosetta pax build/test: https://github.com/NVIDIA/JAX-Toolbox/actions/runs/6239209930

Above workflow should be sufficient for this change

@terrykong terrykong changed the title Update Paxml patchlist Update Paxml patchlist with TE and config patches for improved perf Sep 19, 2023
@terrykong
Copy link
Contributor

Looks like the build passed (yay!), but only some of the MGMN tests passed. Unit tests are expected to fail.

Re-run just to make sure it wasn't a one-off: https://github.com/NVIDIA/JAX-Toolbox/actions/runs/6241478583

@terrykong
Copy link
Contributor

To address the unit test failure: #253

@ashors1
Copy link
Contributor Author

ashors1 commented Sep 19, 2023

The failures don't appear to be one-offs. I see errors like this in the logs:

    partitioned_vars = init_fn(prng_key)
  File "/opt/paxml/paxml/trainer_lib.py", line 1707, in call
    return pjitted_fn(*args)
AttributeError: 'NoneType' object has no attribute 'empty'

Investigating now

@ashors1
Copy link
Contributor Author

ashors1 commented Sep 29, 2023

Google reverted their commit that broke TE: google/paxml@1696411.
Here is the Rosetta build/test with Pax at head: https://github.com/NVIDIA/JAX-Toolbox/actions/runs/6354569506. All tests pass besides the TP-8 test, which appears to be a problem with the test itself rather than with TE. I will work on fixing this test, but in the meantime, I think we can go ahead and merge. @terrykong what do you think?

@nouiz
Copy link
Collaborator

nouiz commented Sep 29, 2023

Merging failing test isn't great. If we do that, can we disable the test until it is fixed?
At least, the CI results will be easy to read.

@terrykong
Copy link
Contributor

TP=8 looks like it failed, but it's not showing up in the metrics pytest check. Let me look into why

@terrykong
Copy link
Contributor

So I've reminded myself that actually two tests are not measured for perf/loss tests because they were failing some time ago. The two tests were:

  1. 1DP8TP1PP
  2. 2DP2TP4PP

It looks like (2) is working now, so created an issue to track adding these back in: #272

But to @nouiz 's comment, the test is actually already omitted, so I think this is okay to merge

@terrykong terrykong merged commit f300efa into main Oct 4, 2023
154 of 162 checks passed
@terrykong terrykong deleted the update-pax-patchlist branch October 4, 2023 18:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants