From 255656a37fcb1d2949b8ec801841e7765c03d863 Mon Sep 17 00:00:00 2001 From: Bryan Crampton Date: Thu, 11 Jan 2024 17:07:48 -0800 Subject: [PATCH 1/3] Update pytorch_inference.py Add support for M1/M2 mac --- sdks/python/apache_beam/ml/inference/pytorch_inference.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py b/sdks/python/apache_beam/ml/inference/pytorch_inference.py index 480dc538195c8..6cea8d0b11237 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py @@ -246,6 +246,9 @@ def __init__( if device == 'GPU': logging.info("Device is set to CUDA") self._device = torch.device('cuda') + elif device == 'mps': + logging.info("Device is set to mps") + self._device = torch.device('mps') else: logging.info("Device is set to CPU") self._device = torch.device('cpu') From 416c3adb25580fad5ced31c69b98b7cd8681f407 Mon Sep 17 00:00:00 2001 From: Bryan Crampton Date: Fri, 12 Jan 2024 01:21:39 +0000 Subject: [PATCH 2/3] Support both Pytorch classes --- sdks/python/apache_beam/ml/inference/pytorch_inference.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py b/sdks/python/apache_beam/ml/inference/pytorch_inference.py index 6cea8d0b11237..93bdece8e7ac1 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py @@ -487,6 +487,9 @@ def __init__( if device == 'GPU': logging.info("Device is set to CUDA") self._device = torch.device('cuda') + elif device == 'mps': + logging.info("Device is set to mps") + self._device = torch.device('mps') else: logging.info("Device is set to CPU") self._device = torch.device('cpu') From 55db8bb0f55dfb24b1be1b88fdc091df23288797 Mon Sep 17 00:00:00 2001 From: Bryan Crampton Date: Mon, 22 Jan 2024 16:41:52 -0700 Subject: [PATCH 3/3] Address comments --- .../ml/inference/pytorch_inference.py | 24 ++++++++++++------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py b/sdks/python/apache_beam/ml/inference/pytorch_inference.py index 93bdece8e7ac1..0a14bf44efaff 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py @@ -91,6 +91,12 @@ def _load_model( "Model handler specified a 'GPU' device, but GPUs are not available. " "Switching to CPU.") device = torch.device('cpu') + if device == torch.device('mps') and not (torch.backend.mps.is_available() and + torch.backend.mps.is_built()): + logging.warning( + "Model handler specified a 'MPS' device, but it is not available. " + "Switching to CPU.") + device = torch.device('cpu') try: logging.info( @@ -216,8 +222,9 @@ def __init__( model_params: A dictionary of arguments required to instantiate the model class. device: the device on which you wish to run the model. If - ``device = GPU`` then a GPU device will be used if it is available. - Otherwise, it will be CPU. + ``device = GPU`` then a cuda device will be used if it is available. + If ``device = MPS`` then a mps device will be used if it is available. + Otherwise, it will be cpu. inference_fn: the inference function to use during RunInference. default=_default_tensor_inference_fn torch_script_model_path: Path to the torch script model. @@ -243,10 +250,10 @@ def __init__( with PyTorch 1.9 and 1.10. """ self._state_dict_path = state_dict_path - if device == 'GPU': + if device in ['gpu', 'GPU', 'cuda', 'CUDA']: logging.info("Device is set to CUDA") self._device = torch.device('cuda') - elif device == 'mps': + elif device in ['mps', 'MPS']: logging.info("Device is set to mps") self._device = torch.device('mps') else: @@ -457,8 +464,9 @@ def __init__( model_params: A dictionary of arguments required to instantiate the model class. device: the device on which you wish to run the model. If - ``device = GPU`` then a GPU device will be used if it is available. - Otherwise, it will be CPU. + ``device = GPU`` then a cuda device will be used if it is available. + If ``device = MPS`` then a mps device will be used if it is available. + Otherwise, it will be cpu. inference_fn: the function to invoke on run_inference. default = default_keyed_tensor_inference_fn torch_script_model_path: Path to the torch script model. @@ -484,10 +492,10 @@ def __init__( on torch>=1.9.0,<1.14.0. """ self._state_dict_path = state_dict_path - if device == 'GPU': + if device in ['gpu', 'GPU', 'cuda', 'CUDA']: logging.info("Device is set to CUDA") self._device = torch.device('cuda') - elif device == 'mps': + elif device in ['mps', 'MPS']: logging.info("Device is set to mps") self._device = torch.device('mps') else: