Skip to content

add all kind of scheduler surport #18

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 33 additions & 13 deletions panorama.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from tqdm.auto import tqdm
import copy
from transformers import CLIPTextModel, CLIPTokenizer, logging
from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler
from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler, DDPMScheduler, EulerAncestralDiscreteScheduler


# suppress partial model loading warning
logging.set_verbosity_error()
Expand Down Expand Up @@ -52,14 +55,16 @@ def __init__(self, device, sd_version='2.0', hf_key=None):
model_key = "runwayml/stable-diffusion-v1-5"
else:
raise ValueError(f'Stable-diffusion version {self.sd_version} not supported.')

# model_key = "../js-sd-svc/models/runway-sd-1.5"
# Create model
self.vae = AutoencoderKL.from_pretrained(model_key, subfolder="vae").to(self.device)
self.tokenizer = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer")
self.text_encoder = CLIPTextModel.from_pretrained(model_key, subfolder="text_encoder").to(self.device)
self.unet = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet").to(self.device)

self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler")
print("now you can add what ever scheduler you want from diffusers~~~")
self.scheduler = EulerAncestralDiscreteScheduler.from_pretrained(model_key, subfolder="scheduler")

print(f'[INFO] loaded stable diffusion!')

Expand Down Expand Up @@ -90,7 +95,7 @@ def decode_latents(self, latents):
return imgs

@torch.no_grad()
def text2panorama(self, prompts, negative_prompts='', height=512, width=2048, num_inference_steps=50,
def text2panorama(self, prompts, negative_prompts='', height=512, width=2048, num_inference_steps=35,
guidance_scale=7.5):

if isinstance(prompts, str):
Expand All @@ -110,32 +115,47 @@ def text2panorama(self, prompts, negative_prompts='', height=512, width=2048, nu

self.scheduler.set_timesteps(num_inference_steps)

with torch.autocast('cuda'):
with torch.autocast('cuda'), tqdm(total=len(self.scheduler.timesteps)) as pbar:
for i, t in enumerate(self.scheduler.timesteps):
count.zero_()
value.zero_()

for h_start, h_end, w_start, w_end in views:
# -- core enhancement -----
noise_pred_all = torch.zeros(latent.shape, device=self.device) # full_latent as begning
# for h_start, h_end, w_start, w_end in views:
for index in range(len(views)):
h_start, h_end, w_start, w_end = views[index]

# TODO we can support batches, and pass multiple views at once to the unet
latent_view = latent[:, :, h_start:h_end, w_start:w_end]

# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latent_model_input = torch.cat([latent_view] * 2)

latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeds)['sample']

# perform guidance
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)


# compute the denoising step with the reference model
latents_view_denoised = self.scheduler.step(noise_pred, t, latent_view)['prev_sample']
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
count[:, :, h_start:h_end, w_start:w_end] += 1

# put the noise to the pos it belongto
noise_pred_all[:, :, h_start:h_end, w_start:w_end] += noise_pred

# # compute the denoising step with the reference model -- no need anymore~
# latents_view_denoised = self.scheduler.step(noise_pred, t, latent_view)['prev_sample']
# value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
# count[:, :, h_start:h_end, w_start:w_end] += 1

noise_pred = torch.where(count > 0, noise_pred_all / count, noise_pred_all)
latent = self.scheduler.step(noise_pred, t, latent)['prev_sample']

pbar.update()

# take the MultiDiffusion step
latent = torch.where(count > 0, value / count, value)
# latent = torch.where(count > 0, value / count, value)

# Img latents -> imgs
imgs = self.decode_latents(latent) # [1, 3, 512, 512]
Expand All @@ -152,7 +172,7 @@ def text2panorama(self, prompts, negative_prompts='', height=512, width=2048, nu
parser.add_argument('--H', type=int, default=512)
parser.add_argument('--W', type=int, default=4096)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--steps', type=int, default=50)
parser.add_argument('--steps', type=int, default=40)
opt = parser.parse_args()

seed_everything(opt.seed)
Expand Down