diff --git a/panorama.py b/panorama.py index 33d49ec..3bdab30 100644 --- a/panorama.py +++ b/panorama.py @@ -1,5 +1,8 @@ +from tqdm.auto import tqdm +import copy from transformers import CLIPTextModel, CLIPTokenizer, logging -from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler +from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler, DDPMScheduler, EulerAncestralDiscreteScheduler + # suppress partial model loading warning logging.set_verbosity_error() @@ -52,7 +55,7 @@ def __init__(self, device, sd_version='2.0', hf_key=None): model_key = "runwayml/stable-diffusion-v1-5" else: raise ValueError(f'Stable-diffusion version {self.sd_version} not supported.') - + # model_key = "../js-sd-svc/models/runway-sd-1.5" # Create model self.vae = AutoencoderKL.from_pretrained(model_key, subfolder="vae").to(self.device) self.tokenizer = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer") @@ -60,6 +63,8 @@ def __init__(self, device, sd_version='2.0', hf_key=None): self.unet = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet").to(self.device) self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler") + print("now you can add what ever scheduler you want from diffusers~~~") + self.scheduler = EulerAncestralDiscreteScheduler.from_pretrained(model_key, subfolder="scheduler") print(f'[INFO] loaded stable diffusion!') @@ -90,7 +95,7 @@ def decode_latents(self, latents): return imgs @torch.no_grad() - def text2panorama(self, prompts, negative_prompts='', height=512, width=2048, num_inference_steps=50, + def text2panorama(self, prompts, negative_prompts='', height=512, width=2048, num_inference_steps=35, guidance_scale=7.5): if isinstance(prompts, str): @@ -110,32 +115,47 @@ def text2panorama(self, prompts, negative_prompts='', height=512, width=2048, nu self.scheduler.set_timesteps(num_inference_steps) - with torch.autocast('cuda'): + with torch.autocast('cuda'), tqdm(total=len(self.scheduler.timesteps)) as pbar: for i, t in enumerate(self.scheduler.timesteps): count.zero_() value.zero_() - - for h_start, h_end, w_start, w_end in views: + # -- core enhancement ----- + noise_pred_all = torch.zeros(latent.shape, device=self.device) # full_latent as begning + # for h_start, h_end, w_start, w_end in views: + for index in range(len(views)): + h_start, h_end, w_start, w_end = views[index] + # TODO we can support batches, and pass multiple views at once to the unet latent_view = latent[:, :, h_start:h_end, w_start:w_end] # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. latent_model_input = torch.cat([latent_view] * 2) - + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeds)['sample'] - + # perform guidance noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + - # compute the denoising step with the reference model - latents_view_denoised = self.scheduler.step(noise_pred, t, latent_view)['prev_sample'] - value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised count[:, :, h_start:h_end, w_start:w_end] += 1 + # put the noise to the pos it belongto + noise_pred_all[:, :, h_start:h_end, w_start:w_end] += noise_pred + + # # compute the denoising step with the reference model -- no need anymore~ + # latents_view_denoised = self.scheduler.step(noise_pred, t, latent_view)['prev_sample'] + # value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised + # count[:, :, h_start:h_end, w_start:w_end] += 1 + + noise_pred = torch.where(count > 0, noise_pred_all / count, noise_pred_all) + latent = self.scheduler.step(noise_pred, t, latent)['prev_sample'] + + pbar.update() + # take the MultiDiffusion step - latent = torch.where(count > 0, value / count, value) + # latent = torch.where(count > 0, value / count, value) # Img latents -> imgs imgs = self.decode_latents(latent) # [1, 3, 512, 512] @@ -152,7 +172,7 @@ def text2panorama(self, prompts, negative_prompts='', height=512, width=2048, nu parser.add_argument('--H', type=int, default=512) parser.add_argument('--W', type=int, default=4096) parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--steps', type=int, default=50) + parser.add_argument('--steps', type=int, default=40) opt = parser.parse_args() seed_everything(opt.seed)