Skip to content

Commit

Permalink
inpainting 1.5 model added in tf backend, options box refractored
Browse files Browse the repository at this point in the history
  • Loading branch information
divamgupta committed Oct 28, 2022
1 parent 6219144 commit 4f1d501
Show file tree
Hide file tree
Showing 7 changed files with 1,962 additions and 166 deletions.
96 changes: 77 additions & 19 deletions backends/stable_diffusion_tf/diffusionbee_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,40 +91,108 @@ def process_opt(d, generator):
print("sdbk nwim %s"%(fpath) )


cur_model_id = -1
cur_model = None
def get_sd_model(model_id):
global p1 , p2 , p3 , p4 , p1_15 , p2_15 , p3_15 , p4_15
global cur_model_id , cur_model
if cur_model_id != model_id:

if cur_model is not None:
cur_model = None
time.sleep(1)
if model_id == 0:

print("sdbk gnms Loading SD Model" )
cur_model = StableDiffusion(img_height=512, img_width=512, jit_compile=False, download_weights=False, is_sd_15_inpaint=False)
cur_model.text_encoder .load_weights(p2)
cur_model.diffusion_model.load_weights(p1)
cur_model.decoder.load_weights(p3)
cur_model.encoder.load_weights(p4)
print("sdbk mdvr 1.4tf")
elif model_id == 1:
print("sdbk mdvr 1.5tf_inp")
print("sdbk gnms Loading SD Inpainting Model" )
cur_model = StableDiffusion(img_height=512, img_width=512, jit_compile=False, download_weights=False, is_sd_15_inpaint=True)
cur_model.text_encoder .load_weights(p2_15)
cur_model.diffusion_model.load_weights(p1_15)
cur_model.decoder.load_weights(p3_15)
cur_model.encoder.load_weights(p4_15)
else:
assert False

cur_model_id = model_id

return cur_model


def main():

global p1 , p2 , p3 , p4 , p1_15 , p2_15 , p3_15 , p4_15

print("sdbk mltl Loading Model")

for _ in range(5):
try:
p1 = ProgressBarDownloader(title="Downloading Model 1/4").download(
p1 = ProgressBarDownloader(title="Downloading Model 1/8").download(
url="https://huggingface.co/fchollet/stable-diffusion/resolve/main/diffusion_model.h5",
md5_checksum="72db3d55b60691e1f8a6a68cd9f47ad0",
verify_ssl=False,
extract_zip=False,
)

p2 = ProgressBarDownloader(title="Downloading Model 2/4").download(
p2 = ProgressBarDownloader(title="Downloading Model 2/8").download(
url="https://huggingface.co/fchollet/stable-diffusion/resolve/main/text_encoder.h5",
md5_checksum="9ea30bed7728473b4270a76aabf1836b",
verify_ssl=False,
extract_zip=False,
)


p3 = ProgressBarDownloader(title="Downloading Model 3/4").download(
p3 = ProgressBarDownloader(title="Downloading Model 3/8").download(
url="https://huggingface.co/fchollet/stable-diffusion/resolve/main/decoder.h5",
md5_checksum="8c86dc2fadfb0da9712a7a06cfa7bf11",
verify_ssl=False,
extract_zip=False,
)

p4 = ProgressBarDownloader(title="Downloading Model 4/4").download(
p4 = ProgressBarDownloader(title="Downloading Model 4/8").download(
url="https://huggingface.co/divamgupta/stable-diffusion-tensorflow/resolve/main/encoder_newW.h5",
md5_checksum="bef951ed69aa5a7a3acae0ab0308b630",
verify_ssl=False,
extract_zip=False,
)

p1_15 = ProgressBarDownloader(title="Downloading Model 5/8").download(
url="https://huggingface.co/divamgupta/stable-diffusion-tensorflow/resolve/main/diffusion_model_15_inpaint.h5",
md5_checksum="fd5868208a33dc4594559433bc493334",
verify_ssl=False,
extract_zip=False,
)

p2_15 = ProgressBarDownloader(title="Downloading Model 6/8").download(
url="https://huggingface.co/divamgupta/stable-diffusion-tensorflow/resolve/main/text_encoder_15_inpaint.h5",
md5_checksum="859cc286026b9c1a510d87f85295b4a4",
verify_ssl=False,
extract_zip=False,
)


p3_15 = ProgressBarDownloader(title="Downloading Model 7/8").download(
url="https://huggingface.co/divamgupta/stable-diffusion-tensorflow/resolve/main/decoder_15_inpaint.h5",
md5_checksum="aecfa5cbf18a06158e0dde99d6d2fadf",
verify_ssl=False,
extract_zip=False,
)

p4_15 = ProgressBarDownloader(title="Downloading Model 8/8").download(
url="https://huggingface.co/divamgupta/stable-diffusion-tensorflow/resolve/main/encoder_15_inpaint.h5",
md5_checksum="f73e95b6d5e1ed32e9a15fe31b1ede70",
verify_ssl=False,
extract_zip=False,
)


break
except Exception as e:
pass
Expand All @@ -140,15 +208,12 @@ def main():


cur_size = (512 , 512)
generator = StableDiffusion(img_height=512, img_width=512, jit_compile=False, download_weights=False)
generator.text_encoder .load_weights(p2)
generator.diffusion_model.load_weights(p1)
generator.decoder.load_weights(p3)
generator.encoder.load_weights(p4)
generator = get_sd_model(0)


default_d = { "W" : 512 , "H" : 512, "num_imgs":1 , "ddim_steps" : 25 ,
"scale" : 7.5, "batch_size":1 , "input_image" : None, "img_strength": 0.5
, "negative_prompt" : "" , "mask_image" : None,}
, "negative_prompt" : "" , "mask_image" : None, "model_id": 0 }


print("sdbk mdld")
Expand All @@ -169,15 +234,8 @@ def main():
d = copy.deepcopy(default_d)
d.update(d_)
print("sdbk inwk") # working on the input

# if cur_size != (d['W'] , d['H']):
# print("sdbk mltl Loading Model")
# generator = StableDiffusion(img_height= d['H'], img_width=d['W'], jit_compile=False, download_weights=False)
# generator.text_encoder .load_weights(p2)
# generator.diffusion_model.load_weights(p1)
# generator.decoder.load_weights(p3)
# print("sdbk mdld")
# cur_size = (d['W'] , d['H'])
generator = None
generator = get_sd_model(d['model_id'])

process_opt(d, generator)
except Exception as e:
Expand Down
2 changes: 1 addition & 1 deletion backends/stable_diffusion_tf/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def download(self, url, out_fname=None, md5_checksum=None,
if time.time() - last_time > 0.1:
last_time = time.time()
print("sdbk mlpr %d"%int(done_percentage) ) # model loading percentage
print("sdbk mlms \"%s\""%("%.2fMB out of %.2fMB"%(dl/1000000 , total_length/1000000) ))
print("sdbk mlms %s"%("%.2fMB out of %.2fMB"%(dl/1000000 , total_length/1000000) ))

print("sdbk mlpr %d"%int(-1) )
print("sdbk mltl Checking Model")
Expand Down
Loading

0 comments on commit 4f1d501

Please sign in to comment.