Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Diffusion Fast Example #2902

Merged
merged 16 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions examples/large_models/diffusion_fast/Download_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import torch
from diffusers import DiffusionPipeline

pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float32,
use_safetensors=True,
)
pipeline.save_pretrained("./Base_Diffusion_model")
64 changes: 64 additions & 0 deletions examples/large_models/diffusion_fast/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@

## Diffusion-Fast

[Diffusion fast](https://github.com/huggingface/diffusion-fast) is a simple and efficient pytorch-native way of optimizing Stable Diffusion XL (SDXL).

It features:
* Running with the bfloat16 precision
* scaled_dot_product_attention (SDPA)
* torch.compile
* Combining q,k,v projections for attention computation
* Dynamic int8 quantization

Details about the optimizations and various results can be found in this [blog](https://pytorch.org/blog/accelerating-generative-ai-3/).
The example has been tested on A10, A100 as well as H100.


#### Pre-requisites

`cd` to the example folder `examples/image_generation/diffusion_fast`

Install dependencies and upgrade torch to nightly build (currently required)
```
git clone https://github.com/huggingface/diffusion-fast.git
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 --ignore-installed -y
pip install accelerate transformers peft
pip install --no-cache-dir git+https://github.com/pytorch-labs/ao@54bcd5a10d0abbe7b0c045052029257099f83fd9
pip install pandas matplotlib seaborn
```
### Step 1: Download the Stable diffusion model

```bash
python Download_model.py
```
This saves the model in `Base_Diffusion_model`

### Step 1: Generate model archive
At this stage we're creating the model archive which includes the configuration of our model in [model_config.yaml](./model_config.yaml).
It's also the point where we need to decide if we want to deploy our model on a single or multiple GPUs.
For the single GPU case we can use the default configuration that can be found in [model_config.yaml](./model_config.yaml).

```
torch-model-archiver --model-name diffusion_fast --version 1.0 --handler diffusion_fast_handler.py --config-file model_config.yaml --extra-files "diffusion-fast/utils/pipeline_utils.py" --archive-format no-archive
mv Base_Diffusion_model diffusion_fast/
```

### Step 2: Add the model archive to model store

```
mkdir model_store
mv diffusion_fast model_store
```

### Step 3: Start torchserve

```
torchserve --start --ts-config config.properties --model-store model_store --models diffusion_fast
```

### Step 4: Run inference

```
python query.py --url "http://localhost:8080/predictions/diffusion_fast" --prompt "a photo of an astronaut riding a horse on mars"
```
The image generated will be written to a file `output-<>.jpg`
4 changes: 4 additions & 0 deletions examples/large_models/diffusion_fast/config.properties
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
inference_address=http://127.0.0.1:8080
management_address=http://127.0.0.1:8081
metrics_address=http://127.0.0.1:8082
max_response_size=655350000
133 changes: 133 additions & 0 deletions examples/large_models/diffusion_fast/diffusion_fast_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import logging
import os
from pathlib import Path

import numpy as np
import torch
from pipeline_utils import load_pipeline

from ts.handler_utils.timer import timed
from ts.torch_handler.base_handler import BaseHandler

logger = logging.getLogger(__name__)


class DiffusionFastHandler(BaseHandler):
"""
Diffusion-Fast handler class for text to image generation.
"""

def __init__(self):
super().__init__()
self.initialized = False

def initialize(self, ctx):
"""In this initialize function, the Diffusion Fast model is loaded and
initialized here.
Args:
ctx (context): It is a JSON Object containing information
pertaining to the model artifacts parameters.
"""
self.context = ctx
self.manifest = ctx.manifest
properties = ctx.system_properties
model_dir = properties.get("model_dir")

if torch.cuda.is_available() and properties.get("gpu_id") is not None:
self.map_location = "cuda"
self.device = torch.device(
self.map_location + ":" + str(properties.get("gpu_id"))
)
else:
self.map_location = "cpu"
self.device = torch.device(self.map_location)

self.num_inference_steps = ctx.model_yaml_config["handler"][
"num_inference_steps"
]

# Parameters for the model
compile_unet = ctx.model_yaml_config["handler"]["compile_unet"]
compile_vae = ctx.model_yaml_config["handler"]["compile_vae"]
compile_mode = ctx.model_yaml_config["handler"]["compile_mode"]
enable_fused_projections = ctx.model_yaml_config["handler"][
"enable_fused_projections"
]
do_quant = ctx.model_yaml_config["handler"]["do_quant"]
change_comp_config = ctx.model_yaml_config["handler"]["change_comp_config"]
no_sdpa = ctx.model_yaml_config["handler"]["no_sdpa"]
no_bf16 = ctx.model_yaml_config["handler"]["no_bf16"]
upcast_vae = ctx.model_yaml_config["handler"]["upcast_vae"]

# Load model weights
model_path = Path(ctx.model_yaml_config["handler"]["model_path"])
ckpt = os.path.join(model_dir, model_path)

self.pipeline = load_pipeline(
ckpt=ckpt,
compile_unet=compile_unet,
compile_vae=compile_vae,
compile_mode=compile_mode,
enable_fused_projections=enable_fused_projections,
do_quant=do_quant,
change_comp_config=change_comp_config,
no_bf16=no_bf16,
no_sdpa=no_sdpa,
upcast_vae=upcast_vae,
)

logger.info("Diffusion Fast model loaded successfully")

self.initialized = True

@timed
def preprocess(self, requests):
"""Basic text preprocessing, of the user's prompt.
Args:
requests (str): The Input data in the form of text is passed on to the preprocess
function.
Returns:
list : The preprocess function returns a list of prompts.
"""

assert (
len(requests) == 1
), "Diffusion Fast is currently only supported with batch_size=1"

inputs = []
for _, data in enumerate(requests):
input_text = data.get("data")
if input_text is None:
input_text = data.get("body")
if isinstance(input_text, (bytes, bytearray)):
input_text = input_text.decode("utf-8")
inputs.append(input_text)
return inputs

@timed
def inference(self, inputs):
"""Generates the image relevant to the received text.
Args:
input_batch (list): List of Text from the pre-process function is passed here
Returns:
list : It returns a list of the generate images for the input text
"""
# Handling inference for sequence_classification.
inferences = self.pipeline(
inputs, num_inference_steps=self.num_inference_steps, height=768, width=768
).images

return inferences

@timed
def postprocess(self, inference_output):
"""Post Process Function converts the generated image into Torchserve readable format.
Args:
inference_output (list): It contains the generated image of the input text.
Returns:
(list): Returns a list of the images.
"""
images = []
for image in inference_output:
images.append(np.array(image).tolist())
return images
18 changes: 18 additions & 0 deletions examples/large_models/diffusion_fast/model_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
minWorkers: 1
maxWorkers: 1
maxBatchDelay: 200
responseTimeout: 3600
deviceType: "gpu"
handler:
model_path: "Base_Diffusion_model"
num_inference_steps: 30
compile_unet: true
compile_mode: "max-autotune"
compile_vae: true
enable_fused_projections: true
do_quant: "int8dynamic"
change_comp_config: true
no_sdpa: false
no_bf16: false
upcast_vae: false
profile: true
27 changes: 27 additions & 0 deletions examples/large_models/diffusion_fast/query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import argparse
import json
from datetime import datetime

import numpy as np
import requests
from PIL import Image

parser = argparse.ArgumentParser()
parser.add_argument(
"--url", type=str, required=True, help="Torchserve inference endpoint"
)
parser.add_argument(
"--prompt", type=str, required=True, help="Prompt for image generation"
)
parser.add_argument(
"--filename",
type=str,
default="output-{}.jpg".format(str(datetime.now().strftime("%Y%m%d%H%M%S"))),
help="Filename of output image",
)
args = parser.parse_args()

response = requests.post(args.url, data=args.prompt)
# Contruct image from response
image = Image.fromarray(np.array(json.loads(response.text), dtype="uint8"))
image.save(args.filename)
3 changes: 3 additions & 0 deletions ts_scripts/spellcheck_conf/wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1164,6 +1164,9 @@ compilable
nightlies
torchexportaotcompile
autotune
SDXL
SDPA
bfloat
bb
babyllama
libbabyllama
Expand Down
Loading