Skip to content

Commit

Permalink
Merge branch 'develop' into sdxl
Browse files Browse the repository at this point in the history
  • Loading branch information
qiacheng committed Sep 7, 2023
2 parents 2971758 + 38df761 commit 021beec
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion scripts/openvino_accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def __init__(self):
self.vae_ckpt = "None"
self.refiner_ckpt = ""


model_state = ModelState()

DEFAULT_OPENVINO_PYTHON_CONFIG = MappingProxyType(
Expand Down Expand Up @@ -157,6 +158,7 @@ def _call(*args):

if inputs_reversed:
example_inputs.reverse()

model = make_fx(subgraph)(*example_inputs)
for node in model.graph.nodes:
if node.target == torch.ops.aten.mul_.Tensor:
Expand Down Expand Up @@ -618,6 +620,7 @@ def get_diffusers_sd_model(model_config, vae_ckpt, sampler_name, enable_caching,
sd_model = StableDiffusionXLControlNetPipeline(**sd_model.components, controlnet=controlnet)
sd_model.controlnet = torch.compile(sd_model.controlnet, backend="openvino_fx")
else:

if model_config != "None":
local_config_file = os.path.join(curr_dir_path, 'configs', model_config)
sd_model = StableDiffusionPipeline.from_single_file(checkpoint_path, local_config_file=local_config_file, use_safetensors=True)
Expand All @@ -633,6 +636,7 @@ def get_diffusers_sd_model(model_config, vae_ckpt, sampler_name, enable_caching,
sd_model.controlnet = torch.compile(sd_model.controlnet, backend="openvino_fx")

#load lora

if ('lora' in modules.extra_networks.extra_network_registry):
import lora
if lora.loaded_loras:
Expand All @@ -647,7 +651,6 @@ def get_diffusers_sd_model(model_config, vae_ckpt, sampler_name, enable_caching,
controlnet = ControlNetModel.from_pretrained("lllyasviel/" + model_state.cn_model)
sd_model = StableDiffusionControlNetPipeline(**sd_model.components, controlnet=controlnet)
sd_model.controlnet = torch.compile(sd_model.controlnet, backend="openvino_fx")

sd_model.sd_checkpoint_info = checkpoint_info
sd_model.sd_model_hash = checkpoint_info.calculate_shorthash()
#sd_model.safety_checker = None
Expand Down Expand Up @@ -781,6 +784,7 @@ def init_new(self, all_prompts, all_seeds, all_subseeds):
raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")

def process_images_openvino(p: StableDiffusionProcessing, model_config, vae_ckpt, sampler_name, enable_caching, openvino_device, mode, is_xl_ckpt, refiner_ckpt, refiner_frac) -> Processed:

"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
if (mode == 0 and p.enable_hr):
return process_images(p)
Expand Down Expand Up @@ -899,6 +903,7 @@ def infotext(iteration=0, position_in_batch=0):
shared.sd_refiner_model = get_diffusers_sd_refiner_model(model_config, vae_ckpt, sampler_name, enable_caching, openvino_device, mode, is_xl_ckpt, refiner_ckpt, refiner_frac)
shared.sd_refiner_model.scheduler = set_scheduler(shared.sd_refiner_model, sampler_name)


if p.scripts is not None:
p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)

Expand Down Expand Up @@ -945,6 +950,7 @@ def callback(iter, t, latents):
'height': p.height,
})


if refiner_ckpt != "None" and is_xl_ckpt is True:
print("here")
base_output_type = "latent"
Expand Down

0 comments on commit 021beec

Please sign in to comment.