Skip to content

Commit

Permalink
support scheduler selection in hires fix
Browse files Browse the repository at this point in the history
  • Loading branch information
AUTOMATIC1111 committed Mar 24, 2024
1 parent 755d2cb commit 9aa9e98
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 44 deletions.
3 changes: 3 additions & 0 deletions modules/infotext_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,9 @@ def parse_generation_parameters(x: str, skip_fields: list[str] | None = None):
if "Hires sampler" not in res:
res["Hires sampler"] = "Use same sampler"

if "Hires schedule type" not in res:
res["Hires schedule type"] = "Use same scheduler"

if "Hires checkpoint" not in res:
res["Hires checkpoint"] = "Use same checkpoint"

Expand Down
6 changes: 6 additions & 0 deletions modules/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1115,6 +1115,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
hr_resize_y: int = 0
hr_checkpoint_name: str = None
hr_sampler_name: str = None
hr_scheduler: str = None
hr_prompt: str = ''
hr_negative_prompt: str = ''
force_task_id: str = None
Expand Down Expand Up @@ -1203,6 +1204,11 @@ def init(self, all_prompts, all_seeds, all_subseeds):
if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:
self.extra_generation_params["Hires sampler"] = self.hr_sampler_name

self.extra_generation_params["Hires schedule type"] = None # to be set in sd_samplers_kdiffusion.py

if self.hr_scheduler is None:
self.hr_scheduler = self.scheduler

self.latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
if self.enable_hr and self.latent_scale_mode is None:
if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):
Expand Down
38 changes: 2 additions & 36 deletions modules/processing_scripts/sampler.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,10 @@
import gradio as gr
import functools

from modules import scripts, sd_samplers, sd_schedulers, shared
from modules.infotext_utils import PasteField
from modules.ui_components import FormRow, FormGroup


def get_sampler_from_infotext(d: dict):
return get_sampler_and_scheduler(d.get("Sampler"), d.get("Schedule type"))[0]


def get_scheduler_from_infotext(d: dict):
return get_sampler_and_scheduler(d.get("Sampler"), d.get("Schedule type"))[1]


@functools.cache
def get_sampler_and_scheduler(sampler_name, scheduler_name):
default_sampler = sd_samplers.samplers[0]
found_scheduler = sd_schedulers.schedulers_map.get(scheduler_name, sd_schedulers.schedulers[0])

name = sampler_name or default_sampler.name

for scheduler in sd_schedulers.schedulers:
name_options = [scheduler.label, scheduler.name, *(scheduler.aliases or [])]

for name_option in name_options:
if name.endswith(" " + name_option):
found_scheduler = scheduler
name = name[0:-(len(name_option) + 1)]
break

sampler = sd_samplers.all_samplers_map.get(name, default_sampler)

# revert back to Automatic if it's the default scheduler for the selected sampler
if sampler.options.get('scheduler', None) == found_scheduler.name:
found_scheduler = sd_schedulers.schedulers[0]

return sampler.name, found_scheduler.label


class ScriptSampler(scripts.ScriptBuiltinUI):
section = "sampler"

Expand Down Expand Up @@ -67,8 +33,8 @@ def ui(self, is_img2img):

self.infotext_fields = [
PasteField(self.steps, "Steps", api="steps"),
PasteField(self.sampler_name, get_sampler_from_infotext, api="sampler_name"),
PasteField(self.scheduler, get_scheduler_from_infotext, api="scheduler"),
PasteField(self.sampler_name, sd_samplers.get_sampler_from_infotext, api="sampler_name"),
PasteField(self.scheduler, sd_samplers.get_scheduler_from_infotext, api="scheduler"),
]

return self.steps, self.sampler_name, self.scheduler
Expand Down
60 changes: 59 additions & 1 deletion modules/sd_samplers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, sd_samplers_lcm, shared, sd_samplers_common
import functools

from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, sd_samplers_lcm, shared, sd_samplers_common, sd_schedulers

# imports for functions that previously were here and are used by other modules
samples_to_image_grid = sd_samplers_common.samples_to_image_grid
Expand Down Expand Up @@ -64,4 +66,60 @@ def visible_samplers():
return [x for x in samplers if x.name not in samplers_hidden]


def get_sampler_from_infotext(d: dict):
return get_sampler_and_scheduler(d.get("Sampler"), d.get("Schedule type"))[0]


def get_scheduler_from_infotext(d: dict):
return get_sampler_and_scheduler(d.get("Sampler"), d.get("Schedule type"))[1]


def get_hr_sampler_and_scheduler(d: dict):
hr_sampler = d.get("Hires sampler", "Use same sampler")
sampler = d.get("Sampler") if hr_sampler == "Use same sampler" else hr_sampler

hr_scheduler = d.get("Hires schedule type", "Use same scheduler")
scheduler = d.get("Schedule type") if hr_scheduler == "Use same scheduler" else hr_scheduler

sampler, scheduler = get_sampler_and_scheduler(sampler, scheduler)

sampler = sampler if sampler != d.get("Sampler") else "Use same sampler"
scheduler = scheduler if scheduler != d.get("Schedule type") else "Use same scheduler"

return sampler, scheduler


def get_hr_sampler_from_infotext(d: dict):
return get_hr_sampler_and_scheduler(d)[0]


