Skip to content

Commit

Permalink
Add VAEImageDecoder for StableDiffusionV3 (#1796)
Browse files Browse the repository at this point in the history
* Add `VAEImageDecoder` for StableDiffusionV3

* Use `keras.Model` for `VAEImageDecoder` and follows the coding style in `VAEAttention`
  • Loading branch information
james77777778 authored and mattdangerw committed Sep 13, 2024
1 parent b10c410 commit 9feb2d8
Show file tree
Hide file tree
Showing 2 changed files with 303 additions and 0 deletions.
126 changes: 126 additions & 0 deletions keras_nlp/src/models/stable_diffusion_v3/vae_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math

from keras import layers
from keras import ops

from keras_nlp.src.utils.keras_utils import standardize_data_format


class VAEAttention(layers.Layer):
def __init__(self, filters, groups=32, data_format=None, **kwargs):
super().__init__(**kwargs)
self.filters = filters
self.data_format = standardize_data_format(data_format)
gn_axis = -1 if self.data_format == "channels_last" else 1

self.group_norm = layers.GroupNormalization(
groups=groups,
axis=gn_axis,
epsilon=1e-6,
dtype=self.dtype_policy,
name="group_norm",
)
self.query_conv2d = layers.Conv2D(
filters,
1,
1,
data_format=self.data_format,
dtype=self.dtype_policy,
name="query_conv2d",
)
self.key_conv2d = layers.Conv2D(
filters,
1,
1,
data_format=self.data_format,
dtype=self.dtype_policy,
name="key_conv2d",
)
self.value_conv2d = layers.Conv2D(
filters,
1,
1,
data_format=self.data_format,
dtype=self.dtype_policy,
name="value_conv2d",
)
self.softmax = layers.Softmax(dtype="float32")
self.output_conv2d = layers.Conv2D(
filters,
1,
1,
data_format=self.data_format,
dtype=self.dtype_policy,
name="output_conv2d",
)

self.groups = groups
self._inverse_sqrt_filters = 1.0 / math.sqrt(float(filters))

def build(self, input_shape):
self.group_norm.build(input_shape)
self.query_conv2d.build(input_shape)
self.key_conv2d.build(input_shape)
self.value_conv2d.build(input_shape)
self.output_conv2d.build(input_shape)

def call(self, inputs, training=None):
x = self.group_norm(inputs)
query = self.query_conv2d(x)
key = self.key_conv2d(x)
value = self.value_conv2d(x)

if self.data_format == "channels_first":
query = ops.transpose(query, (0, 2, 3, 1))
key = ops.transpose(key, (0, 2, 3, 1))
value = ops.transpose(value, (0, 2, 3, 1))
shape = ops.shape(inputs)
b = shape[0]
query = ops.reshape(query, (b, -1, self.filters))
key = ops.reshape(key, (b, -1, self.filters))
value = ops.reshape(value, (b, -1, self.filters))

# Compute attention.
query = ops.multiply(
query, ops.cast(self._inverse_sqrt_filters, query.dtype)
)
# [B, H0 * W0, C], [B, H1 * W1, C] -> [B, H0 * W0, H1 * W1]
attention_scores = ops.einsum("abc,adc->abd", query, key)
attention_scores = ops.cast(
self.softmax(attention_scores), self.compute_dtype
)
# [B, H2 * W2, C], [B, H0 * W0, H1 * W1] -> [B, H1 * W1 ,C]
attention_output = ops.einsum("abc,adb->adc", value, attention_scores)
x = ops.reshape(attention_output, shape)

x = self.output_conv2d(x)
if self.data_format == "channels_first":
x = ops.transpose(x, (0, 3, 1, 2))
x = ops.add(x, inputs)
return x

def get_config(self):
config = super().get_config()
config.update(
{
"filters": self.filters,
"groups": self.groups,
}
)
return config

def compute_output_shape(self, input_shape):
return input_shape
177 changes: 177 additions & 0 deletions keras_nlp/src/models/stable_diffusion_v3/vae_image_decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import keras
from keras import layers

from keras_nlp.src.models.stable_diffusion_v3.vae_attention import VAEAttention
from keras_nlp.src.utils.keras_utils import standardize_data_format


class VAEImageDecoder(keras.Model):
def __init__(
self,
stackwise_num_filters,
stackwise_num_blocks,
output_channels=3,
latent_shape=(None, None, 16),
data_format=None,
dtype=None,
**kwargs,
):
data_format = standardize_data_format(data_format)
gn_axis = -1 if data_format == "channels_last" else 1

# === Functional Model ===
latent_inputs = layers.Input(shape=latent_shape)

x = layers.Conv2D(
stackwise_num_filters[0],
3,
1,
padding="same",
data_format=data_format,
dtype=dtype,
name="input_projection",
)(latent_inputs)
x = apply_resnet_block(
x,
stackwise_num_filters[0],
data_format=data_format,
dtype=dtype,
name="input_block0",
)
x = VAEAttention(
stackwise_num_filters[0],
data_format=data_format,
dtype=dtype,
name="input_attention",
)(x)
x = apply_resnet_block(
x,
stackwise_num_filters[0],
data_format=data_format,
dtype=dtype,
name="input_block1",
)

# Stacks.
for i, filters in enumerate(stackwise_num_filters):
for j in range(stackwise_num_blocks[i]):
x = apply_resnet_block(
x,
filters,
data_format=data_format,
dtype=dtype,
name=f"block{i}_{j}",
)
if i != len(stackwise_num_filters) - 1:
# No upsamling in the last blcok.
x = layers.UpSampling2D(
2,
data_format=data_format,
dtype=dtype,
name=f"upsample_{i}",
)(x)
x = layers.Conv2D(
filters,
3,
1,
padding="same",
data_format=data_format,
dtype=dtype,
name=f"upsample_{i}_conv",
)(x)

# Ouput block.
x = layers.GroupNormalization(
groups=32,
axis=gn_axis,
epsilon=1e-6,
dtype=dtype,
name="output_norm",
)(x)
x = layers.Activation("swish", dtype=dtype, name="output_activation")(x)
image_outputs = layers.Conv2D(
output_channels,
3,
1,
padding="same",
data_format=data_format,
dtype=dtype,
name="output_projection",
)(x)
super().__init__(inputs=latent_inputs, outputs=image_outputs, **kwargs)

# === Config ===
self.stackwise_num_filters = stackwise_num_filters
self.stackwise_num_blocks = stackwise_num_blocks
self.output_channels = output_channels
self.latent_shape = latent_shape

def get_config(self):
config = super().get_config()
config.update(
{
"stackwise_num_filters": self.stackwise_num_filters,
"stackwise_num_blocks": self.stackwise_num_blocks,
"output_channels": self.output_channels,
"image_shape": self.latent_shape,
}
)
return config


def apply_resnet_block(x, filters, data_format=None, dtype=None, name=None):
data_format = standardize_data_format(data_format)
gn_axis = -1 if data_format == "channels_last" else 1
input_filters = x.shape[gn_axis]

residual = x
x = layers.GroupNormalization(
groups=32, axis=gn_axis, epsilon=1e-6, dtype=dtype, name=f"{name}_norm1"
)(x)
x = layers.Activation("swish", dtype=dtype)(x)
x = layers.Conv2D(
filters,
3,
1,
padding="same",
data_format=data_format,
dtype=dtype,
name=f"{name}_conv1",
)(x)
x = layers.GroupNormalization(
groups=32, axis=gn_axis, epsilon=1e-6, dtype=dtype, name=f"{name}_norm2"
)(x)
x = layers.Activation("swish")(x)
x = layers.Conv2D(
filters,
3,
1,
padding="same",
data_format=data_format,
dtype=dtype,
name=f"{name}_conv2",
)(x)
if input_filters != filters:
residual = layers.Conv2D(
filters,
1,
1,
data_format=data_format,
dtype=dtype,
name=f"{name}_residual_projection",
)(residual)
x = layers.Add(dtype=dtype)([residual, x])
return x

0 comments on commit 9feb2d8

Please sign in to comment.