Skip to content

Commit

Permalink
Add Pytorch image segmentation example (apache#21766)
Browse files Browse the repository at this point in the history
  • Loading branch information
yeandy authored and prodriguezdefino committed Jun 21, 2022
1 parent ab256c5 commit 37f43fd
Show file tree
Hide file tree
Showing 3 changed files with 285 additions and 8 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""A pipeline that uses RunInference API to perform image segmentation."""

import argparse
import io
import os
from typing import Iterable
from typing import Optional
from typing import Tuple

import apache_beam as beam
import torch
from apache_beam.io.filesystems import FileSystems
from apache_beam.ml.inference.base import KeyedModelHandler
from apache_beam.ml.inference.base import PredictionResult
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import SetupOptions
from PIL import Image
from torchvision import transforms
from torchvision.models.detection import maskrcnn_resnet50_fpn

COCO_INSTANCE_CLASSES = [
'__background__',
'person',
'bicycle',
'car',
'motorcycle',
'airplane',
'bus',
'train',
'truck',
'boat',
'traffic light',
'fire hydrant',
'N/A',
'stop sign',
'parking meter',
'bench',
'bird',
'cat',
'dog',
'horse',
'sheep',
'cow',
'elephant',
'bear',
'zebra',
'giraffe',
'N/A',
'backpack',
'umbrella',
'N/A',
'N/A',
'handbag',
'tie',
'suitcase',
'frisbee',
'skis',
'snowboard',
'sports ball',
'kite',
'baseball bat',
'baseball glove',
'skateboard',
'surfboard',
'tennis racket',
'bottle',
'N/A',
'wine glass',
'cup',
'fork',
'knife',
'spoon',
'bowl',
'banana',
'apple',
'sandwich',
'orange',
'broccoli',
'carrot',
'hot dog',
'pizza',
'donut',
'cake',
'chair',
'couch',
'potted plant',
'bed',
'N/A',
'dining table',
'N/A',
'N/A',
'toilet',
'N/A',
'tv',
'laptop',
'mouse',
'remote',
'keyboard',
'cell phone',
'microwave',
'oven',
'toaster',
'sink',
'refrigerator',
'N/A',
'book',
'clock',
'vase',
'scissors',
'teddy bear',
'hair drier',
'toothbrush'
]

CLASS_ID_TO_NAME = dict(enumerate(COCO_INSTANCE_CLASSES))


def read_image(image_file_name: str,
path_to_dir: Optional[str] = None) -> Tuple[str, Image.Image]:
if path_to_dir is not None:
image_file_name = os.path.join(path_to_dir, image_file_name)
with FileSystems().open(image_file_name, 'r') as file:
data = Image.open(io.BytesIO(file.read())).convert('RGB')
return image_file_name, data


def preprocess_image(data: Image.Image) -> torch.Tensor:
image_size = (224, 224)
# Pre-trained PyTorch models expect input images normalized with the
# below values (see: https://pytorch.org/vision/stable/models.html)
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transform = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
normalize,
])
return transform(data)


class PostProcessor(beam.DoFn):
def process(self, element: Tuple[str, PredictionResult]) -> Iterable[str]:
filename, prediction_result = element
prediction_labels = prediction_result.inference['labels']
classes = [CLASS_ID_TO_NAME[label.item()] for label in prediction_labels]
yield filename + ';' + str(classes)


def parse_known_args(argv):
"""Parses args for the workflow."""
parser = argparse.ArgumentParser()
parser.add_argument(
'--input',
dest='input',
required=True,
help='Path to the text file containing image names.')
parser.add_argument(
'--output',
dest='output',
required=True,
help='Path where to save output predictions.'
' text file.')
parser.add_argument(
'--model_state_dict_path',
dest='model_state_dict_path',
required=True,
help="Path to the model's state_dict. "
"Default state_dict would be maskrcnn_resnet50_fpn.")
parser.add_argument(
'--images_dir',
help='Path to the directory where images are stored.'
'Not required if image names in the input file have absolute path.')
return parser.parse_known_args(argv)


def run(argv=None, model_class=None, model_params=None, save_main_session=True):
"""
Args:
argv: Command line arguments defined for this example.
model_class: Reference to the class definition of the model.
If None, maskrcnn_resnet50_fpn will be used as default .
model_params: Parameters passed to the constructor of the model_class.
These will be used to instantiate the model object in the
RunInference API.
"""
known_args, pipeline_args = parse_known_args(argv)
pipeline_options = PipelineOptions(pipeline_args)
pipeline_options.view_as(SetupOptions).save_main_session = save_main_session

if not model_class:
model_class = maskrcnn_resnet50_fpn
model_params = {'num_classes': 91}

model_handler = PytorchModelHandlerTensor(
state_dict_path=known_args.model_state_dict_path,
model_class=model_class,
model_params=model_params)

with beam.Pipeline(options=pipeline_options) as p:
filename_value_pair = (
p
| 'ReadImageNames' >> beam.io.ReadFromText(
known_args.input, skip_header_lines=1)
| 'ReadImageData' >> beam.Map(
lambda image_name: read_image(
image_file_name=image_name, path_to_dir=known_args.images_dir))
| 'PreprocessImages' >> beam.MapTuple(
lambda file_name, data: (file_name, preprocess_image(data))))
predictions = (
filename_value_pair
|
'PyTorchRunInference' >> RunInference(KeyedModelHandler(model_handler))
| 'ProcessOutput' >> beam.ParDo(PostProcessor()))

_ = predictions | "WriteOutput" >> beam.io.WriteToText(
known_args.output,
shard_name_template='',
append_trailing_newlines=True)


if __name__ == '__main__':
run()
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from apache_beam.ml.inference.base import KeyedModelHandler
from apache_beam.ml.inference.base import PredictionResult
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.inference.pytorch_inference import PytorchModelHandler
from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerKeyedTensor
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import SetupOptions
from transformers import BertConfig
Expand Down Expand Up @@ -167,7 +167,7 @@ def forward(self, **kwargs):

# TODO: Remove once nested tensors https://github.com/pytorch/nestedtensor
# is officially released.
class PytorchNoBatchModelHandler(PytorchModelHandler):
class PytorchNoBatchModelHandler(PytorchModelHandlerKeyedTensor):
"""Wrapper to PytorchModelHandler to limit batch size to 1.
The tokenized strings generated from BertTokenizer may have different
Expand Down
48 changes: 42 additions & 6 deletions sdks/python/apache_beam/ml/inference/pytorch_inference_it_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@

try:
import torch
from apache_beam.examples.inference import pytorch_language_modeling
from apache_beam.examples.inference import pytorch_image_classification
from apache_beam.examples.inference import pytorch_image_segmentation
from apache_beam.examples.inference import pytorch_language_modeling
except ImportError as e:
torch = None

Expand Down Expand Up @@ -92,6 +93,42 @@ def test_torch_run_inference_imagenet_mobilenetv2(self):
filename, prediction = prediction.split(',')
self.assertEqual(_EXPECTED_OUTPUTS[filename], prediction)

@pytest.mark.uses_pytorch
@pytest.mark.it_postcommit
def test_torch_run_inference_coco_maskrcnn_resnet50_fpn(self):
test_pipeline = TestPipeline(is_integration_test=True)
# text files containing absolute path to the coco validation data on GCS
file_of_image_names = 'gs://apache-beam-ml/testing/inputs/it_coco_validation_inputs.txt' # disable: line-too-long
output_file_dir = 'gs://apache-beam-ml/testing/predictions'
output_file = '/'.join([output_file_dir, str(uuid.uuid4()), 'result.txt'])

model_state_dict_path = 'gs://apache-beam-ml/models/torchvision.models.detection.maskrcnn_resnet50_fpn.pth'
images_dir = 'gs://apache-beam-ml/datasets/coco/raw-data/val2017'
extra_opts = {
'input': file_of_image_names,
'output': output_file,
'model_state_dict_path': model_state_dict_path,
'images_dir': images_dir,
}
pytorch_image_segmentation.run(
test_pipeline.get_full_options_as_args(**extra_opts),
save_main_session=False)

self.assertEqual(FileSystems().exists(output_file), True)
predictions = process_outputs(filepath=output_file)
actuals_file = 'gs://apache-beam-ml/testing/expected_outputs/test_torch_run_inference_coco_maskrcnn_resnet50_fpn_actuals.txt'
actuals = process_outputs(filepath=actuals_file)

predictions_dict = {}
for prediction in predictions:
filename, prediction_labels = prediction.split(';')
predictions_dict[filename] = prediction_labels

for actual in actuals:
filename, actual_labels = actual.split(';')
prediction_labels = predictions_dict[filename]
self.assertEqual(actual_labels, prediction_labels)

@pytest.mark.uses_pytorch
@pytest.mark.it_postcommit
def test_torch_run_inference_bert_for_masked_lm(self):
Expand All @@ -118,13 +155,12 @@ def test_torch_run_inference_bert_for_masked_lm(self):

predictions_dict = {}
for prediction in predictions:
text, predicted_masked_text, predicted_text = prediction.split(';')
predictions_dict[text] = (predicted_masked_text, predicted_text)
text, predicted_text = prediction.split(';')
predictions_dict[text] = predicted_text

for actual in actuals:
text, actual_masked_text, actual_predicted_text = actual.split(';')
predicted_masked_text, predicted_predicted_text = predictions_dict[text]
self.assertEqual(actual_masked_text, predicted_masked_text)
text, actual_predicted_text = actual.split(';')
predicted_predicted_text = predictions_dict[text]
self.assertEqual(actual_predicted_text, predicted_predicted_text)


Expand Down

0 comments on commit 37f43fd

Please sign in to comment.