Skip to content

Commit

Permalink
add more explanations for the stabilizing trick (for steps < 15)
Browse files Browse the repository at this point in the history
  • Loading branch information
LuChengTHU committed Nov 6, 2022
1 parent f65012c commit a497362
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 16 deletions.
19 changes: 12 additions & 7 deletions src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
solver_type (`str`, default `dpm_solver`):
the solver type for the second-order solver. Either `dpm_solver` or `taylor`. The solver type slightly
affects the sample quality, especially for small number of steps.
denoise_final (`bool`, default `True`):
whether to use lower-order solvers in the final steps.
lower_order_final (`bool`, default `True`):
whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10.
"""

Expand All @@ -131,7 +132,7 @@ def __init__(
dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0,
solver_type: str = "dpm_solver",
denoise_final: bool = True,
lower_order_final: bool = True,
):
if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
Expand Down Expand Up @@ -405,17 +406,21 @@ def step(
else:
step_index = step_index.item()
prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
denoise_final = (step_index == len(self.timesteps) - 1) and self.config.denoise_final
denoise_second = (step_index == len(self.timesteps) - 2) and self.config.denoise_final
lower_order_final = (
(step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
)
lower_order_second = (
(step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
)

model_output = self.convert_model_output(model_output, timestep, sample)
for i in range(self.config.solver_order - 1):
self.model_outputs[i] = self.model_outputs[i + 1]
self.model_outputs[-1] = model_output

if self.config.solver_order == 1 or self.lower_order_nums < 1 or denoise_final:
if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
prev_sample = self.dpm_solver_first_order_update(model_output, timestep, prev_timestep, sample)
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or denoise_second:
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
timestep_list = [self.timesteps[step_index - 1], timestep]
prev_sample = self.multistep_dpm_solver_second_order_update(
self.model_outputs, timestep_list, prev_timestep, sample
Expand Down
11 changes: 6 additions & 5 deletions src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,9 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
solver_type (`str`, default `dpm_solver`):
the solver type for the second-order solver. Either `dpm_solver` or `taylor`. The solver type slightly
affects the sample quality, especially for small number of steps.
denoise_final (`bool`, default `True`):
whether to use lower-order solvers in the final steps.
lower_order_final (`bool`, default `True`):
whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10.
"""

Expand All @@ -153,7 +154,7 @@ def __init__(
dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0,
solver_type: str = "dpm_solver",
denoise_final: bool = True,
lower_order_final: bool = True,
):
if trained_betas is not None:
self.betas = jnp.asarray(trained_betas)
Expand Down Expand Up @@ -471,7 +472,7 @@ def step_3(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray:

if self.config.solver_order == 2:
return step_2(state)
elif self.config.denoise_final:
elif self.config.lower_order_final:
return jax.lax.cond(
state.lower_order_nums < 2,
step_2,
Expand All @@ -493,7 +494,7 @@ def step_3(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray:

if self.config.solver_order == 1:
prev_sample = step_1(state)
elif self.config.denoise_final:
elif self.config.lower_order_final:
prev_sample = jax.lax.cond(
state.lower_order_nums < 1,
step_1,
Expand Down
8 changes: 4 additions & 4 deletions tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ def get_scheduler_config(self, **kwargs):
"thresholding": False,
"sample_max_value": 1.0,
"solver_type": "dpm_solver",
"denoise_final": False,
"lower_order_final": False,
}

config.update(**kwargs)
Expand Down Expand Up @@ -702,9 +702,9 @@ def test_solver_order_and_type(self):
sample = self.full_loop(solver_order=order, solver_type=solver_type, predict_x0=predict_x0)
assert not torch.isnan(sample).any(), "Samples have nan numbers"

def test_denoise_final(self):
self.check_over_configs(denoise_final=True)
self.check_over_configs(denoise_final=False)
def test_lower_order_final(self):
self.check_over_configs(lower_order_final=True)
self.check_over_configs(lower_order_final=False)

def test_inference_steps(self):
for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]:
Expand Down

0 comments on commit a497362

Please sign in to comment.