Skip to content

Commit

Permalink
examples: Stable Diffusion torch.compile
Browse files Browse the repository at this point in the history
- Tutorial for using torch.compile with Stable Diffusion and
Torch-TensorRT
- Demonstration of output images without need to rerun the whole script
upon initial compilation of the docs
  • Loading branch information
gs-olive committed Nov 3, 2023
1 parent 6887426 commit 68a5892
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 1 deletion.
3 changes: 2 additions & 1 deletion docsrc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ Tutorials
tutorials/_rendered_examples/dynamo/torch_compile_resnet_example
tutorials/_rendered_examples/dynamo/torch_compile_transformers_example
tutorials/_rendered_examples/dynamo/torch_compile_advanced_usage
tutorials/_rendered_examples/dynamo/torch_compile_stable_diffusion

Python API Documenation
------------------------
Expand Down Expand Up @@ -206,4 +207,4 @@ Legacy Further Information (TorchScript)
* `GTC 2021 Fall Talk <https://www.nvidia.com/en-us/on-demand/session/gtcfall21-a31107/>`_
* `PyTorch Ecosystem Day 2021 <https://assets.pytorch.org/pted2021/posters/I6.png>`_
* `PyTorch Developer Conference 2021 <https://s3.amazonaws.com/assets.pytorch.org/ptdd2021/posters/D2.png>`_
* `PyTorch Developer Conference 2022 <https://pytorch.s3.amazonaws.com/posters/ptc2022/C04.pdf>`_
* `PyTorch Developer Conference 2022 <https://pytorch.s3.amazonaws.com/posters/ptc2022/C04.pdf>`_
Binary file added docsrc/tutorials/images/majestic_castle.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions examples/dynamo/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ a number of ways you can leverage this backend to accelerate inference.
* :ref:`torch_compile_resnet`: Compiling a ResNet model using the Torch Compile Frontend for ``torch_tensorrt.compile``
* :ref:`torch_compile_transformer`: Compiling a Transformer model using ``torch.compile``
* :ref:`torch_compile_advanced_usage`: Advanced usage including making a custom backend to use directly with the ``torch.compile`` API
* :ref:`torch_compile_stable_diffusion`: Compiling a Stable Diffusion model using ``torch.compile``
55 changes: 55 additions & 0 deletions examples/dynamo/torch_compile_stable_diffusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""
.. _torch_compile_stable_diffusion:
Torch Compile Stable Diffusion
======================================================
This interactive script is intended as a sample of the Torch-TensorRT workflow with `torch.compile` on a Stable Diffusion model. A sample output is featured below:
.. image:: /tutorials/images/majestic_castle.png
:width: 512px
:height: 512px
:scale: 50 %
:align: right
"""

# %%
# Imports and Model Definition
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

import torch
from diffusers import DiffusionPipeline

import torch_tensorrt

model_id = "CompVis/stable-diffusion-v1-4"
device = "cuda:0"

# Instantiate Stable Diffusion Pipeline with FP16 weights
pipe = DiffusionPipeline.from_pretrained(
model_id, revision="fp16", torch_dtype=torch.float16
)
pipe = pipe.to(device)

backend = "torch_tensorrt"

# Optimize the UNet portion with Torch-TensorRT
pipe.unet = torch.compile(
pipe.unet,
backend=backend,
options={
"truncate_long_and_double": True,
"precision": torch.float16,
},
dynamic=False,
)

# %%
# Inference
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

prompt = "a majestic castle in the clouds"
image = pipe(prompt).images[0]

image.save("images/majestic_castle.png")
image.show()

0 comments on commit 68a5892

Please sign in to comment.