Skip to content

Commit

Permalink
feat: data parallel inference sample
Browse files Browse the repository at this point in the history
  • Loading branch information
bowang007 committed May 16, 2024
1 parent de81be2 commit dfbf6ea
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 0 deletions.
3 changes: 3 additions & 0 deletions docsrc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ Tutorials
tutorials/_rendered_examples/dynamo/torch_compile_transformers_example
tutorials/_rendered_examples/dynamo/torch_compile_advanced_usage
tutorials/_rendered_examples/dynamo/torch_compile_stable_diffusion
tutorials/_rendered_examples/distributed_inference/data_parallel_gpt2
tutorials/_rendered_examples/distributed_inference/data_parallel_stable_diffusion


Python API Documenation
------------------------
Expand Down
14 changes: 14 additions & 0 deletions examples/distributed_inference/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Torch-TensorRT parallelism for distributed inference

Examples in this folder demonstrates doing distributed inference on multiple devices with Torch-TensorRT backend.

1. Data parallel distributed inference based on [Acclerate](https://huggingface.co/docs/accelerate/usage_guides/distributed_inference)

Using Accelerate users can achieve data parallel distributed inference with Torch-TensorRt backend. In this case, the entire model
will be loaded onto each GPU and different chunks of batch input is processed on each device.

See the examples started with `data_parallel` for more details.

2. Tensor parallel distributed inference

In development.
64 changes: 64 additions & 0 deletions examples/distributed_inference/data_parallel_gpt2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""
.. _data_parallel_gpt2:
Torch-TensorRT Distributed Inference
======================================================
This interactive script is intended as a sample of distributed inference using data
parallelism using Accelerate
library with the Torch-TensorRT workflow on GPT2 model.
"""

# %%
# Imports and Model Definition
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

import torch
from accelerate import PartialState
from transformers import AutoTokenizer, GPT2LMHeadModel

import torch_tensorrt

tokenizer = AutoTokenizer.from_pretrained("gpt2")

# Set input prompts for different devices
prompt1 = "GPT2 is a model developed by."
prompt2 = "Llama is a model developed by "

input_id1 = tokenizer(prompt1, return_tensors="pt").input_ids
input_id2 = tokenizer(prompt2, return_tensors="pt").input_ids

distributed_state = PartialState()

# Import GPT2 model and load to distributed devices
model = GPT2LMHeadModel.from_pretrained("gpt2").eval().to(distributed_state.device)


# Instantiate model with Torch-TensorRT backend
model.forward = torch.compile(
model.forward,
backend="torch_tensorrt",
options={
"truncate_long_and_double": True,
"enabled_precisions": {torch.float16},
"debug": True,
},
dynamic=False,
)

# %%
# Inference
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# Assume there are 2 processes (2 devices)
with distributed_state.split_between_processes([input_id1, input_id2]) as prompt:
cur_input = torch.clone(prompt[0]).to(distributed_state.device)

gen_tokens = model.generate(
cur_input,
do_sample=True,
temperature=0.9,
max_length=100,
)
gen_text = tokenizer.batch_decode(gen_tokens)[0]
61 changes: 61 additions & 0 deletions examples/distributed_inference/data_parallel_stable_diffusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""
.. _data_parallel_stable_diffusion:
Torch-TensorRT Distributed Inference
======================================================
This interactive script is intended as a sample of distributed inference using data
parallelism using Accelerate
library with the Torch-TensorRT workflow on Stable Diffusion model.
"""

# %%
# Imports and Model Definition
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
import torch
from accelerate import PartialState
from diffusers import DiffusionPipeline

import torch_tensorrt

model_id = "CompVis/stable-diffusion-v1-4"

# Instantiate Stable Diffusion Pipeline with FP16 weights
pipe = DiffusionPipeline.from_pretrained(
model_id, revision="fp16", torch_dtype=torch.float16
)

distributed_state = PartialState()
pipe = pipe.to(distributed_state.device)

backend = "torch_tensorrt"

# Optimize the UNet portion with Torch-TensorRT
pipe.unet = torch.compile( # %%
# Inference
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# Assume there are 2 processes (2 devices)
pipe.unet,
backend=backend,
options={
"truncate_long_and_double": True,
"precision": torch.float16,
"debug": True,
"use_python_runtime": True,
},
dynamic=False,
)
torch_tensorrt.runtime.set_multi_device_safe_mode(True)


# %%
# Inference
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# Assume there are 2 processes (2 devices)
with distributed_state.split_between_processes(["a dog", "a cat"]) as prompt:
print("before \n")
result = pipe(prompt).images[0]
print("after ")
result.save(f"result_{distributed_state.process_index}.png")
3 changes: 3 additions & 0 deletions examples/distributed_inference/requirement.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
accelerate
transformers
diffusers

0 comments on commit dfbf6ea

Please sign in to comment.