Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add VAEImageDecoder for StableDiffusionV3 #1796

Merged
merged 2 commits into from
Aug 28, 2024

Conversation

james77777778
Copy link
Collaborator

Numerics check:
https://colab.research.google.com/drive/1YsWvZ0NBINDgdqipsldso1Y1NKJUDkUf?usp=sharing

Future works:

  • Implement CLIPPreprocessor
  • Wrap CLIPTextEncoder and T5XXLTextEncoder for the use in StableDiffusionV3
  • Implement VAEImageDecoder
  • Implement MMDiT
  • Implement StableDiffusionV3 (inference model)

@divyashreepathihalli @mattdangerw @SamanehSaadat

Copy link
Collaborator

@divyashreepathihalli divyashreepathihalli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!! LGTM!

@divyashreepathihalli divyashreepathihalli added the kokoro:force-run Runs Tests on GPU label Aug 26, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Aug 26, 2024
Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm! minor nits and a few design notes I don't think we need to solve in this PR

keras_nlp/src/models/stable_diffusion_v3/vae_attention.py Outdated Show resolved Hide resolved
keras_nlp/src/models/stable_diffusion_v3/vae_attention.py Outdated Show resolved Hide resolved
from keras_nlp.src.utils.keras_utils import standardize_data_format


class VAEImageDecoder(Backbone):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that with pali gemma, our "backbone" contains all the weights needed from a pre-trained model. So in that case the image encoder and text decoder collectively form a single backbone class.

We should discuss the high level flows that we want as we go, but our current approach is...

  • StableDiffusionBackbone should contain all the pretrained weights for using the entire model without a specific task setup. This can come from stitching other backbones/sub models together. No preprocessing.
  • StableDiffusion[TaskName] would wrap the backbone with a setup for a particular task. Preprocessing included. Ideally allowing both find-tuning and inference, but that would depend on the task at hand. For stable diffusion the main task is definitely text to image, though I'm not sure what we should call that. StableDiffusionImageGenerator is kinda long.

Copy link
Collaborator Author

@james77777778 james77777778 Aug 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was unsure of how we wanted to assemble these encoders and the decoder, so I made them as a Backbone first.

We should discuss the high level flows that we want as we go, but our current approach is...

Got it. Will make the encoders and decoder as a keras.Model to follow that pattern.

I think the task name, ImageGenerator, is a bit ambiguous. Maybe we should call it TextToImage instead?
It is also possible to use SD3 for ImageToImage and Inpaint tasks.
https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/stable_diffusion_3

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TextToImage sounds fine to me. Shorter.

Got it. Will make the encoders and decoder as a keras.Model to follow that pattern.

I suspect we still have more to figure out here. For these big "composite models" with lots of sub-components, it would be good if we allowed loading sub models individually some how. E.g. load the text encoder of a T5 model, or just the image encoder of PaliGemma. That's a valid use case, that fit's with the flexibility we'd like to shoot for, and we don't support it today. But a probably for another PR I think.

@divyashreepathihalli divyashreepathihalli added the kokoro:force-run Runs Tests on GPU label Aug 27, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Aug 27, 2024
@mattdangerw mattdangerw merged commit 536474a into keras-team:keras-hub Aug 28, 2024
10 checks passed
@james77777778 james77777778 deleted the add-vae-decoder branch August 29, 2024 02:26
mattdangerw pushed a commit to mattdangerw/keras-nlp that referenced this pull request Sep 10, 2024
* Add `VAEImageDecoder` for StableDiffusionV3

* Use `keras.Model` for `VAEImageDecoder` and follows the coding style in `VAEAttention`
mattdangerw pushed a commit that referenced this pull request Sep 11, 2024
* Add `VAEImageDecoder` for StableDiffusionV3

* Use `keras.Model` for `VAEImageDecoder` and follows the coding style in `VAEAttention`
mattdangerw pushed a commit that referenced this pull request Sep 13, 2024
* Add `VAEImageDecoder` for StableDiffusionV3

* Use `keras.Model` for `VAEImageDecoder` and follows the coding style in `VAEAttention`
mattdangerw pushed a commit that referenced this pull request Sep 17, 2024
* Add `VAEImageDecoder` for StableDiffusionV3

* Use `keras.Model` for `VAEImageDecoder` and follows the coding style in `VAEAttention`
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants