Skip to content

Commit

Permalink
image popups, faster generation for tf, error display bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
divamgupta committed Sep 24, 2022
1 parent d18cb5f commit ae3e284
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 34 deletions.
19 changes: 10 additions & 9 deletions backends/stable_diffusion_tf/diffusionbee_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def process_opt(d, generator):
seed = None
img = generator.generate(
d['prompt'],
img_height=d["H"], img_width=d["W"],
num_steps=d['ddim_steps'],
unconditional_guidance_scale=d['scale'],
temperature=1,
Expand Down Expand Up @@ -109,18 +110,18 @@ def main():
d.update(d_)
print("sdbk inwk") # working on the input

if cur_size != (d['W'] , d['H']):
print("sdbk mltl Loading Model")
generator = Text2Image(img_height= d['H'], img_width=d['W'], jit_compile=False, download_weights=False)
generator.text_encoder .load_weights(p2)
generator.diffusion_model.load_weights(p1)
generator.decoder.load_weights(p3)
print("sdbk mdld")
cur_size = (d['W'] , d['H'])
# if cur_size != (d['W'] , d['H']):
# print("sdbk mltl Loading Model")
# generator = Text2Image(img_height= d['H'], img_width=d['W'], jit_compile=False, download_weights=False)
# generator.text_encoder .load_weights(p2)
# generator.diffusion_model.load_weights(p1)
# generator.decoder.load_weights(p3)
# print("sdbk mdld")
# cur_size = (d['W'] , d['H'])

process_opt(d, generator)
except Exception as e:
print("sbdk errr %s"%(str(e)))
print("sdbk errr %s"%(str(e)))
print("py2b eror " + str(e))


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@ def __init__(self, img_height=1000, img_width=1000, jit_compile=False, download_
self.img_width = img_width
self.tokenizer = SimpleTokenizer()

text_encoder, diffusion_model, decoder = get_models(img_height, img_width, download_weights=download_weights)
text_encoder, diffusion_model, decoder , text_encoder_f , diffusion_model_f , decoder_f = get_models(img_height, img_width, download_weights=download_weights)
self.text_encoder = text_encoder
self.diffusion_model = diffusion_model
self.decoder = decoder
self.text_encoder_f = text_encoder_f
self.diffusion_model_f = diffusion_model_f
self.decoder_f = decoder_f
if jit_compile:
self.text_encoder.compile(jit_compile=True)
self.diffusion_model.compile(jit_compile=True)
Expand All @@ -32,13 +35,20 @@ def __init__(self, img_height=1000, img_width=1000, jit_compile=False, download_
def generate(
self,
prompt,
img_height, img_width,
batch_size=1,
num_steps=25,
unconditional_guidance_scale=7.5,
temperature=1,
seed=None,
img_id=0,
):

if self.img_height == img_height and self.img_width == img_width:
self.use_eager = False
else:
self.use_eager = True

try:
seed = int(seed)
if seed < 1:
Expand All @@ -52,27 +62,37 @@ def generate(
seed = seed + 1234*img_id
# Tokenize prompt (i.e. starting context)
inputs = self.tokenizer.encode(prompt)
assert len(inputs) < 77, "Prompt is too long (should be < 77 tokens)"
assert len(inputs) < 77, "Prompt is too long!"
phrase = inputs + [49407] * (77 - len(inputs))
phrase = np.array(phrase)[None].astype("int32")
phrase = np.repeat(phrase, batch_size, axis=0)

# Encode prompt tokens (and their positions) into a "context vector"
pos_ids = np.array(list(range(77)))[None].astype("int32")
pos_ids = np.repeat(pos_ids, batch_size, axis=0)
context = self.text_encoder.predict_on_batch([phrase, pos_ids])

if self.use_eager:
context = self.text_encoder_f([phrase, pos_ids])
else:
context = self.text_encoder.predict_on_batch([phrase, pos_ids])

# Encode unconditional tokens (and their positions into an
# "unconditional context vector"
unconditional_tokens = np.array(_UNCONDITIONAL_TOKENS)[None].astype("int32")
unconditional_tokens = np.repeat(unconditional_tokens, batch_size, axis=0)
self.unconditional_tokens = tf.convert_to_tensor(unconditional_tokens)
unconditional_context = self.text_encoder.predict_on_batch(
[self.unconditional_tokens, pos_ids]
)

if self.use_eager:
unconditional_context = self.text_encoder_f(
[self.unconditional_tokens, pos_ids]
)
else:
unconditional_context = self.text_encoder.predict_on_batch(
[self.unconditional_tokens, pos_ids]
)
timesteps = np.arange(1, 1000, 1000 // num_steps)
latent, alphas, alphas_prev = self.get_starting_parameters(
timesteps, batch_size, seed
img_height, img_width , timesteps, batch_size, seed
)

# Diffusion stage
Expand Down Expand Up @@ -102,7 +122,10 @@ def generate(
)

# Decoding stage
decoded = self.decoder.predict_on_batch(latent)
if self.use_eager:
decoded = self.decoder_f(latent)
else:
decoded = self.decoder.predict_on_batch(latent)
decoded = ((decoded + 1) / 2) * 255
return np.clip(decoded, 0, 255).astype("uint8")

Expand All @@ -127,10 +150,17 @@ def get_model_output(
timesteps = np.array([t])
t_emb = self.timestep_embedding(timesteps)
t_emb = np.repeat(t_emb, batch_size, axis=0)
unconditional_latent = self.diffusion_model.predict_on_batch(
[latent, t_emb, unconditional_context]
)
latent = self.diffusion_model.predict_on_batch([latent, t_emb, context])

if self.use_eager:
unconditional_latent = self.diffusion_model_f(
[latent, t_emb, unconditional_context]
)
latent = self.diffusion_model_f([latent, t_emb, context])
else:
unconditional_latent = self.diffusion_model.predict_on_batch(
[latent, t_emb, unconditional_context]
)
latent = self.diffusion_model.predict_on_batch([latent, t_emb, context])
return unconditional_latent + unconditional_guidance_scale * (
latent - unconditional_latent
)
Expand All @@ -150,9 +180,9 @@ def get_x_prev_and_pred_x0(self, x, e_t, index, a_t, a_prev, temperature, seed):
x_prev = math.sqrt(a_prev) * pred_x0 + dir_xt
return x_prev, pred_x0

def get_starting_parameters(self, timesteps, batch_size, seed):
n_h = self.img_height // 8
n_w = self.img_width // 8
def get_starting_parameters(self, img_height, img_width , timesteps, batch_size, seed):
n_h = img_height // 8
n_w = img_width // 8
alphas = [_ALPHAS_CUMPROD[t] for t in timesteps]
alphas_prev = [1.0] + alphas[:-1]
latent_np = np.random.RandomState(seed).normal(size=(batch_size, n_h, n_w, 4)).astype('float32')
Expand All @@ -168,21 +198,24 @@ def get_models(img_height, img_width, download_weights=True):
# Create text encoder
input_word_ids = tf.keras.layers.Input(shape=(MAX_TEXT_LEN,), dtype="int32")
input_pos_ids = tf.keras.layers.Input(shape=(MAX_TEXT_LEN,), dtype="int32")
embeds = CLIPTextTransformer()([input_word_ids, input_pos_ids])
text_encoder_f = CLIPTextTransformer()
embeds = text_encoder_f([input_word_ids, input_pos_ids])
text_encoder = tf.keras.models.Model([input_word_ids, input_pos_ids], embeds)

# Creation diffusion UNet
context = tf.keras.layers.Input((MAX_TEXT_LEN, 768))
t_emb = tf.keras.layers.Input((320,))
latent = tf.keras.layers.Input((n_h, n_w, 4))
unet = UNetModel()
diffusion_model_f = unet
diffusion_model = tf.keras.models.Model(
[latent, t_emb, context], unet([latent, t_emb, context])
)

# Create decoder
latent = tf.keras.layers.Input((n_h, n_w, 4))
decoder = Decoder()
decoder_f = decoder
decoder = tf.keras.models.Model(latent, decoder(latent))

if download_weights:
Expand All @@ -203,4 +236,4 @@ def get_models(img_height, img_width, download_weights=True):
text_encoder.load_weights(text_encoder_weights_fpath)
diffusion_model.load_weights(diffusion_model_weights_fpath)
decoder.load_weights(decoder_weights_fpath)
return text_encoder, diffusion_model, decoder
return text_encoder, diffusion_model, decoder, text_encoder_f , diffusion_model_f , decoder_f
8 changes: 6 additions & 2 deletions electron_app/src/components/History.vue
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

<div v-for="img in history_box.imgs" :key="img" style="height:230px; float:left; margin-right: 10px; margin-bottom: 30px;">

<img class="gal_img" v-if="img" :src="'file://' + img" style="height:100%">
<img @click="open_image_popup( img )" class="gal_img" v-if="img" :src="'file://' + img" style="height:100%">
<br>
<div @click="save_image(img)" class="l_button">Save Image</div>
<br>
Expand All @@ -33,6 +33,7 @@
</div>
</template>
<script>
import {open_popup} from "../utils"
import Vue from 'vue'
Expand Down Expand Up @@ -64,7 +65,10 @@ export default {
return
let org_path = generated_image.replaceAll("file://" , "")
window.ipcRenderer.sendSync('save_file', org_path+"||" +out_path);
}
},
open_image_popup(img){
open_popup("file://"+img , undefined);
},
},
}
Expand Down
14 changes: 9 additions & 5 deletions electron_app/src/components/ImgGenerate.vue
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
<b-form-select
style="border-color:rgba(0,0,0,0.1)"
v-model="dif_steps"
:options="[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49 , 50]"
:options="[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49 , 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75]"
required
></b-form-select>
</b-form-group>
Expand Down Expand Up @@ -137,7 +137,7 @@

<div v-if="generated_images.length == 1" >
<center>
<img class="gal_img" v-if="generated_images[0]" :src="'file://' + generated_images[0]" style=" height: calc(100vh - 380px ); margin-top: 60px;">
<img @click="open_image_popup( generated_images[0])" class="gal_img" v-if="generated_images[0]" :src="'file://' + generated_images[0]" style=" height: calc(100vh - 380px ); margin-top: 60px;">
<br>
<div @click="save_image(generated_images[0])" class="l_button">Save Image</div>
</center>
Expand All @@ -149,7 +149,7 @@

<b-col v-for="img in generated_images" :key="img" style="margin-top:80px" md="6" lg="4" xl="3" >
<center>
<img class="gal_img" v-if="img" :src="'file://' + img" style="max-width:85%">
<img @click="open_image_popup( img )" class="gal_img" v-if="img" :src="'file://' + img" style="max-width:85%">
<br>
<div @click="save_image(img)" class="l_button">Save Image</div>
</center>
Expand Down Expand Up @@ -189,7 +189,7 @@
import LoaderModal from '../components_bare/LoaderModal.vue'
import Vue from 'vue'
import {open_popup} from "../utils"
export default {
name: 'ImgGenerate',
Expand All @@ -208,7 +208,7 @@ export default {
dif_steps : 25,
guidence_scale : 7.5 ,
is_adv_options : false ,
seed : 0 ,
seed : "" ,
prompt : "",
num_imgs : 1,
generated_images : [],
Expand Down Expand Up @@ -269,6 +269,10 @@ export default {
this.stable_diffusion.text_to_img(params, callbacks);
} ,
open_image_popup(img){
open_popup("file://"+img , undefined);
},
open_arthub(){
window.ipcRenderer.sendSync('open_url', "https://arthub.ai");
},
Expand Down
63 changes: 62 additions & 1 deletion electron_app/src/utils.js
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,65 @@ function resolve_asset_illustration(name) {



export { compute_n_cols ,resolve_asset_illustration , simple_hash }

const escapeHtml = (unsafe) => {
return unsafe.replaceAll('&', '&amp;').replaceAll('<', '&lt;').replaceAll('>', '&gt;').replaceAll('"', '&quot;').replaceAll("'", '&#039;');
}


function open_popup( img_url , text ){

let css = `
<style>
img {
width: 100%;
height:100%;
object-fit: contain;
user-drag: none;
}
body{
padding : 0;
margin: 0;
background-color: #F2F2F2;
-webkit-user-select: none;
-webkit-app-region: drag;
user-drag: none;
-webkit-user-drag: none;
user-select: none;
-moz-user-select: none;
-webkit-user-select: none;
-ms-user-select: none;
}
p{
padding:40px;
}
audio{
position: fixed ;
bottom: 20px;
left: 50%;
transform: translateX(-50%);
}
</style>
`
let html = '<html><head>'+css+'</head><body>' ;

if (img_url)
html += '<img src="'+escapeHtml(img_url)+'"> ';

if( text )
html += '<p> '+ escapeHtml(text) +' </p>';

html += '</body></html>'
let uri = "data:text/html," + encodeURIComponent(html);
window.open(uri, '_blank', 'top=100,left=100,frame=false,nodeIntegration=no');


}




export { compute_n_cols ,resolve_asset_illustration , simple_hash , open_popup}

0 comments on commit ae3e284

Please sign in to comment.