-
Notifications
You must be signed in to change notification settings - Fork 228
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add
VAEImageDecoder
for StableDiffusionV3 (#1796)
* Add `VAEImageDecoder` for StableDiffusionV3 * Use `keras.Model` for `VAEImageDecoder` and follows the coding style in `VAEAttention`
- Loading branch information
1 parent
b10c410
commit 9feb2d8
Showing
2 changed files
with
303 additions
and
0 deletions.
There are no files selected for viewing
126 changes: 126 additions & 0 deletions
126
keras_nlp/src/models/stable_diffusion_v3/vae_attention.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
# Copyright 2024 The KerasNLP Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import math | ||
|
||
from keras import layers | ||
from keras import ops | ||
|
||
from keras_nlp.src.utils.keras_utils import standardize_data_format | ||
|
||
|
||
class VAEAttention(layers.Layer): | ||
def __init__(self, filters, groups=32, data_format=None, **kwargs): | ||
super().__init__(**kwargs) | ||
self.filters = filters | ||
self.data_format = standardize_data_format(data_format) | ||
gn_axis = -1 if self.data_format == "channels_last" else 1 | ||
|
||
self.group_norm = layers.GroupNormalization( | ||
groups=groups, | ||
axis=gn_axis, | ||
epsilon=1e-6, | ||
dtype=self.dtype_policy, | ||
name="group_norm", | ||
) | ||
self.query_conv2d = layers.Conv2D( | ||
filters, | ||
1, | ||
1, | ||
data_format=self.data_format, | ||
dtype=self.dtype_policy, | ||
name="query_conv2d", | ||
) | ||
self.key_conv2d = layers.Conv2D( | ||
filters, | ||
1, | ||
1, | ||
data_format=self.data_format, | ||
dtype=self.dtype_policy, | ||
name="key_conv2d", | ||
) | ||
self.value_conv2d = layers.Conv2D( | ||
filters, | ||
1, | ||
1, | ||
data_format=self.data_format, | ||
dtype=self.dtype_policy, | ||
name="value_conv2d", | ||
) | ||
self.softmax = layers.Softmax(dtype="float32") | ||
self.output_conv2d = layers.Conv2D( | ||
filters, | ||
1, | ||
1, | ||
data_format=self.data_format, | ||
dtype=self.dtype_policy, | ||
name="output_conv2d", | ||
) | ||
|
||
self.groups = groups | ||
self._inverse_sqrt_filters = 1.0 / math.sqrt(float(filters)) | ||
|
||
def build(self, input_shape): | ||
self.group_norm.build(input_shape) | ||
self.query_conv2d.build(input_shape) | ||
self.key_conv2d.build(input_shape) | ||
self.value_conv2d.build(input_shape) | ||
self.output_conv2d.build(input_shape) | ||
|
||
def call(self, inputs, training=None): | ||
x = self.group_norm(inputs) | ||
query = self.query_conv2d(x) | ||
key = self.key_conv2d(x) | ||
value = self.value_conv2d(x) | ||
|
||
if self.data_format == "channels_first": | ||
query = ops.transpose(query, (0, 2, 3, 1)) | ||
key = ops.transpose(key, (0, 2, 3, 1)) | ||
value = ops.transpose(value, (0, 2, 3, 1)) | ||
shape = ops.shape(inputs) | ||
b = shape[0] | ||
query = ops.reshape(query, (b, -1, self.filters)) | ||
key = ops.reshape(key, (b, -1, self.filters)) | ||
value = ops.reshape(value, (b, -1, self.filters)) | ||
|
||
# Compute attention. | ||
query = ops.multiply( | ||
query, ops.cast(self._inverse_sqrt_filters, query.dtype) | ||
) | ||
# [B, H0 * W0, C], [B, H1 * W1, C] -> [B, H0 * W0, H1 * W1] | ||
attention_scores = ops.einsum("abc,adc->abd", query, key) | ||
attention_scores = ops.cast( | ||
self.softmax(attention_scores), self.compute_dtype | ||
) | ||
# [B, H2 * W2, C], [B, H0 * W0, H1 * W1] -> [B, H1 * W1 ,C] | ||
attention_output = ops.einsum("abc,adb->adc", value, attention_scores) | ||
x = ops.reshape(attention_output, shape) | ||
|
||
x = self.output_conv2d(x) | ||
if self.data_format == "channels_first": | ||
x = ops.transpose(x, (0, 3, 1, 2)) | ||
x = ops.add(x, inputs) | ||
return x | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"filters": self.filters, | ||
"groups": self.groups, | ||
} | ||
) | ||
return config | ||
|
||
def compute_output_shape(self, input_shape): | ||
return input_shape |
177 changes: 177 additions & 0 deletions
177
keras_nlp/src/models/stable_diffusion_v3/vae_image_decoder.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
# Copyright 2024 The KerasNLP Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import keras | ||
from keras import layers | ||
|
||
from keras_nlp.src.models.stable_diffusion_v3.vae_attention import VAEAttention | ||
from keras_nlp.src.utils.keras_utils import standardize_data_format | ||
|
||
|
||
class VAEImageDecoder(keras.Model): | ||
def __init__( | ||
self, | ||
stackwise_num_filters, | ||
stackwise_num_blocks, | ||
output_channels=3, | ||
latent_shape=(None, None, 16), | ||
data_format=None, | ||
dtype=None, | ||
**kwargs, | ||
): | ||
data_format = standardize_data_format(data_format) | ||
gn_axis = -1 if data_format == "channels_last" else 1 | ||
|
||
# === Functional Model === | ||
latent_inputs = layers.Input(shape=latent_shape) | ||
|
||
x = layers.Conv2D( | ||
stackwise_num_filters[0], | ||
3, | ||
1, | ||
padding="same", | ||
data_format=data_format, | ||
dtype=dtype, | ||
name="input_projection", | ||
)(latent_inputs) | ||
x = apply_resnet_block( | ||
x, | ||
stackwise_num_filters[0], | ||
data_format=data_format, | ||
dtype=dtype, | ||
name="input_block0", | ||
) | ||
x = VAEAttention( | ||
stackwise_num_filters[0], | ||
data_format=data_format, | ||
dtype=dtype, | ||
name="input_attention", | ||
)(x) | ||
x = apply_resnet_block( | ||
x, | ||
stackwise_num_filters[0], | ||
data_format=data_format, | ||
dtype=dtype, | ||
name="input_block1", | ||
) | ||
|
||
# Stacks. | ||
for i, filters in enumerate(stackwise_num_filters): | ||
for j in range(stackwise_num_blocks[i]): | ||
x = apply_resnet_block( | ||
x, | ||
filters, | ||
data_format=data_format, | ||
dtype=dtype, | ||
name=f"block{i}_{j}", | ||
) | ||
if i != len(stackwise_num_filters) - 1: | ||
# No upsamling in the last blcok. | ||
x = layers.UpSampling2D( | ||
2, | ||
data_format=data_format, | ||
dtype=dtype, | ||
name=f"upsample_{i}", | ||
)(x) | ||
x = layers.Conv2D( | ||
filters, | ||
3, | ||
1, | ||
padding="same", | ||
data_format=data_format, | ||
dtype=dtype, | ||
name=f"upsample_{i}_conv", | ||
)(x) | ||
|
||
# Ouput block. | ||
x = layers.GroupNormalization( | ||
groups=32, | ||
axis=gn_axis, | ||
epsilon=1e-6, | ||
dtype=dtype, | ||
name="output_norm", | ||
)(x) | ||
x = layers.Activation("swish", dtype=dtype, name="output_activation")(x) | ||
image_outputs = layers.Conv2D( | ||
output_channels, | ||
3, | ||
1, | ||
padding="same", | ||
data_format=data_format, | ||
dtype=dtype, | ||
name="output_projection", | ||
)(x) | ||
super().__init__(inputs=latent_inputs, outputs=image_outputs, **kwargs) | ||
|
||
# === Config === | ||
self.stackwise_num_filters = stackwise_num_filters | ||
self.stackwise_num_blocks = stackwise_num_blocks | ||
self.output_channels = output_channels | ||
self.latent_shape = latent_shape | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"stackwise_num_filters": self.stackwise_num_filters, | ||
"stackwise_num_blocks": self.stackwise_num_blocks, | ||
"output_channels": self.output_channels, | ||
"image_shape": self.latent_shape, | ||
} | ||
) | ||
return config | ||
|
||
|
||
def apply_resnet_block(x, filters, data_format=None, dtype=None, name=None): | ||
data_format = standardize_data_format(data_format) | ||
gn_axis = -1 if data_format == "channels_last" else 1 | ||
input_filters = x.shape[gn_axis] | ||
|
||
residual = x | ||
x = layers.GroupNormalization( | ||
groups=32, axis=gn_axis, epsilon=1e-6, dtype=dtype, name=f"{name}_norm1" | ||
)(x) | ||
x = layers.Activation("swish", dtype=dtype)(x) | ||
x = layers.Conv2D( | ||
filters, | ||
3, | ||
1, | ||
padding="same", | ||
data_format=data_format, | ||
dtype=dtype, | ||
name=f"{name}_conv1", | ||
)(x) | ||
x = layers.GroupNormalization( | ||
groups=32, axis=gn_axis, epsilon=1e-6, dtype=dtype, name=f"{name}_norm2" | ||
)(x) | ||
x = layers.Activation("swish")(x) | ||
x = layers.Conv2D( | ||
filters, | ||
3, | ||
1, | ||
padding="same", | ||
data_format=data_format, | ||
dtype=dtype, | ||
name=f"{name}_conv2", | ||
)(x) | ||
if input_filters != filters: | ||
residual = layers.Conv2D( | ||
filters, | ||
1, | ||
1, | ||
data_format=data_format, | ||
dtype=dtype, | ||
name=f"{name}_residual_projection", | ||
)(residual) | ||
x = layers.Add(dtype=dtype)([residual, x]) | ||
return x |