From e648d3d5ac7ca8dba47d5f6b32414d37da53b229 Mon Sep 17 00:00:00 2001 From: LuChengTHU Date: Thu, 3 Nov 2022 17:42:28 +0800 Subject: [PATCH 01/22] add dpmsolver discrete pytorch scheduler --- src/diffusers/__init__.py | 1 + src/diffusers/schedulers/__init__.py | 1 + .../scheduling_dpmsolver_discrete.py | 377 ++++++++++++++++++ src/diffusers/utils/dummy_pt_objects.py | 15 + 4 files changed, 394 insertions(+) create mode 100644 src/diffusers/schedulers/scheduling_dpmsolver_discrete.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 3f3f3b56a253..e3eb5b2e80dd 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -42,6 +42,7 @@ from .schedulers import ( DDIMScheduler, DDPMScheduler, + DPMSolverDiscreteScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, IPNDMScheduler, diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 1be541ba8b66..5fb00c288213 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -19,6 +19,7 @@ if is_torch_available(): from .scheduling_ddim import DDIMScheduler from .scheduling_ddpm import DDPMScheduler + from .scheduling_dpmsolver_discrete import DPMSolverDiscreteScheduler from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler from .scheduling_euler_discrete import EulerDiscreteScheduler from .scheduling_ipndm import IPNDMScheduler diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_discrete.py b/src/diffusers/schedulers/scheduling_dpmsolver_discrete.py new file mode 100644 index 000000000000..51ed10e465f9 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_dpmsolver_discrete.py @@ -0,0 +1,377 @@ +# Copyright 2022 TSAIL Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver + +import math +from typing import Optional, Tuple, Union, List + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils import SchedulerMixin, SchedulerOutput + + +def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + + def alpha_bar(time_step): + return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class DPMSolverDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + DPM-Solver. + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functions. + + For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095 + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + skip_prk_steps (`bool`): + allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required + before plms steps; defaults to `False`. + set_alpha_to_one (`bool`, default `False`): + each diffusion step uses the value of alphas product at that step and at the previous one. For the final + step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the value of alpha at step 0. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. + + """ + + _compatible_classes = [ + "DDIMScheduler", + "DDPMScheduler", + "PNDMScheduler", + "LMSDiscreteScheduler", + "EulerDiscreteScheduler", + "EulerAncestralDiscreteScheduler", + ] + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[np.ndarray] = None, + steps_offset: int = 0, + solver_order=3, + predict_x0=True, + thresholding=False, + sample_max_value=1.0, + solver_type="dpm_solver", + denoise_final=True, + ): + if trained_betas is not None: + self.betas = torch.from_numpy(trained_betas) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + # Currently we only support VP-type noise schedule + self.alpha_t = torch.sqrt(self.alphas_cumprod) + self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) + self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # settings for DPM-Solver + self.solver_order = solver_order + self.predict_x0 = predict_x0 + self.thresholding = thresholding + self.sample_max_value = sample_max_value + self.denoise_final = denoise_final + if solver_type in ["dpm_solver", "taylor"]: + self.solver_type = solver_type + else: + raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") + + # setable values + self.num_inference_steps = None + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) + self.model_outputs = [None,] * self.solver_order + self.lower_order_nums = 0 + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, optional): + the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + self.num_inference_steps = num_inference_steps + timesteps = np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1).round()[::-1][:-1].copy().astype(np.int64) + self.timesteps = torch.from_numpy(timesteps).to(device) + self.model_outputs = [None,] * self.solver_order + self.lower_order_nums = 0 + + def convert_model_output( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor + ) -> torch.FloatTensor: + """ + TODO + """ + if self.predict_x0: + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = (sample - sigma_t * model_output) / alpha_t + if self.thresholding: + # Dynamic thresholding in https://arxiv.org/abs/2205.11487 + p = 0.995 # A hyperparameter in the paper of "Imagen" [1]. + s = torch.quantile(torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)), p, dim=1) + s = torch.maximum(s, self.sample_max_value * torch.ones_like(s).to(s.device))[(...,) + (None,)*(x0_pred.ndim - 1)] + x0_pred = torch.clamp(x0_pred, -s, s) / s + return x0_pred + else: + return model_output + + def dpm_solver_first_order_update( + self, + model_output: torch.FloatTensor, + timestep: int, + prev_timestep: int, + sample: torch.FloatTensor, + ) -> torch.FloatTensor: + """ + TODO + """ + lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep] + alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep] + sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep] + h = lambda_t - lambda_s + if self.predict_x0: + x_t = ( + (sigma_t / sigma_s) * sample + - (alpha_t * (torch.exp(-h) - 1.)) * model_output + ) + else: + x_t = ( + (alpha_t / alpha_s) * sample + - (sigma_t * (torch.exp(h) - 1.)) * model_output + ) + return x_t + + def multistep_dpm_solver_second_order_update( + self, + model_output_list: List[torch.FloatTensor], + timestep_list: List[int], + prev_timestep: int, + sample: torch.FloatTensor, + ) -> torch.FloatTensor: + """ + TODO + """ + t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] + m0, m1 = model_output_list[-1], model_output_list[-2] + lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1] + alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] + sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] + h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 + r0 = h_0 / h + D0, D1 = m0, (1. / r0) * (m0 - m1) + if self.predict_x0: + if self.solver_type == 'dpm_solver': + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.)) * D0 + - 0.5 * (alpha_t * (torch.exp(-h) - 1.)) * D1 + ) + elif self.solver_type == 'taylor': + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.) / h + 1.)) * D1 + ) + else: + if self.solver_type == 'dpm_solver': + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.)) * D0 + - 0.5 * (sigma_t * (torch.exp(h) - 1.)) * D1 + ) + elif self.solver_type == 'taylor': + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.)) * D0 + - (sigma_t * ((torch.exp(h) - 1.) / h - 1.)) * D1 + ) + return x_t + + def multistep_dpm_solver_third_order_update( + self, + model_output_list: List[torch.FloatTensor], + timestep_list: List[int], + prev_timestep: int, + sample: torch.FloatTensor, + ) -> torch.FloatTensor: + """ + TODO + """ + t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3] + m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] + lambda_t, lambda_s0, lambda_s1, lambda_s2 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1], self.lambda_t[s2] + alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] + sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] + h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 + r0, r1 = h_0 / h, h_1 / h + D0 = m0 + D1_0, D1_1 = (1. / r0) * (m0 - m1), (1. / r1) * (m1 - m2) + D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) + D2 = (1. / (r0 + r1)) * (D1_0 - D1_1) + if self.predict_x0: + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.) / h + 1.)) * D1 + - (alpha_t * ((torch.exp(-h) - 1. + h) / h**2 - 0.5)) * D2 + ) + else: + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.)) * D0 + - (sigma_t * ((torch.exp(h) - 1.) / h - 1.)) * D1 + - (sigma_t * ((torch.exp(h) - 1. - h) / h**2 - 0.5)) * D2 + ) + return x_t + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Step function propagating the sample with the multistep DPM-Solver. This has one forward pass with multiple + times to approximate the solution. + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is + True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + step_index = (self.timesteps == timestep).nonzero().item() + prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] + denoise_final = (step_index == len(self.timesteps) - 1) and self.denoise_final + denoise_second = (step_index == len(self.timesteps) - 2) and self.denoise_final + + model_output = self.convert_model_output(model_output, timestep, sample) + self.model_outputs.append(model_output) + for i in range(self.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.model_outputs[-1] = model_output + + if self.solver_order == 1 or self.lower_order_nums < 1 or denoise_final: + prev_sample = self.dpm_solver_first_order_update(model_output, timestep, prev_timestep, sample) + elif self.solver_order == 2 or self.lower_order_nums < 2 or denoise_second: + timestep_list = [self.timesteps[step_index - 1], timestep] + prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, timestep_list, prev_timestep, sample) + else: + timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep] + prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, timestep_list, prev_timestep, sample) + + if self.lower_order_nums < self.solver_order: + self.lower_order_nums += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): input sample + + Returns: + `torch.FloatTensor`: scaled input sample + """ + return sample + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 25aa82d6c5b2..b7b910281cd4 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -302,6 +302,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class DPMSolverDiscreteScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class EulerAncestralDiscreteScheduler(metaclass=DummyObject): _backends = ["torch"] From 2fe70df840e0aff05cc531a7e055681f847c723b Mon Sep 17 00:00:00 2001 From: LuChengTHU Date: Thu, 3 Nov 2022 22:16:52 +0800 Subject: [PATCH 02/22] fix some typos in dpm-solver pytorch --- .../scheduling_dpmsolver_discrete.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_discrete.py b/src/diffusers/schedulers/scheduling_dpmsolver_discrete.py index 51ed10e465f9..4ebb51f76c6e 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_discrete.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_discrete.py @@ -104,13 +104,12 @@ def __init__( beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[np.ndarray] = None, - steps_offset: int = 0, - solver_order=3, - predict_x0=True, - thresholding=False, - sample_max_value=1.0, - solver_type="dpm_solver", - denoise_final=True, + solver_order: int = 3, + predict_x0: bool = True, + thresholding: bool = False, + sample_max_value: float = 1.0, + solver_type: str = "dpm_solver", + denoise_final: bool = True, ): if trained_betas is not None: self.betas = torch.from_numpy(trained_betas) @@ -150,7 +149,7 @@ def __init__( # setable values self.num_inference_steps = None - timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() self.timesteps = torch.from_numpy(timesteps) self.model_outputs = [None,] * self.solver_order self.lower_order_nums = 0 @@ -185,7 +184,7 @@ def convert_model_output( x0_pred = (sample - sigma_t * model_output) / alpha_t if self.thresholding: # Dynamic thresholding in https://arxiv.org/abs/2205.11487 - p = 0.995 # A hyperparameter in the paper of "Imagen" [1]. + p = 0.995 # A hyperparameter in the paper of "Imagen" (https://arxiv.org/abs/2205.11487). s = torch.quantile(torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)), p, dim=1) s = torch.maximum(s, self.sample_max_value * torch.ones_like(s).to(s.device))[(...,) + (None,)*(x0_pred.ndim - 1)] x0_pred = torch.clamp(x0_pred, -s, s) / s @@ -338,7 +337,6 @@ def step( denoise_second = (step_index == len(self.timesteps) - 2) and self.denoise_final model_output = self.convert_model_output(model_output, timestep, sample) - self.model_outputs.append(model_output) for i in range(self.solver_order - 1): self.model_outputs[i] = self.model_outputs[i + 1] self.model_outputs[-1] = model_output From dc2d54858acbb6998de7dd555a41e6339645a423 Mon Sep 17 00:00:00 2001 From: LuChengTHU Date: Fri, 4 Nov 2022 02:17:46 +0800 Subject: [PATCH 03/22] add dpm-solver pytorch in stable-diffusion pipeline --- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 1ccc87804e68..a66587475531 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -15,6 +15,7 @@ EulerDiscreteScheduler, LMSDiscreteScheduler, PNDMScheduler, + DPMSolverDiscreteScheduler, ) from ...utils import deprecate, logging from . import StableDiffusionPipelineOutput @@ -59,7 +60,8 @@ def __init__( tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: Union[ - DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler + DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, + DPMSolverDiscreteScheduler, ], safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, From b71ec000e9147729fcf9287296810b97fe4bfb78 Mon Sep 17 00:00:00 2001 From: LuChengTHU Date: Fri, 4 Nov 2022 14:00:00 +0800 Subject: [PATCH 04/22] add jax/flax version dpm-solver --- src/diffusers/__init__.py | 1 + .../pipeline_flax_stable_diffusion.py | 6 +- src/diffusers/schedulers/__init__.py | 1 + .../scheduling_dpmsolver_discrete.py | 4 +- .../scheduling_dpmsolver_discrete_flax.py | 487 ++++++++++++++++++ src/diffusers/utils/dummy_flax_objects.py | 15 + 6 files changed, 509 insertions(+), 5 deletions(-) create mode 100644 src/diffusers/schedulers/scheduling_dpmsolver_discrete_flax.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index e3eb5b2e80dd..dd03041bb215 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -93,6 +93,7 @@ from .schedulers import ( FlaxDDIMScheduler, FlaxDDPMScheduler, + FlaxDPMSolverDiscreteScheduler, FlaxKarrasVeScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index fe0e284c6720..a8ed353580dd 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -14,7 +14,7 @@ from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel from ...pipeline_flax_utils import FlaxDiffusionPipeline -from ...schedulers import FlaxDDIMScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler +from ...schedulers import FlaxDDIMScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler, FlaxDPMSolverDiscreteScheduler from ...utils import logging from . import FlaxStableDiffusionPipelineOutput from .safety_checker_flax import FlaxStableDiffusionSafetyChecker @@ -43,7 +43,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of - [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], or [`FlaxPNDMScheduler`]. + [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or [`FlaxDPMSolverDiscreteScheduler`]. safety_checker ([`FlaxStableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. @@ -57,7 +57,7 @@ def __init__( text_encoder: FlaxCLIPTextModel, tokenizer: CLIPTokenizer, unet: FlaxUNet2DConditionModel, - scheduler: Union[FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler], + scheduler: Union[FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverDiscreteScheduler], safety_checker: FlaxStableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, dtype: jnp.dtype = jnp.float32, diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 5fb00c288213..8e095d384fca 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -36,6 +36,7 @@ if is_flax_available(): from .scheduling_ddim_flax import FlaxDDIMScheduler from .scheduling_ddpm_flax import FlaxDDPMScheduler + from .scheduling_dpmsolver_discrete_flax import FlaxDPMSolverDiscreteScheduler from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler from .scheduling_pndm_flax import FlaxPNDMScheduler diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_discrete.py b/src/diffusers/schedulers/scheduling_dpmsolver_discrete.py index 4ebb51f76c6e..4387516bd447 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_discrete.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_discrete.py @@ -104,12 +104,12 @@ def __init__( beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[np.ndarray] = None, - solver_order: int = 3, + solver_order: int = 2, predict_x0: bool = True, thresholding: bool = False, sample_max_value: float = 1.0, solver_type: str = "dpm_solver", - denoise_final: bool = True, + denoise_final: bool = False, ): if trained_betas is not None: self.betas = torch.from_numpy(trained_betas) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_discrete_flax.py b/src/diffusers/schedulers/scheduling_dpmsolver_discrete_flax.py new file mode 100644 index 000000000000..9c1bc68d9acc --- /dev/null +++ b/src/diffusers/schedulers/scheduling_dpmsolver_discrete_flax.py @@ -0,0 +1,487 @@ +# Copyright 2022 TSAIL Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver + +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union, List + +import flax +import jax +import jax.numpy as jnp + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left + + +def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray: + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + + Returns: + betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs + """ + + def alpha_bar(time_step): + return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return jnp.array(betas, dtype=jnp.float32) + + +@flax.struct.dataclass +class DPMSolverDiscreteSchedulerState: + # setable values + num_inference_steps: Optional[int] = None + timesteps: Optional[jnp.ndarray] = None + + # running values + model_outputs: Optional[jnp.ndarray] = None + lower_order_nums: Optional[int] = None + step_index: Optional[int] = None + prev_timestep: Optional[int] = None + cur_sample: Optional[jnp.ndarray] = None + + @classmethod + def create(cls, num_train_timesteps: int): + return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1]) + + +@dataclass +class FlaxDPMSolverDiscreteSchedulerOutput(FlaxSchedulerOutput): + state: DPMSolverDiscreteSchedulerState + + +class FlaxDPMSolverDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): + """ + DPM-Solver. + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functions. + + For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095 + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + skip_prk_steps (`bool`): + allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required + before plms steps; defaults to `False`. + set_alpha_to_one (`bool`, default `False`): + each diffusion step uses the value of alphas product at that step and at the previous one. For the final + step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the value of alpha at step 0. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. + + """ + + @property + def has_state(self): + return True + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[jnp.ndarray] = None, + solver_order: int = 2, + predict_x0: bool = True, + thresholding: bool = False, + sample_max_value: float = 1.0, + solver_type: str = "dpm_solver", + denoise_final: bool = False, + ): + if trained_betas is not None: + self.betas = jnp.asarray(trained_betas) + elif beta_schedule == "linear": + self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0) + # Currently we only support VP-type noise schedule + self.alpha_t = jnp.sqrt(self.alphas_cumprod) + self.sigma_t = jnp.sqrt(1 - self.alphas_cumprod) + self.lambda_t = jnp.log(self.alpha_t) - jnp.log(self.sigma_t) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # settings for DPM-Solver + self.solver_order = solver_order + self.predict_x0 = predict_x0 + self.thresholding = thresholding + self.sample_max_value = sample_max_value + self.denoise_final = denoise_final + if solver_type in ["dpm_solver", "taylor"]: + self.solver_type = solver_type + else: + raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") + + def create_state(self): + return DPMSolverDiscreteSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps) + + def set_timesteps(self, state: DPMSolverDiscreteSchedulerState, num_inference_steps: int, shape: Tuple) -> DPMSolverDiscreteSchedulerState: + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + state (`DPMSolverDiscreteSchedulerState`): + the `FlaxDPMSolverDiscreteScheduler` state data class instance. + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + shape (`Tuple`): + the shape of the samples to be generated. + """ + timesteps = jnp.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1).round()[::-1][:-1].astype(jnp.int32) + + return state.replace( + num_inference_steps=num_inference_steps, + timesteps=timesteps, + model_outputs=jnp.zeros((self.solver_order,) + shape), + lower_order_nums=0, + step_index=0, + prev_timestep=-1, + cur_sample=jnp.zeros(shape), + ) + + def convert_model_output( + self, + model_output: jnp.ndarray, + timestep: int, + sample: jnp.ndarray, + ) -> jnp.ndarray: + """ + TODO + """ + if self.predict_x0: + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = (sample - sigma_t * model_output) / alpha_t + if self.thresholding: + # A hyperparameter in the paper of Imagen (https://arxiv.org/abs/2205.11487). + p = 0.995 + s = jnp.percentile(jnp.abs(x0_pred), p, axis=tuple(range(1, x0_pred.ndim))) + s = jnp.max(s, self.max_val) + x0_pred = jnp.clip(x0_pred, -s, s) / s + return x0_pred + else: + return model_output + + def dpm_solver_first_order_update( + self, + model_output: jnp.ndarray, + timestep: int, + prev_timestep: int, + sample: jnp.ndarray + ) -> jnp.ndarray: + """ + TODO + """ + t, s0 = prev_timestep, timestep + m0 = model_output + lambda_t, lambda_s = self.lambda_t[t], self.lambda_t[s0] + alpha_t, alpha_s = self.alpha_t[t], self.alpha_t[s0] + sigma_t, sigma_s = self.sigma_t[t], self.sigma_t[s0] + h = lambda_t - lambda_s + if self.predict_x0: + x_t = ( + (sigma_t / sigma_s) * sample + - (alpha_t * (jnp.exp(-h) - 1.)) * m0 + ) + else: + x_t = ( + (alpha_t / alpha_s) * sample + - (sigma_t * (jnp.exp(h) - 1.)) * m0 + ) + return x_t + + def multistep_dpm_solver_second_order_update( + self, + model_output_list: jnp.ndarray, + timestep_list: List[int], + prev_timestep: int, + sample: jnp.ndarray, + ) -> jnp.ndarray: + """ + TODO + """ + t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] + m0, m1 = model_output_list[-1], model_output_list[-2] + lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1] + alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] + sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] + h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 + r0 = h_0 / h + D0, D1 = m0, (1. / r0) * (m0 - m1) + if self.predict_x0: + if self.solver_type == 'dpm_solver': + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (jnp.exp(-h) - 1.)) * D0 + - 0.5 * (alpha_t * (jnp.exp(-h) - 1.)) * D1 + ) + elif self.solver_type == 'taylor': + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (jnp.exp(-h) - 1.)) * D0 + + (alpha_t * ((jnp.exp(-h) - 1.) / h + 1.)) * D1 + ) + else: + if self.solver_type == 'dpm_solver': + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (jnp.exp(h) - 1.)) * D0 + - 0.5 * (sigma_t * (jnp.exp(h) - 1.)) * D1 + ) + elif self.solver_type == 'taylor': + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (jnp.exp(h) - 1.)) * D0 + - (sigma_t * ((jnp.exp(h) - 1.) / h - 1.)) * D1 + ) + return x_t + + def multistep_dpm_solver_third_order_update( + self, + model_output_list: jnp.ndarray, + timestep_list: List[int], + prev_timestep: int, + sample: jnp.ndarray, + ) -> jnp.ndarray: + """ + TODO + """ + t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3] + m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] + lambda_t, lambda_s0, lambda_s1, lambda_s2 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1], self.lambda_t[s2] + alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] + sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] + h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 + r0, r1 = h_0 / h, h_1 / h + D0 = m0 + D1_0, D1_1 = (1. / r0) * (m0 - m1), (1. / r1) * (m1 - m2) + D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) + D2 = (1. / (r0 + r1)) * (D1_0 - D1_1) + if self.predict_x0: + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (jnp.exp(-h) - 1.)) * D0 + + (alpha_t * ((jnp.exp(-h) - 1.) / h + 1.)) * D1 + - (alpha_t * ((jnp.exp(-h) - 1. + h) / h**2 - 0.5)) * D2 + ) + else: + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (jnp.exp(h) - 1.)) * D0 + - (sigma_t * ((jnp.exp(h) - 1.) / h - 1.)) * D1 + - (sigma_t * ((jnp.exp(h) - 1. - h) / h**2 - 0.5)) * D2 + ) + return x_t + + def step( + self, + state: DPMSolverDiscreteSchedulerState, + model_output: jnp.ndarray, + timestep: int, + sample: jnp.ndarray, + return_dict: bool = True, + ) -> Union[FlaxDPMSolverDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by DPM-Solver. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + state (`DPMSolverDiscreteSchedulerState`): the `FlaxDPMSolverDiscreteScheduler` state data class instance. + model_output (`jnp.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than FlaxDPMSolverDiscreteSchedulerOutput class + + Returns: + [`FlaxDPMSolverDiscreteSchedulerOutput`] or `tuple`: [`FlaxDPMSolverDiscreteSchedulerOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + prev_timestep = jax.lax.cond( + state.step_index == len(state.timesteps) - 1, + lambda _: 0, + lambda _: state.timesteps[state.step_index + 1], + (), + ) + + model_output = self.convert_model_output(model_output, timestep, sample) + + model_outputs_new = jnp.roll(state.model_outputs, -1, axis=0) + model_outputs_new = model_outputs_new.at[-1].set(model_output) + state = state.replace( + model_outputs=model_outputs_new, + prev_timestep=prev_timestep, + cur_sample=sample, + ) + + def step_1( + state: DPMSolverDiscreteSchedulerState + ) -> jnp.ndarray: + return self.dpm_solver_first_order_update( + state.model_outputs[-1], + state.timesteps[state.step_index], + state.prev_timestep, + state.cur_sample, + ) + + def step_23( + state: DPMSolverDiscreteSchedulerState + ) -> jnp.ndarray: + + def step_2( + state: DPMSolverDiscreteSchedulerState + ) -> jnp.ndarray: + timestep_list = jnp.array([ + state.timesteps[state.step_index - 1], + state.timesteps[state.step_index] + ]) + return self.multistep_dpm_solver_second_order_update( + state.model_outputs, + timestep_list, + state.prev_timestep, + state.cur_sample, + ) + + def step_3( + state: DPMSolverDiscreteSchedulerState + ) -> jnp.ndarray: + timestep_list = jnp.array([ + state.timesteps[state.step_index - 2], + state.timesteps[state.step_index - 1], + state.timesteps[state.step_index] + ]) + return self.multistep_dpm_solver_third_order_update( + state.model_outputs, + timestep_list, + state.prev_timestep, + state.cur_sample, + ) + + if self.solver_order == 2: + return step_2(state) + elif self.denoise_final: + return jax.lax.cond( + state.lower_order_nums < 2, + step_2, + lambda state: jax.lax.cond( + state.step_index == len(state.timesteps) - 2, + step_2, + step_3, + state, + ), + state, + ) + else: + return jax.lax.cond( + state.lower_order_nums < 2, + step_2, + step_3, + state, + ) + + if self.solver_order == 1: + prev_sample = step_1(state) + elif self.denoise_final: + prev_sample = jax.lax.cond( + state.lower_order_nums < 1, + step_1, + lambda state: jax.lax.cond( + state.step_index == len(state.timesteps) - 1, + step_1, + step_23, + state, + ), + state, + ) + else: + prev_sample = jax.lax.cond( + state.lower_order_nums < 1, + step_1, + step_23, + state, + ) + + state = state.replace( + lower_order_nums=jnp.minimum(state.lower_order_nums + 1, self.solver_order), + step_index=(state.step_index + 1), + ) + + if not return_dict: + return (prev_sample, state) + + return FlaxDPMSolverDiscreteSchedulerOutput(prev_sample=prev_sample, state=state) + + def scale_model_input( + self, state: DPMSolverDiscreteSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None + ) -> jnp.ndarray: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + state (`DPMSolverDiscreteSchedulerState`): the `FlaxDPMSolverDiscreteScheduler` state data class instance. + sample (`jnp.ndarray`): input sample + timestep (`int`, optional): current timestep + + Returns: + `jnp.ndarray`: scaled input sample + """ + return sample + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/utils/dummy_flax_objects.py b/src/diffusers/utils/dummy_flax_objects.py index 708022d85b31..601ea4ed6b38 100644 --- a/src/diffusers/utils/dummy_flax_objects.py +++ b/src/diffusers/utils/dummy_flax_objects.py @@ -94,6 +94,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) +class FlaxDPMSolverDiscreteScheduler(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + class FlaxKarrasVeScheduler(metaclass=DummyObject): _backends = ["flax"] From 92ae94905d9a6544068fa93c23316fc3337c4377 Mon Sep 17 00:00:00 2001 From: LuChengTHU Date: Fri, 4 Nov 2022 14:05:15 +0800 Subject: [PATCH 05/22] change code style --- .../pipeline_flax_stable_diffusion.py | 14 +- .../pipeline_stable_diffusion.py | 8 +- .../scheduling_dpmsolver_discrete.py | 97 +++++++------ .../scheduling_dpmsolver_discrete_flax.py | 127 ++++++++---------- 4 files changed, 129 insertions(+), 117 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index a8ed353580dd..86e373efea72 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -14,7 +14,12 @@ from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel from ...pipeline_flax_utils import FlaxDiffusionPipeline -from ...schedulers import FlaxDDIMScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler, FlaxDPMSolverDiscreteScheduler +from ...schedulers import ( + FlaxDDIMScheduler, + FlaxDPMSolverDiscreteScheduler, + FlaxLMSDiscreteScheduler, + FlaxPNDMScheduler, +) from ...utils import logging from . import FlaxStableDiffusionPipelineOutput from .safety_checker_flax import FlaxStableDiffusionSafetyChecker @@ -43,7 +48,8 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of - [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or [`FlaxDPMSolverDiscreteScheduler`]. + [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or + [`FlaxDPMSolverDiscreteScheduler`]. safety_checker ([`FlaxStableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. @@ -57,7 +63,9 @@ def __init__( text_encoder: FlaxCLIPTextModel, tokenizer: CLIPTokenizer, unet: FlaxUNet2DConditionModel, - scheduler: Union[FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverDiscreteScheduler], + scheduler: Union[ + FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverDiscreteScheduler + ], safety_checker: FlaxStableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, dtype: jnp.dtype = jnp.float32, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index a66587475531..72885e6f9180 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -11,11 +11,11 @@ from ...pipeline_utils import DiffusionPipeline from ...schedulers import ( DDIMScheduler, + DPMSolverDiscreteScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler, PNDMScheduler, - DPMSolverDiscreteScheduler, ) from ...utils import deprecate, logging from . import StableDiffusionPipelineOutput @@ -60,7 +60,11 @@ def __init__( tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: Union[ - DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, DPMSolverDiscreteScheduler, ], safety_checker: StableDiffusionSafetyChecker, diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_discrete.py b/src/diffusers/schedulers/scheduling_dpmsolver_discrete.py index 4387516bd447..b2c51bba12be 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_discrete.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_discrete.py @@ -15,7 +15,7 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver import math -from typing import Optional, Tuple, Union, List +from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -151,7 +151,9 @@ def __init__( self.num_inference_steps = None timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() self.timesteps = torch.from_numpy(timesteps) - self.model_outputs = [None,] * self.solver_order + self.model_outputs = [ + None, + ] * self.solver_order self.lower_order_nums = 0 def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): @@ -165,16 +167,20 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ self.num_inference_steps = num_inference_steps - timesteps = np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1).round()[::-1][:-1].copy().astype(np.int64) + timesteps = ( + np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) self.timesteps = torch.from_numpy(timesteps).to(device) - self.model_outputs = [None,] * self.solver_order + self.model_outputs = [ + None, + ] * self.solver_order self.lower_order_nums = 0 def convert_model_output( - self, - model_output: torch.FloatTensor, - timestep: int, - sample: torch.FloatTensor + self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor ) -> torch.FloatTensor: """ TODO @@ -184,9 +190,11 @@ def convert_model_output( x0_pred = (sample - sigma_t * model_output) / alpha_t if self.thresholding: # Dynamic thresholding in https://arxiv.org/abs/2205.11487 - p = 0.995 # A hyperparameter in the paper of "Imagen" (https://arxiv.org/abs/2205.11487). + p = 0.995 # A hyperparameter in the paper of "Imagen" (https://arxiv.org/abs/2205.11487). s = torch.quantile(torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)), p, dim=1) - s = torch.maximum(s, self.sample_max_value * torch.ones_like(s).to(s.device))[(...,) + (None,)*(x0_pred.ndim - 1)] + s = torch.maximum(s, self.sample_max_value * torch.ones_like(s).to(s.device))[ + (...,) + (None,) * (x0_pred.ndim - 1) + ] x0_pred = torch.clamp(x0_pred, -s, s) / s return x0_pred else: @@ -207,15 +215,9 @@ def dpm_solver_first_order_update( sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep] h = lambda_t - lambda_s if self.predict_x0: - x_t = ( - (sigma_t / sigma_s) * sample - - (alpha_t * (torch.exp(-h) - 1.)) * model_output - ) + x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output else: - x_t = ( - (alpha_t / alpha_s) * sample - - (sigma_t * (torch.exp(h) - 1.)) * model_output - ) + x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output return x_t def multistep_dpm_solver_second_order_update( @@ -235,32 +237,32 @@ def multistep_dpm_solver_second_order_update( sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 r0 = h_0 / h - D0, D1 = m0, (1. / r0) * (m0 - m1) + D0, D1 = m0, (1.0 / r0) * (m0 - m1) if self.predict_x0: - if self.solver_type == 'dpm_solver': + if self.solver_type == "dpm_solver": x_t = ( (sigma_t / sigma_s0) * sample - - (alpha_t * (torch.exp(-h) - 1.)) * D0 - - 0.5 * (alpha_t * (torch.exp(-h) - 1.)) * D1 + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 ) - elif self.solver_type == 'taylor': + elif self.solver_type == "taylor": x_t = ( (sigma_t / sigma_s0) * sample - - (alpha_t * (torch.exp(-h) - 1.)) * D0 - + (alpha_t * ((torch.exp(-h) - 1.) / h + 1.)) * D1 + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 ) else: - if self.solver_type == 'dpm_solver': + if self.solver_type == "dpm_solver": x_t = ( (alpha_t / alpha_s0) * sample - - (sigma_t * (torch.exp(h) - 1.)) * D0 - - 0.5 * (sigma_t * (torch.exp(h) - 1.)) * D1 + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1 ) - elif self.solver_type == 'taylor': + elif self.solver_type == "taylor": x_t = ( (alpha_t / alpha_s0) * sample - - (sigma_t * (torch.exp(h) - 1.)) * D0 - - (sigma_t * ((torch.exp(h) - 1.) / h - 1.)) * D1 + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 ) return x_t @@ -276,28 +278,33 @@ def multistep_dpm_solver_third_order_update( """ t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3] m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] - lambda_t, lambda_s0, lambda_s1, lambda_s2 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1], self.lambda_t[s2] + lambda_t, lambda_s0, lambda_s1, lambda_s2 = ( + self.lambda_t[t], + self.lambda_t[s0], + self.lambda_t[s1], + self.lambda_t[s2], + ) alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 r0, r1 = h_0 / h, h_1 / h D0 = m0 - D1_0, D1_1 = (1. / r0) * (m0 - m1), (1. / r1) * (m1 - m2) + D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) - D2 = (1. / (r0 + r1)) * (D1_0 - D1_1) + D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) if self.predict_x0: x_t = ( (sigma_t / sigma_s0) * sample - - (alpha_t * (torch.exp(-h) - 1.)) * D0 - + (alpha_t * ((torch.exp(-h) - 1.) / h + 1.)) * D1 - - (alpha_t * ((torch.exp(-h) - 1. + h) / h**2 - 0.5)) * D2 + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 + - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 ) else: x_t = ( (alpha_t / alpha_s0) * sample - - (sigma_t * (torch.exp(h) - 1.)) * D0 - - (sigma_t * ((torch.exp(h) - 1.) / h - 1.)) * D1 - - (sigma_t * ((torch.exp(h) - 1. - h) / h**2 - 0.5)) * D2 + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 ) return x_t @@ -336,7 +343,7 @@ def step( denoise_final = (step_index == len(self.timesteps) - 1) and self.denoise_final denoise_second = (step_index == len(self.timesteps) - 2) and self.denoise_final - model_output = self.convert_model_output(model_output, timestep, sample) + model_output = self.convert_model_output(model_output, timestep, sample) for i in range(self.solver_order - 1): self.model_outputs[i] = self.model_outputs[i + 1] self.model_outputs[-1] = model_output @@ -345,10 +352,14 @@ def step( prev_sample = self.dpm_solver_first_order_update(model_output, timestep, prev_timestep, sample) elif self.solver_order == 2 or self.lower_order_nums < 2 or denoise_second: timestep_list = [self.timesteps[step_index - 1], timestep] - prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, timestep_list, prev_timestep, sample) + prev_sample = self.multistep_dpm_solver_second_order_update( + self.model_outputs, timestep_list, prev_timestep, sample + ) else: timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep] - prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, timestep_list, prev_timestep, sample) + prev_sample = self.multistep_dpm_solver_third_order_update( + self.model_outputs, timestep_list, prev_timestep, sample + ) if self.lower_order_nums < self.solver_order: self.lower_order_nums += 1 diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_discrete_flax.py b/src/diffusers/schedulers/scheduling_dpmsolver_discrete_flax.py index 9c1bc68d9acc..d546d2ec0186 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_discrete_flax.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_discrete_flax.py @@ -16,7 +16,7 @@ import math from dataclasses import dataclass -from typing import Optional, Tuple, Union, List +from typing import List, Optional, Tuple, Union import flax import jax @@ -137,9 +137,7 @@ def __init__( self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. - self.betas = ( - jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2 - ) + self.betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2 elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) @@ -170,7 +168,9 @@ def __init__( def create_state(self): return DPMSolverDiscreteSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps) - def set_timesteps(self, state: DPMSolverDiscreteSchedulerState, num_inference_steps: int, shape: Tuple) -> DPMSolverDiscreteSchedulerState: + def set_timesteps( + self, state: DPMSolverDiscreteSchedulerState, num_inference_steps: int, shape: Tuple + ) -> DPMSolverDiscreteSchedulerState: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -182,7 +182,11 @@ def set_timesteps(self, state: DPMSolverDiscreteSchedulerState, num_inference_st shape (`Tuple`): the shape of the samples to be generated. """ - timesteps = jnp.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1).round()[::-1][:-1].astype(jnp.int32) + timesteps = ( + jnp.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) + .round()[::-1][:-1] + .astype(jnp.int32) + ) return state.replace( num_inference_steps=num_inference_steps, @@ -217,11 +221,7 @@ def convert_model_output( return model_output def dpm_solver_first_order_update( - self, - model_output: jnp.ndarray, - timestep: int, - prev_timestep: int, - sample: jnp.ndarray + self, model_output: jnp.ndarray, timestep: int, prev_timestep: int, sample: jnp.ndarray ) -> jnp.ndarray: """ TODO @@ -233,15 +233,9 @@ def dpm_solver_first_order_update( sigma_t, sigma_s = self.sigma_t[t], self.sigma_t[s0] h = lambda_t - lambda_s if self.predict_x0: - x_t = ( - (sigma_t / sigma_s) * sample - - (alpha_t * (jnp.exp(-h) - 1.)) * m0 - ) + x_t = (sigma_t / sigma_s) * sample - (alpha_t * (jnp.exp(-h) - 1.0)) * m0 else: - x_t = ( - (alpha_t / alpha_s) * sample - - (sigma_t * (jnp.exp(h) - 1.)) * m0 - ) + x_t = (alpha_t / alpha_s) * sample - (sigma_t * (jnp.exp(h) - 1.0)) * m0 return x_t def multistep_dpm_solver_second_order_update( @@ -261,32 +255,32 @@ def multistep_dpm_solver_second_order_update( sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 r0 = h_0 / h - D0, D1 = m0, (1. / r0) * (m0 - m1) + D0, D1 = m0, (1.0 / r0) * (m0 - m1) if self.predict_x0: - if self.solver_type == 'dpm_solver': + if self.solver_type == "dpm_solver": x_t = ( (sigma_t / sigma_s0) * sample - - (alpha_t * (jnp.exp(-h) - 1.)) * D0 - - 0.5 * (alpha_t * (jnp.exp(-h) - 1.)) * D1 + - (alpha_t * (jnp.exp(-h) - 1.0)) * D0 + - 0.5 * (alpha_t * (jnp.exp(-h) - 1.0)) * D1 ) - elif self.solver_type == 'taylor': + elif self.solver_type == "taylor": x_t = ( (sigma_t / sigma_s0) * sample - - (alpha_t * (jnp.exp(-h) - 1.)) * D0 - + (alpha_t * ((jnp.exp(-h) - 1.) / h + 1.)) * D1 + - (alpha_t * (jnp.exp(-h) - 1.0)) * D0 + + (alpha_t * ((jnp.exp(-h) - 1.0) / h + 1.0)) * D1 ) else: - if self.solver_type == 'dpm_solver': + if self.solver_type == "dpm_solver": x_t = ( (alpha_t / alpha_s0) * sample - - (sigma_t * (jnp.exp(h) - 1.)) * D0 - - 0.5 * (sigma_t * (jnp.exp(h) - 1.)) * D1 + - (sigma_t * (jnp.exp(h) - 1.0)) * D0 + - 0.5 * (sigma_t * (jnp.exp(h) - 1.0)) * D1 ) - elif self.solver_type == 'taylor': + elif self.solver_type == "taylor": x_t = ( (alpha_t / alpha_s0) * sample - - (sigma_t * (jnp.exp(h) - 1.)) * D0 - - (sigma_t * ((jnp.exp(h) - 1.) / h - 1.)) * D1 + - (sigma_t * (jnp.exp(h) - 1.0)) * D0 + - (sigma_t * ((jnp.exp(h) - 1.0) / h - 1.0)) * D1 ) return x_t @@ -302,28 +296,33 @@ def multistep_dpm_solver_third_order_update( """ t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3] m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] - lambda_t, lambda_s0, lambda_s1, lambda_s2 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1], self.lambda_t[s2] + lambda_t, lambda_s0, lambda_s1, lambda_s2 = ( + self.lambda_t[t], + self.lambda_t[s0], + self.lambda_t[s1], + self.lambda_t[s2], + ) alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 r0, r1 = h_0 / h, h_1 / h D0 = m0 - D1_0, D1_1 = (1. / r0) * (m0 - m1), (1. / r1) * (m1 - m2) + D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) - D2 = (1. / (r0 + r1)) * (D1_0 - D1_1) + D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) if self.predict_x0: x_t = ( (sigma_t / sigma_s0) * sample - - (alpha_t * (jnp.exp(-h) - 1.)) * D0 - + (alpha_t * ((jnp.exp(-h) - 1.) / h + 1.)) * D1 - - (alpha_t * ((jnp.exp(-h) - 1. + h) / h**2 - 0.5)) * D2 + - (alpha_t * (jnp.exp(-h) - 1.0)) * D0 + + (alpha_t * ((jnp.exp(-h) - 1.0) / h + 1.0)) * D1 + - (alpha_t * ((jnp.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 ) else: x_t = ( (alpha_t / alpha_s0) * sample - - (sigma_t * (jnp.exp(h) - 1.)) * D0 - - (sigma_t * ((jnp.exp(h) - 1.) / h - 1.)) * D1 - - (sigma_t * ((jnp.exp(h) - 1. - h) / h**2 - 0.5)) * D2 + - (sigma_t * (jnp.exp(h) - 1.0)) * D0 + - (sigma_t * ((jnp.exp(h) - 1.0) / h - 1.0)) * D1 + - (sigma_t * ((jnp.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 ) return x_t @@ -336,8 +335,8 @@ def step( return_dict: bool = True, ) -> Union[FlaxDPMSolverDiscreteSchedulerOutput, Tuple]: """ - Predict the sample at the previous timestep by DPM-Solver. Core function to propagate the diffusion - process from the learned model outputs (most often the predicted noise). + Predict the sample at the previous timestep by DPM-Solver. Core function to propagate the diffusion process + from the learned model outputs (most often the predicted noise). Args: state (`DPMSolverDiscreteSchedulerState`): the `FlaxDPMSolverDiscreteScheduler` state data class instance. @@ -348,8 +347,8 @@ def step( return_dict (`bool`): option for returning tuple rather than FlaxDPMSolverDiscreteSchedulerOutput class Returns: - [`FlaxDPMSolverDiscreteSchedulerOutput`] or `tuple`: [`FlaxDPMSolverDiscreteSchedulerOutput`] if `return_dict` is True, otherwise a - `tuple`. When returning a tuple, the first element is the sample tensor. + [`FlaxDPMSolverDiscreteSchedulerOutput`] or `tuple`: [`FlaxDPMSolverDiscreteSchedulerOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ prev_timestep = jax.lax.cond( @@ -359,7 +358,7 @@ def step( (), ) - model_output = self.convert_model_output(model_output, timestep, sample) + model_output = self.convert_model_output(model_output, timestep, sample) model_outputs_new = jnp.roll(state.model_outputs, -1, axis=0) model_outputs_new = model_outputs_new.at[-1].set(model_output) @@ -369,9 +368,7 @@ def step( cur_sample=sample, ) - def step_1( - state: DPMSolverDiscreteSchedulerState - ) -> jnp.ndarray: + def step_1(state: DPMSolverDiscreteSchedulerState) -> jnp.ndarray: return self.dpm_solver_first_order_update( state.model_outputs[-1], state.timesteps[state.step_index], @@ -379,17 +376,9 @@ def step_1( state.cur_sample, ) - def step_23( - state: DPMSolverDiscreteSchedulerState - ) -> jnp.ndarray: - - def step_2( - state: DPMSolverDiscreteSchedulerState - ) -> jnp.ndarray: - timestep_list = jnp.array([ - state.timesteps[state.step_index - 1], - state.timesteps[state.step_index] - ]) + def step_23(state: DPMSolverDiscreteSchedulerState) -> jnp.ndarray: + def step_2(state: DPMSolverDiscreteSchedulerState) -> jnp.ndarray: + timestep_list = jnp.array([state.timesteps[state.step_index - 1], state.timesteps[state.step_index]]) return self.multistep_dpm_solver_second_order_update( state.model_outputs, timestep_list, @@ -397,14 +386,14 @@ def step_2( state.cur_sample, ) - def step_3( - state: DPMSolverDiscreteSchedulerState - ) -> jnp.ndarray: - timestep_list = jnp.array([ - state.timesteps[state.step_index - 2], - state.timesteps[state.step_index - 1], - state.timesteps[state.step_index] - ]) + def step_3(state: DPMSolverDiscreteSchedulerState) -> jnp.ndarray: + timestep_list = jnp.array( + [ + state.timesteps[state.step_index - 2], + state.timesteps[state.step_index - 1], + state.timesteps[state.step_index], + ] + ) return self.multistep_dpm_solver_third_order_update( state.model_outputs, timestep_list, @@ -454,7 +443,7 @@ def step_3( step_1, step_23, state, - ) + ) state = state.replace( lower_order_nums=jnp.minimum(state.lower_order_nums + 1, self.solver_order), From c06d554aeec4a43e3defe67aa1d73d1a6973f445 Mon Sep 17 00:00:00 2001 From: LuChengTHU Date: Fri, 4 Nov 2022 14:06:07 +0800 Subject: [PATCH 06/22] change code style --- src/diffusers/schedulers/scheduling_dpmsolver_discrete_flax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_discrete_flax.py b/src/diffusers/schedulers/scheduling_dpmsolver_discrete_flax.py index d546d2ec0186..5e8c7ef40418 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_discrete_flax.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_discrete_flax.py @@ -23,7 +23,7 @@ import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left +from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray: From dc3d496f200a333b412ab562630a505a7c78aa2b Mon Sep 17 00:00:00 2001 From: LuChengTHU Date: Fri, 4 Nov 2022 15:21:06 +0800 Subject: [PATCH 07/22] add docs --- .../scheduling_dpmsolver_discrete.py | 98 +++++++++++++++---- .../scheduling_dpmsolver_discrete_flax.py | 93 +++++++++++++++--- 2 files changed, 155 insertions(+), 36 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_discrete.py b/src/diffusers/schedulers/scheduling_dpmsolver_discrete.py index b2c51bba12be..2bd3e449413d 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_discrete.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_discrete.py @@ -55,15 +55,24 @@ def alpha_bar(time_step): class DPMSolverDiscreteScheduler(SchedulerMixin, ConfigMixin): """ - DPM-Solver. + DPM-Solver (and the improved version DPM-Solver++) is a fast dedicated high-order solver for diffusion ODEs with + the convergence order guarantee. Empirically, sampling by DPM-Solver with only 20 steps can generate high-quality + samples, and it can generate quite good samples even in only 10 steps. + + For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095 + + Currently, we support the multistep DPM-Solver for both noise prediction models and data prediction models. We + recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. + + We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space + diffusion models, you can set both `predict_x0=True` and `thresholding=True` to use the dynamic thresholding. Note + that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and [`~ConfigMixin.from_config`] functions. - For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095 - Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. @@ -73,17 +82,26 @@ class DPMSolverDiscreteScheduler(SchedulerMixin, ConfigMixin): `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. - skip_prk_steps (`bool`): - allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required - before plms steps; defaults to `False`. - set_alpha_to_one (`bool`, default `False`): - each diffusion step uses the value of alphas product at that step and at the previous one. For the final - step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, - otherwise it uses the value of alpha at step 0. - steps_offset (`int`, default `0`): - an offset added to the inference steps. You can use a combination of `offset=1` and - `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in - stable diffusion. + solver_order (`int`, default `2`): + the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided + sampling, and `solver_order=3` for unconditional sampling. + predict_x0 (`bool`, default `True`): + DPM-Solver is designed for both the noise prediction model (DPM-Solver, https://arxiv.org/abs/2206.00927) + with `predict_x0=False` and the data prediction model (DPM-Solver++, https://arxiv.org/abs/2211.01095) with + `predict_x0=True`. We recommend to use `predict_x0=True` and `solver_order=2` for guided sampling (e.g. + stable-diffusion). + thresholding (`bool`, default `False`): + whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). + For pixel-space diffusion models, you can set both `predict_x0=True` and `thresholding=True` to use the + dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models + (such as stable-diffusion). + sample_max_value (`float`, default `1.0`): + the threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`. + solver_type (`str`, default `dpm_solver`): + the solver type for the second-order solver. Either `dpm_solver` or `taylor`. The solver type slightly + affects the sample quality, especially for small number of steps. + denoise_final (`bool`, default `False`): + whether to use lower-order solvers in the final steps. """ @@ -183,7 +201,16 @@ def convert_model_output( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor ) -> torch.FloatTensor: """ - TODO + Convert the noise prediction model to either the noise or the data prediction model. + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + + Returns: + `torch.FloatTensor`: the converted model output. """ if self.predict_x0: alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] @@ -208,7 +235,17 @@ def dpm_solver_first_order_update( sample: torch.FloatTensor, ) -> torch.FloatTensor: """ - TODO + One step for the first-order DPM-Solver (equivalent to DDIM). + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + prev_timestep (`int`): previous discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + + Returns: + `torch.FloatTensor`: the sample tensor at the previous timestep. """ lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep] alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep] @@ -228,7 +265,18 @@ def multistep_dpm_solver_second_order_update( sample: torch.FloatTensor, ) -> torch.FloatTensor: """ - TODO + One step for the second-order multistep DPM-Solver. + + Args: + model_output_list (`List[torch.FloatTensor]`): + direct outputs from learned diffusion model at current and latter timesteps. + timestep (`int`): current and latter discrete timestep in the diffusion chain. + prev_timestep (`int`): previous discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + + Returns: + `torch.FloatTensor`: the sample tensor at the previous timestep. """ t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] m0, m1 = model_output_list[-1], model_output_list[-2] @@ -274,7 +322,18 @@ def multistep_dpm_solver_third_order_update( sample: torch.FloatTensor, ) -> torch.FloatTensor: """ - TODO + One step for the third-order multistep DPM-Solver. + + Args: + model_output_list (`List[torch.FloatTensor]`): + direct outputs from learned diffusion model at current and latter timesteps. + timestep (`int`): current and latter discrete timestep in the diffusion chain. + prev_timestep (`int`): previous discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + + Returns: + `torch.FloatTensor`: the sample tensor at the previous timestep. """ t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3] m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] @@ -316,8 +375,7 @@ def step( return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ - Step function propagating the sample with the multistep DPM-Solver. This has one forward pass with multiple - times to approximate the solution. + Step function propagating the sample with the multistep DPM-Solver. Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_discrete_flax.py b/src/diffusers/schedulers/scheduling_dpmsolver_discrete_flax.py index 5e8c7ef40418..aec0fb38e88c 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_discrete_flax.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_discrete_flax.py @@ -80,7 +80,18 @@ class FlaxDPMSolverDiscreteSchedulerOutput(FlaxSchedulerOutput): class FlaxDPMSolverDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): """ - DPM-Solver. + DPM-Solver (and the improved version DPM-Solver++) is a fast dedicated high-order solver for diffusion ODEs with + the convergence order guarantee. Empirically, sampling by DPM-Solver with only 20 steps can generate high-quality + samples, and it can generate quite good samples even in only 10 steps. + + For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095 + + Currently, we support the multistep DPM-Solver for both noise prediction models and data prediction models. We + recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. + + We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space + diffusion models, you can set both `predict_x0=True` and `thresholding=True` to use the dynamic thresholding. Note + that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. @@ -98,17 +109,26 @@ class FlaxDPMSolverDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. - skip_prk_steps (`bool`): - allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required - before plms steps; defaults to `False`. - set_alpha_to_one (`bool`, default `False`): - each diffusion step uses the value of alphas product at that step and at the previous one. For the final - step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, - otherwise it uses the value of alpha at step 0. - steps_offset (`int`, default `0`): - an offset added to the inference steps. You can use a combination of `offset=1` and - `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in - stable diffusion. + solver_order (`int`, default `2`): + the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided + sampling, and `solver_order=3` for unconditional sampling. + predict_x0 (`bool`, default `True`): + DPM-Solver is designed for both the noise prediction model (DPM-Solver, https://arxiv.org/abs/2206.00927) + with `predict_x0=False` and the data prediction model (DPM-Solver++, https://arxiv.org/abs/2211.01095) with + `predict_x0=True`. We recommend to use `predict_x0=True` and `solver_order=2` for guided sampling (e.g. + stable-diffusion). + thresholding (`bool`, default `False`): + whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). + For pixel-space diffusion models, you can set both `predict_x0=True` and `thresholding=True` to use the + dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models + (such as stable-diffusion). + sample_max_value (`float`, default `1.0`): + the threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`. + solver_type (`str`, default `dpm_solver`): + the solver type for the second-order solver. Either `dpm_solver` or `taylor`. The solver type slightly + affects the sample quality, especially for small number of steps. + denoise_final (`bool`, default `False`): + whether to use lower-order solvers in the final steps. """ @@ -205,7 +225,16 @@ def convert_model_output( sample: jnp.ndarray, ) -> jnp.ndarray: """ - TODO + Convert the noise prediction model to either the noise or the data prediction model. + + Args: + model_output (`jnp.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + + Returns: + `jnp.ndarray`: the converted model output. """ if self.predict_x0: alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] @@ -224,7 +253,17 @@ def dpm_solver_first_order_update( self, model_output: jnp.ndarray, timestep: int, prev_timestep: int, sample: jnp.ndarray ) -> jnp.ndarray: """ - TODO + One step for the first-order DPM-Solver (equivalent to DDIM). + + Args: + model_output (`jnp.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + prev_timestep (`int`): previous discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + + Returns: + `jnp.ndarray`: the sample tensor at the previous timestep. """ t, s0 = prev_timestep, timestep m0 = model_output @@ -246,7 +285,18 @@ def multistep_dpm_solver_second_order_update( sample: jnp.ndarray, ) -> jnp.ndarray: """ - TODO + One step for the second-order multistep DPM-Solver. + + Args: + model_output_list (`List[jnp.ndarray]`): + direct outputs from learned diffusion model at current and latter timesteps. + timestep (`int`): current and latter discrete timestep in the diffusion chain. + prev_timestep (`int`): previous discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + + Returns: + `jnp.ndarray`: the sample tensor at the previous timestep. """ t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] m0, m1 = model_output_list[-1], model_output_list[-2] @@ -292,7 +342,18 @@ def multistep_dpm_solver_third_order_update( sample: jnp.ndarray, ) -> jnp.ndarray: """ - TODO + One step for the third-order multistep DPM-Solver. + + Args: + model_output_list (`List[jnp.ndarray]`): + direct outputs from learned diffusion model at current and latter timesteps. + timestep (`int`): current and latter discrete timestep in the diffusion chain. + prev_timestep (`int`): previous discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + + Returns: + `jnp.ndarray`: the sample tensor at the previous timestep. """ t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3] m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] From 9fb0acbd46c7ac41211d8276e9e670cca023df62 Mon Sep 17 00:00:00 2001 From: LuChengTHU Date: Fri, 4 Nov 2022 16:07:38 +0800 Subject: [PATCH 08/22] add `add_noise` method for dpmsolver --- .../scheduling_dpmsolver_discrete.py | 23 +++++++++++++++++++ .../scheduling_dpmsolver_discrete_flax.py | 19 ++++++++++++++- 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_discrete.py b/src/diffusers/schedulers/scheduling_dpmsolver_discrete.py index 2bd3e449413d..c5df7171f9cd 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_discrete.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_discrete.py @@ -440,5 +440,28 @@ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch """ return sample + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + def __len__(self): return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_discrete_flax.py b/src/diffusers/schedulers/scheduling_dpmsolver_discrete_flax.py index aec0fb38e88c..322aacc0ed5c 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_discrete_flax.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_discrete_flax.py @@ -23,7 +23,7 @@ import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput +from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray: @@ -533,5 +533,22 @@ def scale_model_input( """ return sample + def add_noise( + self, + original_samples: jnp.ndarray, + noise: jnp.ndarray, + timesteps: jnp.ndarray, + ) -> jnp.ndarray: + sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape) + + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.0 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + def __len__(self): return self.config.num_train_timesteps From d657d4390b857be4c86f69efbcd530d95dc63cff Mon Sep 17 00:00:00 2001 From: LuChengTHU Date: Fri, 4 Nov 2022 18:53:14 +0800 Subject: [PATCH 09/22] add pytorch unit test for dpmsolver --- .../scheduling_dpmsolver_discrete.py | 6 +- tests/test_config.py | 21 ++- tests/test_scheduler.py | 168 ++++++++++++++++++ 3 files changed, 193 insertions(+), 2 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_discrete.py b/src/diffusers/schedulers/scheduling_dpmsolver_discrete.py index c5df7171f9cd..dc7ea2c5a474 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_discrete.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_discrete.py @@ -396,7 +396,11 @@ def step( if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) - step_index = (self.timesteps == timestep).nonzero().item() + step_index = (self.timesteps == timestep).nonzero() + if len(step_index) == 0: + step_index = len(self.timesteps) - 1 + else: + step_index = step_index.item() prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] denoise_final = (step_index == len(self.timesteps) - 1) and self.denoise_final denoise_second = (step_index == len(self.timesteps) - 2) and self.denoise_final diff --git a/tests/test_config.py b/tests/test_config.py index 7a9f270af364..83e1879df49a 100755 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -19,7 +19,14 @@ import unittest import diffusers -from diffusers import DDIMScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, PNDMScheduler, logging +from diffusers import ( + DDIMScheduler, + DPMSolverDiscreteScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + PNDMScheduler, + logging, +) from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.utils.testing_utils import CaptureLogger @@ -283,3 +290,15 @@ def test_load_pndm(self): assert pndm.__class__ == PNDMScheduler # no warning should be thrown assert cap_logger.out == "" + + def test_load_dpmsolver(self): + logger = logging.get_logger("diffusers.configuration_utils") + + with CaptureLogger(logger) as cap_logger: + pndm = DPMSolverDiscreteScheduler.from_config( + "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler" + ) + + assert pndm.__class__ == DPMSolverDiscreteScheduler + # no warning should be thrown + assert cap_logger.out == "" diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 29186aaac99b..4b2d0c3b3074 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -24,6 +24,7 @@ from diffusers import ( DDIMScheduler, DDPMScheduler, + DPMSolverDiscreteScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, IPNDMScheduler, @@ -549,6 +550,173 @@ def test_full_loop_with_no_set_alpha_to_one(self): assert abs(result_mean.item() - 0.1941) < 1e-3 +class DPMSolverDiscreteSchedulerTest(SchedulerCommonTest): + scheduler_classes = (DPMSolverDiscreteScheduler,) + forward_default_kwargs = (("num_inference_steps", 25),) + + def get_scheduler_config(self, **kwargs): + config = { + "num_train_timesteps": 1000, + "beta_start": 0.0001, + "beta_end": 0.02, + "beta_schedule": "linear", + "solver_order": 2, + "predict_x0": True, + "thresholding": False, + "sample_max_value": 1.0, + "solver_type": "dpm_solver", + "denoise_final": False, + } + + config.update(**kwargs) + return config + + def check_over_configs(self, time_step=0, **config): + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + sample = self.dummy_sample + residual = 0.1 * sample + dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10] + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + scheduler.set_timesteps(num_inference_steps) + # copy over dummy past residuals + scheduler.model_outputs = dummy_past_residuals[: scheduler.solver_order] + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler = scheduler_class.from_config(tmpdirname) + new_scheduler.set_timesteps(num_inference_steps) + # copy over dummy past residuals + new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.solver_order] + + output, new_output = sample, sample + for t in range(time_step, time_step + scheduler.solver_order + 1): + output = scheduler.step(residual, t, output, **kwargs).prev_sample + new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample + + assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def test_from_pretrained_save_pretrained(self): + pass + + def check_over_forward(self, time_step=0, **forward_kwargs): + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + sample = self.dummy_sample + residual = 0.1 * sample + dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10] + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + scheduler.set_timesteps(num_inference_steps) + + # copy over dummy past residuals (must be after setting timesteps) + scheduler.model_outputs = dummy_past_residuals[: scheduler.solver_order] + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler = scheduler_class.from_config(tmpdirname) + # copy over dummy past residuals + new_scheduler.set_timesteps(num_inference_steps) + + # copy over dummy past residual (must be after setting timesteps) + new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.solver_order] + + output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample + new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample + + assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def full_loop(self, **config): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + + num_inference_steps = 10 + model = self.dummy_model() + sample = self.dummy_sample_deter + scheduler.set_timesteps(num_inference_steps) + + for i, t in enumerate(scheduler.timesteps): + residual = model(sample, t) + sample = scheduler.step(residual, t, sample).prev_sample + + return sample + + def test_step_shape(self): + kwargs = dict(self.forward_default_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + sample = self.dummy_sample + residual = 0.1 * sample + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + scheduler.set_timesteps(num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + # copy over dummy past residuals (must be done after set_timesteps) + dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10] + scheduler.model_outputs = dummy_past_residuals[: scheduler.solver_order] + + time_step_0 = scheduler.timesteps[5] + time_step_1 = scheduler.timesteps[6] + + output_0 = scheduler.step(residual, time_step_0, sample, **kwargs).prev_sample + output_1 = scheduler.step(residual, time_step_1, sample, **kwargs).prev_sample + + self.assertEqual(output_0.shape, sample.shape) + self.assertEqual(output_0.shape, output_1.shape) + + def test_timesteps(self): + for timesteps in [25, 50, 100, 999, 1000]: + self.check_over_configs(num_train_timesteps=timesteps) + + def test_thresholding(self): + self.check_over_configs(thresholding=False) + for order in [1, 2, 3]: + for solver_type in ["dpm_solver", "taylor"]: + for threshold in [0.5, 1.0, 2.0]: + self.check_over_configs( + thresholding=True, + sample_max_value=threshold, + predict_x0=True, + solver_order=order, + solver_type=solver_type, + ) + + def test_solver_order_and_type(self): + for solver_type in ["dpm_solver", "taylor"]: + for order in [1, 2, 3]: + for predict_x0 in [True, False]: + self.check_over_configs(solver_order=order, solver_type=solver_type, predict_x0=predict_x0) + sample = self.full_loop(solver_order=order, solver_type=solver_type, predict_x0=predict_x0) + assert not torch.isnan(sample).any(), "Samples have nan numbers" + + def test_denoise_final(self): + self.check_over_configs(denoise_final=True) + self.check_over_configs(denoise_final=False) + + def test_inference_steps(self): + for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]: + self.check_over_forward(num_inference_steps=num_inference_steps, time_step=0) + + def test_full_loop_no_noise(self): + sample = self.full_loop() + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_mean.item() - 0.3301) < 1e-3 + + class PNDMSchedulerTest(SchedulerCommonTest): scheduler_classes = (PNDMScheduler,) forward_default_kwargs = (("num_inference_steps", 50),) From 00e86324bf47e7beef643f7266ecad8002f21349 Mon Sep 17 00:00:00 2001 From: LuChengTHU Date: Fri, 4 Nov 2022 20:37:27 +0800 Subject: [PATCH 10/22] add dummy object for pytorch dpmsolver --- .../dummy_torch_and_accelerate_objects.py | 467 ++++++++++++++++++ 1 file changed, 467 insertions(+) create mode 100644 src/diffusers/utils/dummy_torch_and_accelerate_objects.py diff --git a/src/diffusers/utils/dummy_torch_and_accelerate_objects.py b/src/diffusers/utils/dummy_torch_and_accelerate_objects.py new file mode 100644 index 000000000000..22a9ff8d34cb --- /dev/null +++ b/src/diffusers/utils/dummy_torch_and_accelerate_objects.py @@ -0,0 +1,467 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +# flake8: noqa + +from ..utils import DummyObject, requires_backends + + +class ModelMixin(metaclass=DummyObject): + _backends = ["torch", "accelerate"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "accelerate"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + +class AutoencoderKL(metaclass=DummyObject): + _backends = ["torch", "accelerate"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "accelerate"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + +class Transformer2DModel(metaclass=DummyObject): + _backends = ["torch", "accelerate"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "accelerate"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + +class UNet1DModel(metaclass=DummyObject): + _backends = ["torch", "accelerate"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "accelerate"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + +class UNet2DConditionModel(metaclass=DummyObject): + _backends = ["torch", "accelerate"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "accelerate"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + +class UNet2DModel(metaclass=DummyObject): + _backends = ["torch", "accelerate"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "accelerate"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + +class VQModel(metaclass=DummyObject): + _backends = ["torch", "accelerate"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "accelerate"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + +def get_constant_schedule(*args, **kwargs): + requires_backends(get_constant_schedule, ["torch", "accelerate"]) + + +def get_constant_schedule_with_warmup(*args, **kwargs): + requires_backends(get_constant_schedule_with_warmup, ["torch", "accelerate"]) + + +def get_cosine_schedule_with_warmup(*args, **kwargs): + requires_backends(get_cosine_schedule_with_warmup, ["torch", "accelerate"]) + + +def get_cosine_with_hard_restarts_schedule_with_warmup(*args, **kwargs): + requires_backends(get_cosine_with_hard_restarts_schedule_with_warmup, ["torch", "accelerate"]) + + +def get_linear_schedule_with_warmup(*args, **kwargs): + requires_backends(get_linear_schedule_with_warmup, ["torch", "accelerate"]) + + +def get_polynomial_decay_schedule_with_warmup(*args, **kwargs): + requires_backends(get_polynomial_decay_schedule_with_warmup, ["torch", "accelerate"]) + + +def get_scheduler(*args, **kwargs): + requires_backends(get_scheduler, ["torch", "accelerate"]) + + +class DiffusionPipeline(metaclass=DummyObject): + _backends = ["torch", "accelerate"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "accelerate"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + +class DanceDiffusionPipeline(metaclass=DummyObject): + _backends = ["torch", "accelerate"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "accelerate"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + +class DDIMPipeline(metaclass=DummyObject): + _backends = ["torch", "accelerate"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "accelerate"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + +class DDPMPipeline(metaclass=DummyObject): + _backends = ["torch", "accelerate"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "accelerate"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + +class KarrasVePipeline(metaclass=DummyObject): + _backends = ["torch", "accelerate"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "accelerate"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + +class LDMPipeline(metaclass=DummyObject): + _backends = ["torch", "accelerate"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "accelerate"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + +class PNDMPipeline(metaclass=DummyObject): + _backends = ["torch", "accelerate"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "accelerate"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + +class RePaintPipeline(metaclass=DummyObject): + _backends = ["torch", "accelerate"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "accelerate"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + +class ScoreSdeVePipeline(metaclass=DummyObject): + _backends = ["torch", "accelerate"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "accelerate"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + +class DDIMScheduler(metaclass=DummyObject): + _backends = ["torch", "accelerate"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "accelerate"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + +class DDPMScheduler(metaclass=DummyObject): + _backends = ["torch", "accelerate"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "accelerate"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + +class DPMSolverDiscreteScheduler(metaclass=DummyObject): + _backends = ["torch", "accelerate"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "accelerate"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + +class EulerAncestralDiscreteScheduler(metaclass=DummyObject): + _backends = ["torch", "accelerate"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "accelerate"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + +class EulerDiscreteScheduler(metaclass=DummyObject): + _backends = ["torch", "accelerate"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "accelerate"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + +class IPNDMScheduler(metaclass=DummyObject): + _backends = ["torch", "accelerate"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "accelerate"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + +class KarrasVeScheduler(metaclass=DummyObject): + _backends = ["torch", "accelerate"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "accelerate"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + +class PNDMScheduler(metaclass=DummyObject): + _backends = ["torch", "accelerate"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "accelerate"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + +class RePaintScheduler(metaclass=DummyObject): + _backends = ["torch", "accelerate"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "accelerate"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + +class SchedulerMixin(metaclass=DummyObject): + _backends = ["torch", "accelerate"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "accelerate"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + +class ScoreSdeVeScheduler(metaclass=DummyObject): + _backends = ["torch", "accelerate"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "accelerate"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + +class VQDiffusionScheduler(metaclass=DummyObject): + _backends = ["torch", "accelerate"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "accelerate"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + +class EMAModel(metaclass=DummyObject): + _backends = ["torch", "accelerate"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "accelerate"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "accelerate"]) From 6a1f83424b3fc01443d4e39335fc7dc7d4ede429 Mon Sep 17 00:00:00 2001 From: Cheng Lu Date: Sat, 5 Nov 2022 12:00:31 +0800 Subject: [PATCH 11/22] Update src/diffusers/schedulers/scheduling_dpmsolver_discrete.py Co-authored-by: Suraj Patil --- src/diffusers/schedulers/scheduling_dpmsolver_discrete.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_discrete.py b/src/diffusers/schedulers/scheduling_dpmsolver_discrete.py index dc7ea2c5a474..7de53a53c74a 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_discrete.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_discrete.py @@ -169,9 +169,7 @@ def __init__( self.num_inference_steps = None timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() self.timesteps = torch.from_numpy(timesteps) - self.model_outputs = [ - None, - ] * self.solver_order + self.model_outputs = [None] * self.solver_order self.lower_order_nums = 0 def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): From 4843fc34e5392692a37d24b834d0d374516fb523 Mon Sep 17 00:00:00 2001 From: Cheng Lu Date: Sat, 5 Nov 2022 12:14:35 +0800 Subject: [PATCH 12/22] Update tests/test_config.py Co-authored-by: Suraj Patil --- tests/test_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_config.py b/tests/test_config.py index 83e1879df49a..a19ab9261b8a 100755 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -295,7 +295,7 @@ def test_load_dpmsolver(self): logger = logging.get_logger("diffusers.configuration_utils") with CaptureLogger(logger) as cap_logger: - pndm = DPMSolverDiscreteScheduler.from_config( + dpm = DPMSolverDiscreteScheduler.from_config( "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler" ) From 31ed11005e8c4c0a5dd62c78a0e22f238be4f818 Mon Sep 17 00:00:00 2001 From: Cheng Lu Date: Sat, 5 Nov 2022 12:14:47 +0800 Subject: [PATCH 13/22] Update tests/test_config.py Co-authored-by: Suraj Patil --- tests/test_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_config.py b/tests/test_config.py index a19ab9261b8a..51bb7ca8c5c2 100755 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -299,6 +299,6 @@ def test_load_dpmsolver(self): "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler" ) - assert pndm.__class__ == DPMSolverDiscreteScheduler + assert dpm.__class__ == DPMSolverDiscreteScheduler # no warning should be thrown assert cap_logger.out == "" From 864d0bb1364a84d04c4b634e3c1a6a21c8366cb4 Mon Sep 17 00:00:00 2001 From: LuChengTHU Date: Sat, 5 Nov 2022 12:44:59 +0800 Subject: [PATCH 14/22] resolve the code comments --- .../scheduling_dpmsolver_discrete.py | 71 ++++++++++--------- .../scheduling_dpmsolver_discrete_flax.py | 66 +++++++++-------- tests/test_scheduler.py | 12 ++-- 3 files changed, 80 insertions(+), 69 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_discrete.py b/src/diffusers/schedulers/scheduling_dpmsolver_discrete.py index 7de53a53c74a..99ba5864e1e6 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_discrete.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_discrete.py @@ -95,12 +95,15 @@ class DPMSolverDiscreteScheduler(SchedulerMixin, ConfigMixin): For pixel-space diffusion models, you can set both `predict_x0=True` and `thresholding=True` to use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). + dynamic_thresholding_ratio (`float`, default `0.995`): + the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen + (https://arxiv.org/abs/2205.11487). sample_max_value (`float`, default `1.0`): the threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`. solver_type (`str`, default `dpm_solver`): the solver type for the second-order solver. Either `dpm_solver` or `taylor`. The solver type slightly affects the sample quality, especially for small number of steps. - denoise_final (`bool`, default `False`): + denoise_final (`bool`, default `True`): whether to use lower-order solvers in the final steps. """ @@ -125,9 +128,10 @@ def __init__( solver_order: int = 2, predict_x0: bool = True, thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, sample_max_value: float = 1.0, solver_type: str = "dpm_solver", - denoise_final: bool = False, + denoise_final: bool = True, ): if trained_betas is not None: self.betas = torch.from_numpy(trained_betas) @@ -155,21 +159,14 @@ def __init__( self.init_noise_sigma = 1.0 # settings for DPM-Solver - self.solver_order = solver_order - self.predict_x0 = predict_x0 - self.thresholding = thresholding - self.sample_max_value = sample_max_value - self.denoise_final = denoise_final - if solver_type in ["dpm_solver", "taylor"]: - self.solver_type = solver_type - else: + if solver_type not in ["dpm_solver", "taylor"]: raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") # setable values self.num_inference_steps = None timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() self.timesteps = torch.from_numpy(timesteps) - self.model_outputs = [None] * self.solver_order + self.model_outputs = [None] * solver_order self.lower_order_nums = 0 def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): @@ -192,7 +189,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic self.timesteps = torch.from_numpy(timesteps).to(device) self.model_outputs = [ None, - ] * self.solver_order + ] * self.config.solver_order self.lower_order_nums = 0 def convert_model_output( @@ -210,17 +207,19 @@ def convert_model_output( Returns: `torch.FloatTensor`: the converted model output. """ - if self.predict_x0: + if self.config.predict_x0: alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] x0_pred = (sample - sigma_t * model_output) / alpha_t - if self.thresholding: + if self.config.thresholding: # Dynamic thresholding in https://arxiv.org/abs/2205.11487 - p = 0.995 # A hyperparameter in the paper of "Imagen" (https://arxiv.org/abs/2205.11487). - s = torch.quantile(torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)), p, dim=1) - s = torch.maximum(s, self.sample_max_value * torch.ones_like(s).to(s.device))[ - (...,) + (None,) * (x0_pred.ndim - 1) - ] - x0_pred = torch.clamp(x0_pred, -s, s) / s + dynamic_max_val = torch.quantile( + torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)), self.config.dynamic_thresholding_ratio, dim=1 + ) + dynamic_max_val = torch.maximum( + dynamic_max_val, + self.config.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device), + )[(...,) + (None,) * (x0_pred.ndim - 1)] + x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val return x0_pred else: return model_output @@ -235,6 +234,8 @@ def dpm_solver_first_order_update( """ One step for the first-order DPM-Solver (equivalent to DDIM). + See https://arxiv.org/abs/2206.00927 for the detailed derivation. + Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. @@ -249,7 +250,7 @@ def dpm_solver_first_order_update( alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep] sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep] h = lambda_t - lambda_s - if self.predict_x0: + if self.config.predict_x0: x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output else: x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output @@ -284,27 +285,29 @@ def multistep_dpm_solver_second_order_update( h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 r0 = h_0 / h D0, D1 = m0, (1.0 / r0) * (m0 - m1) - if self.predict_x0: - if self.solver_type == "dpm_solver": + if self.config.predict_x0: + # See https://arxiv.org/abs/2211.01095 for detailed derivations + if self.config.solver_type == "dpm_solver": x_t = ( (sigma_t / sigma_s0) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * D0 - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 ) - elif self.solver_type == "taylor": + elif self.config.solver_type == "taylor": x_t = ( (sigma_t / sigma_s0) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 ) else: - if self.solver_type == "dpm_solver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + if self.config.solver_type == "dpm_solver": x_t = ( (alpha_t / alpha_s0) * sample - (sigma_t * (torch.exp(h) - 1.0)) * D0 - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1 ) - elif self.solver_type == "taylor": + elif self.config.solver_type == "taylor": x_t = ( (alpha_t / alpha_s0) * sample - (sigma_t * (torch.exp(h) - 1.0)) * D0 @@ -349,7 +352,8 @@ def multistep_dpm_solver_third_order_update( D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) - if self.predict_x0: + if self.config.predict_x0: + # See https://arxiv.org/abs/2206.00927 for detailed derivations x_t = ( (sigma_t / sigma_s0) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * D0 @@ -357,6 +361,7 @@ def multistep_dpm_solver_third_order_update( - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 ) else: + # See https://arxiv.org/abs/2206.00927 for detailed derivations x_t = ( (alpha_t / alpha_s0) * sample - (sigma_t * (torch.exp(h) - 1.0)) * D0 @@ -400,17 +405,17 @@ def step( else: step_index = step_index.item() prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] - denoise_final = (step_index == len(self.timesteps) - 1) and self.denoise_final - denoise_second = (step_index == len(self.timesteps) - 2) and self.denoise_final + denoise_final = (step_index == len(self.timesteps) - 1) and self.config.denoise_final + denoise_second = (step_index == len(self.timesteps) - 2) and self.config.denoise_final model_output = self.convert_model_output(model_output, timestep, sample) - for i in range(self.solver_order - 1): + for i in range(self.config.solver_order - 1): self.model_outputs[i] = self.model_outputs[i + 1] self.model_outputs[-1] = model_output - if self.solver_order == 1 or self.lower_order_nums < 1 or denoise_final: + if self.config.solver_order == 1 or self.lower_order_nums < 1 or denoise_final: prev_sample = self.dpm_solver_first_order_update(model_output, timestep, prev_timestep, sample) - elif self.solver_order == 2 or self.lower_order_nums < 2 or denoise_second: + elif self.config.solver_order == 2 or self.lower_order_nums < 2 or denoise_second: timestep_list = [self.timesteps[step_index - 1], timestep] prev_sample = self.multistep_dpm_solver_second_order_update( self.model_outputs, timestep_list, prev_timestep, sample @@ -421,7 +426,7 @@ def step( self.model_outputs, timestep_list, prev_timestep, sample ) - if self.lower_order_nums < self.solver_order: + if self.lower_order_nums < self.config.solver_order: self.lower_order_nums += 1 if not return_dict: diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_discrete_flax.py b/src/diffusers/schedulers/scheduling_dpmsolver_discrete_flax.py index 322aacc0ed5c..90110c99e6cd 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_discrete_flax.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_discrete_flax.py @@ -122,12 +122,15 @@ class FlaxDPMSolverDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): For pixel-space diffusion models, you can set both `predict_x0=True` and `thresholding=True` to use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). + dynamic_thresholding_ratio (`float`, default `0.995`): + the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen + (https://arxiv.org/abs/2205.11487). sample_max_value (`float`, default `1.0`): the threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`. solver_type (`str`, default `dpm_solver`): the solver type for the second-order solver. Either `dpm_solver` or `taylor`. The solver type slightly affects the sample quality, especially for small number of steps. - denoise_final (`bool`, default `False`): + denoise_final (`bool`, default `True`): whether to use lower-order solvers in the final steps. """ @@ -147,9 +150,10 @@ def __init__( solver_order: int = 2, predict_x0: bool = True, thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, sample_max_value: float = 1.0, solver_type: str = "dpm_solver", - denoise_final: bool = False, + denoise_final: bool = True, ): if trained_betas is not None: self.betas = jnp.asarray(trained_betas) @@ -175,14 +179,7 @@ def __init__( self.init_noise_sigma = 1.0 # settings for DPM-Solver - self.solver_order = solver_order - self.predict_x0 = predict_x0 - self.thresholding = thresholding - self.sample_max_value = sample_max_value - self.denoise_final = denoise_final - if solver_type in ["dpm_solver", "taylor"]: - self.solver_type = solver_type - else: + if solver_type not in ["dpm_solver", "taylor"]: raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") def create_state(self): @@ -211,7 +208,7 @@ def set_timesteps( return state.replace( num_inference_steps=num_inference_steps, timesteps=timesteps, - model_outputs=jnp.zeros((self.solver_order,) + shape), + model_outputs=jnp.zeros((self.config.solver_order,) + shape), lower_order_nums=0, step_index=0, prev_timestep=-1, @@ -236,15 +233,18 @@ def convert_model_output( Returns: `jnp.ndarray`: the converted model output. """ - if self.predict_x0: + if self.config.predict_x0: alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] x0_pred = (sample - sigma_t * model_output) / alpha_t - if self.thresholding: - # A hyperparameter in the paper of Imagen (https://arxiv.org/abs/2205.11487). - p = 0.995 - s = jnp.percentile(jnp.abs(x0_pred), p, axis=tuple(range(1, x0_pred.ndim))) - s = jnp.max(s, self.max_val) - x0_pred = jnp.clip(x0_pred, -s, s) / s + if self.config.thresholding: + # Dynamic thresholding in https://arxiv.org/abs/2205.11487 + dynamic_max_val = jnp.percentile( + jnp.abs(x0_pred), self.config.dynamic_thresholding_ratio, axis=tuple(range(1, x0_pred.ndim)) + ) + dynamic_max_val = jnp.maximum( + dynamic_max_val, self.config.sample_max_value * jnp.ones_like(dynamic_max_val) + ) + x0_pred = jnp.clip(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val return x0_pred else: return model_output @@ -255,6 +255,8 @@ def dpm_solver_first_order_update( """ One step for the first-order DPM-Solver (equivalent to DDIM). + See https://arxiv.org/abs/2206.00927 for the detailed derivation. + Args: model_output (`jnp.ndarray`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. @@ -271,7 +273,7 @@ def dpm_solver_first_order_update( alpha_t, alpha_s = self.alpha_t[t], self.alpha_t[s0] sigma_t, sigma_s = self.sigma_t[t], self.sigma_t[s0] h = lambda_t - lambda_s - if self.predict_x0: + if self.config.predict_x0: x_t = (sigma_t / sigma_s) * sample - (alpha_t * (jnp.exp(-h) - 1.0)) * m0 else: x_t = (alpha_t / alpha_s) * sample - (sigma_t * (jnp.exp(h) - 1.0)) * m0 @@ -306,27 +308,29 @@ def multistep_dpm_solver_second_order_update( h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 r0 = h_0 / h D0, D1 = m0, (1.0 / r0) * (m0 - m1) - if self.predict_x0: - if self.solver_type == "dpm_solver": + if self.config.predict_x0: + # See https://arxiv.org/abs/2211.01095 for detailed derivations + if self.config.solver_type == "dpm_solver": x_t = ( (sigma_t / sigma_s0) * sample - (alpha_t * (jnp.exp(-h) - 1.0)) * D0 - 0.5 * (alpha_t * (jnp.exp(-h) - 1.0)) * D1 ) - elif self.solver_type == "taylor": + elif self.config.solver_type == "taylor": x_t = ( (sigma_t / sigma_s0) * sample - (alpha_t * (jnp.exp(-h) - 1.0)) * D0 + (alpha_t * ((jnp.exp(-h) - 1.0) / h + 1.0)) * D1 ) else: - if self.solver_type == "dpm_solver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + if self.config.solver_type == "dpm_solver": x_t = ( (alpha_t / alpha_s0) * sample - (sigma_t * (jnp.exp(h) - 1.0)) * D0 - 0.5 * (sigma_t * (jnp.exp(h) - 1.0)) * D1 ) - elif self.solver_type == "taylor": + elif self.config.solver_type == "taylor": x_t = ( (alpha_t / alpha_s0) * sample - (sigma_t * (jnp.exp(h) - 1.0)) * D0 @@ -371,7 +375,8 @@ def multistep_dpm_solver_third_order_update( D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) - if self.predict_x0: + if self.config.predict_x0: + # See https://arxiv.org/abs/2206.00927 for detailed derivations x_t = ( (sigma_t / sigma_s0) * sample - (alpha_t * (jnp.exp(-h) - 1.0)) * D0 @@ -379,6 +384,7 @@ def multistep_dpm_solver_third_order_update( - (alpha_t * ((jnp.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 ) else: + # See https://arxiv.org/abs/2206.00927 for detailed derivations x_t = ( (alpha_t / alpha_s0) * sample - (sigma_t * (jnp.exp(h) - 1.0)) * D0 @@ -462,9 +468,9 @@ def step_3(state: DPMSolverDiscreteSchedulerState) -> jnp.ndarray: state.cur_sample, ) - if self.solver_order == 2: + if self.config.solver_order == 2: return step_2(state) - elif self.denoise_final: + elif self.config.denoise_final: return jax.lax.cond( state.lower_order_nums < 2, step_2, @@ -484,9 +490,9 @@ def step_3(state: DPMSolverDiscreteSchedulerState) -> jnp.ndarray: state, ) - if self.solver_order == 1: + if self.config.solver_order == 1: prev_sample = step_1(state) - elif self.denoise_final: + elif self.config.denoise_final: prev_sample = jax.lax.cond( state.lower_order_nums < 1, step_1, @@ -507,7 +513,7 @@ def step_3(state: DPMSolverDiscreteSchedulerState) -> jnp.ndarray: ) state = state.replace( - lower_order_nums=jnp.minimum(state.lower_order_nums + 1, self.solver_order), + lower_order_nums=jnp.minimum(state.lower_order_nums + 1, self.config.solver_order), step_index=(state.step_index + 1), ) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 4b2d0c3b3074..98ded72eb162 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -583,17 +583,17 @@ def check_over_configs(self, time_step=0, **config): scheduler = scheduler_class(**scheduler_config) scheduler.set_timesteps(num_inference_steps) # copy over dummy past residuals - scheduler.model_outputs = dummy_past_residuals[: scheduler.solver_order] + scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order] with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) new_scheduler = scheduler_class.from_config(tmpdirname) new_scheduler.set_timesteps(num_inference_steps) # copy over dummy past residuals - new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.solver_order] + new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order] output, new_output = sample, sample - for t in range(time_step, time_step + scheduler.solver_order + 1): + for t in range(time_step, time_step + scheduler.config.solver_order + 1): output = scheduler.step(residual, t, output, **kwargs).prev_sample new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample @@ -615,7 +615,7 @@ def check_over_forward(self, time_step=0, **forward_kwargs): scheduler.set_timesteps(num_inference_steps) # copy over dummy past residuals (must be after setting timesteps) - scheduler.model_outputs = dummy_past_residuals[: scheduler.solver_order] + scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order] with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) @@ -624,7 +624,7 @@ def check_over_forward(self, time_step=0, **forward_kwargs): new_scheduler.set_timesteps(num_inference_steps) # copy over dummy past residual (must be after setting timesteps) - new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.solver_order] + new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order] output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample @@ -666,7 +666,7 @@ def test_step_shape(self): # copy over dummy past residuals (must be done after set_timesteps) dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10] - scheduler.model_outputs = dummy_past_residuals[: scheduler.solver_order] + scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order] time_step_0 = scheduler.timesteps[5] time_step_1 = scheduler.timesteps[6] From e9f0fbcb606ae2d7db112d04a554238fb156224a Mon Sep 17 00:00:00 2001 From: LuChengTHU Date: Sat, 5 Nov 2022 14:08:21 +0800 Subject: [PATCH 15/22] rename the file --- src/diffusers/schedulers/__init__.py | 4 ++-- ...pmsolver_discrete.py => scheduling_dpmsolver_multistep.py} | 0 ...iscrete_flax.py => scheduling_dpmsolver_multistep_flax.py} | 0 3 files changed, 2 insertions(+), 2 deletions(-) rename src/diffusers/schedulers/{scheduling_dpmsolver_discrete.py => scheduling_dpmsolver_multistep.py} (100%) rename src/diffusers/schedulers/{scheduling_dpmsolver_discrete_flax.py => scheduling_dpmsolver_multistep_flax.py} (100%) diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 8e095d384fca..3bbd16e3c058 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -19,7 +19,7 @@ if is_torch_available(): from .scheduling_ddim import DDIMScheduler from .scheduling_ddpm import DDPMScheduler - from .scheduling_dpmsolver_discrete import DPMSolverDiscreteScheduler + from .scheduling_dpmsolver_multistep import DPMSolverDiscreteScheduler from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler from .scheduling_euler_discrete import EulerDiscreteScheduler from .scheduling_ipndm import IPNDMScheduler @@ -36,7 +36,7 @@ if is_flax_available(): from .scheduling_ddim_flax import FlaxDDIMScheduler from .scheduling_ddpm_flax import FlaxDDPMScheduler - from .scheduling_dpmsolver_discrete_flax import FlaxDPMSolverDiscreteScheduler + from .scheduling_dpmsolver_multistep_flax import FlaxDPMSolverDiscreteScheduler from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler from .scheduling_pndm_flax import FlaxPNDMScheduler diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_discrete.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py similarity index 100% rename from src/diffusers/schedulers/scheduling_dpmsolver_discrete.py rename to src/diffusers/schedulers/scheduling_dpmsolver_multistep.py diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_discrete_flax.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py similarity index 100% rename from src/diffusers/schedulers/scheduling_dpmsolver_discrete_flax.py rename to src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py From bc2afd5f05f199e659e21c60c1c360f931832a75 Mon Sep 17 00:00:00 2001 From: LuChengTHU Date: Sat, 5 Nov 2022 14:11:20 +0800 Subject: [PATCH 16/22] change class name --- src/diffusers/__init__.py | 4 +- .../pipeline_flax_stable_diffusion.py | 6 +-- .../pipeline_stable_diffusion.py | 4 +- src/diffusers/schedulers/__init__.py | 4 +- .../scheduling_dpmsolver_multistep.py | 2 +- .../scheduling_dpmsolver_multistep_flax.py | 42 +++++++++---------- src/diffusers/utils/dummy_flax_objects.py | 2 +- src/diffusers/utils/dummy_pt_objects.py | 2 +- .../dummy_torch_and_accelerate_objects.py | 2 +- tests/test_config.py | 6 +-- tests/test_scheduler.py | 6 +-- 11 files changed, 40 insertions(+), 40 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index dd03041bb215..da56dc888138 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -42,7 +42,7 @@ from .schedulers import ( DDIMScheduler, DDPMScheduler, - DPMSolverDiscreteScheduler, + DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, IPNDMScheduler, @@ -93,7 +93,7 @@ from .schedulers import ( FlaxDDIMScheduler, FlaxDDPMScheduler, - FlaxDPMSolverDiscreteScheduler, + FlaxDPMSolverMultistepScheduler, FlaxKarrasVeScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index 86e373efea72..5a910f8453ce 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -16,7 +16,7 @@ from ...pipeline_flax_utils import FlaxDiffusionPipeline from ...schedulers import ( FlaxDDIMScheduler, - FlaxDPMSolverDiscreteScheduler, + FlaxDPMSolverMultistepScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler, ) @@ -49,7 +49,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or - [`FlaxDPMSolverDiscreteScheduler`]. + [`FlaxDPMSolverMultistepScheduler`]. safety_checker ([`FlaxStableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. @@ -64,7 +64,7 @@ def __init__( tokenizer: CLIPTokenizer, unet: FlaxUNet2DConditionModel, scheduler: Union[ - FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverDiscreteScheduler + FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler ], safety_checker: FlaxStableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 72885e6f9180..094841f9778e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -11,7 +11,7 @@ from ...pipeline_utils import DiffusionPipeline from ...schedulers import ( DDIMScheduler, - DPMSolverDiscreteScheduler, + DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler, @@ -65,7 +65,7 @@ def __init__( LMSDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, - DPMSolverDiscreteScheduler, + DPMSolverMultistepScheduler, ], safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 3bbd16e3c058..6217bfcd6985 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -19,7 +19,7 @@ if is_torch_available(): from .scheduling_ddim import DDIMScheduler from .scheduling_ddpm import DDPMScheduler - from .scheduling_dpmsolver_multistep import DPMSolverDiscreteScheduler + from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler from .scheduling_euler_discrete import EulerDiscreteScheduler from .scheduling_ipndm import IPNDMScheduler @@ -36,7 +36,7 @@ if is_flax_available(): from .scheduling_ddim_flax import FlaxDDIMScheduler from .scheduling_ddpm_flax import FlaxDDPMScheduler - from .scheduling_dpmsolver_multistep_flax import FlaxDPMSolverDiscreteScheduler + from .scheduling_dpmsolver_multistep_flax import FlaxDPMSolverMultistepScheduler from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler from .scheduling_pndm_flax import FlaxPNDMScheduler diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 99ba5864e1e6..b28e842d753a 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -53,7 +53,7 @@ def alpha_bar(time_step): return torch.tensor(betas, dtype=torch.float32) -class DPMSolverDiscreteScheduler(SchedulerMixin, ConfigMixin): +class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): """ DPM-Solver (and the improved version DPM-Solver++) is a fast dedicated high-order solver for diffusion ODEs with the convergence order guarantee. Empirically, sampling by DPM-Solver with only 20 steps can generate high-quality diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py index 90110c99e6cd..fa933718825a 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py @@ -56,7 +56,7 @@ def alpha_bar(time_step): @flax.struct.dataclass -class DPMSolverDiscreteSchedulerState: +class DPMSolverMultistepSchedulerState: # setable values num_inference_steps: Optional[int] = None timesteps: Optional[jnp.ndarray] = None @@ -74,11 +74,11 @@ def create(cls, num_train_timesteps: int): @dataclass -class FlaxDPMSolverDiscreteSchedulerOutput(FlaxSchedulerOutput): - state: DPMSolverDiscreteSchedulerState +class FlaxDPMSolverMultistepSchedulerOutput(FlaxSchedulerOutput): + state: DPMSolverMultistepSchedulerState -class FlaxDPMSolverDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): +class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): """ DPM-Solver (and the improved version DPM-Solver++) is a fast dedicated high-order solver for diffusion ODEs with the convergence order guarantee. Empirically, sampling by DPM-Solver with only 20 steps can generate high-quality @@ -183,17 +183,17 @@ def __init__( raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") def create_state(self): - return DPMSolverDiscreteSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps) + return DPMSolverMultistepSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps) def set_timesteps( - self, state: DPMSolverDiscreteSchedulerState, num_inference_steps: int, shape: Tuple - ) -> DPMSolverDiscreteSchedulerState: + self, state: DPMSolverMultistepSchedulerState, num_inference_steps: int, shape: Tuple + ) -> DPMSolverMultistepSchedulerState: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Args: - state (`DPMSolverDiscreteSchedulerState`): - the `FlaxDPMSolverDiscreteScheduler` state data class instance. + state (`DPMSolverMultistepSchedulerState`): + the `FlaxDPMSolverMultistepScheduler` state data class instance. num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. shape (`Tuple`): @@ -395,26 +395,26 @@ def multistep_dpm_solver_third_order_update( def step( self, - state: DPMSolverDiscreteSchedulerState, + state: DPMSolverMultistepSchedulerState, model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray, return_dict: bool = True, - ) -> Union[FlaxDPMSolverDiscreteSchedulerOutput, Tuple]: + ) -> Union[FlaxDPMSolverMultistepSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by DPM-Solver. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: - state (`DPMSolverDiscreteSchedulerState`): the `FlaxDPMSolverDiscreteScheduler` state data class instance. + state (`DPMSolverMultistepSchedulerState`): the `FlaxDPMSolverMultistepScheduler` state data class instance. model_output (`jnp.ndarray`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. - return_dict (`bool`): option for returning tuple rather than FlaxDPMSolverDiscreteSchedulerOutput class + return_dict (`bool`): option for returning tuple rather than FlaxDPMSolverMultistepSchedulerOutput class Returns: - [`FlaxDPMSolverDiscreteSchedulerOutput`] or `tuple`: [`FlaxDPMSolverDiscreteSchedulerOutput`] if + [`FlaxDPMSolverMultistepSchedulerOutput`] or `tuple`: [`FlaxDPMSolverMultistepSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ @@ -435,7 +435,7 @@ def step( cur_sample=sample, ) - def step_1(state: DPMSolverDiscreteSchedulerState) -> jnp.ndarray: + def step_1(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: return self.dpm_solver_first_order_update( state.model_outputs[-1], state.timesteps[state.step_index], @@ -443,8 +443,8 @@ def step_1(state: DPMSolverDiscreteSchedulerState) -> jnp.ndarray: state.cur_sample, ) - def step_23(state: DPMSolverDiscreteSchedulerState) -> jnp.ndarray: - def step_2(state: DPMSolverDiscreteSchedulerState) -> jnp.ndarray: + def step_23(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: + def step_2(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: timestep_list = jnp.array([state.timesteps[state.step_index - 1], state.timesteps[state.step_index]]) return self.multistep_dpm_solver_second_order_update( state.model_outputs, @@ -453,7 +453,7 @@ def step_2(state: DPMSolverDiscreteSchedulerState) -> jnp.ndarray: state.cur_sample, ) - def step_3(state: DPMSolverDiscreteSchedulerState) -> jnp.ndarray: + def step_3(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: timestep_list = jnp.array( [ state.timesteps[state.step_index - 2], @@ -520,17 +520,17 @@ def step_3(state: DPMSolverDiscreteSchedulerState) -> jnp.ndarray: if not return_dict: return (prev_sample, state) - return FlaxDPMSolverDiscreteSchedulerOutput(prev_sample=prev_sample, state=state) + return FlaxDPMSolverMultistepSchedulerOutput(prev_sample=prev_sample, state=state) def scale_model_input( - self, state: DPMSolverDiscreteSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None + self, state: DPMSolverMultistepSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None ) -> jnp.ndarray: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: - state (`DPMSolverDiscreteSchedulerState`): the `FlaxDPMSolverDiscreteScheduler` state data class instance. + state (`DPMSolverMultistepSchedulerState`): the `FlaxDPMSolverMultistepScheduler` state data class instance. sample (`jnp.ndarray`): input sample timestep (`int`, optional): current timestep diff --git a/src/diffusers/utils/dummy_flax_objects.py b/src/diffusers/utils/dummy_flax_objects.py index 601ea4ed6b38..8e308bb41bea 100644 --- a/src/diffusers/utils/dummy_flax_objects.py +++ b/src/diffusers/utils/dummy_flax_objects.py @@ -94,7 +94,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) -class FlaxDPMSolverDiscreteScheduler(metaclass=DummyObject): +class FlaxDPMSolverMultistepScheduler(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index b7b910281cd4..9d296d29977d 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -302,7 +302,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class DPMSolverDiscreteScheduler(metaclass=DummyObject): +class DPMSolverMultistepScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): diff --git a/src/diffusers/utils/dummy_torch_and_accelerate_objects.py b/src/diffusers/utils/dummy_torch_and_accelerate_objects.py index 22a9ff8d34cb..c05abb7bc762 100644 --- a/src/diffusers/utils/dummy_torch_and_accelerate_objects.py +++ b/src/diffusers/utils/dummy_torch_and_accelerate_objects.py @@ -302,7 +302,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "accelerate"]) -class DPMSolverDiscreteScheduler(metaclass=DummyObject): +class DPMSolverMultistepScheduler(metaclass=DummyObject): _backends = ["torch", "accelerate"] def __init__(self, *args, **kwargs): diff --git a/tests/test_config.py b/tests/test_config.py index 51bb7ca8c5c2..5084769def41 100755 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -21,7 +21,7 @@ import diffusers from diffusers import ( DDIMScheduler, - DPMSolverDiscreteScheduler, + DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, PNDMScheduler, @@ -295,10 +295,10 @@ def test_load_dpmsolver(self): logger = logging.get_logger("diffusers.configuration_utils") with CaptureLogger(logger) as cap_logger: - dpm = DPMSolverDiscreteScheduler.from_config( + dpm = DPMSolverMultistepScheduler.from_config( "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler" ) - assert dpm.__class__ == DPMSolverDiscreteScheduler + assert dpm.__class__ == DPMSolverMultistepScheduler # no warning should be thrown assert cap_logger.out == "" diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 98ded72eb162..ecd4b6c4f94b 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -24,7 +24,7 @@ from diffusers import ( DDIMScheduler, DDPMScheduler, - DPMSolverDiscreteScheduler, + DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, IPNDMScheduler, @@ -550,8 +550,8 @@ def test_full_loop_with_no_set_alpha_to_one(self): assert abs(result_mean.item() - 0.1941) < 1e-3 -class DPMSolverDiscreteSchedulerTest(SchedulerCommonTest): - scheduler_classes = (DPMSolverDiscreteScheduler,) +class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): + scheduler_classes = (DPMSolverMultistepScheduler,) forward_default_kwargs = (("num_inference_steps", 25),) def get_scheduler_config(self, **kwargs): From 7c7c2ec3bdecca26e11ac0a301359308fe6e14b3 Mon Sep 17 00:00:00 2001 From: LuChengTHU Date: Sat, 5 Nov 2022 14:13:25 +0800 Subject: [PATCH 17/22] fix code style --- .../schedulers/scheduling_dpmsolver_multistep_flax.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py index fa933718825a..6930ae5a16b8 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py @@ -406,7 +406,8 @@ def step( from the learned model outputs (most often the predicted noise). Args: - state (`DPMSolverMultistepSchedulerState`): the `FlaxDPMSolverMultistepScheduler` state data class instance. + state (`DPMSolverMultistepSchedulerState`): + the `FlaxDPMSolverMultistepScheduler` state data class instance. model_output (`jnp.ndarray`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`jnp.ndarray`): @@ -530,7 +531,8 @@ def scale_model_input( current timestep. Args: - state (`DPMSolverMultistepSchedulerState`): the `FlaxDPMSolverMultistepScheduler` state data class instance. + state (`DPMSolverMultistepSchedulerState`): + the `FlaxDPMSolverMultistepScheduler` state data class instance. sample (`jnp.ndarray`): input sample timestep (`int`, optional): current timestep From a6efda1c63a773f3116e73f88a87c020f6b163f2 Mon Sep 17 00:00:00 2001 From: LuChengTHU Date: Sat, 5 Nov 2022 14:22:06 +0800 Subject: [PATCH 18/22] add auto docs for dpmsolver multistep --- docs/source/api/schedulers.mdx | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/source/api/schedulers.mdx b/docs/source/api/schedulers.mdx index f073f6b37912..12575a5ecae2 100644 --- a/docs/source/api/schedulers.mdx +++ b/docs/source/api/schedulers.mdx @@ -70,6 +70,12 @@ Original paper can be found [here](https://arxiv.org/abs/2010.02502). [[autodoc]] DDPMScheduler +#### Multistep DPM-Solver + +Original paper can be found [here](https://arxiv.org/abs/2206.00927) and the [improved version](https://arxiv.org/abs/2211.01095). The original implementation can be found [here](https://github.com/LuChengTHU/dpm-solver). + +[[autodoc]] DPMSolverMultistepScheduler + #### Variance exploding, stochastic sampling from Karras et. al Original paper can be found [here](https://arxiv.org/abs/2006.11239). From 5566a2b7e1fae6174d398a83eadbf90e80804760 Mon Sep 17 00:00:00 2001 From: LuChengTHU Date: Sun, 6 Nov 2022 18:52:38 +0800 Subject: [PATCH 19/22] add more explanations for the stabilizing trick (for steps < 15) --- .../scheduling_dpmsolver_multistep.py | 19 ++++++++++++------- .../scheduling_dpmsolver_multistep_flax.py | 11 ++++++----- tests/test_scheduler.py | 8 ++++---- 3 files changed, 22 insertions(+), 16 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index b28e842d753a..fb15cef7e92a 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -103,8 +103,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): solver_type (`str`, default `dpm_solver`): the solver type for the second-order solver. Either `dpm_solver` or `taylor`. The solver type slightly affects the sample quality, especially for small number of steps. - denoise_final (`bool`, default `True`): - whether to use lower-order solvers in the final steps. + lower_order_final (`bool`, default `True`): + whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically + find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10. """ @@ -131,7 +132,7 @@ def __init__( dynamic_thresholding_ratio: float = 0.995, sample_max_value: float = 1.0, solver_type: str = "dpm_solver", - denoise_final: bool = True, + lower_order_final: bool = True, ): if trained_betas is not None: self.betas = torch.from_numpy(trained_betas) @@ -405,17 +406,21 @@ def step( else: step_index = step_index.item() prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] - denoise_final = (step_index == len(self.timesteps) - 1) and self.config.denoise_final - denoise_second = (step_index == len(self.timesteps) - 2) and self.config.denoise_final + lower_order_final = ( + (step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15 + ) + lower_order_second = ( + (step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 + ) model_output = self.convert_model_output(model_output, timestep, sample) for i in range(self.config.solver_order - 1): self.model_outputs[i] = self.model_outputs[i + 1] self.model_outputs[-1] = model_output - if self.config.solver_order == 1 or self.lower_order_nums < 1 or denoise_final: + if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: prev_sample = self.dpm_solver_first_order_update(model_output, timestep, prev_timestep, sample) - elif self.config.solver_order == 2 or self.lower_order_nums < 2 or denoise_second: + elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: timestep_list = [self.timesteps[step_index - 1], timestep] prev_sample = self.multistep_dpm_solver_second_order_update( self.model_outputs, timestep_list, prev_timestep, sample diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py index 6930ae5a16b8..4afb0a01f5e7 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py @@ -130,8 +130,9 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): solver_type (`str`, default `dpm_solver`): the solver type for the second-order solver. Either `dpm_solver` or `taylor`. The solver type slightly affects the sample quality, especially for small number of steps. - denoise_final (`bool`, default `True`): - whether to use lower-order solvers in the final steps. + lower_order_final (`bool`, default `True`): + whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically + find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10. """ @@ -153,7 +154,7 @@ def __init__( dynamic_thresholding_ratio: float = 0.995, sample_max_value: float = 1.0, solver_type: str = "dpm_solver", - denoise_final: bool = True, + lower_order_final: bool = True, ): if trained_betas is not None: self.betas = jnp.asarray(trained_betas) @@ -471,7 +472,7 @@ def step_3(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: if self.config.solver_order == 2: return step_2(state) - elif self.config.denoise_final: + elif self.config.lower_order_final: return jax.lax.cond( state.lower_order_nums < 2, step_2, @@ -493,7 +494,7 @@ def step_3(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: if self.config.solver_order == 1: prev_sample = step_1(state) - elif self.config.denoise_final: + elif self.config.lower_order_final: prev_sample = jax.lax.cond( state.lower_order_nums < 1, step_1, diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index ecd4b6c4f94b..ef6938bc1d53 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -565,7 +565,7 @@ def get_scheduler_config(self, **kwargs): "thresholding": False, "sample_max_value": 1.0, "solver_type": "dpm_solver", - "denoise_final": False, + "lower_order_final": False, } config.update(**kwargs) @@ -702,9 +702,9 @@ def test_solver_order_and_type(self): sample = self.full_loop(solver_order=order, solver_type=solver_type, predict_x0=predict_x0) assert not torch.isnan(sample).any(), "Samples have nan numbers" - def test_denoise_final(self): - self.check_over_configs(denoise_final=True) - self.check_over_configs(denoise_final=False) + def test_lower_order_final(self): + self.check_over_configs(lower_order_final=True) + self.check_over_configs(lower_order_final=False) def test_inference_steps(self): for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]: From 3ac6ab4ac94a9bf7091609d7e28b4f9ffa91a75b Mon Sep 17 00:00:00 2001 From: LuChengTHU Date: Sun, 6 Nov 2022 20:38:00 +0800 Subject: [PATCH 20/22] delete the dummy file --- .../dummy_torch_and_accelerate_objects.py | 467 ------------------ 1 file changed, 467 deletions(-) delete mode 100644 src/diffusers/utils/dummy_torch_and_accelerate_objects.py diff --git a/src/diffusers/utils/dummy_torch_and_accelerate_objects.py b/src/diffusers/utils/dummy_torch_and_accelerate_objects.py deleted file mode 100644 index c05abb7bc762..000000000000 --- a/src/diffusers/utils/dummy_torch_and_accelerate_objects.py +++ /dev/null @@ -1,467 +0,0 @@ -# This file is autogenerated by the command `make fix-copies`, do not edit. -# flake8: noqa - -from ..utils import DummyObject, requires_backends - - -class ModelMixin(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class AutoencoderKL(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class Transformer2DModel(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class UNet1DModel(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class UNet2DConditionModel(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class UNet2DModel(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class VQModel(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -def get_constant_schedule(*args, **kwargs): - requires_backends(get_constant_schedule, ["torch", "accelerate"]) - - -def get_constant_schedule_with_warmup(*args, **kwargs): - requires_backends(get_constant_schedule_with_warmup, ["torch", "accelerate"]) - - -def get_cosine_schedule_with_warmup(*args, **kwargs): - requires_backends(get_cosine_schedule_with_warmup, ["torch", "accelerate"]) - - -def get_cosine_with_hard_restarts_schedule_with_warmup(*args, **kwargs): - requires_backends(get_cosine_with_hard_restarts_schedule_with_warmup, ["torch", "accelerate"]) - - -def get_linear_schedule_with_warmup(*args, **kwargs): - requires_backends(get_linear_schedule_with_warmup, ["torch", "accelerate"]) - - -def get_polynomial_decay_schedule_with_warmup(*args, **kwargs): - requires_backends(get_polynomial_decay_schedule_with_warmup, ["torch", "accelerate"]) - - -def get_scheduler(*args, **kwargs): - requires_backends(get_scheduler, ["torch", "accelerate"]) - - -class DiffusionPipeline(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class DanceDiffusionPipeline(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class DDIMPipeline(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class DDPMPipeline(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class KarrasVePipeline(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class LDMPipeline(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class PNDMPipeline(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class RePaintPipeline(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class ScoreSdeVePipeline(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class DDIMScheduler(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class DDPMScheduler(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class DPMSolverMultistepScheduler(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class EulerAncestralDiscreteScheduler(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class EulerDiscreteScheduler(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class IPNDMScheduler(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class KarrasVeScheduler(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class PNDMScheduler(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class RePaintScheduler(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class SchedulerMixin(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class ScoreSdeVeScheduler(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class VQDiffusionScheduler(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class EMAModel(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) From f54cf995f4106e835513073c75e38b26da3eea01 Mon Sep 17 00:00:00 2001 From: LuChengTHU Date: Sun, 6 Nov 2022 23:05:35 +0800 Subject: [PATCH 21/22] change the API name of predict_epsilon, algorithm_type and solver_type --- .../scheduling_dpmsolver_multistep.py | 93 +++++++++++------- .../scheduling_dpmsolver_multistep_flax.py | 97 ++++++++++++------- tests/test_scheduler.py | 46 ++++++--- 3 files changed, 152 insertions(+), 84 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index fb15cef7e92a..d166354809b0 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -65,8 +65,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space - diffusion models, you can set both `predict_x0=True` and `thresholding=True` to use the dynamic thresholding. Note - that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). + diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the dynamic + thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as + stable-diffusion). [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. @@ -85,24 +86,30 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): solver_order (`int`, default `2`): the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. - predict_x0 (`bool`, default `True`): - DPM-Solver is designed for both the noise prediction model (DPM-Solver, https://arxiv.org/abs/2206.00927) - with `predict_x0=False` and the data prediction model (DPM-Solver++, https://arxiv.org/abs/2211.01095) with - `predict_x0=True`. We recommend to use `predict_x0=True` and `solver_order=2` for guided sampling (e.g. - stable-diffusion). + predict_epsilon (`bool`, default `True`): + we currently support both the noise prediction model and the data prediction model. If the model predicts + the noise / epsilon, set `predict_epsilon` to `True`. If the model predicts the data / x0 directly, set + `predict_epsilon` to `False`. thresholding (`bool`, default `False`): whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). - For pixel-space diffusion models, you can set both `predict_x0=True` and `thresholding=True` to use the - dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models - (such as stable-diffusion). + For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to + use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion + models (such as stable-diffusion). dynamic_thresholding_ratio (`float`, default `0.995`): the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen (https://arxiv.org/abs/2205.11487). sample_max_value (`float`, default `1.0`): - the threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`. - solver_type (`str`, default `dpm_solver`): - the solver type for the second-order solver. Either `dpm_solver` or `taylor`. The solver type slightly - affects the sample quality, especially for small number of steps. + the threshold value for dynamic thresholding. Valid only when `thresholding=True` and + `algorithm_type="dpmsolver++`. + algorithm_type (`str`, default `dpmsolver++`): + the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++`. The `dpmsolver` type implements the + algorithms in https://arxiv.org/abs/2206.00927, and the `dpmsolver++` type implements the algorithms in + https://arxiv.org/abs/2211.01095. We recommend to use `dpmsolver++` with `solver_order=2` for guided + sampling (e.g. stable-diffusion). + solver_type (`str`, default `midpoint`): + the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects + the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are + slightly better, so we recommend to use the `midpoint` type. lower_order_final (`bool`, default `True`): whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10. @@ -127,11 +134,12 @@ def __init__( beta_schedule: str = "linear", trained_betas: Optional[np.ndarray] = None, solver_order: int = 2, - predict_x0: bool = True, + predict_epsilon: bool = True, thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, sample_max_value: float = 1.0, - solver_type: str = "dpm_solver", + algorithm_type: str = "dpmsolver++", + solver_type: str = "midpoint", lower_order_final: bool = True, ): if trained_betas is not None: @@ -160,7 +168,9 @@ def __init__( self.init_noise_sigma = 1.0 # settings for DPM-Solver - if solver_type not in ["dpm_solver", "taylor"]: + if algorithm_type not in ["dpmsolver", "dpmsolver++"]: + raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") + if solver_type not in ["midpoint", "heun"]: raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") # setable values @@ -197,7 +207,14 @@ def convert_model_output( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor ) -> torch.FloatTensor: """ - Convert the noise prediction model to either the noise or the data prediction model. + Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs. + + DPM-Solver is designed to discretize an integral of the noise prediciton model, and DPM-Solver++ is designed to + discretize an integral of the data prediction model. So we need to first convert the model output to the + corresponding type to match the algorithm. + + Note that the algorithm type and the model type is decoupled. That is to say, we can use either DPM-Solver or + DPM-Solver++ for both noise prediction model and data prediction model. Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. @@ -208,9 +225,13 @@ def convert_model_output( Returns: `torch.FloatTensor`: the converted model output. """ - if self.config.predict_x0: - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] - x0_pred = (sample - sigma_t * model_output) / alpha_t + # DPM-Solver++ needs to solve an integral of the data prediction model. + if self.config.algorithm_type == "dpmsolver++": + if self.config.predict_epsilon: + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = (sample - sigma_t * model_output) / alpha_t + else: + x0_pred = model_output if self.config.thresholding: # Dynamic thresholding in https://arxiv.org/abs/2205.11487 dynamic_max_val = torch.quantile( @@ -222,8 +243,14 @@ def convert_model_output( )[(...,) + (None,) * (x0_pred.ndim - 1)] x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val return x0_pred - else: - return model_output + # DPM-Solver needs to solve an integral of the noise prediction model. + elif self.config.algorithm_type == "dpmsolver": + if self.config.predict_epsilon: + return model_output + else: + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + epsilon = (sample - alpha_t * model_output) / sigma_t + return epsilon def dpm_solver_first_order_update( self, @@ -251,9 +278,9 @@ def dpm_solver_first_order_update( alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep] sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep] h = lambda_t - lambda_s - if self.config.predict_x0: + if self.config.algorithm_type == "dpmsolver++": x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output - else: + elif self.config.algorithm_type == "dpmsolver": x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output return x_t @@ -286,29 +313,29 @@ def multistep_dpm_solver_second_order_update( h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 r0 = h_0 / h D0, D1 = m0, (1.0 / r0) * (m0 - m1) - if self.config.predict_x0: + if self.config.algorithm_type == "dpmsolver++": # See https://arxiv.org/abs/2211.01095 for detailed derivations - if self.config.solver_type == "dpm_solver": + if self.config.solver_type == "midpoint": x_t = ( (sigma_t / sigma_s0) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * D0 - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 ) - elif self.config.solver_type == "taylor": + elif self.config.solver_type == "heun": x_t = ( (sigma_t / sigma_s0) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 ) - else: + elif self.config.algorithm_type == "dpmsolver": # See https://arxiv.org/abs/2206.00927 for detailed derivations - if self.config.solver_type == "dpm_solver": + if self.config.solver_type == "midpoint": x_t = ( (alpha_t / alpha_s0) * sample - (sigma_t * (torch.exp(h) - 1.0)) * D0 - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1 ) - elif self.config.solver_type == "taylor": + elif self.config.solver_type == "heun": x_t = ( (alpha_t / alpha_s0) * sample - (sigma_t * (torch.exp(h) - 1.0)) * D0 @@ -353,7 +380,7 @@ def multistep_dpm_solver_third_order_update( D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) - if self.config.predict_x0: + if self.config.algorithm_type == "dpmsolver++": # See https://arxiv.org/abs/2206.00927 for detailed derivations x_t = ( (sigma_t / sigma_s0) * sample @@ -361,7 +388,7 @@ def multistep_dpm_solver_third_order_update( + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 ) - else: + elif self.config.algorithm_type == "dpmsolver": # See https://arxiv.org/abs/2206.00927 for detailed derivations x_t = ( (alpha_t / alpha_s0) * sample diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py index 4afb0a01f5e7..c9a6d1cd5c0b 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py @@ -90,8 +90,9 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space - diffusion models, you can set both `predict_x0=True` and `thresholding=True` to use the dynamic thresholding. Note - that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). + diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the dynamic + thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as + stable-diffusion). [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. @@ -112,24 +113,30 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): solver_order (`int`, default `2`): the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. - predict_x0 (`bool`, default `True`): - DPM-Solver is designed for both the noise prediction model (DPM-Solver, https://arxiv.org/abs/2206.00927) - with `predict_x0=False` and the data prediction model (DPM-Solver++, https://arxiv.org/abs/2211.01095) with - `predict_x0=True`. We recommend to use `predict_x0=True` and `solver_order=2` for guided sampling (e.g. - stable-diffusion). + predict_epsilon (`bool`, default `True`): + we currently support both the noise prediction model and the data prediction model. If the model predicts + the noise / epsilon, set `predict_epsilon` to `True`. If the model predicts the data / x0 directly, set + `predict_epsilon` to `False`. thresholding (`bool`, default `False`): whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). - For pixel-space diffusion models, you can set both `predict_x0=True` and `thresholding=True` to use the - dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models - (such as stable-diffusion). + For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to + use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion + models (such as stable-diffusion). dynamic_thresholding_ratio (`float`, default `0.995`): the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen (https://arxiv.org/abs/2205.11487). sample_max_value (`float`, default `1.0`): - the threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`. - solver_type (`str`, default `dpm_solver`): - the solver type for the second-order solver. Either `dpm_solver` or `taylor`. The solver type slightly - affects the sample quality, especially for small number of steps. + the threshold value for dynamic thresholding. Valid only when `thresholding=True` and + `algorithm_type="dpmsolver++`. + algorithm_type (`str`, default `dpmsolver++`): + the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++`. The `dpmsolver` type implements the + algorithms in https://arxiv.org/abs/2206.00927, and the `dpmsolver++` type implements the algorithms in + https://arxiv.org/abs/2211.01095. We recommend to use `dpmsolver++` with `solver_order=2` for guided + sampling (e.g. stable-diffusion). + solver_type (`str`, default `midpoint`): + the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects + the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are + slightly better, so we recommend to use the `midpoint` type. lower_order_final (`bool`, default `True`): whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10. @@ -149,11 +156,12 @@ def __init__( beta_schedule: str = "linear", trained_betas: Optional[jnp.ndarray] = None, solver_order: int = 2, - predict_x0: bool = True, + predict_epsilon: bool = True, thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, sample_max_value: float = 1.0, - solver_type: str = "dpm_solver", + algorithm_type: str = "dpmsolver++", + solver_type: str = "midpoint", lower_order_final: bool = True, ): if trained_betas is not None: @@ -180,7 +188,9 @@ def __init__( self.init_noise_sigma = 1.0 # settings for DPM-Solver - if solver_type not in ["dpm_solver", "taylor"]: + if algorithm_type not in ["dpmsolver", "dpmsolver++"]: + raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") + if solver_type not in ["midpoint", "heun"]: raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") def create_state(self): @@ -223,7 +233,14 @@ def convert_model_output( sample: jnp.ndarray, ) -> jnp.ndarray: """ - Convert the noise prediction model to either the noise or the data prediction model. + Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs. + + DPM-Solver is designed to discretize an integral of the noise prediciton model, and DPM-Solver++ is designed to + discretize an integral of the data prediction model. So we need to first convert the model output to the + corresponding type to match the algorithm. + + Note that the algorithm type and the model type is decoupled. That is to say, we can use either DPM-Solver or + DPM-Solver++ for both noise prediction model and data prediction model. Args: model_output (`jnp.ndarray`): direct output from learned diffusion model. @@ -234,9 +251,13 @@ def convert_model_output( Returns: `jnp.ndarray`: the converted model output. """ - if self.config.predict_x0: - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] - x0_pred = (sample - sigma_t * model_output) / alpha_t + # DPM-Solver++ needs to solve an integral of the data prediction model. + if self.config.algorithm_type == "dpmsolver++": + if self.config.predict_epsilon: + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = (sample - sigma_t * model_output) / alpha_t + else: + x0_pred = model_output if self.config.thresholding: # Dynamic thresholding in https://arxiv.org/abs/2205.11487 dynamic_max_val = jnp.percentile( @@ -247,8 +268,14 @@ def convert_model_output( ) x0_pred = jnp.clip(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val return x0_pred - else: - return model_output + # DPM-Solver needs to solve an integral of the noise prediction model. + elif self.config.algorithm_type == "dpmsolver": + if self.config.predict_epsilon: + return model_output + else: + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + epsilon = (sample - alpha_t * model_output) / sigma_t + return epsilon def dpm_solver_first_order_update( self, model_output: jnp.ndarray, timestep: int, prev_timestep: int, sample: jnp.ndarray @@ -274,9 +301,9 @@ def dpm_solver_first_order_update( alpha_t, alpha_s = self.alpha_t[t], self.alpha_t[s0] sigma_t, sigma_s = self.sigma_t[t], self.sigma_t[s0] h = lambda_t - lambda_s - if self.config.predict_x0: + if self.config.algorithm_type == "dpmsolver++": x_t = (sigma_t / sigma_s) * sample - (alpha_t * (jnp.exp(-h) - 1.0)) * m0 - else: + elif self.config.algorithm_type == "dpmsolver": x_t = (alpha_t / alpha_s) * sample - (sigma_t * (jnp.exp(h) - 1.0)) * m0 return x_t @@ -309,29 +336,29 @@ def multistep_dpm_solver_second_order_update( h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 r0 = h_0 / h D0, D1 = m0, (1.0 / r0) * (m0 - m1) - if self.config.predict_x0: + if self.config.algorithm_type == "dpmsolver++": # See https://arxiv.org/abs/2211.01095 for detailed derivations - if self.config.solver_type == "dpm_solver": + if self.config.solver_type == "midpoint": x_t = ( (sigma_t / sigma_s0) * sample - (alpha_t * (jnp.exp(-h) - 1.0)) * D0 - 0.5 * (alpha_t * (jnp.exp(-h) - 1.0)) * D1 ) - elif self.config.solver_type == "taylor": + elif self.config.solver_type == "heun": x_t = ( (sigma_t / sigma_s0) * sample - (alpha_t * (jnp.exp(-h) - 1.0)) * D0 + (alpha_t * ((jnp.exp(-h) - 1.0) / h + 1.0)) * D1 ) - else: + elif self.config.algorithm_type == "dpmsolver": # See https://arxiv.org/abs/2206.00927 for detailed derivations - if self.config.solver_type == "dpm_solver": + if self.config.solver_type == "midpoint": x_t = ( (alpha_t / alpha_s0) * sample - (sigma_t * (jnp.exp(h) - 1.0)) * D0 - 0.5 * (sigma_t * (jnp.exp(h) - 1.0)) * D1 ) - elif self.config.solver_type == "taylor": + elif self.config.solver_type == "heun": x_t = ( (alpha_t / alpha_s0) * sample - (sigma_t * (jnp.exp(h) - 1.0)) * D0 @@ -376,7 +403,7 @@ def multistep_dpm_solver_third_order_update( D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) - if self.config.predict_x0: + if self.config.algorithm_type == "dpmsolver++": # See https://arxiv.org/abs/2206.00927 for detailed derivations x_t = ( (sigma_t / sigma_s0) * sample @@ -384,7 +411,7 @@ def multistep_dpm_solver_third_order_update( + (alpha_t * ((jnp.exp(-h) - 1.0) / h + 1.0)) * D1 - (alpha_t * ((jnp.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 ) - else: + elif self.config.algorithm_type == "dpmsolver": # See https://arxiv.org/abs/2206.00927 for detailed derivations x_t = ( (alpha_t / alpha_s0) * sample @@ -472,7 +499,7 @@ def step_3(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: if self.config.solver_order == 2: return step_2(state) - elif self.config.lower_order_final: + elif self.config.lower_order_final and len(state.timesteps) < 15: return jax.lax.cond( state.lower_order_nums < 2, step_2, @@ -494,7 +521,7 @@ def step_3(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: if self.config.solver_order == 1: prev_sample = step_1(state) - elif self.config.lower_order_final: + elif self.config.lower_order_final and len(state.timesteps) < 15: prev_sample = jax.lax.cond( state.lower_order_nums < 1, step_1, diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index ef6938bc1d53..056f723835ba 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -561,10 +561,11 @@ def get_scheduler_config(self, **kwargs): "beta_end": 0.02, "beta_schedule": "linear", "solver_order": 2, - "predict_x0": True, + "predict_epsilon": True, "thresholding": False, "sample_max_value": 1.0, - "solver_type": "dpm_solver", + "algorithm_type": "dpmsolver++", + "solver_type": "midpoint", "lower_order_final": False, } @@ -684,23 +685,36 @@ def test_timesteps(self): def test_thresholding(self): self.check_over_configs(thresholding=False) for order in [1, 2, 3]: - for solver_type in ["dpm_solver", "taylor"]: + for solver_type in ["midpoint", "heun"]: for threshold in [0.5, 1.0, 2.0]: - self.check_over_configs( - thresholding=True, - sample_max_value=threshold, - predict_x0=True, - solver_order=order, - solver_type=solver_type, - ) + for predict_epsilon in [True, False]: + self.check_over_configs( + thresholding=True, + predict_epsilon=predict_epsilon, + sample_max_value=threshold, + algorithm_type="dpmsolver++", + solver_order=order, + solver_type=solver_type, + ) def test_solver_order_and_type(self): - for solver_type in ["dpm_solver", "taylor"]: - for order in [1, 2, 3]: - for predict_x0 in [True, False]: - self.check_over_configs(solver_order=order, solver_type=solver_type, predict_x0=predict_x0) - sample = self.full_loop(solver_order=order, solver_type=solver_type, predict_x0=predict_x0) - assert not torch.isnan(sample).any(), "Samples have nan numbers" + for algorithm_type in ["dpmsolver", "dpmsolver++"]: + for solver_type in ["midpoint", "heun"]: + for order in [1, 2, 3]: + for predict_epsilon in [True, False]: + self.check_over_configs( + solver_order=order, + solver_type=solver_type, + predict_epsilon=predict_epsilon, + algorithm_type=algorithm_type, + ) + sample = self.full_loop( + solver_order=order, + solver_type=solver_type, + predict_epsilon=predict_epsilon, + algorithm_type=algorithm_type, + ) + assert not torch.isnan(sample).any(), "Samples have nan numbers" def test_lower_order_final(self): self.check_over_configs(lower_order_final=True) From dee238fb418cf9f7906dbcc8d0e941a463834cdf Mon Sep 17 00:00:00 2001 From: LuChengTHU Date: Sun, 6 Nov 2022 23:17:21 +0800 Subject: [PATCH 22/22] add compatible lists --- src/diffusers/schedulers/scheduling_ddim.py | 1 + src/diffusers/schedulers/scheduling_ddpm.py | 1 + src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py | 1 + src/diffusers/schedulers/scheduling_euler_discrete.py | 1 + src/diffusers/schedulers/scheduling_lms_discrete.py | 1 + src/diffusers/schedulers/scheduling_pndm.py | 1 + 6 files changed, 6 insertions(+) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 62ee9c0244ae..8d4407c16c30 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -115,6 +115,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): "LMSDiscreteScheduler", "EulerDiscreteScheduler", "EulerAncestralDiscreteScheduler", + "DPMSolverMultistepScheduler", ] @register_to_config diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 114a86b4320e..171c9598eba2 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -108,6 +108,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): "LMSDiscreteScheduler", "EulerDiscreteScheduler", "EulerAncestralDiscreteScheduler", + "DPMSolverMultistepScheduler", ] @register_to_config diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index fe45b3d591f5..7f44067325cf 100644 --- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -73,6 +73,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): "LMSDiscreteScheduler", "PNDMScheduler", "EulerDiscreteScheduler", + "DPMSolverMultistepScheduler", ] @register_to_config diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 0cb31a451272..50a1bd89f839 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -74,6 +74,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): "LMSDiscreteScheduler", "PNDMScheduler", "EulerAncestralDiscreteScheduler", + "DPMSolverMultistepScheduler", ] @register_to_config diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 8d633267c607..d636fe6fe87f 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -73,6 +73,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): "PNDMScheduler", "EulerDiscreteScheduler", "EulerAncestralDiscreteScheduler", + "DPMSolverMultistepScheduler", ] @register_to_config diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 0082ede787b8..eec18af8d382 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -94,6 +94,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): "LMSDiscreteScheduler", "EulerDiscreteScheduler", "EulerAncestralDiscreteScheduler", + "DPMSolverMultistepScheduler", ] @register_to_config