diff --git a/sdks/python/apache_beam/ml/inference/pytorch_test.py b/sdks/python/apache_beam/ml/inference/pytorch_test.py index 6c374dbf1830a..9314a23400db6 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_test.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_test.py @@ -42,6 +42,11 @@ except ImportError: raise unittest.SkipTest('PyTorch dependencies are not installed') +try: + from apache_beam.io.gcp.gcsfilesystem import GCSFileSystem +except ImportError: + GCSFileSystem = None # type: ignore + def _compare_prediction_result(a, b): return ( @@ -166,6 +171,7 @@ def test_pipeline_local_model(self): predictions, equal_to(expected_predictions, equals_fn=_compare_prediction_result)) + @unittest.skipIf(GCSFileSystem is None, 'GCP dependencies are not installed') def test_pipeline_gcs_model(self): with TestPipeline() as pipeline: examples = torch.from_numpy( @@ -178,7 +184,8 @@ def test_pipeline_gcs_model(self): for example in examples]).reshape(-1, 1)) ] - gs_pth = 'gs://apache-beam-ml/pytorch_lin_reg_model_2x+0.5_state_dict.pth' + gs_pth = 'gs://apache-beam-ml/models/' \ + 'pytorch_lin_reg_model_2x+0.5_state_dict.pth' model_loader = PytorchModelLoader( state_dict_path=gs_pth, model_class=PytorchLinearRegression,