Skip to content

Commit

Permalink
[BEAM-13984] followup Fix precommit
Browse files Browse the repository at this point in the history
  • Loading branch information
Abacn committed Jun 1, 2022
1 parent ca33943 commit 145615a
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion sdks/python/apache_beam/ml/inference/pytorch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@
except ImportError:
raise unittest.SkipTest('PyTorch dependencies are not installed')

try:
from apache_beam.io.gcp.gcsfilesystem import GCSFileSystem
except ImportError:
GCSFileSystem = None # type: ignore


def _compare_prediction_result(a, b):
return (
Expand Down Expand Up @@ -166,6 +171,7 @@ def test_pipeline_local_model(self):
predictions,
equal_to(expected_predictions, equals_fn=_compare_prediction_result))

@unittest.skipIf(GCSFileSystem is None, 'GCP dependencies are not installed')
def test_pipeline_gcs_model(self):
with TestPipeline() as pipeline:
examples = torch.from_numpy(
Expand All @@ -178,7 +184,8 @@ def test_pipeline_gcs_model(self):
for example in examples]).reshape(-1, 1))
]

gs_pth = 'gs://apache-beam-ml/pytorch_lin_reg_model_2x+0.5_state_dict.pth'
gs_pth = 'gs://apache-beam-ml/models/' \
'pytorch_lin_reg_model_2x+0.5_state_dict.pth'
model_loader = PytorchModelLoader(
state_dict_path=gs_pth,
model_class=PytorchLinearRegression,
Expand Down

0 comments on commit 145615a

Please sign in to comment.