def get_hr_scheduler_from_infotext(d: dict):
return get_hr_sampler_and_scheduler(d)[1]


@functools.cache
def get_sampler_and_scheduler(sampler_name, scheduler_name):
default_sampler = samplers[0]
found_scheduler = sd_schedulers.schedulers_map.get(scheduler_name, sd_schedulers.schedulers[0])

name = sampler_name or default_sampler.name

for scheduler in sd_schedulers.schedulers:
name_options = [scheduler.label, scheduler.name, *(scheduler.aliases or [])]

for name_option in name_options:
if name.endswith(" " + name_option):
found_scheduler = scheduler
name = name[0:-(len(name_option) + 1)]
break

sampler = all_samplers_map.get(name, default_sampler)

# revert back to Automatic if it's the default scheduler for the selected sampler
if sampler.options.get('scheduler', None) == found_scheduler.name:
found_scheduler = sd_schedulers.schedulers[0]

return sampler.name, found_scheduler.label


set_samplers()
6 changes: 4 additions & 2 deletions modules/sd_samplers_kdiffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def get_sigmas(self, p, steps):

steps += 1 if discard_next_to_last_sigma else 0

scheduler_name = p.scheduler or 'Automatic'
scheduler_name = (p.hr_scheduler if p.is_hr_pass else p.scheduler) or 'Automatic'
if scheduler_name == 'Automatic':
scheduler_name = self.config.options.get('scheduler', None)

Expand All @@ -95,8 +95,10 @@ def get_sigmas(self, p, steps):
else:
sigmas_kwargs = {'sigma_min': sigma_min, 'sigma_max': sigma_max}

if scheduler.label != 'Automatic':
if scheduler.label != 'Automatic' and not p.is_hr_pass:
p.extra_generation_params["Schedule type"] = scheduler.label
elif scheduler.label != p.extra_generation_params.get("Schedule type"):
p.extra_generation_params["Hires schedule type"] = scheduler.label

if opts.sigma_min != 0 and opts.sigma_min != m_sigma_min:
sigmas_kwargs['sigma_min'] = opts.sigma_min
Expand Down
3 changes: 2 additions & 1 deletion modules/txt2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import gradio as gr


def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, negative_prompt: str, prompt_styles, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args, force_enable_hr=False):
def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, negative_prompt: str, prompt_styles, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_scheduler: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args, force_enable_hr=False):
override_settings = create_override_settings_dict(override_settings_texts)

if force_enable_hr:
Expand All @@ -38,6 +38,7 @@ def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, ne
hr_resize_y=hr_resize_y,
hr_checkpoint_name=None if hr_checkpoint_name == 'Use same checkpoint' else hr_checkpoint_name,
hr_sampler_name=None if hr_sampler_name == 'Use same sampler' else hr_sampler_name,
hr_scheduler=None if hr_scheduler == 'Use same scheduler' else hr_scheduler,
hr_prompt=hr_prompt,
hr_negative_prompt=hr_negative_prompt,
override_settings=override_settings,
Expand Down
11 changes: 7 additions & 4 deletions modules/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,10 +322,11 @@ def create_ui():

with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container:

hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint")
hr_checkpoint_name = gr.Dropdown(label='Checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint")
create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh")

hr_sampler_name = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + sd_samplers.visible_sampler_names(), value="Use same sampler")
hr_sampler_name = gr.Dropdown(label='Sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + sd_samplers.visible_sampler_names(), value="Use same sampler")

This comment has been minimized.

Copy link
@light-and-ray

light-and-ray Mar 27, 2024

Contributor

You have changed the label, and it breaks ui load saves
изображение

This comment has been minimized.

Copy link
@light-and-ray

light-and-ray Mar 27, 2024

Contributor

Made PR #15394

hr_scheduler = gr.Dropdown(label='Schedule type', elem_id="hr_scheduler", choices=["Use same scheduler"] + [x.label for x in sd_schedulers.schedulers], value="Use same scheduler")

with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
with gr.Column(scale=80):
Expand Down Expand Up @@ -394,6 +395,7 @@ def create_ui():
hr_resize_y,
hr_checkpoint_name,
hr_sampler_name,
hr_scheduler,
hr_prompt,
hr_negative_prompt,
override_settings,
Expand Down Expand Up @@ -456,8 +458,9 @@ def create_ui():
PasteField(hr_resize_x, "Hires resize-1", api="hr_resize_x"),
PasteField(hr_resize_y, "Hires resize-2", api="hr_resize_y"),
PasteField(hr_checkpoint_name, "Hires checkpoint", api="hr_checkpoint_name"),
PasteField(hr_sampler_name, "Hires sampler", api="hr_sampler_name"),
PasteField(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()),
PasteField(hr_sampler_name, sd_samplers.get_hr_sampler_from_infotext, api="hr_sampler_name"),
PasteField(hr_scheduler, sd_samplers.get_hr_scheduler_from_infotext, api="hr_scheduler"),
PasteField(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" or d.get("Hires schedule type", "Use same scheduler") != "Use same scheduler" else gr.update()),
PasteField(hr_prompt, "Hires prompt", api="hr_prompt"),
PasteField(hr_negative_prompt, "Hires negative prompt", api="hr_negative_prompt"),
PasteField(hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()),
Expand Down

0 comments on commit 9aa9e98

Please sign in to comment.