Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[do not merge] testing rotary embedding + torch.compile #9321

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,8 +550,11 @@ def get_1d_rotary_pos_embed(
pos = torch.from_numpy(pos) # type: ignore # [S]

theta = theta * ntk_factor
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor # [D/2]
freqs = freqs.to(pos.device)
freqs = (
1.0
/ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
/ linear_factor
) # [D/2]
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
if use_real and repeat_interleave_real:
# flux, hunyuan-dit, cogvideox
Expand Down
192 changes: 99 additions & 93 deletions src/diffusers/models/transformers/transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
from torch.profiler import record_function


@maybe_allow_in_graph
Expand Down Expand Up @@ -439,109 +440,114 @@ def forward(
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.x_embedder(hidden_states)
with record_function(" x_embedder"):
hidden_states = self.x_embedder(hidden_states)

timestep = timestep.to(hidden_states.dtype) * 1000
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
else:
guidance = None
temb = (
self.time_text_embed(timestep, pooled_projections)
if guidance is None
else self.time_text_embed(timestep, guidance, pooled_projections)
)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)

if txt_ids.ndim == 3:
logger.warning(
"Passing `txt_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
txt_ids = txt_ids[0]
if img_ids.ndim == 3:
logger.warning(
"Passing `img_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
with record_function(" time_text_embed"):
timestep = timestep.to(hidden_states.dtype) * 1000
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
else:
guidance = None
temb = (
self.time_text_embed(timestep, pooled_projections)
if guidance is None
else self.time_text_embed(timestep, guidance, pooled_projections)
)
img_ids = img_ids[0]
ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids)

for index_block, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)

return custom_forward

ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
encoder_hidden_states = self.context_embedder(encoder_hidden_states)

with record_function(" pos_embeds (rotary)"):
if txt_ids.ndim == 3:
logger.warning(
"Passing `txt_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)

else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
txt_ids = txt_ids[0]
if img_ids.ndim == 3:
logger.warning(
"Passing `img_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
img_ids = img_ids[0]
ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids)

with record_function(" blocks"):
for index_block, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)

return custom_forward

ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)

# controlnet residual
if controlnet_block_samples is not None:
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)

# controlnet residual
if controlnet_block_samples is not None:
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]

hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)

for index_block, block in enumerate(self.single_transformer_blocks):
if self.training and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)

return custom_forward

ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
with record_function(" single blocks"):
for index_block, block in enumerate(self.single_transformer_blocks):
if self.training and self.gradient_checkpointing:

else:
hidden_states = block(
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)

# controlnet residual
if controlnet_single_block_samples is not None:
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+ controlnet_single_block_samples[index_block // interval_control]
)
return custom_forward

ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)

else:
hidden_states = block(
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)

# controlnet residual
if controlnet_single_block_samples is not None:
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+ controlnet_single_block_samples[index_block // interval_control]
)

hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]

Expand Down
38 changes: 22 additions & 16 deletions src/diffusers/pipelines/flux/pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import FluxPipelineOutput

from torch.profiler import record_function


if is_torch_xla_available():
import torch_xla.core.xla_model as xm
Expand Down Expand Up @@ -716,21 +718,24 @@ def __call__(
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be another candidate for syncs because we're iterating over a tensor here. See this internal link for more details.


noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
with record_function("transformer"):

noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]

# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
with record_function("scheduler step"):
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]

if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
Expand All @@ -757,10 +762,11 @@ def __call__(
image = latents

else:
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
with record_function("decode latent"):
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)

# Offload all models
self.maybe_free_model_hooks()
Expand Down
20 changes: 11 additions & 9 deletions src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, logging
from .scheduling_utils import SchedulerMixin

from torch.profiler import record_function

logger = logging.get_logger(__name__) # pylint: disable=invalid-name

Expand Down Expand Up @@ -284,20 +284,22 @@ def step(
" one of the `scheduler.timesteps` as a timestep."
),
)

if self.step_index is None:
self._init_step_index(timestep)
with record_function(" _init_step_index"):
if self.step_index is None:
self._init_step_index(timestep)

# Upcast to avoid precision issues when computing prev_sample
sample = sample.to(torch.float32)

with record_function(" get sigma and sigma_next"):
sigma = self.sigmas[self.step_index]
sigma_next = self.sigmas[self.step_index + 1]

sigma = self.sigmas[self.step_index]
sigma_next = self.sigmas[self.step_index + 1]

prev_sample = sample + (sigma_next - sigma) * model_output
with record_function(" get prev_sample"):
prev_sample = sample + (sigma_next - sigma) * model_output

# Cast sample back to model compatible dtype
prev_sample = prev_sample.to(model_output.dtype)
prev_sample = prev_sample.to(model_output.dtype)

# upon completion increase step index by one
self._step_index += 1
Expand Down
Loading