From 37f43fd1bc03701b3fcb6638f12a576650f502e4 Mon Sep 17 00:00:00 2001 From: Andy Ye Date: Wed, 15 Jun 2022 15:09:37 -0400 Subject: [PATCH] Add Pytorch image segmentation example (#21766) --- .../inference/pytorch_image_segmentation.py | 241 ++++++++++++++++++ .../inference/pytorch_language_modeling.py | 4 +- .../ml/inference/pytorch_inference_it_test.py | 48 +++- 3 files changed, 285 insertions(+), 8 deletions(-) create mode 100644 sdks/python/apache_beam/examples/inference/pytorch_image_segmentation.py diff --git a/sdks/python/apache_beam/examples/inference/pytorch_image_segmentation.py b/sdks/python/apache_beam/examples/inference/pytorch_image_segmentation.py new file mode 100644 index 0000000000000..e0e2e676052f1 --- /dev/null +++ b/sdks/python/apache_beam/examples/inference/pytorch_image_segmentation.py @@ -0,0 +1,241 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""A pipeline that uses RunInference API to perform image segmentation.""" + +import argparse +import io +import os +from typing import Iterable +from typing import Optional +from typing import Tuple + +import apache_beam as beam +import torch +from apache_beam.io.filesystems import FileSystems +from apache_beam.ml.inference.base import KeyedModelHandler +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import SetupOptions +from PIL import Image +from torchvision import transforms +from torchvision.models.detection import maskrcnn_resnet50_fpn + +COCO_INSTANCE_CLASSES = [ + '__background__', + 'person', + 'bicycle', + 'car', + 'motorcycle', + 'airplane', + 'bus', + 'train', + 'truck', + 'boat', + 'traffic light', + 'fire hydrant', + 'N/A', + 'stop sign', + 'parking meter', + 'bench', + 'bird', + 'cat', + 'dog', + 'horse', + 'sheep', + 'cow', + 'elephant', + 'bear', + 'zebra', + 'giraffe', + 'N/A', + 'backpack', + 'umbrella', + 'N/A', + 'N/A', + 'handbag', + 'tie', + 'suitcase', + 'frisbee', + 'skis', + 'snowboard', + 'sports ball', + 'kite', + 'baseball bat', + 'baseball glove', + 'skateboard', + 'surfboard', + 'tennis racket', + 'bottle', + 'N/A', + 'wine glass', + 'cup', + 'fork', + 'knife', + 'spoon', + 'bowl', + 'banana', + 'apple', + 'sandwich', + 'orange', + 'broccoli', + 'carrot', + 'hot dog', + 'pizza', + 'donut', + 'cake', + 'chair', + 'couch', + 'potted plant', + 'bed', + 'N/A', + 'dining table', + 'N/A', + 'N/A', + 'toilet', + 'N/A', + 'tv', + 'laptop', + 'mouse', + 'remote', + 'keyboard', + 'cell phone', + 'microwave', + 'oven', + 'toaster', + 'sink', + 'refrigerator', + 'N/A', + 'book', + 'clock', + 'vase', + 'scissors', + 'teddy bear', + 'hair drier', + 'toothbrush' +] + +CLASS_ID_TO_NAME = dict(enumerate(COCO_INSTANCE_CLASSES)) + + +def read_image(image_file_name: str, + path_to_dir: Optional[str] = None) -> Tuple[str, Image.Image]: + if path_to_dir is not None: + image_file_name = os.path.join(path_to_dir, image_file_name) + with FileSystems().open(image_file_name, 'r') as file: + data = Image.open(io.BytesIO(file.read())).convert('RGB') + return image_file_name, data + + +def preprocess_image(data: Image.Image) -> torch.Tensor: + image_size = (224, 224) + # Pre-trained PyTorch models expect input images normalized with the + # below values (see: https://pytorch.org/vision/stable/models.html) + normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + transform = transforms.Compose([ + transforms.Resize(image_size), + transforms.ToTensor(), + normalize, + ]) + return transform(data) + + +class PostProcessor(beam.DoFn): + def process(self, element: Tuple[str, PredictionResult]) -> Iterable[str]: + filename, prediction_result = element + prediction_labels = prediction_result.inference['labels'] + classes = [CLASS_ID_TO_NAME[label.item()] for label in prediction_labels] + yield filename + ';' + str(classes) + + +def parse_known_args(argv): + """Parses args for the workflow.""" + parser = argparse.ArgumentParser() + parser.add_argument( + '--input', + dest='input', + required=True, + help='Path to the text file containing image names.') + parser.add_argument( + '--output', + dest='output', + required=True, + help='Path where to save output predictions.' + ' text file.') + parser.add_argument( + '--model_state_dict_path', + dest='model_state_dict_path', + required=True, + help="Path to the model's state_dict. " + "Default state_dict would be maskrcnn_resnet50_fpn.") + parser.add_argument( + '--images_dir', + help='Path to the directory where images are stored.' + 'Not required if image names in the input file have absolute path.') + return parser.parse_known_args(argv) + + +def run(argv=None, model_class=None, model_params=None, save_main_session=True): + """ + Args: + argv: Command line arguments defined for this example. + model_class: Reference to the class definition of the model. + If None, maskrcnn_resnet50_fpn will be used as default . + model_params: Parameters passed to the constructor of the model_class. + These will be used to instantiate the model object in the + RunInference API. + """ + known_args, pipeline_args = parse_known_args(argv) + pipeline_options = PipelineOptions(pipeline_args) + pipeline_options.view_as(SetupOptions).save_main_session = save_main_session + + if not model_class: + model_class = maskrcnn_resnet50_fpn + model_params = {'num_classes': 91} + + model_handler = PytorchModelHandlerTensor( + state_dict_path=known_args.model_state_dict_path, + model_class=model_class, + model_params=model_params) + + with beam.Pipeline(options=pipeline_options) as p: + filename_value_pair = ( + p + | 'ReadImageNames' >> beam.io.ReadFromText( + known_args.input, skip_header_lines=1) + | 'ReadImageData' >> beam.Map( + lambda image_name: read_image( + image_file_name=image_name, path_to_dir=known_args.images_dir)) + | 'PreprocessImages' >> beam.MapTuple( + lambda file_name, data: (file_name, preprocess_image(data)))) + predictions = ( + filename_value_pair + | + 'PyTorchRunInference' >> RunInference(KeyedModelHandler(model_handler)) + | 'ProcessOutput' >> beam.ParDo(PostProcessor())) + + _ = predictions | "WriteOutput" >> beam.io.WriteToText( + known_args.output, + shard_name_template='', + append_trailing_newlines=True) + + +if __name__ == '__main__': + run() diff --git a/sdks/python/apache_beam/examples/inference/pytorch_language_modeling.py b/sdks/python/apache_beam/examples/inference/pytorch_language_modeling.py index 4152317a908d8..6d7b36f01560e 100644 --- a/sdks/python/apache_beam/examples/inference/pytorch_language_modeling.py +++ b/sdks/python/apache_beam/examples/inference/pytorch_language_modeling.py @@ -34,7 +34,7 @@ from apache_beam.ml.inference.base import KeyedModelHandler from apache_beam.ml.inference.base import PredictionResult from apache_beam.ml.inference.base import RunInference -from apache_beam.ml.inference.pytorch_inference import PytorchModelHandler +from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerKeyedTensor from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.options.pipeline_options import SetupOptions from transformers import BertConfig @@ -167,7 +167,7 @@ def forward(self, **kwargs): # TODO: Remove once nested tensors https://github.com/pytorch/nestedtensor # is officially released. - class PytorchNoBatchModelHandler(PytorchModelHandler): + class PytorchNoBatchModelHandler(PytorchModelHandlerKeyedTensor): """Wrapper to PytorchModelHandler to limit batch size to 1. The tokenized strings generated from BertTokenizer may have different diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference_it_test.py b/sdks/python/apache_beam/ml/inference/pytorch_inference_it_test.py index a4231f404347c..784182e18407f 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference_it_test.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference_it_test.py @@ -31,8 +31,9 @@ try: import torch - from apache_beam.examples.inference import pytorch_language_modeling from apache_beam.examples.inference import pytorch_image_classification + from apache_beam.examples.inference import pytorch_image_segmentation + from apache_beam.examples.inference import pytorch_language_modeling except ImportError as e: torch = None @@ -92,6 +93,42 @@ def test_torch_run_inference_imagenet_mobilenetv2(self): filename, prediction = prediction.split(',') self.assertEqual(_EXPECTED_OUTPUTS[filename], prediction) + @pytest.mark.uses_pytorch + @pytest.mark.it_postcommit + def test_torch_run_inference_coco_maskrcnn_resnet50_fpn(self): + test_pipeline = TestPipeline(is_integration_test=True) + # text files containing absolute path to the coco validation data on GCS + file_of_image_names = 'gs://apache-beam-ml/testing/inputs/it_coco_validation_inputs.txt' # disable: line-too-long + output_file_dir = 'gs://apache-beam-ml/testing/predictions' + output_file = '/'.join([output_file_dir, str(uuid.uuid4()), 'result.txt']) + + model_state_dict_path = 'gs://apache-beam-ml/models/torchvision.models.detection.maskrcnn_resnet50_fpn.pth' + images_dir = 'gs://apache-beam-ml/datasets/coco/raw-data/val2017' + extra_opts = { + 'input': file_of_image_names, + 'output': output_file, + 'model_state_dict_path': model_state_dict_path, + 'images_dir': images_dir, + } + pytorch_image_segmentation.run( + test_pipeline.get_full_options_as_args(**extra_opts), + save_main_session=False) + + self.assertEqual(FileSystems().exists(output_file), True) + predictions = process_outputs(filepath=output_file) + actuals_file = 'gs://apache-beam-ml/testing/expected_outputs/test_torch_run_inference_coco_maskrcnn_resnet50_fpn_actuals.txt' + actuals = process_outputs(filepath=actuals_file) + + predictions_dict = {} + for prediction in predictions: + filename, prediction_labels = prediction.split(';') + predictions_dict[filename] = prediction_labels + + for actual in actuals: + filename, actual_labels = actual.split(';') + prediction_labels = predictions_dict[filename] + self.assertEqual(actual_labels, prediction_labels) + @pytest.mark.uses_pytorch @pytest.mark.it_postcommit def test_torch_run_inference_bert_for_masked_lm(self): @@ -118,13 +155,12 @@ def test_torch_run_inference_bert_for_masked_lm(self): predictions_dict = {} for prediction in predictions: - text, predicted_masked_text, predicted_text = prediction.split(';') - predictions_dict[text] = (predicted_masked_text, predicted_text) + text, predicted_text = prediction.split(';') + predictions_dict[text] = predicted_text for actual in actuals: - text, actual_masked_text, actual_predicted_text = actual.split(';') - predicted_masked_text, predicted_predicted_text = predictions_dict[text] - self.assertEqual(actual_masked_text, predicted_masked_text) + text, actual_predicted_text = actual.split(';') + predicted_predicted_text = predictions_dict[text] self.assertEqual(actual_predicted_text, predicted_predicted_text)