Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(Experimental) Integrate Metal PjRt plugin #1504

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft

Conversation

jonatanklosko
Copy link
Member

@jonatanklosko jonatanklosko commented Jun 5, 2024

This integrates the PjRt plugin from the jax-metal for running on the Apple GPU. To test it, one can set client: :mps on EXLA backend/compiler. Since the plugin is loaded as a separate dynamic library, it can be tested without any changes to XLA (just make sure to remove the cache/ directory).

Certain computations can already be run, but the plugin is still very much incomplete. This PR is a room for experimentation and is meant to track the plugin progress. I reported a number of issues upstream, comments in the code point to those. In a few places I applied workarounds as temporary solutions or just to avoid VM crashes (segfaults), those are marked with a TODO.

Issues

For tracking purposes, here is a list of the Metal plugin issues reported upstream:

Crucial

Not implemented

Edge cases

All issues: link.


Note: this PR is against the jk-s32 branch, which changes the default integer precision to 32 bits. This is a planned change (#1491), but it's not integrated yet to avoid conflicts with other work in progress.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant