Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TPU cannot do simple arithmetic! #9973

Closed
ayaka14732 opened this issue Mar 21, 2022 · 5 comments
Closed

TPU cannot do simple arithmetic! #9973

ayaka14732 opened this issue Mar 21, 2022 · 5 comments
Assignees
Labels
bug Something isn't working

Comments

@ayaka14732
Copy link
Collaborator

I am trying to do simple matrix multiplication on TPU, but it gives a wrong result:

import jax.numpy as np
import numpy as onp

# On CPU
x = onp.array([[0.3744, 0.1656],
               [0.4707, 0.1663]])
y = onp.array([[0.3946, 0.1186],
               [0.1569, 0.3145]])
z = onp.dot(x, y)

# On TPU
x_ = np.asarray(x)
y_ = np.asarray(y)
z_ = np.dot(x_, y_)

print('JAX device:', x_.device())

# Compare
print('CPU result:', z)
print('TPU result:', z_)
assert np.allclose(z, z_)

Output:

JAX device: TPU_0(process=0,(0,0,0,0))
CPU result: [[0.17372088 0.09648504]
 [0.21183069 0.10812637]]
TPU result: [[0.17405128 0.09669876]
 [0.21180916 0.10805416]]
Traceback (most recent call last):
  File "/home/ayaka/main.py", line 21, in <module>
    assert np.allclose(z, z_)
AssertionError

Manual calculation:

0.3744 * 0.3946 + 0.1656 * 0.1569 = 0.13732088

So the result on CPU is correct, while the result on TPU is wrong.

Library versions:

jax                  0.3.4
jaxlib               0.3.2
libtpu-nightly       0.1.dev20220315
@ayaka14732 ayaka14732 added the bug Something isn't working label Mar 21, 2022
@tomhennigan
Copy link
Collaborator

There are two "sharp bits" to be aware of here.

1. float64 - Your NumPy example is running with f64 inputs/outputs but the JAX version is f32. See https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision

2. Compute precision - XLA supports computing dot product in various precisions. JAX's default precision is optimised for performance and thus has relatively low precision (see #9952). This is typically OK for NN training, but can be the wrong choice for other applications.

You can adjust the precision globally using:

jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST)

Or you can do so on a per operation basis:

c = jnp.dot(a, b, precision=jax.lax.Precision.HIGHEST)

Here is a copy of your example using HIGHEST precision on TPU (with f32 inputs) and getting a very close result compared to CPU: https://colab.research.google.com/gist/tomhennigan/69844409e46fd71267acf7479e2ba7f4/example-of-changing-default-precision-in-jax.ipynb

@ayaka14732
Copy link
Collaborator Author

ayaka14732 commented Mar 21, 2022

Thank you for the explanation!

I am still wondering:

1. Are there any research indicating low precision will not affect the model performance? As deep learning models are growing larger and larger, I am thinking that different precision may result to totally different output after many layers of operations.

2. How much training time can I save when using lower precision?

@tomhennigan
Copy link
Collaborator

Are there any research indicating low precision will not affect the model performance? As deep learning models are growing larger and larger, I am thinking that different precision may result to totally different output after many layers of operations.

It is important to note that while the default precision is low it is deterministic, so if you train a model in low precision and do inference on that trained model in low precision you should get the expected answer.

For very large models it is typical to drop the precision, because you need to compute so many floating point operations to train these models that the improvement in performance very significant on training time.

For Gopher (a large language model from DeepMind) we talk about low precision training (even lower than f32 defaults in JAX) with bfloat16 in Appendix C.2 of our paper https://arxiv.org/pdf/2112.11446.pdf.

How much training time can I save when using lower precision?

Typically accelerators have special hardware ("tensor cores") for half precision (e.g. bf16, f16) compute and you can expect computations to run somewhere between 5-10x faster than full precision f32 computations.

JAX's default precision for f32 dot product means the actual computation is done in bf16 on the TPU, so the performance improvement is significant vs. Precision.HIGH or Precision.HIGHEST.

@ayaka14732
Copy link
Collaborator Author

Thank you for the detailed explanation!

@mattjj
Copy link
Collaborator

mattjj commented Mar 21, 2022

I think we can close this issue, but please let me know if not!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants