Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for M1/M2 mac in PytorchModelHandlerTensor #29999

Closed
wants to merge 3 commits into from

Conversation

bfcrampton
Copy link

Add support for M1/M2 mac

Add support for M1/M2 mac by allowing mps pytorch device type.


Thank you for your contribution! Follow this checklist to help us incorporate your contribution quickly and easily:

  • Mention the appropriate issue in your description (for example: addresses #123), if applicable. This will automatically add a link to the pull request in the issue. If you would like the issue to automatically close on merging the pull request, comment fixes #<ISSUE NUMBER> instead.
    • N/A
  • Update CHANGES.md with noteworthy changes.
    • N/A
  • If this contribution is large, please file an Apache Individual Contributor License Agreement.
    • N/A

See the Contributor Guide for more tips on how to make review process smoother.

To check the build health, please visit https://github.com/apache/beam/blob/master/.test-infra/BUILD_STATUS.md

GitHub Actions Tests Status (on master branch)

Build python source distribution and wheels
Python tests
Java tests
Go tests

See CI.md for more information about GitHub Actions CI or the workflows README to see a list of phrases to trigger workflows.

Add support for M1/M2 mac
@bfcrampton bfcrampton changed the title Update pytorch_inference.py Add support for M1/M2 mac in PytorchModelHandlerTensor Jan 12, 2024
Copy link
Contributor

Assigning reviewers. If you would like to opt out of this review, comment assign to next reviewer:

R: @damccorm for label python.

Available commands:

  • stop reviewer notifications - opt out of the automated review tooling
  • remind me after tests pass - tag the comment author after tests pass
  • waiting on author - shift the attention set back to the author (any comment or push by the author will return the attention set to the reviewers)

The PR bot will only process comments in the main thread (not review comments).

Copy link
Contributor

@damccorm damccorm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this looks like a helpful change. I think we need a few more things for completeness though.

  1. Could you please update _load_model to add a check like
    if device == torch.device('cuda') and not torch.cuda.is_available():
    for mps with a clear warning/fallback to CPU if mps isn't available (similar to the code snippet here -
    if device == torch.device('cuda') and not torch.cuda.is_available():
    )
  2. Could you please update the pydoc for each of the modified functions to call out that mps is an allowed device type?
  3. We should probably prefer MPS to mps to stay consistent with our GPU parameter. Alternately, we can just make both case insensitive (this would probably be for the best anyways).

Copy link
Contributor

Reminder, please take a look at this pr: @damccorm

@damccorm
Copy link
Contributor

waiting on author

@bfcrampton
Copy link
Author

bfcrampton commented Jan 22, 2024

Thanks for the review!

Addressed the comments. On #3 I decided to make everything case insensitive as well as support specifying cuda or CUDA directly. This is a minor nit that I have to workaround when using this library, everywhere else a device would be either cpu, cuda, or mps in my code.

Lmk what you think, happy to just leave the case insensitivity part.

@bfcrampton bfcrampton force-pushed the patch-1 branch 2 times, most recently from 198f42b to fcb8fa2 Compare January 22, 2024 23:42
@@ -243,9 +250,12 @@ def __init__(
with PyTorch 1.9 and 1.10.
"""
self._state_dict_path = state_dict_path
if device == 'GPU':
if device in ['gpu', 'GPU', 'cuda', 'CUDA']:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, sorry for the slow review - this looks mostly good. The only thing I'd ask is that if we go case-insensitive we do it all the way (so device.lower() == 'gpu' here, and the same type of change applied elsewhere in the PR). Otherwise, this looks good to me, thanks!

@@ -91,6 +91,12 @@ def _load_model(
"Model handler specified a 'GPU' device, but GPUs are not available. "
"Switching to CPU.")
device = torch.device('cpu')
if device == torch.device('mps') and not (torch.backend.mps.is_available() and
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, it also looks like torch.device('mps') causes older versions of pytorch to throw. To maintain compat with older versions, could we wrap this in a try except? It should be fine to silently swallow the exception since we'll throw below if someone tries to set the device to mps and the pytorch version doesn't support it.

This is why the test checks are failing

Copy link
Contributor

This pull request has been marked as stale due to 60 days of inactivity. It will be closed in 1 week if no further activity occurs. If you think that’s incorrect or this pull request requires a review, please simply write any comment. If closed, you can revive the PR at any time and @mention a reviewer or discuss it on the dev@beam.apache.org list. Thank you for your contributions.

@github-actions github-actions bot added the stale label Mar 27, 2024
Copy link
Contributor

github-actions bot commented Apr 4, 2024

This pull request has been closed due to lack of activity. If you think that is incorrect, or the pull request requires review, you can revive the PR at any time.

@github-actions github-actions bot closed this Apr 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants