From 3435ba2a9f686e8adff8dbaaf5ed2746c2bd0327 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A3=B9=E7=BB=B4?= Date: Thu, 18 Dec 2025 11:30:37 +0800 Subject: [PATCH 1/6] Add z-image-omni-base implementation --- src/diffusers/__init__.py | 4 + src/diffusers/models/__init__.py | 1 + src/diffusers/models/transformers/__init__.py | 1 + .../transformers/transformer_z_image_omni.py | 1005 +++++++++++++++++ src/diffusers/pipelines/__init__.py | 4 +- src/diffusers/pipelines/z_image/__init__.py | 3 +- .../z_image/pipeline_z_image_omni.py | 733 ++++++++++++ 7 files changed, 1748 insertions(+), 3 deletions(-) create mode 100644 src/diffusers/models/transformers/transformer_z_image_omni.py create mode 100644 src/diffusers/pipelines/z_image/pipeline_z_image_omni.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 29a38b43120a..0d04e78c4b1e 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -279,6 +279,7 @@ "WanTransformer3DModel", "WanVACETransformer3DModel", "ZImageTransformer2DModel", + "ZImageOmniTransformer2DModel", "attention_backend", ] ) @@ -668,6 +669,7 @@ "WuerstchenPriorPipeline", "ZImageImg2ImgPipeline", "ZImagePipeline", + "ZImageOmniPipeline", ] ) @@ -1013,6 +1015,7 @@ WanTransformer3DModel, WanVACETransformer3DModel, ZImageTransformer2DModel, + ZImageOmniTransformer2DModel, attention_backend, ) from .modular_pipelines import ComponentsManager, ComponentSpec, ModularPipeline, ModularPipelineBlocks @@ -1371,6 +1374,7 @@ WuerstchenPriorPipeline, ZImageImg2ImgPipeline, ZImagePipeline, + ZImageOmniPipeline, ) try: diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 29d8b0b5a55d..87d5c0ef95af 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -116,6 +116,7 @@ _import_structure["transformers.transformer_wan_animate"] = ["WanAnimateTransformer3DModel"] _import_structure["transformers.transformer_wan_vace"] = ["WanVACETransformer3DModel"] _import_structure["transformers.transformer_z_image"] = ["ZImageTransformer2DModel"] + _import_structure["transformers.transformer_z_image_omni"] = ["ZImageOmniTransformer2DModel"] _import_structure["unets.unet_1d"] = ["UNet1DModel"] _import_structure["unets.unet_2d"] = ["UNet2DModel"] _import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"] diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index a42f6b2716e1..71f352fc80aa 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -48,3 +48,4 @@ from .transformer_wan_animate import WanAnimateTransformer3DModel from .transformer_wan_vace import WanVACETransformer3DModel from .transformer_z_image import ZImageTransformer2DModel + from .transformer_z_image_omni import ZImageOmniTransformer2DModel diff --git a/src/diffusers/models/transformers/transformer_z_image_omni.py b/src/diffusers/models/transformers/transformer_z_image_omni.py new file mode 100644 index 000000000000..65bf141f9079 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_z_image_omni.py @@ -0,0 +1,1005 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import einops +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils.rnn import pad_sequence + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...models.attention_processor import Attention +from ...models.modeling_utils import ModelMixin +from ...models.normalization import RMSNorm +from ...utils.torch_utils import maybe_allow_in_graph +from ..attention_dispatch import dispatch_attention_fn +from ..modeling_outputs import Transformer2DModelOutput + + +ADALN_EMBED_DIM = 256 +SEQ_MULTI_OF = 32 + + +class TimestepEmbedder(nn.Module): + def __init__(self, out_size, mid_size=None, frequency_embedding_size=256): + super().__init__() + if mid_size is None: + mid_size = out_size + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, mid_size, bias=True), + nn.SiLU(), + nn.Linear(mid_size, out_size, bias=True), + ) + + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + with torch.amp.autocast("cuda", enabled=False): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + weight_dtype = self.mlp[0].weight.dtype + compute_dtype = getattr(self.mlp[0], "compute_dtype", None) + if weight_dtype.is_floating_point: + t_freq = t_freq.to(weight_dtype) + elif compute_dtype is not None: + t_freq = t_freq.to(compute_dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +class ZSingleStreamAttnProcessor: + """ + Processor for Z-Image single stream attention that adapts the existing Attention class to match the behavior of the + original Z-ImageAttention module. + """ + + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "ZSingleStreamAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + freqs_cis: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) + + # Apply Norms + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE + def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + with torch.amp.autocast("cuda", enabled=False): + x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(2) + x_out = torch.view_as_real(x * freqs_cis).flatten(3) + return x_out.type_as(x_in) # todo + + if freqs_cis is not None: + query = apply_rotary_emb(query, freqs_cis) + key = apply_rotary_emb(key, freqs_cis) + + # Cast to correct dtype + dtype = query.dtype + query, key = query.to(dtype), key.to(dtype) + + # From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len] + if attention_mask is not None and attention_mask.ndim == 2: + attention_mask = attention_mask[:, None, None, :] + + # Compute joint attention + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + + # Reshape back + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(dtype) + + output = attn.to_out[0](hidden_states) + if len(attn.to_out) > 1: # dropout + output = attn.to_out[1](output) + + return output + + +class FeedForward(nn.Module): + def __init__(self, dim: int, hidden_dim: int): + super().__init__() + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def _forward_silu_gating(self, x1, x3): + return F.silu(x1) * x3 + + def forward(self, x): + return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x))) + + +@maybe_allow_in_graph +class ZImageTransformerBlock(nn.Module): + def __init__( + self, + layer_id: int, + dim: int, + n_heads: int, + n_kv_heads: int, + norm_eps: float, + qk_norm: bool, + modulation=True, + ): + super().__init__() + self.dim = dim + self.head_dim = dim // n_heads + + # Refactored to use diffusers Attention with custom processor + # Original Z-Image params: dim, n_heads, n_kv_heads, qk_norm + self.attention = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=dim // n_heads, + heads=n_heads, + qk_norm="rms_norm" if qk_norm else None, + eps=1e-5, + bias=False, + out_bias=False, + processor=ZSingleStreamAttnProcessor(), + ) + + self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8)) + self.layer_id = layer_id + + self.attention_norm1 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) + + self.attention_norm2 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm2 = RMSNorm(dim, eps=norm_eps) + + self.modulation = modulation + if modulation: + self.adaLN_modulation = nn.Sequential(nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True)) + + def forward( + self, + x: torch.Tensor, + attn_mask: torch.Tensor, + freqs_cis: torch.Tensor, + adaln_input: Optional[torch.Tensor] = None, + noise_mask: Optional[torch.Tensor] = None, + adaln_noisy: Optional[torch.Tensor] = None, + adaln_clean: Optional[torch.Tensor] = None, + ): + if self.modulation: + if noise_mask is not None and adaln_noisy is not None and adaln_clean is not None: + # Per-token modulation based on noise_mask + # noise_mask: (batch, seq_len), 1 for noisy, 0 for clean + # adaln_noisy: (batch, embed_dim) for noisy tokens + # adaln_clean: (batch, embed_dim) for clean tokens + batch_size, seq_len = x.shape[0], x.shape[1] + + # Generate modulation for noisy and clean tokens separately + mod_noisy = self.adaLN_modulation(adaln_noisy) # (batch, 4*dim) + mod_clean = self.adaLN_modulation(adaln_clean) # (batch, 4*dim) + + # Split into scale and gate + scale_msa_noisy, gate_msa_noisy, scale_mlp_noisy, gate_mlp_noisy = mod_noisy.chunk(4, dim=1) + scale_msa_clean, gate_msa_clean, scale_mlp_clean, gate_mlp_clean = mod_clean.chunk(4, dim=1) + + # Apply tanh to gates + gate_msa_noisy, gate_mlp_noisy = gate_msa_noisy.tanh(), gate_mlp_noisy.tanh() + gate_msa_clean, gate_mlp_clean = gate_msa_clean.tanh(), gate_mlp_clean.tanh() + + # Add 1 to scales + scale_msa_noisy, scale_mlp_noisy = 1.0 + scale_msa_noisy, 1.0 + scale_mlp_noisy + scale_msa_clean, scale_mlp_clean = 1.0 + scale_msa_clean, 1.0 + scale_mlp_clean + + # Expand to (batch, seq_len, dim) and select based on noise_mask + noise_mask_expanded = noise_mask.unsqueeze(-1) # (batch, seq_len, 1) + scale_msa = torch.where(noise_mask_expanded == 1, + scale_msa_noisy.unsqueeze(1).expand(-1, seq_len, -1), + scale_msa_clean.unsqueeze(1).expand(-1, seq_len, -1)) + scale_mlp = torch.where(noise_mask_expanded == 1, + scale_mlp_noisy.unsqueeze(1).expand(-1, seq_len, -1), + scale_mlp_clean.unsqueeze(1).expand(-1, seq_len, -1)) + gate_msa = torch.where(noise_mask_expanded == 1, + gate_msa_noisy.unsqueeze(1).expand(-1, seq_len, -1), + gate_msa_clean.unsqueeze(1).expand(-1, seq_len, -1)) + gate_mlp = torch.where(noise_mask_expanded == 1, + gate_mlp_noisy.unsqueeze(1).expand(-1, seq_len, -1), + gate_mlp_clean.unsqueeze(1).expand(-1, seq_len, -1)) + else: + # Original global modulation + assert adaln_input is not None + scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2) + gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh() + scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp + + # Attention block + attn_out = self.attention( + self.attention_norm1(x) * scale_msa, attention_mask=attn_mask, freqs_cis=freqs_cis + ) + x = x + gate_msa * self.attention_norm2(attn_out) + + # FFN block + x = x + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(x) * scale_mlp)) + else: + # Attention block + attn_out = self.attention(self.attention_norm1(x), attention_mask=attn_mask, freqs_cis=freqs_cis) + x = x + self.attention_norm2(attn_out) + + # FFN block + x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x))) + + return x + + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True), + ) + + def forward(self, x, c=None, noise_mask=None, c_noisy=None, c_clean=None): + if noise_mask is not None and c_noisy is not None and c_clean is not None: + # Per-token modulation based on noise_mask + batch_size, seq_len = x.shape[0], x.shape[1] + scale_noisy = 1.0 + self.adaLN_modulation(c_noisy) # (batch, hidden_size) + scale_clean = 1.0 + self.adaLN_modulation(c_clean) # (batch, hidden_size) + + # Select based on noise_mask + noise_mask_expanded = noise_mask.unsqueeze(-1) # (batch, seq_len, 1) + scale = torch.where(noise_mask_expanded == 1, + scale_noisy.unsqueeze(1).expand(-1, seq_len, -1), + scale_clean.unsqueeze(1).expand(-1, seq_len, -1)) + else: + # Original global modulation + assert c is not None, "Either c or (c_noisy, c_clean) must be provided" + scale = 1.0 + self.adaLN_modulation(c) + scale = scale.unsqueeze(1) + x = self.norm_final(x) * scale + x = self.linear(x) + return x + + +class RopeEmbedder: + def __init__( + self, + theta: float = 256.0, + axes_dims: List[int] = (16, 56, 56), + axes_lens: List[int] = (64, 128, 128), + ): + self.theta = theta + self.axes_dims = axes_dims + self.axes_lens = axes_lens + assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length" + self.freqs_cis = None + + @staticmethod + def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256.0): + with torch.device("cpu"): + freqs_cis = [] + for i, (d, e) in enumerate(zip(dim, end)): + freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) + timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) + freqs = torch.outer(timestep, freqs).float() + freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64 + freqs_cis.append(freqs_cis_i) + + return freqs_cis + + def __call__(self, ids: torch.Tensor): + assert ids.ndim == 2 + assert ids.shape[-1] == len(self.axes_dims) + device = ids.device + + if self.freqs_cis is None: + self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) + self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] + else: + # Ensure freqs_cis are on the same device as ids + if self.freqs_cis[0].device != device: + self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] + + result = [] + for i in range(len(self.axes_dims)): + index = ids[:, i] + result.append(self.freqs_cis[i][index]) + return torch.cat(result, dim=-1) + + +class ZImageOmniTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): + _supports_gradient_checkpointing = True + _no_split_modules = ["ZImageTransformerBlock"] + _repeated_blocks = ["ZImageTransformerBlock"] + _skip_layerwise_casting_patterns = ["t_embedder", "cap_embedder"] # precision sensitive layers + + @register_to_config + def __init__( + self, + all_patch_size=(2,), + all_f_patch_size=(1,), + in_channels=16, + dim=3840, + n_layers=30, + n_refiner_layers=2, + n_heads=30, + n_kv_heads=30, + norm_eps=1e-5, + qk_norm=True, + cap_feat_dim=2560, + siglip_feat_dim=1152, + rope_theta=256.0, + t_scale=1000.0, + axes_dims=[32, 48, 48], + axes_lens=[1024, 512, 512], + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels + self.all_patch_size = all_patch_size + self.all_f_patch_size = all_f_patch_size + self.dim = dim + self.n_heads = n_heads + + self.rope_theta = rope_theta + self.t_scale = t_scale + self.gradient_checkpointing = False + + assert len(all_patch_size) == len(all_f_patch_size) + + all_x_embedder = {} + all_final_layer = {} + for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)): + x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * in_channels, dim, bias=True) + all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder + + final_layer = FinalLayer(dim, patch_size * patch_size * f_patch_size * self.out_channels) + all_final_layer[f"{patch_size}-{f_patch_size}"] = final_layer + + self.all_x_embedder = nn.ModuleDict(all_x_embedder) + self.all_final_layer = nn.ModuleDict(all_final_layer) + self.noise_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + 1000 + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + modulation=True, + ) + for layer_id in range(n_refiner_layers) + ] + ) + self.context_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + modulation=False, + ) + for layer_id in range(n_refiner_layers) + ] + ) + self.siglip_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + 2000 + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + modulation=False, + ) + for layer_id in range(n_refiner_layers) + ] + ) + self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024) + self.cap_embedder = nn.Sequential(RMSNorm(cap_feat_dim, eps=norm_eps), nn.Linear(cap_feat_dim, dim, bias=True)) + self.siglip_embedder = nn.Sequential(RMSNorm(siglip_feat_dim, eps=norm_eps), nn.Linear(siglip_feat_dim, dim, bias=True)) + + self.x_pad_token = nn.Parameter(torch.empty((1, dim))) + self.cap_pad_token = nn.Parameter(torch.empty((1, dim))) + self.siglip_pad_token = nn.Parameter(torch.empty((1, dim))) + + self.layers = nn.ModuleList( + [ + ZImageTransformerBlock(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm) + for layer_id in range(n_layers) + ] + ) + head_dim = dim // n_heads + assert head_dim == sum(axes_dims) + self.axes_dims = axes_dims + self.axes_lens = axes_lens + + self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens) + + def unpatchify( + self, + unified: List[torch.Tensor], + size: List[Tuple], + patch_size, + f_patch_size, + x_pos_offsets, + ) -> List[torch.Tensor]: + + pH = pW = patch_size + pF = f_patch_size + bsz = len(unified) + assert len(size) == bsz + + x = [] + for i in range(bsz): + x_item = [] + unified_x = unified[i][x_pos_offsets[i][0]:x_pos_offsets[i][1]] + cu_len = 0 + for j in range(len(size[i])): + if size[i][j] is None: + x_item.append(None) + ori_len = 0 + pad_len = SEQ_MULTI_OF + cu_len += (pad_len + ori_len) + else: + F, H, W = size[i][j] + ori_len = (F // pF) * (H // pH) * (W // pW) # without padding + pad_len = (-ori_len) % SEQ_MULTI_OF + # assert ori_len + pad_len == unified_x.shape[0], f"Batch item {i}, patch {j}: ori_len {ori_len} + pad_len {pad_len} != unified_x.shape[0] {unified_x.shape[0]}" + x_item.append(einops.rearrange( + unified_x[cu_len:cu_len + ori_len].view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels), + "f h w pf ph pw c -> c (f pf) (h ph) (w pw)", + )) + cu_len += (ori_len + pad_len) + x.append(x_item[-1]) + return x + + @staticmethod + def create_coordinate_grid(size, start=None, device=None): + if start is None: + start = (0 for _ in size) + + axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)] + grids = torch.meshgrid(axes, indexing="ij") + return torch.stack(grids, dim=-1) + + def patchify_and_embed( + self, + all_x, + all_cap_feats, + all_siglip_feats, + patch_size: int, + f_patch_size: int, + images_noise_mask: List[List[int]] + ): + + bsz = len(all_x) + pH = pW = patch_size + pF = f_patch_size + device = all_x[0][-1].device + dtype = all_x[0][-1].dtype + + all_x_padded = [] + all_x_size = [] + all_x_pos_ids = [] + all_x_pad_mask = [] + all_x_len = [] + all_x_noise_mask = [] + all_cap_padded_feats = [] + all_cap_pos_ids = [] + all_cap_pad_mask = [] + all_cap_len = [] + all_cap_noise_mask = [] + all_siglip_padded_feats = [] + all_siglip_pos_ids = [] + all_siglip_pad_mask = [] + all_siglip_len = [] + all_siglip_noise_mask = [] + + for i in range(bsz): + # process caption + num_images = len(all_x[i]) + cap_padded_feats = [] + cap_item_cu_len = 1 + cap_start_pos = [] + cap_end_pos = [] + cap_padded_pos_ids = [] + cap_pad_mask = [] + cap_len = [] + cap_noise_mask = [] + for j, cap_item in enumerate(all_cap_feats[i]): + cap_item_ori_len = len(cap_item) + cap_item_padding_len = (-cap_item_ori_len) % SEQ_MULTI_OF + cap_len.append(cap_item_ori_len + cap_item_padding_len) + cap_item_padding_pos_ids = ( + self.create_coordinate_grid( + size=(1, 1, 1), + # start=(cap_item_cu_len, 0, 0), + start=(0, 0, 0), + device=device, + ) + .flatten(0, 2) + .repeat(cap_item_padding_len, 1) + ) + cap_start_pos.append(cap_item_cu_len) + # cap_item_cu_len += 1 # for the padding tokens + cap_item_ori_pos_ids = self.create_coordinate_grid( + size=(cap_item_ori_len, 1, 1), + start=(cap_item_cu_len, 0, 0), + device=device, + ).flatten(0, 2) + cap_padded_pos_ids.append(cap_item_ori_pos_ids) + cap_padded_pos_ids.append(cap_item_padding_pos_ids) + cap_item_cu_len += cap_item_ori_len # for the caption tokens + cap_end_pos.append(cap_item_cu_len) + cap_item_cu_len += 2 # for the image vae tokens and siglip tokens + cap_pad_mask.append(torch.zeros((cap_item_ori_len,), dtype=torch.bool, device=device)) + cap_pad_mask.append(torch.ones((cap_item_padding_len,), dtype=torch.bool, device=device)) + cap_item_padded_feat = torch.cat([cap_item, cap_item[-1:].repeat(cap_item_padding_len, 1)], dim=0) + cap_padded_feats.append(cap_item_padded_feat) + if j < len(images_noise_mask[i]): + cap_noise_mask.extend([images_noise_mask[i][j]] * (cap_item_ori_len + cap_item_padding_len)) + else: + cap_noise_mask.extend([1] * (cap_item_ori_len + cap_item_padding_len)) + + all_cap_noise_mask.append(cap_noise_mask) + cap_padded_pos_ids = torch.cat(cap_padded_pos_ids, dim=0) + all_cap_pos_ids.append(cap_padded_pos_ids) + cap_pad_mask = torch.cat(cap_pad_mask, dim=0) + all_cap_pad_mask.append(cap_pad_mask) + all_cap_padded_feats.append(torch.cat(cap_padded_feats, dim=0)) + all_cap_len.append(cap_len) + # print(f"> all_cap_feats[i]: {all_cap_feats[i].shape}", flush=True) + + # process data + x_padded = [] + x_padded_pos_ids = [] + x_pad_mask = [] + x_len = [] + x_size = [] + x_noise_mask = [] + for j, x_item in enumerate(all_x[i]): + if x_item is not None: + # print(i, j, flush=True) + # print(f"x_item: {x_item.shape}", flush=True) + C, F, H, W = x_item.size() + x_size.append((F, H, W)) + F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW + + x_item = x_item.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) + x_item = einops.rearrange(x_item, "c f pf h ph w pw -> (f h w) (pf ph pw c)") + + x_item_ori_len = len(x_item) + x_item_padding_len = (-x_item_ori_len) % SEQ_MULTI_OF + x_len.append(x_item_ori_len + x_item_padding_len) + # padded_pos_ids + x_item_padding_pos_ids = ( + self.create_coordinate_grid( + size=(1, 1, 1), + # start=(cap_start_pos[j], 0, 0), + start=(0, 0, 0), + device=device, + ) + .flatten(0, 2) + .repeat(x_item_padding_len, 1) + ) + x_item_ori_pos_ids = self.create_coordinate_grid( + size=(F_tokens, H_tokens, W_tokens), start=(cap_end_pos[j], 0, 0), device=device + ).flatten(0, 2) + x_padded_pos_ids.append(x_item_ori_pos_ids) + x_padded_pos_ids.append(x_item_padding_pos_ids) + + x_pad_mask.append(torch.zeros((x_item_ori_len,), dtype=torch.bool, device=device)) + x_pad_mask.append(torch.ones((x_item_padding_len,), dtype=torch.bool, device=device)) + x_item_padded_feat = torch.cat([x_item, x_item[-1:].repeat(x_item_padding_len, 1)], dim=0) + x_padded.append(x_item_padded_feat) + x_noise_mask.extend([images_noise_mask[i][j]] * (x_item_ori_len + x_item_padding_len)) + else: + x_pad_dim = 64 + x_item_ori_len = 0 + x_item_padding_len = SEQ_MULTI_OF + x_size.append(None) + x_item_padding_pos_ids = ( + self.create_coordinate_grid( + size=(1, 1, 1), + # start=(cap_start_pos[j], 0, 0), + start=(0, 0, 0), + device=device, + ) + .flatten(0, 2) + .repeat(x_item_padding_len, 1) + ) + x_len.append(x_item_ori_len + x_item_padding_len) + x_padded_pos_ids.append(x_item_padding_pos_ids) + x_pad_mask.append(torch.ones((x_item_padding_len,), dtype=torch.bool, device=device)) + x_padded.append(torch.zeros((x_item_padding_len, x_pad_dim), dtype=dtype, device=device)) + x_noise_mask.extend([images_noise_mask[i][j]] * x_item_padding_len) + + all_x_noise_mask.append(x_noise_mask) + all_x_size.append(x_size) + x_padded_pos_ids = torch.cat(x_padded_pos_ids, dim=0) + all_x_pos_ids.append(x_padded_pos_ids) + x_pad_mask = torch.cat(x_pad_mask, dim=0) + all_x_pad_mask.append(x_pad_mask) + all_x_padded.append(torch.cat(x_padded, dim=0)) + all_x_len.append(x_len) + # print(f"> all_x[i]: {all_x[i].shape}", flush=True) + + # process siglip_feats + if all_siglip_feats[i] is None: + all_siglip_len.append([0 for j in range(num_images)]) + all_siglip_padded_feats.append(None) + else: + sig_padded_feats = [] + sig_padded_pos_ids = [] + sig_pad_mask = [] + sig_len = [] + sig_noise_mask = [] + for j, sig_item in enumerate(all_siglip_feats[i]): + if sig_item is not None: + sig_H, sig_W, sig_C = sig_item.size() + sig_H_tokens, sig_W_tokens, sig_F_tokens = sig_H, sig_W, 1 + + sig_item = sig_item.view(sig_C, sig_F_tokens, 1, sig_H_tokens, 1, sig_W_tokens, 1) + sig_item = einops.rearrange(sig_item, "c f pf h ph w pw -> (f h w) (pf ph pw c)") + + sig_item_ori_len = len(sig_item) + sig_item_padding_len = (-sig_item_ori_len) % SEQ_MULTI_OF + sig_len.append(sig_item_ori_len + sig_item_padding_len) + # padded_pos_ids + sig_item_padding_pos_ids = ( + self.create_coordinate_grid( + size=(1, 1, 1), + # start=(cap_start_pos[j], 0, 0), + start=(0, 0, 0), + device=device, + ) + .flatten(0, 2) + .repeat(sig_item_padding_len, 1) + ) + sig_item_ori_pos_ids = self.create_coordinate_grid( + size=(sig_F_tokens, sig_H_tokens, sig_W_tokens), start=(cap_end_pos[j] + 1, 0, 0), device=device + ) + sig_item_ori_pos_ids[..., 1] = sig_item_ori_pos_ids[..., 1] / (sig_H_tokens - 1) * (x_size[j][1] - 1) + sig_item_ori_pos_ids[..., 2] = sig_item_ori_pos_ids[..., 2] / (sig_W_tokens - 1) * (x_size[j][2] - 1) + sig_item_ori_pos_ids = sig_item_ori_pos_ids.flatten(0, 2) + sig_padded_pos_ids.append(sig_item_ori_pos_ids) + sig_padded_pos_ids.append(sig_item_padding_pos_ids) + + sig_pad_mask.append(torch.zeros((sig_item_ori_len,), dtype=torch.bool, device=device)) + sig_pad_mask.append(torch.ones((sig_item_padding_len,), dtype=torch.bool, device=device)) + sig_item_padded_feat = torch.cat([sig_item, sig_item[-1:].repeat(sig_item_padding_len, 1)], dim=0) + sig_padded_feats.append(sig_item_padded_feat) + sig_noise_mask.extend([images_noise_mask[i][j]] * (sig_item_ori_len + sig_item_padding_len)) + else: + sig_pad_dim = 1152 + sig_item_ori_len = 0 + sig_item_padding_len = SEQ_MULTI_OF + sig_item_padding_pos_ids = ( + self.create_coordinate_grid( + size=(1, 1, 1), + # start=(cap_start_pos[j], 0, 0), + start=(0, 0, 0), + device=device, + ) + .flatten(0, 2) + .repeat(sig_item_padding_len, 1) + ) + sig_padded_pos_ids.append(sig_item_padding_pos_ids) + sig_pad_mask.append(torch.ones((sig_item_padding_len,), dtype=torch.bool, device=device)) + sig_padded_feats.append(torch.zeros((sig_item_padding_len, sig_pad_dim), dtype=dtype, device=device)) + sig_noise_mask.extend([images_noise_mask[i][j]] * sig_item_padding_len) + + all_siglip_noise_mask.append(sig_noise_mask) + sig_padded_pos_ids = torch.cat(sig_padded_pos_ids, dim=0) + all_siglip_pos_ids.append(sig_padded_pos_ids) + sig_pad_mask = torch.cat(sig_pad_mask, dim=0) + all_siglip_pad_mask.append(sig_pad_mask) + all_siglip_padded_feats.append(torch.cat(sig_padded_feats, dim=0)) + all_siglip_len.append(sig_len) + # print(f"> all_siglip_feats[i]: {all_siglip_feats[i].shape}", flush=True) + + all_x_pos_offsets = [] + for i in range(bsz): + start = sum(all_cap_len[i]) + end = start + sum(all_x_len[i]) + all_x_pos_offsets.append((start, end)) + assert all_x_padded[i].shape[0] + all_cap_padded_feats[i].shape[0] == sum(all_cap_len[i]) + sum(all_x_len[i]), f"Batch item {i}: x length {all_x_padded[i].shape[0]} + cap length {all_cap_padded_feats[i].shape[0]} != sum(all_cap_len[i]) + sum(all_x_len[i]) {sum(all_cap_len[i]) + sum(all_x_len[i])}" + + return ( + all_x_padded, + all_cap_padded_feats, + all_siglip_padded_feats, + all_x_size, + all_x_pos_ids, + all_cap_pos_ids, + all_siglip_pos_ids, + all_x_pad_mask, + all_cap_pad_mask, + all_siglip_pad_mask, + all_x_pos_offsets, + all_x_noise_mask, + all_cap_noise_mask, + all_siglip_noise_mask, + ) + + def forward( + self, + x: List[torch.Tensor], + t, + cap_feats: List[List[torch.Tensor]], + cond_latents: List[List[torch.Tensor]], + siglip_feats: List[List[torch.Tensor]], + patch_size=2, + f_patch_size=1, + return_dict: bool = True, + ): + assert patch_size in self.all_patch_size + assert f_patch_size in self.all_f_patch_size + + bsz = len(x) + device = x[0].device + t = torch.cat([t, torch.ones_like(t, dtype=t.dtype, device=device)], dim=0) # (N, D) -> (2N, D) + # t = torch.cat([t, t], dim=0) # (N, D) -> (2N, D) + t = t * self.t_scale + t = self.t_embedder(t) # (2N, embed_dim): first N are original t, last N are t=1 + + # Split t into noisy and clean embeddings + t_noisy = t[:bsz] # (bsz, embed_dim) - original t + t_clean = t[bsz:] # (bsz, embed_dim) - t=1 + + x = [cond_latents[i] + [x[i]] for i in range(bsz)] + image_noise_mask = [[0] * (len(x[i]) - 1) + [1] for i in range(bsz)] + # print(len(x[0]), len(cap_feats[0]), len(siglip_feats[0]), len(image_noise_mask[0])) + # print([[_.shape, _.mean(), _.std()] for _ in x[0]]) + + ( + x, + cap_feats, + siglip_feats, + x_size, + x_pos_ids, + cap_pos_ids, + siglip_pos_ids, + x_inner_pad_mask, + cap_inner_pad_mask, + siglip_inner_pad_mask, + x_pos_offsets, + x_noise_mask, + cap_noise_mask, + siglip_noise_mask + ) = self.patchify_and_embed( + x, cap_feats, siglip_feats, patch_size, f_patch_size, image_noise_mask + ) + + # x embed & refine + x_item_seqlens = [len(_) for _ in x] + assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens) + x_max_item_seqlen = max(x_item_seqlens) + + x = torch.cat(x, dim=0) + x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x) + + x[torch.cat(x_inner_pad_mask)] = self.x_pad_token + x = list(x.split(x_item_seqlens, dim=0)) + x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split([len(_) for _ in x_pos_ids], dim=0)) + + x = pad_sequence(x, batch_first=True, padding_value=0.0) + x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0) + x_freqs_cis = x_freqs_cis[:, : x.shape[1]] + + # Create x_noise_mask tensor matching x shape + x_noise_mask_tensor = [] + for i in range(bsz): + x_mask = torch.tensor(x_noise_mask[i], dtype=torch.long, device=device) + x_noise_mask_tensor.append(x_mask) + x_noise_mask_tensor = pad_sequence(x_noise_mask_tensor, batch_first=True, padding_value=0) + x_noise_mask_tensor = x_noise_mask_tensor[:, : x.shape[1]] + + # Match t_embedder output dtype to x for layerwise casting compatibility + t_noisy_x = t_noisy.type_as(x) + t_clean_x = t_clean.type_as(x) + + x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(x_item_seqlens): + x_attn_mask[i, :seq_len] = 1 + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for layer in self.noise_refiner: + x = self._gradient_checkpointing_func(layer, x, x_attn_mask, x_freqs_cis, + noise_mask=x_noise_mask_tensor, + adaln_noisy=t_noisy_x, adaln_clean=t_clean_x) + else: + for layer in self.noise_refiner: + x = layer(x, x_attn_mask, x_freqs_cis, + noise_mask=x_noise_mask_tensor, + adaln_noisy=t_noisy_x, adaln_clean=t_clean_x) + + # cap embed & refine (no modulation, so no changes needed) + cap_item_seqlens = [len(_) for _ in cap_feats] + cap_max_item_seqlen = max(cap_item_seqlens) + + cap_feats = torch.cat(cap_feats, dim=0) + cap_feats = self.cap_embedder(cap_feats) + cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token + cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0)) + cap_freqs_cis = list( + self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split([len(_) for _ in cap_pos_ids], dim=0) + ) + + cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0) + cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0) + # Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors + cap_freqs_cis = cap_freqs_cis[:, : cap_feats.shape[1]] + + cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(cap_item_seqlens): + cap_attn_mask[i, :seq_len] = 1 + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for layer in self.context_refiner: + cap_feats = self._gradient_checkpointing_func(layer, cap_feats, cap_attn_mask, cap_freqs_cis) + else: + for layer in self.context_refiner: + cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis) + + # siglip embed & refine + if siglip_feats[0] is not None: + siglip_item_seqlens = [len(_) for _ in siglip_feats] + siglip_max_item_seqlen = max(siglip_item_seqlens) + + siglip_feats = torch.cat(siglip_feats, dim=0) + siglip_feats = self.siglip_embedder(siglip_feats) + siglip_feats[torch.cat(siglip_inner_pad_mask)] = self.siglip_pad_token + siglip_feats = list(siglip_feats.split(siglip_item_seqlens, dim=0)) + siglip_freqs_cis = list( + self.rope_embedder(torch.cat(siglip_pos_ids, dim=0)).split([len(_) for _ in siglip_pos_ids], dim=0) + ) + + siglip_feats = pad_sequence(siglip_feats, batch_first=True, padding_value=0.0) + siglip_freqs_cis = pad_sequence(siglip_freqs_cis, batch_first=True, padding_value=0.0) + # Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors + siglip_freqs_cis = siglip_freqs_cis[:, : siglip_feats.shape[1]] + + siglip_attn_mask = torch.zeros((bsz, siglip_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(siglip_item_seqlens): + siglip_attn_mask[i, :seq_len] = 1 + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for layer in self.siglip_refiner: + siglip_feats = self._gradient_checkpointing_func(layer, siglip_feats, siglip_attn_mask, siglip_freqs_cis) + else: + for layer in self.siglip_refiner: + siglip_feats = layer(siglip_feats, siglip_attn_mask, siglip_freqs_cis) + + # unified + unified = [] + unified_freqs_cis = [] + unified_noise_mask = [] + if siglip_feats[0] is not None: + for i in range(bsz): + x_len = x_item_seqlens[i] + cap_len = cap_item_seqlens[i] + siglip_len = siglip_item_seqlens[i] + unified.append(torch.cat([cap_feats[i][:cap_len], x[i][:x_len], siglip_feats[i][:siglip_len]])) + unified_freqs_cis.append(torch.cat([cap_freqs_cis[i][:cap_len], x_freqs_cis[i][:x_len], siglip_freqs_cis[i][:siglip_len]])) + # Merge masks: cap_noise_mask + x_noise_mask + siglip_noise_mask + unified_noise_mask.append(torch.tensor( + cap_noise_mask[i] + x_noise_mask[i] + siglip_noise_mask[i], + dtype=torch.long, device=device + )) + unified_item_seqlens = [a + b + c for a, b, c in zip(cap_item_seqlens, x_item_seqlens, siglip_item_seqlens)] + else: + for i in range(bsz): + x_len = x_item_seqlens[i] + cap_len = cap_item_seqlens[i] + unified.append(torch.cat([cap_feats[i][:cap_len], x[i][:x_len]])) + unified_freqs_cis.append(torch.cat([cap_freqs_cis[i][:cap_len], x_freqs_cis[i][:x_len]])) + # Merge masks: cap_noise_mask + x_noise_mask + unified_noise_mask.append(torch.tensor( + cap_noise_mask[i] + x_noise_mask[i], + dtype=torch.long, device=device + )) + unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)] + assert unified_item_seqlens == [len(_) for _ in unified] + unified_max_item_seqlen = max(unified_item_seqlens) + + unified = pad_sequence(unified, batch_first=True, padding_value=0.0) + unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0) + unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(unified_item_seqlens): + unified_attn_mask[i, :seq_len] = 1 + + # Create unified_noise_mask tensor matching unified shape + unified_noise_mask_tensor = pad_sequence(unified_noise_mask, batch_first=True, padding_value=0) + unified_noise_mask_tensor = unified_noise_mask_tensor[:, : unified.shape[1]] + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for layer in self.layers: + unified = self._gradient_checkpointing_func( + layer, unified, unified_attn_mask, unified_freqs_cis, + noise_mask=unified_noise_mask_tensor, + adaln_noisy=t_noisy_x, adaln_clean=t_clean_x + ) + else: + for layer in self.layers: + # print(unified.shape, unified_noise_mask_tensor.shape) + unified = layer(unified, unified_attn_mask, unified_freqs_cis, + noise_mask=unified_noise_mask_tensor, + adaln_noisy=t_noisy_x, adaln_clean=t_clean_x) + + unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"]( + unified, + noise_mask=unified_noise_mask_tensor, + c_noisy=t_noisy_x, c_clean=t_clean_x + ) + + # unified = list(unified.unbind(dim=0)) + # x = self.unpatchify(unified, x_size, patch_size, f_patch_size) + x = self.unpatchify(unified, x_size, patch_size, f_patch_size, x_pos_offsets) + + if not return_dict: + return (x,) + + return Transformer2DModelOutput(sample=x) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 388551f812f8..38e7a60cff18 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -404,7 +404,7 @@ "Kandinsky5T2IPipeline", "Kandinsky5I2IPipeline", ] - _import_structure["z_image"] = ["ZImageImg2ImgPipeline", "ZImagePipeline"] + _import_structure["z_image"] = ["ZImageImg2ImgPipeline", "ZImagePipeline", "ZImageOmniPipeline"] _import_structure["skyreels_v2"] = [ "SkyReelsV2DiffusionForcingPipeline", "SkyReelsV2DiffusionForcingImageToVideoPipeline", @@ -841,7 +841,7 @@ WuerstchenDecoderPipeline, WuerstchenPriorPipeline, ) - from .z_image import ZImageImg2ImgPipeline, ZImagePipeline + from .z_image import ZImageImg2ImgPipeline, ZImagePipeline, ZImageOmniPipeline try: if not is_onnx_available(): diff --git a/src/diffusers/pipelines/z_image/__init__.py b/src/diffusers/pipelines/z_image/__init__.py index f4342713e3e9..bb3240bebcb8 100644 --- a/src/diffusers/pipelines/z_image/__init__.py +++ b/src/diffusers/pipelines/z_image/__init__.py @@ -24,6 +24,7 @@ _import_structure["pipeline_output"] = ["ZImagePipelineOutput"] _import_structure["pipeline_z_image"] = ["ZImagePipeline"] _import_structure["pipeline_z_image_img2img"] = ["ZImageImg2ImgPipeline"] + _import_structure["pipeline_z_image_omni"] = ["ZImageOmniPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: @@ -37,7 +38,7 @@ from .pipeline_output import ZImagePipelineOutput from .pipeline_z_image import ZImagePipeline from .pipeline_z_image_img2img import ZImageImg2ImgPipeline - + from .pipeline_z_image_omni import ZImageOmniPipeline else: import sys diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_omni.py b/src/diffusers/pipelines/z_image/pipeline_z_image_omni.py new file mode 100644 index 000000000000..62e8b7de5031 --- /dev/null +++ b/src/diffusers/pipelines/z_image/pipeline_z_image_omni.py @@ -0,0 +1,733 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import PIL + +import torch +from transformers import AutoTokenizer, PreTrainedModel, Siglip2ImageProcessorFast, Siglip2VisionModel + +from ...image_processor import VaeImageProcessor +from ...loaders import FromSingleFileMixin, ZImageLoraLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from .pipeline_output import ZImagePipelineOutput + +from ..flux2.image_processor import Flux2ImageProcessor + +from ...models.transformers.transformer_z_image_omni import ZImageOmniTransformer2DModel + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import ZImageOmniPipeline + + >>> pipe = ZImageOmniPipeline.from_pretrained("Z-a-o/Z-Image-Turbo", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> # Optionally, set the attention backend to flash-attn 2 or 3, default is SDPA in PyTorch. + >>> # (1) Use flash attention 2 + >>> # pipe.transformer.set_attention_backend("flash") + >>> # (2) Use flash attention 3 + >>> # pipe.transformer.set_attention_backend("_flash_3") + + >>> prompt = "一幅为名为“造相「Z-IMAGE-TURBO」”的项目设计的创意海报。画面巧妙地将文字概念视觉化:一辆复古蒸汽小火车化身为巨大的拉链头,正拉开厚厚的冬日积雪,展露出一个生机盎然的春天。" + >>> image = pipe( + ... prompt, + ... height=1024, + ... width=1024, + ... num_inference_steps=9, + ... guidance_scale=0.0, + ... generator=torch.Generator("cuda").manual_seed(42), + ... ).images[0] + >>> image.save("zimage.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class ZImageOmniPipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMixin): + model_cpu_offload_seq = "text_encoder->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: PreTrainedModel, + tokenizer: AutoTokenizer, + transformer: ZImageOmniTransformer2DModel, + siglip: Siglip2VisionModel, + siglip_processor: Siglip2ImageProcessorFast, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + transformer=transformer, + siglip=siglip, + siglip_processor=siglip_processor, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + # self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + num_condition_images: int = 0, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt_embeds = self._encode_prompt( + prompt=prompt, + device=device, + prompt_embeds=prompt_embeds, + max_sequence_length=max_sequence_length, + num_condition_images=num_condition_images, + ) + + if do_classifier_free_guidance: + if negative_prompt is None: + negative_prompt = ["" for _ in prompt] + else: + negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + assert len(prompt) == len(negative_prompt) + negative_prompt_embeds = self._encode_prompt( + prompt=negative_prompt, + device=device, + prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + num_condition_images=num_condition_images, + ) + else: + negative_prompt_embeds = [] + return prompt_embeds, negative_prompt_embeds + + def _encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + max_sequence_length: int = 512, + num_condition_images: int = 0, + ) -> List[torch.FloatTensor]: + device = device or self._execution_device + + if prompt_embeds is not None: + return prompt_embeds + + if isinstance(prompt, str): + prompt = [prompt] + + for i, prompt_item in enumerate(prompt): + + if num_condition_images == 0: + prompt[i] = ["<|im_start|>user\n" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n"] + elif num_condition_images > 0: + prompt_list = ["<|im_start|>user\n<|vision_start|>"] + prompt_list += ["<|vision_end|><|vision_start|>"] * (num_condition_images - 1) + prompt_list += ["<|vision_end|>" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n<|vision_start|>"] + prompt_list += ["<|vision_end|><|im_end|>"] + prompt[i] = prompt_list + + flattened_prompt = [] + prompt_list_lengths = [] + + for i in range(len(prompt)): + prompt_list_lengths.append(len(prompt[i])) + flattened_prompt.extend(prompt[i]) + + text_inputs = self.tokenizer( + flattened_prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + prompt_masks = text_inputs.attention_mask.to(device).bool() + + prompt_embeds = self.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_masks, + output_hidden_states=True, + ).hidden_states[-2] + + embeddings_list = [] + start_idx = 0 + for i in range(len(prompt_list_lengths)): + batch_embeddings = [] + end_idx = start_idx + prompt_list_lengths[i] + for j in range(start_idx, end_idx): + batch_embeddings.append(prompt_embeds[j][prompt_masks[j]]) + embeddings_list.append(batch_embeddings) + start_idx = end_idx + + return embeddings_list + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + return latents + + def prepare_image_latents( + self, + images: List[torch.Tensor], + batch_size, + device, + dtype, + ): + image_latents = [] + for image in images: + image = image.to(device=device, dtype=dtype) + image_latent = (self.vae.encode(image.bfloat16()).latent_dist.mode()[0] - self.vae.config.shift_factor) * self.vae.config.scaling_factor + image_latent = image_latent.unsqueeze(1).to(dtype) + image_latents.append(image_latent) # (16, 128, 128) + + # image_latents = [image_latents] * batch_size + image_latents = [image_latents.copy() for _ in range(batch_size)] + + + return image_latents + + def prepare_siglip_embeds( + self, + images: List[torch.Tensor], + batch_size, + device, + dtype, + ): + siglip_embeds = [] + for image in images: + siglip_inputs = self.siglip_processor(images=[image], return_tensors="pt").to(device) + shape = siglip_inputs.spatial_shapes[0] + hidden_state = self.siglip(**siglip_inputs).last_hidden_state + B, N, C = hidden_state.shape + hidden_state = hidden_state[:, :shape[0] * shape[1]] + hidden_state = hidden_state.view(shape[0], shape[1], C) + siglip_embeds.append(hidden_state.to(dtype)) + + # siglip_embeds = [siglip_embeds] * batch_size + siglip_embeds = [siglip_embeds.copy() for _ in range(batch_size)] + + return siglip_embeds + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: Optional[Union[List[PIL.Image.Image], PIL.Image.Image]] = None, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 5.0, + cfg_normalization: bool = False, + cfg_truncation: float = 1.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to 1024): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 1024): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + cfg_normalization (`bool`, *optional*, defaults to False): + Whether to apply configuration normalization. + cfg_truncation (`float`, *optional*, defaults to 1.0): + The truncation value for configuration. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain + tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, *optional*, defaults to 512): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + + if image is not None and not isinstance(image, list): + image = [image] + num_condition_images = len(image) if image is not None else 0 + + device = self._execution_device + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + self._cfg_normalization = cfg_normalization + self._cfg_truncation = cfg_truncation + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = len(prompt_embeds) + + # If prompt_embeds is provided and prompt is None, skip encoding + if prompt_embeds is not None and prompt is None: + if self.do_classifier_free_guidance and negative_prompt_embeds is None: + raise ValueError( + "When `prompt_embeds` is provided without `prompt`, " + "`negative_prompt_embeds` must also be provided for classifier-free guidance." + ) + else: + ( + prompt_embeds, + negative_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + num_condition_images=num_condition_images, + ) + + # 3. Process condition images. Copied from diffusers.pipelines.flux2.pipeline_flux2 + condition_images = [] + resized_images = [] + if image is not None: + for img in image: + self.image_processor.check_image_input(img) + for img in image: + image_width, image_height = img.size + if image_width * image_height > 1024 * 1024: + if height is not None and width is not None: + img = self.image_processor._resize_to_target_area(img, height * width) + else: + img = self.image_processor._resize_to_target_area(img, 1024 * 1024) + image_width, image_height = img.size + resized_images.append(img) + + multiple_of = self.vae_scale_factor * 2 + image_width = (image_width // multiple_of) * multiple_of + image_height = (image_height // multiple_of) * multiple_of + img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop") + condition_images.append(img) + + if len(condition_images) > 0: + height = height or image_height + width = width or image_width + + else: + height = height or 1024 + width = width or 1024 + + vae_scale = self.vae_scale_factor * 2 + if height % vae_scale != 0: + raise ValueError( + f"Height must be divisible by {vae_scale} (got {height}). " + f"Please adjust the height to a multiple of {vae_scale}." + ) + if width % vae_scale != 0: + raise ValueError( + f"Width must be divisible by {vae_scale} (got {width}). " + f"Please adjust the width to a multiple of {vae_scale}." + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.in_channels + + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + torch.float32, + device, + generator, + latents, + ) + + condition_latents = self.prepare_image_latents( + images=condition_images, + batch_size=batch_size * num_images_per_prompt, + device=device, + dtype=torch.float32, + ) + condition_latents = [[lat.to(self.transformer.dtype) for lat in lats] for lats in condition_latents] + if self.do_classifier_free_guidance: + negative_condition_latents = [[lat.clone() for lat in batch] for batch in condition_latents] + + condition_siglip_embeds = self.prepare_siglip_embeds( + images=resized_images, + batch_size=batch_size * num_images_per_prompt, + device=device, + dtype=torch.float32, + ) + condition_siglip_embeds = [[se.to(self.transformer.dtype) for se in sels] for sels in condition_siglip_embeds] + if self.do_classifier_free_guidance: + negative_condition_siglip_embeds = [[se.clone() for se in batch] for batch in condition_siglip_embeds] + + # Repeat prompt_embeds for num_images_per_prompt + if num_images_per_prompt > 1: + prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)] + if self.do_classifier_free_guidance and negative_prompt_embeds: + negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)] + + condition_siglip_embeds = [None if sels == [] else sels + [None] for sels in condition_siglip_embeds] + negative_condition_siglip_embeds = [None if sels == [] else sels + [None] for sels in negative_condition_siglip_embeds] + + actual_batch_size = batch_size * num_images_per_prompt + image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2) + + # 5. Prepare timesteps + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + self.scheduler.sigma_min = 0.0 + scheduler_kwargs = {"mu": mu} + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + **scheduler_kwargs, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]) + timestep = (1000 - timestep) / 1000 + # Normalized time for time-aware config (0 at start, 1 at end) + t_norm = timestep[0].item() + + # Handle cfg truncation + current_guidance_scale = self.guidance_scale + if ( + self.do_classifier_free_guidance + and self._cfg_truncation is not None + and float(self._cfg_truncation) <= 1 + ): + if t_norm > self._cfg_truncation: + current_guidance_scale = 0.0 + + # Run CFG only if configured AND scale is non-zero + apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0 + + if apply_cfg: + latents_typed = latents.to(self.transformer.dtype) + latent_model_input = latents_typed.repeat(2, 1, 1, 1) + prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds + condition_latents_model_input = condition_latents + negative_condition_latents + condition_siglip_embeds_model_input = condition_siglip_embeds + negative_condition_siglip_embeds + timestep_model_input = timestep.repeat(2) + else: + latent_model_input = latents.to(self.transformer.dtype) + prompt_embeds_model_input = prompt_embeds + condition_latents_model_input = condition_latents + condition_siglip_embeds_model_input = condition_siglip_embeds + timestep_model_input = timestep + + latent_model_input = latent_model_input.unsqueeze(2) + latent_model_input_list = list(latent_model_input.unbind(dim=0)) + + model_out_list = self.transformer( + latent_model_input_list, + timestep_model_input, + prompt_embeds_model_input, + condition_latents_model_input, + condition_siglip_embeds_model_input, + return_dict=False + )[0] + + if apply_cfg: + # Perform CFG + pos_out = model_out_list[:actual_batch_size] + neg_out = model_out_list[actual_batch_size:] + + noise_pred = [] + for j in range(actual_batch_size): + pos = pos_out[j].float() + neg = neg_out[j].float() + + pred = pos + current_guidance_scale * (pos - neg) + + # Renormalization + if self._cfg_normalization and float(self._cfg_normalization) > 0.0: + ori_pos_norm = torch.linalg.vector_norm(pos) + new_pos_norm = torch.linalg.vector_norm(pred) + max_new_norm = ori_pos_norm * float(self._cfg_normalization) + if new_pos_norm > max_new_norm: + pred = pred * (max_new_norm / new_pos_norm) + + noise_pred.append(pred) + + noise_pred = torch.stack(noise_pred, dim=0) + else: + noise_pred = torch.stack([t.float() for t in model_out_list], dim=0) + + noise_pred = noise_pred.squeeze(2) + noise_pred = -noise_pred + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0] + assert latents.dtype == torch.float32 + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if output_type == "latent": + image = latents + + else: + latents = latents.to(self.vae.dtype) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ZImagePipelineOutput(images=image) From 3cbb38d9d3160472fc267e95457d4287ac235a6d Mon Sep 17 00:00:00 2001 From: Jerry Qilong Wu Date: Fri, 19 Dec 2025 07:33:23 +0000 Subject: [PATCH 2/6] Merged into one transformer for Z-Image. --- src/diffusers/__init__.py | 2 - src/diffusers/models/__init__.py | 1 - src/diffusers/models/transformers/__init__.py | 1 - .../transformers/transformer_z_image.py | 643 ++++++++++- .../transformers/transformer_z_image_omni.py | 1005 ----------------- .../z_image/pipeline_z_image_omni.py | 4 +- 6 files changed, 626 insertions(+), 1030 deletions(-) delete mode 100644 src/diffusers/models/transformers/transformer_z_image_omni.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 0d04e78c4b1e..b87efce56876 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -279,7 +279,6 @@ "WanTransformer3DModel", "WanVACETransformer3DModel", "ZImageTransformer2DModel", - "ZImageOmniTransformer2DModel", "attention_backend", ] ) @@ -1015,7 +1014,6 @@ WanTransformer3DModel, WanVACETransformer3DModel, ZImageTransformer2DModel, - ZImageOmniTransformer2DModel, attention_backend, ) from .modular_pipelines import ComponentsManager, ComponentSpec, ModularPipeline, ModularPipelineBlocks diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 87d5c0ef95af..29d8b0b5a55d 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -116,7 +116,6 @@ _import_structure["transformers.transformer_wan_animate"] = ["WanAnimateTransformer3DModel"] _import_structure["transformers.transformer_wan_vace"] = ["WanVACETransformer3DModel"] _import_structure["transformers.transformer_z_image"] = ["ZImageTransformer2DModel"] - _import_structure["transformers.transformer_z_image_omni"] = ["ZImageOmniTransformer2DModel"] _import_structure["unets.unet_1d"] = ["UNet1DModel"] _import_structure["unets.unet_2d"] = ["UNet2DModel"] _import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"] diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 71f352fc80aa..a42f6b2716e1 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -48,4 +48,3 @@ from .transformer_wan_animate import WanAnimateTransformer3DModel from .transformer_wan_vace import WanVACETransformer3DModel from .transformer_z_image import ZImageTransformer2DModel - from .transformer_z_image_omni import ZImageOmniTransformer2DModel diff --git a/src/diffusers/models/transformers/transformer_z_image.py b/src/diffusers/models/transformers/transformer_z_image.py index 5c401b9d202b..fd893d6749f9 100644 --- a/src/diffusers/models/transformers/transformer_z_image.py +++ b/src/diffusers/models/transformers/transformer_z_image.py @@ -215,12 +215,54 @@ def forward( attn_mask: torch.Tensor, freqs_cis: torch.Tensor, adaln_input: Optional[torch.Tensor] = None, + noise_mask: Optional[torch.Tensor] = None, + adaln_noisy: Optional[torch.Tensor] = None, + adaln_clean: Optional[torch.Tensor] = None, ): if self.modulation: - assert adaln_input is not None - scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2) - gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh() - scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp + if noise_mask is not None and adaln_noisy is not None and adaln_clean is not None: + # Per-token modulation based on noise_mask, (batch, seq_len), 1 for noisy tokens, 0 for clean tokens + batch_size, seq_len = x.shape[0], x.shape[1] + + mod_noisy = self.adaLN_modulation(adaln_noisy) + mod_clean = self.adaLN_modulation(adaln_clean) + + scale_msa_noisy, gate_msa_noisy, scale_mlp_noisy, gate_mlp_noisy = mod_noisy.chunk(4, dim=1) + scale_msa_clean, gate_msa_clean, scale_mlp_clean, gate_mlp_clean = mod_clean.chunk(4, dim=1) + + gate_msa_noisy, gate_mlp_noisy = gate_msa_noisy.tanh(), gate_mlp_noisy.tanh() + gate_msa_clean, gate_mlp_clean = gate_msa_clean.tanh(), gate_mlp_clean.tanh() + + scale_msa_noisy, scale_mlp_noisy = 1.0 + scale_msa_noisy, 1.0 + scale_mlp_noisy + scale_msa_clean, scale_mlp_clean = 1.0 + scale_msa_clean, 1.0 + scale_mlp_clean + + noise_mask_expanded = noise_mask.unsqueeze(-1) # (batch, seq_len, 1) + scale_msa = torch.where( + noise_mask_expanded == 1, + scale_msa_noisy.unsqueeze(1).expand(-1, seq_len, -1), + scale_msa_clean.unsqueeze(1).expand(-1, seq_len, -1), + ) + scale_mlp = torch.where( + noise_mask_expanded == 1, + scale_mlp_noisy.unsqueeze(1).expand(-1, seq_len, -1), + scale_mlp_clean.unsqueeze(1).expand(-1, seq_len, -1), + ) + gate_msa = torch.where( + noise_mask_expanded == 1, + gate_msa_noisy.unsqueeze(1).expand(-1, seq_len, -1), + gate_msa_clean.unsqueeze(1).expand(-1, seq_len, -1), + ) + gate_mlp = torch.where( + noise_mask_expanded == 1, + gate_mlp_noisy.unsqueeze(1).expand(-1, seq_len, -1), + gate_mlp_clean.unsqueeze(1).expand(-1, seq_len, -1), + ) + else: + # Original global modulation + assert adaln_input is not None + scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2) + gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh() + scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp # Attention block attn_out = self.attention( @@ -252,9 +294,26 @@ def __init__(self, hidden_size, out_channels): nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True), ) - def forward(self, x, c): - scale = 1.0 + self.adaLN_modulation(c) - x = self.norm_final(x) * scale.unsqueeze(1) + def forward(self, x, c=None, noise_mask=None, c_noisy=None, c_clean=None): + if noise_mask is not None and c_noisy is not None and c_clean is not None: + # Per-token modulation based on noise_mask + batch_size, seq_len = x.shape[0], x.shape[1] + scale_noisy = 1.0 + self.adaLN_modulation(c_noisy) + scale_clean = 1.0 + self.adaLN_modulation(c_clean) + + noise_mask_expanded = noise_mask.unsqueeze(-1) + scale = torch.where( + noise_mask_expanded == 1, + scale_noisy.unsqueeze(1).expand(-1, seq_len, -1), + scale_clean.unsqueeze(1).expand(-1, seq_len, -1), + ) + else: + # Original global modulation + assert c is not None, "Either c or (c_noisy, c_clean) must be provided" + scale = 1.0 + self.adaLN_modulation(c) + scale = scale.unsqueeze(1) + + x = self.norm_final(x) * scale x = self.linear(x) return x @@ -325,6 +384,7 @@ def __init__( norm_eps=1e-5, qk_norm=True, cap_feat_dim=2560, + siglip_feat_dim=None, # Optional: set to enable SigLIP support for Omni rope_theta=256.0, t_scale=1000.0, axes_dims=[32, 48, 48], @@ -386,6 +446,31 @@ def __init__( self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024) self.cap_embedder = nn.Sequential(RMSNorm(cap_feat_dim, eps=norm_eps), nn.Linear(cap_feat_dim, dim, bias=True)) + # Optional SigLIP components (for Omni variant) + if siglip_feat_dim is not None: + self.siglip_embedder = nn.Sequential( + RMSNorm(siglip_feat_dim, eps=norm_eps), nn.Linear(siglip_feat_dim, dim, bias=True) + ) + self.siglip_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + 2000 + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + modulation=False, + ) + for layer_id in range(n_refiner_layers) + ] + ) + self.siglip_pad_token = nn.Parameter(torch.empty((1, dim))) + else: + self.siglip_embedder = None + self.siglip_refiner = None + self.siglip_pad_token = None + self.x_pad_token = nn.Parameter(torch.empty((1, dim))) self.cap_pad_token = nn.Parameter(torch.empty((1, dim))) @@ -402,22 +487,57 @@ def __init__( self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens) - def unpatchify(self, x: List[torch.Tensor], size: List[Tuple], patch_size, f_patch_size) -> List[torch.Tensor]: + def unpatchify( + self, + x: List[torch.Tensor], + size: List[Tuple], + patch_size, + f_patch_size, + x_pos_offsets: Optional[List[Tuple[int, int]]] = None, + ) -> List[torch.Tensor]: pH = pW = patch_size pF = f_patch_size bsz = len(x) assert len(size) == bsz - for i in range(bsz): - F, H, W = size[i] - ori_len = (F // pF) * (H // pH) * (W // pW) - # "f h w pf ph pw c -> c (f pf) (h ph) (w pw)" - x[i] = ( - x[i][:ori_len] - .view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels) - .permute(6, 0, 3, 1, 4, 2, 5) - .reshape(self.out_channels, F, H, W) - ) - return x + + if x_pos_offsets is not None: + # Omni: extract target image from unified sequence (cond_images + target) + result = [] + for i in range(bsz): + unified_x = x[i][x_pos_offsets[i][0] : x_pos_offsets[i][1]] + cu_len = 0 + x_item = None + for j in range(len(size[i])): + if size[i][j] is None: + ori_len = 0 + pad_len = SEQ_MULTI_OF + cu_len += pad_len + ori_len + else: + F, H, W = size[i][j] + ori_len = (F // pF) * (H // pH) * (W // pW) + pad_len = (-ori_len) % SEQ_MULTI_OF + x_item = ( + unified_x[cu_len : cu_len + ori_len] + .view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels) + .permute(6, 0, 3, 1, 4, 2, 5) + .reshape(self.out_channels, F, H, W) + ) + cu_len += ori_len + pad_len + result.append(x_item) # Return only the last (target) image + return result + else: + # Original mode: simple unpatchify + for i in range(bsz): + F, H, W = size[i] + ori_len = (F // pF) * (H // pH) * (W // pW) + # "f h w pf ph pw c -> c (f pf) (h ph) (w pw)" + x[i] = ( + x[i][:ori_len] + .view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels) + .permute(6, 0, 3, 1, 4, 2, 5) + .reshape(self.out_channels, F, H, W) + ) + return x @staticmethod def create_coordinate_grid(size, start=None, device=None): @@ -531,11 +651,258 @@ def patchify_and_embed( all_cap_pad_mask, ) + def patchify_and_embed_omni( + self, + all_x: List[List[torch.Tensor]], + all_cap_feats: List[List[torch.Tensor]], + all_siglip_feats: List[List[torch.Tensor]], + patch_size: int, + f_patch_size: int, + images_noise_mask: List[List[int]], + ): + bsz = len(all_x) + pH = pW = patch_size + pF = f_patch_size + device = all_x[0][-1].device + dtype = all_x[0][-1].dtype + + all_x_padded = [] + all_x_size = [] + all_x_pos_ids = [] + all_x_pad_mask = [] + all_x_len = [] + all_x_noise_mask = [] + all_cap_padded_feats = [] + all_cap_pos_ids = [] + all_cap_pad_mask = [] + all_cap_len = [] + all_cap_noise_mask = [] + all_siglip_padded_feats = [] + all_siglip_pos_ids = [] + all_siglip_pad_mask = [] + all_siglip_len = [] + all_siglip_noise_mask = [] + + for i in range(bsz): + # Process captions + num_images = len(all_x[i]) + cap_padded_feats = [] + cap_item_cu_len = 1 + cap_start_pos = [] + cap_end_pos = [] + cap_padded_pos_ids = [] + cap_pad_mask = [] + cap_len = [] + cap_noise_mask = [] + + for j, cap_item in enumerate(all_cap_feats[i]): + cap_item_ori_len = len(cap_item) + cap_item_padding_len = (-cap_item_ori_len) % SEQ_MULTI_OF + cap_len.append(cap_item_ori_len + cap_item_padding_len) + + cap_item_padding_pos_ids = ( + self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device) + .flatten(0, 2) + .repeat(cap_item_padding_len, 1) + ) + cap_start_pos.append(cap_item_cu_len) + cap_item_ori_pos_ids = self.create_coordinate_grid( + size=(cap_item_ori_len, 1, 1), start=(cap_item_cu_len, 0, 0), device=device + ).flatten(0, 2) + cap_padded_pos_ids.append(cap_item_ori_pos_ids) + cap_padded_pos_ids.append(cap_item_padding_pos_ids) + cap_item_cu_len += cap_item_ori_len + cap_end_pos.append(cap_item_cu_len) + cap_item_cu_len += 2 # for image vae tokens and siglip tokens + + cap_pad_mask.append(torch.zeros((cap_item_ori_len,), dtype=torch.bool, device=device)) + cap_pad_mask.append(torch.ones((cap_item_padding_len,), dtype=torch.bool, device=device)) + cap_item_padded_feat = torch.cat([cap_item, cap_item[-1:].repeat(cap_item_padding_len, 1)], dim=0) + cap_padded_feats.append(cap_item_padded_feat) + + if j < len(images_noise_mask[i]): + cap_noise_mask.extend([images_noise_mask[i][j]] * (cap_item_ori_len + cap_item_padding_len)) + else: + cap_noise_mask.extend([1] * (cap_item_ori_len + cap_item_padding_len)) + + all_cap_noise_mask.append(cap_noise_mask) + cap_padded_pos_ids = torch.cat(cap_padded_pos_ids, dim=0) + all_cap_pos_ids.append(cap_padded_pos_ids) + cap_pad_mask = torch.cat(cap_pad_mask, dim=0) + all_cap_pad_mask.append(cap_pad_mask) + all_cap_padded_feats.append(torch.cat(cap_padded_feats, dim=0)) + all_cap_len.append(cap_len) + + # Process images (x) + x_padded = [] + x_padded_pos_ids = [] + x_pad_mask = [] + x_len = [] + x_size = [] + x_noise_mask = [] + + for j, x_item in enumerate(all_x[i]): + if x_item is not None: + C, F, H, W = x_item.size() + x_size.append((F, H, W)) + F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW + + x_item = x_item.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) + x_item = x_item.permute(1, 3, 5, 2, 4, 6, 0).reshape( + F_tokens * H_tokens * W_tokens, pF * pH * pW * C + ) + + x_item_ori_len = len(x_item) + x_item_padding_len = (-x_item_ori_len) % SEQ_MULTI_OF + x_len.append(x_item_ori_len + x_item_padding_len) + + x_item_padding_pos_ids = ( + self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device) + .flatten(0, 2) + .repeat(x_item_padding_len, 1) + ) + x_item_ori_pos_ids = self.create_coordinate_grid( + size=(F_tokens, H_tokens, W_tokens), start=(cap_end_pos[j], 0, 0), device=device + ).flatten(0, 2) + x_padded_pos_ids.append(x_item_ori_pos_ids) + x_padded_pos_ids.append(x_item_padding_pos_ids) + + x_pad_mask.append(torch.zeros((x_item_ori_len,), dtype=torch.bool, device=device)) + x_pad_mask.append(torch.ones((x_item_padding_len,), dtype=torch.bool, device=device)) + x_item_padded_feat = torch.cat([x_item, x_item[-1:].repeat(x_item_padding_len, 1)], dim=0) + x_padded.append(x_item_padded_feat) + x_noise_mask.extend([images_noise_mask[i][j]] * (x_item_ori_len + x_item_padding_len)) + else: + x_pad_dim = 64 + x_item_padding_len = SEQ_MULTI_OF + x_size.append(None) + x_item_padding_pos_ids = ( + self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device) + .flatten(0, 2) + .repeat(x_item_padding_len, 1) + ) + x_len.append(x_item_padding_len) + x_padded_pos_ids.append(x_item_padding_pos_ids) + x_pad_mask.append(torch.ones((x_item_padding_len,), dtype=torch.bool, device=device)) + x_padded.append(torch.zeros((x_item_padding_len, x_pad_dim), dtype=dtype, device=device)) + x_noise_mask.extend([images_noise_mask[i][j]] * x_item_padding_len) + + all_x_noise_mask.append(x_noise_mask) + all_x_size.append(x_size) + x_padded_pos_ids = torch.cat(x_padded_pos_ids, dim=0) + all_x_pos_ids.append(x_padded_pos_ids) + x_pad_mask = torch.cat(x_pad_mask, dim=0) + all_x_pad_mask.append(x_pad_mask) + all_x_padded.append(torch.cat(x_padded, dim=0)) + all_x_len.append(x_len) + + # Process siglip_feats + if all_siglip_feats[i] is None: + all_siglip_len.append([0 for _ in range(num_images)]) + all_siglip_padded_feats.append(None) + else: + sig_padded_feats = [] + sig_padded_pos_ids = [] + sig_pad_mask = [] + sig_len = [] + sig_noise_mask = [] + + for j, sig_item in enumerate(all_siglip_feats[i]): + if sig_item is not None: + sig_H, sig_W, sig_C = sig_item.size() + sig_H_tokens, sig_W_tokens, sig_F_tokens = sig_H, sig_W, 1 + + sig_item = sig_item.view(sig_C, sig_F_tokens, 1, sig_H_tokens, 1, sig_W_tokens, 1) + sig_item = sig_item.permute(1, 3, 5, 2, 4, 6, 0).reshape( + sig_F_tokens * sig_H_tokens * sig_W_tokens, sig_C + ) + + sig_item_ori_len = len(sig_item) + sig_item_padding_len = (-sig_item_ori_len) % SEQ_MULTI_OF + sig_len.append(sig_item_ori_len + sig_item_padding_len) + + sig_item_padding_pos_ids = ( + self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device) + .flatten(0, 2) + .repeat(sig_item_padding_len, 1) + ) + sig_item_ori_pos_ids = self.create_coordinate_grid( + size=(sig_F_tokens, sig_H_tokens, sig_W_tokens), + start=(cap_end_pos[j] + 1, 0, 0), + device=device, + ) + # Scale position IDs to match x resolution + sig_item_ori_pos_ids[..., 1] = ( + sig_item_ori_pos_ids[..., 1] / (sig_H_tokens - 1) * (x_size[j][1] - 1) + ) + sig_item_ori_pos_ids[..., 2] = ( + sig_item_ori_pos_ids[..., 2] / (sig_W_tokens - 1) * (x_size[j][2] - 1) + ) + sig_item_ori_pos_ids = sig_item_ori_pos_ids.flatten(0, 2) + sig_padded_pos_ids.append(sig_item_ori_pos_ids) + sig_padded_pos_ids.append(sig_item_padding_pos_ids) + + sig_pad_mask.append(torch.zeros((sig_item_ori_len,), dtype=torch.bool, device=device)) + sig_pad_mask.append(torch.ones((sig_item_padding_len,), dtype=torch.bool, device=device)) + sig_item_padded_feat = torch.cat( + [sig_item, sig_item[-1:].repeat(sig_item_padding_len, 1)], dim=0 + ) + sig_padded_feats.append(sig_item_padded_feat) + sig_noise_mask.extend([images_noise_mask[i][j]] * (sig_item_ori_len + sig_item_padding_len)) + else: + sig_pad_dim = self.config.siglip_feat_dim or 1152 + sig_item_padding_len = SEQ_MULTI_OF + sig_item_padding_pos_ids = ( + self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device) + .flatten(0, 2) + .repeat(sig_item_padding_len, 1) + ) + sig_padded_pos_ids.append(sig_item_padding_pos_ids) + sig_pad_mask.append(torch.ones((sig_item_padding_len,), dtype=torch.bool, device=device)) + sig_padded_feats.append( + torch.zeros((sig_item_padding_len, sig_pad_dim), dtype=dtype, device=device) + ) + sig_noise_mask.extend([images_noise_mask[i][j]] * sig_item_padding_len) + + all_siglip_noise_mask.append(sig_noise_mask) + sig_padded_pos_ids = torch.cat(sig_padded_pos_ids, dim=0) + all_siglip_pos_ids.append(sig_padded_pos_ids) + sig_pad_mask = torch.cat(sig_pad_mask, dim=0) + all_siglip_pad_mask.append(sig_pad_mask) + all_siglip_padded_feats.append(torch.cat(sig_padded_feats, dim=0)) + all_siglip_len.append(sig_len) + + # Compute x position offsets + all_x_pos_offsets = [] + for i in range(bsz): + start = sum(all_cap_len[i]) + end = start + sum(all_x_len[i]) + all_x_pos_offsets.append((start, end)) + + return ( + all_x_padded, + all_cap_padded_feats, + all_siglip_padded_feats, + all_x_size, + all_x_pos_ids, + all_cap_pos_ids, + all_siglip_pos_ids, + all_x_pad_mask, + all_cap_pad_mask, + all_siglip_pad_mask, + all_x_pos_offsets, + all_x_noise_mask, + all_cap_noise_mask, + all_siglip_noise_mask, + ) + def forward( self, x: List[torch.Tensor], t, cap_feats: List[torch.Tensor], + cond_latents: Optional[List[List[torch.Tensor]]] = None, + siglip_feats: Optional[List[List[torch.Tensor]]] = None, patch_size=2, f_patch_size=1, return_dict: bool = True, @@ -543,6 +910,26 @@ def forward( assert patch_size in self.all_patch_size assert f_patch_size in self.all_f_patch_size + # Determine mode based on cond_latents + omni_mode = cond_latents is not None + + if omni_mode: + return self._forward_omni( + x, t, cap_feats, cond_latents, siglip_feats, patch_size, f_patch_size, return_dict + ) + else: + return self._forward_basic(x, t, cap_feats, patch_size, f_patch_size, return_dict) + + def _forward_basic( + self, + x: List[torch.Tensor], + t, + cap_feats: List[torch.Tensor], + patch_size: int, + f_patch_size: int, + return_dict: bool, + ): + """Original text-to-image forward pass.""" bsz = len(x) device = x[0].device t = t * self.t_scale @@ -651,3 +1038,221 @@ def forward( return (x,) return Transformer2DModelOutput(sample=x) + + def _forward_omni( + self, + x: List[torch.Tensor], + t, + cap_feats: List[List[torch.Tensor]], + cond_latents: List[List[torch.Tensor]], + siglip_feats: List[List[torch.Tensor]], + patch_size: int, + f_patch_size: int, + return_dict: bool, + ): + """Omni mode forward pass with image conditioning.""" + bsz = len(x) + device = x[0].device + + # Create dual timestep embeddings: one for noisy tokens (t), one for clean tokens (t=1) + t_combined = torch.cat([t, torch.ones_like(t, dtype=t.dtype, device=device)], dim=0) + t_combined = t_combined * self.t_scale + t_combined = self.t_embedder(t_combined) + t_noisy = t_combined[:bsz] # Original timestep for noisy tokens + t_clean = t_combined[bsz:] # t=1 for clean (condition) tokens + + # Combine condition latents with target latent + x = [cond_latents[i] + [x[i]] for i in range(bsz)] + image_noise_mask = [[0] * (len(x[i]) - 1) + [1] for i in range(bsz)] + + # Patchify and embed for Omni mode + ( + x, + cap_feats, + siglip_feats, + x_size, + x_pos_ids, + cap_pos_ids, + siglip_pos_ids, + x_inner_pad_mask, + cap_inner_pad_mask, + siglip_inner_pad_mask, + x_pos_offsets, + x_noise_mask, + cap_noise_mask, + siglip_noise_mask, + ) = self.patchify_and_embed_omni(x, cap_feats, siglip_feats, patch_size, f_patch_size, image_noise_mask) + + # x embed & refine + x_item_seqlens = [len(_) for _ in x] + assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens) + x_max_item_seqlen = max(x_item_seqlens) + + x = torch.cat(x, dim=0) + x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x) + + x[torch.cat(x_inner_pad_mask)] = self.x_pad_token + x = list(x.split(x_item_seqlens, dim=0)) + x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split([len(_) for _ in x_pos_ids], dim=0)) + + x = pad_sequence(x, batch_first=True, padding_value=0.0) + x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0) + x_freqs_cis = x_freqs_cis[:, : x.shape[1]] + + # Create x_noise_mask tensor + x_noise_mask_tensor = [] + for i in range(bsz): + x_mask = torch.tensor(x_noise_mask[i], dtype=torch.long, device=device) + x_noise_mask_tensor.append(x_mask) + x_noise_mask_tensor = pad_sequence(x_noise_mask_tensor, batch_first=True, padding_value=0) + x_noise_mask_tensor = x_noise_mask_tensor[:, : x.shape[1]] + + # Match t_embedder output dtype to x + t_noisy_x = t_noisy.type_as(x) + t_clean_x = t_clean.type_as(x) + + x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(x_item_seqlens): + x_attn_mask[i, :seq_len] = 1 + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for layer in self.noise_refiner: + x = self._gradient_checkpointing_func( + layer, x, x_attn_mask, x_freqs_cis, + noise_mask=x_noise_mask_tensor, adaln_noisy=t_noisy_x, adaln_clean=t_clean_x + ) + else: + for layer in self.noise_refiner: + x = layer( + x, x_attn_mask, x_freqs_cis, + noise_mask=x_noise_mask_tensor, adaln_noisy=t_noisy_x, adaln_clean=t_clean_x + ) + + # cap embed & refine (no modulation) + cap_item_seqlens = [len(_) for _ in cap_feats] + cap_max_item_seqlen = max(cap_item_seqlens) + + cap_feats = torch.cat(cap_feats, dim=0) + cap_feats = self.cap_embedder(cap_feats) + cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token + cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0)) + cap_freqs_cis = list( + self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split([len(_) for _ in cap_pos_ids], dim=0) + ) + + cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0) + cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0) + cap_freqs_cis = cap_freqs_cis[:, : cap_feats.shape[1]] + + cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(cap_item_seqlens): + cap_attn_mask[i, :seq_len] = 1 + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for layer in self.context_refiner: + cap_feats = self._gradient_checkpointing_func(layer, cap_feats, cap_attn_mask, cap_freqs_cis) + else: + for layer in self.context_refiner: + cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis) + + # siglip embed & refine (if available) + siglip_item_seqlens = None + if siglip_feats[0] is not None and self.siglip_embedder is not None: + siglip_item_seqlens = [len(_) for _ in siglip_feats] + siglip_max_item_seqlen = max(siglip_item_seqlens) + + siglip_feats = torch.cat(siglip_feats, dim=0) + siglip_feats = self.siglip_embedder(siglip_feats) + siglip_feats[torch.cat(siglip_inner_pad_mask)] = self.siglip_pad_token + siglip_feats = list(siglip_feats.split(siglip_item_seqlens, dim=0)) + siglip_freqs_cis = list( + self.rope_embedder(torch.cat(siglip_pos_ids, dim=0)).split([len(_) for _ in siglip_pos_ids], dim=0) + ) + + siglip_feats = pad_sequence(siglip_feats, batch_first=True, padding_value=0.0) + siglip_freqs_cis = pad_sequence(siglip_freqs_cis, batch_first=True, padding_value=0.0) + siglip_freqs_cis = siglip_freqs_cis[:, : siglip_feats.shape[1]] + + siglip_attn_mask = torch.zeros((bsz, siglip_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(siglip_item_seqlens): + siglip_attn_mask[i, :seq_len] = 1 + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for layer in self.siglip_refiner: + siglip_feats = self._gradient_checkpointing_func( + layer, siglip_feats, siglip_attn_mask, siglip_freqs_cis + ) + else: + for layer in self.siglip_refiner: + siglip_feats = layer(siglip_feats, siglip_attn_mask, siglip_freqs_cis) + + # Build unified sequence + unified = [] + unified_freqs_cis = [] + unified_noise_mask = [] + + if siglip_item_seqlens is not None: + for i in range(bsz): + x_len = x_item_seqlens[i] + cap_len = cap_item_seqlens[i] + siglip_len = siglip_item_seqlens[i] + unified.append( + torch.cat([cap_feats[i][:cap_len], x[i][:x_len], siglip_feats[i][:siglip_len]]) + ) + unified_freqs_cis.append( + torch.cat([cap_freqs_cis[i][:cap_len], x_freqs_cis[i][:x_len], siglip_freqs_cis[i][:siglip_len]]) + ) + unified_noise_mask.append( + torch.tensor( + cap_noise_mask[i] + x_noise_mask[i] + siglip_noise_mask[i], + dtype=torch.long, + device=device, + ) + ) + unified_item_seqlens = [a + b + c for a, b, c in zip(cap_item_seqlens, x_item_seqlens, siglip_item_seqlens)] + else: + for i in range(bsz): + x_len = x_item_seqlens[i] + cap_len = cap_item_seqlens[i] + unified.append(torch.cat([cap_feats[i][:cap_len], x[i][:x_len]])) + unified_freqs_cis.append(torch.cat([cap_freqs_cis[i][:cap_len], x_freqs_cis[i][:x_len]])) + unified_noise_mask.append( + torch.tensor(cap_noise_mask[i] + x_noise_mask[i], dtype=torch.long, device=device) + ) + unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)] + + unified_max_item_seqlen = max(unified_item_seqlens) + + unified = pad_sequence(unified, batch_first=True, padding_value=0.0) + unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0) + unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(unified_item_seqlens): + unified_attn_mask[i, :seq_len] = 1 + + # Create unified_noise_mask tensor + unified_noise_mask_tensor = pad_sequence(unified_noise_mask, batch_first=True, padding_value=0) + unified_noise_mask_tensor = unified_noise_mask_tensor[:, : unified.shape[1]] + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for layer in self.layers: + unified = self._gradient_checkpointing_func( + layer, unified, unified_attn_mask, unified_freqs_cis, + noise_mask=unified_noise_mask_tensor, adaln_noisy=t_noisy_x, adaln_clean=t_clean_x + ) + else: + for layer in self.layers: + unified = layer( + unified, unified_attn_mask, unified_freqs_cis, + noise_mask=unified_noise_mask_tensor, adaln_noisy=t_noisy_x, adaln_clean=t_clean_x + ) + + unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"]( + unified, noise_mask=unified_noise_mask_tensor, c_noisy=t_noisy_x, c_clean=t_clean_x + ) + + x = self.unpatchify(unified, x_size, patch_size, f_patch_size, x_pos_offsets) + + if not return_dict: + return (x,) + + return Transformer2DModelOutput(sample=x) diff --git a/src/diffusers/models/transformers/transformer_z_image_omni.py b/src/diffusers/models/transformers/transformer_z_image_omni.py deleted file mode 100644 index 65bf141f9079..000000000000 --- a/src/diffusers/models/transformers/transformer_z_image_omni.py +++ /dev/null @@ -1,1005 +0,0 @@ -# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -import einops -from typing import List, Optional, Tuple - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.nn.utils.rnn import pad_sequence - -from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...models.attention_processor import Attention -from ...models.modeling_utils import ModelMixin -from ...models.normalization import RMSNorm -from ...utils.torch_utils import maybe_allow_in_graph -from ..attention_dispatch import dispatch_attention_fn -from ..modeling_outputs import Transformer2DModelOutput - - -ADALN_EMBED_DIM = 256 -SEQ_MULTI_OF = 32 - - -class TimestepEmbedder(nn.Module): - def __init__(self, out_size, mid_size=None, frequency_embedding_size=256): - super().__init__() - if mid_size is None: - mid_size = out_size - self.mlp = nn.Sequential( - nn.Linear(frequency_embedding_size, mid_size, bias=True), - nn.SiLU(), - nn.Linear(mid_size, out_size, bias=True), - ) - - self.frequency_embedding_size = frequency_embedding_size - - @staticmethod - def timestep_embedding(t, dim, max_period=10000): - with torch.amp.autocast("cuda", enabled=False): - half = dim // 2 - freqs = torch.exp( - -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half - ) - args = t[:, None].float() * freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - return embedding - - def forward(self, t): - t_freq = self.timestep_embedding(t, self.frequency_embedding_size) - weight_dtype = self.mlp[0].weight.dtype - compute_dtype = getattr(self.mlp[0], "compute_dtype", None) - if weight_dtype.is_floating_point: - t_freq = t_freq.to(weight_dtype) - elif compute_dtype is not None: - t_freq = t_freq.to(compute_dtype) - t_emb = self.mlp(t_freq) - return t_emb - - -class ZSingleStreamAttnProcessor: - """ - Processor for Z-Image single stream attention that adapts the existing Attention class to match the behavior of the - original Z-ImageAttention module. - """ - - _attention_backend = None - _parallel_config = None - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - "ZSingleStreamAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher." - ) - - def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - freqs_cis: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - query = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - - query = query.unflatten(-1, (attn.heads, -1)) - key = key.unflatten(-1, (attn.heads, -1)) - value = value.unflatten(-1, (attn.heads, -1)) - - # Apply Norms - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # Apply RoPE - def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: - with torch.amp.autocast("cuda", enabled=False): - x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2)) - freqs_cis = freqs_cis.unsqueeze(2) - x_out = torch.view_as_real(x * freqs_cis).flatten(3) - return x_out.type_as(x_in) # todo - - if freqs_cis is not None: - query = apply_rotary_emb(query, freqs_cis) - key = apply_rotary_emb(key, freqs_cis) - - # Cast to correct dtype - dtype = query.dtype - query, key = query.to(dtype), key.to(dtype) - - # From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len] - if attention_mask is not None and attention_mask.ndim == 2: - attention_mask = attention_mask[:, None, None, :] - - # Compute joint attention - hidden_states = dispatch_attention_fn( - query, - key, - value, - attn_mask=attention_mask, - dropout_p=0.0, - is_causal=False, - backend=self._attention_backend, - parallel_config=self._parallel_config, - ) - - # Reshape back - hidden_states = hidden_states.flatten(2, 3) - hidden_states = hidden_states.to(dtype) - - output = attn.to_out[0](hidden_states) - if len(attn.to_out) > 1: # dropout - output = attn.to_out[1](output) - - return output - - -class FeedForward(nn.Module): - def __init__(self, dim: int, hidden_dim: int): - super().__init__() - self.w1 = nn.Linear(dim, hidden_dim, bias=False) - self.w2 = nn.Linear(hidden_dim, dim, bias=False) - self.w3 = nn.Linear(dim, hidden_dim, bias=False) - - def _forward_silu_gating(self, x1, x3): - return F.silu(x1) * x3 - - def forward(self, x): - return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x))) - - -@maybe_allow_in_graph -class ZImageTransformerBlock(nn.Module): - def __init__( - self, - layer_id: int, - dim: int, - n_heads: int, - n_kv_heads: int, - norm_eps: float, - qk_norm: bool, - modulation=True, - ): - super().__init__() - self.dim = dim - self.head_dim = dim // n_heads - - # Refactored to use diffusers Attention with custom processor - # Original Z-Image params: dim, n_heads, n_kv_heads, qk_norm - self.attention = Attention( - query_dim=dim, - cross_attention_dim=None, - dim_head=dim // n_heads, - heads=n_heads, - qk_norm="rms_norm" if qk_norm else None, - eps=1e-5, - bias=False, - out_bias=False, - processor=ZSingleStreamAttnProcessor(), - ) - - self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8)) - self.layer_id = layer_id - - self.attention_norm1 = RMSNorm(dim, eps=norm_eps) - self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) - - self.attention_norm2 = RMSNorm(dim, eps=norm_eps) - self.ffn_norm2 = RMSNorm(dim, eps=norm_eps) - - self.modulation = modulation - if modulation: - self.adaLN_modulation = nn.Sequential(nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True)) - - def forward( - self, - x: torch.Tensor, - attn_mask: torch.Tensor, - freqs_cis: torch.Tensor, - adaln_input: Optional[torch.Tensor] = None, - noise_mask: Optional[torch.Tensor] = None, - adaln_noisy: Optional[torch.Tensor] = None, - adaln_clean: Optional[torch.Tensor] = None, - ): - if self.modulation: - if noise_mask is not None and adaln_noisy is not None and adaln_clean is not None: - # Per-token modulation based on noise_mask - # noise_mask: (batch, seq_len), 1 for noisy, 0 for clean - # adaln_noisy: (batch, embed_dim) for noisy tokens - # adaln_clean: (batch, embed_dim) for clean tokens - batch_size, seq_len = x.shape[0], x.shape[1] - - # Generate modulation for noisy and clean tokens separately - mod_noisy = self.adaLN_modulation(adaln_noisy) # (batch, 4*dim) - mod_clean = self.adaLN_modulation(adaln_clean) # (batch, 4*dim) - - # Split into scale and gate - scale_msa_noisy, gate_msa_noisy, scale_mlp_noisy, gate_mlp_noisy = mod_noisy.chunk(4, dim=1) - scale_msa_clean, gate_msa_clean, scale_mlp_clean, gate_mlp_clean = mod_clean.chunk(4, dim=1) - - # Apply tanh to gates - gate_msa_noisy, gate_mlp_noisy = gate_msa_noisy.tanh(), gate_mlp_noisy.tanh() - gate_msa_clean, gate_mlp_clean = gate_msa_clean.tanh(), gate_mlp_clean.tanh() - - # Add 1 to scales - scale_msa_noisy, scale_mlp_noisy = 1.0 + scale_msa_noisy, 1.0 + scale_mlp_noisy - scale_msa_clean, scale_mlp_clean = 1.0 + scale_msa_clean, 1.0 + scale_mlp_clean - - # Expand to (batch, seq_len, dim) and select based on noise_mask - noise_mask_expanded = noise_mask.unsqueeze(-1) # (batch, seq_len, 1) - scale_msa = torch.where(noise_mask_expanded == 1, - scale_msa_noisy.unsqueeze(1).expand(-1, seq_len, -1), - scale_msa_clean.unsqueeze(1).expand(-1, seq_len, -1)) - scale_mlp = torch.where(noise_mask_expanded == 1, - scale_mlp_noisy.unsqueeze(1).expand(-1, seq_len, -1), - scale_mlp_clean.unsqueeze(1).expand(-1, seq_len, -1)) - gate_msa = torch.where(noise_mask_expanded == 1, - gate_msa_noisy.unsqueeze(1).expand(-1, seq_len, -1), - gate_msa_clean.unsqueeze(1).expand(-1, seq_len, -1)) - gate_mlp = torch.where(noise_mask_expanded == 1, - gate_mlp_noisy.unsqueeze(1).expand(-1, seq_len, -1), - gate_mlp_clean.unsqueeze(1).expand(-1, seq_len, -1)) - else: - # Original global modulation - assert adaln_input is not None - scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2) - gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh() - scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp - - # Attention block - attn_out = self.attention( - self.attention_norm1(x) * scale_msa, attention_mask=attn_mask, freqs_cis=freqs_cis - ) - x = x + gate_msa * self.attention_norm2(attn_out) - - # FFN block - x = x + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(x) * scale_mlp)) - else: - # Attention block - attn_out = self.attention(self.attention_norm1(x), attention_mask=attn_mask, freqs_cis=freqs_cis) - x = x + self.attention_norm2(attn_out) - - # FFN block - x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x))) - - return x - - -class FinalLayer(nn.Module): - def __init__(self, hidden_size, out_channels): - super().__init__() - self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.linear = nn.Linear(hidden_size, out_channels, bias=True) - - self.adaLN_modulation = nn.Sequential( - nn.SiLU(), - nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True), - ) - - def forward(self, x, c=None, noise_mask=None, c_noisy=None, c_clean=None): - if noise_mask is not None and c_noisy is not None and c_clean is not None: - # Per-token modulation based on noise_mask - batch_size, seq_len = x.shape[0], x.shape[1] - scale_noisy = 1.0 + self.adaLN_modulation(c_noisy) # (batch, hidden_size) - scale_clean = 1.0 + self.adaLN_modulation(c_clean) # (batch, hidden_size) - - # Select based on noise_mask - noise_mask_expanded = noise_mask.unsqueeze(-1) # (batch, seq_len, 1) - scale = torch.where(noise_mask_expanded == 1, - scale_noisy.unsqueeze(1).expand(-1, seq_len, -1), - scale_clean.unsqueeze(1).expand(-1, seq_len, -1)) - else: - # Original global modulation - assert c is not None, "Either c or (c_noisy, c_clean) must be provided" - scale = 1.0 + self.adaLN_modulation(c) - scale = scale.unsqueeze(1) - x = self.norm_final(x) * scale - x = self.linear(x) - return x - - -class RopeEmbedder: - def __init__( - self, - theta: float = 256.0, - axes_dims: List[int] = (16, 56, 56), - axes_lens: List[int] = (64, 128, 128), - ): - self.theta = theta - self.axes_dims = axes_dims - self.axes_lens = axes_lens - assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length" - self.freqs_cis = None - - @staticmethod - def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256.0): - with torch.device("cpu"): - freqs_cis = [] - for i, (d, e) in enumerate(zip(dim, end)): - freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) - timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) - freqs = torch.outer(timestep, freqs).float() - freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64 - freqs_cis.append(freqs_cis_i) - - return freqs_cis - - def __call__(self, ids: torch.Tensor): - assert ids.ndim == 2 - assert ids.shape[-1] == len(self.axes_dims) - device = ids.device - - if self.freqs_cis is None: - self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) - self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] - else: - # Ensure freqs_cis are on the same device as ids - if self.freqs_cis[0].device != device: - self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] - - result = [] - for i in range(len(self.axes_dims)): - index = ids[:, i] - result.append(self.freqs_cis[i][index]) - return torch.cat(result, dim=-1) - - -class ZImageOmniTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): - _supports_gradient_checkpointing = True - _no_split_modules = ["ZImageTransformerBlock"] - _repeated_blocks = ["ZImageTransformerBlock"] - _skip_layerwise_casting_patterns = ["t_embedder", "cap_embedder"] # precision sensitive layers - - @register_to_config - def __init__( - self, - all_patch_size=(2,), - all_f_patch_size=(1,), - in_channels=16, - dim=3840, - n_layers=30, - n_refiner_layers=2, - n_heads=30, - n_kv_heads=30, - norm_eps=1e-5, - qk_norm=True, - cap_feat_dim=2560, - siglip_feat_dim=1152, - rope_theta=256.0, - t_scale=1000.0, - axes_dims=[32, 48, 48], - axes_lens=[1024, 512, 512], - ) -> None: - super().__init__() - self.in_channels = in_channels - self.out_channels = in_channels - self.all_patch_size = all_patch_size - self.all_f_patch_size = all_f_patch_size - self.dim = dim - self.n_heads = n_heads - - self.rope_theta = rope_theta - self.t_scale = t_scale - self.gradient_checkpointing = False - - assert len(all_patch_size) == len(all_f_patch_size) - - all_x_embedder = {} - all_final_layer = {} - for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)): - x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * in_channels, dim, bias=True) - all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder - - final_layer = FinalLayer(dim, patch_size * patch_size * f_patch_size * self.out_channels) - all_final_layer[f"{patch_size}-{f_patch_size}"] = final_layer - - self.all_x_embedder = nn.ModuleDict(all_x_embedder) - self.all_final_layer = nn.ModuleDict(all_final_layer) - self.noise_refiner = nn.ModuleList( - [ - ZImageTransformerBlock( - 1000 + layer_id, - dim, - n_heads, - n_kv_heads, - norm_eps, - qk_norm, - modulation=True, - ) - for layer_id in range(n_refiner_layers) - ] - ) - self.context_refiner = nn.ModuleList( - [ - ZImageTransformerBlock( - layer_id, - dim, - n_heads, - n_kv_heads, - norm_eps, - qk_norm, - modulation=False, - ) - for layer_id in range(n_refiner_layers) - ] - ) - self.siglip_refiner = nn.ModuleList( - [ - ZImageTransformerBlock( - 2000 + layer_id, - dim, - n_heads, - n_kv_heads, - norm_eps, - qk_norm, - modulation=False, - ) - for layer_id in range(n_refiner_layers) - ] - ) - self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024) - self.cap_embedder = nn.Sequential(RMSNorm(cap_feat_dim, eps=norm_eps), nn.Linear(cap_feat_dim, dim, bias=True)) - self.siglip_embedder = nn.Sequential(RMSNorm(siglip_feat_dim, eps=norm_eps), nn.Linear(siglip_feat_dim, dim, bias=True)) - - self.x_pad_token = nn.Parameter(torch.empty((1, dim))) - self.cap_pad_token = nn.Parameter(torch.empty((1, dim))) - self.siglip_pad_token = nn.Parameter(torch.empty((1, dim))) - - self.layers = nn.ModuleList( - [ - ZImageTransformerBlock(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm) - for layer_id in range(n_layers) - ] - ) - head_dim = dim // n_heads - assert head_dim == sum(axes_dims) - self.axes_dims = axes_dims - self.axes_lens = axes_lens - - self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens) - - def unpatchify( - self, - unified: List[torch.Tensor], - size: List[Tuple], - patch_size, - f_patch_size, - x_pos_offsets, - ) -> List[torch.Tensor]: - - pH = pW = patch_size - pF = f_patch_size - bsz = len(unified) - assert len(size) == bsz - - x = [] - for i in range(bsz): - x_item = [] - unified_x = unified[i][x_pos_offsets[i][0]:x_pos_offsets[i][1]] - cu_len = 0 - for j in range(len(size[i])): - if size[i][j] is None: - x_item.append(None) - ori_len = 0 - pad_len = SEQ_MULTI_OF - cu_len += (pad_len + ori_len) - else: - F, H, W = size[i][j] - ori_len = (F // pF) * (H // pH) * (W // pW) # without padding - pad_len = (-ori_len) % SEQ_MULTI_OF - # assert ori_len + pad_len == unified_x.shape[0], f"Batch item {i}, patch {j}: ori_len {ori_len} + pad_len {pad_len} != unified_x.shape[0] {unified_x.shape[0]}" - x_item.append(einops.rearrange( - unified_x[cu_len:cu_len + ori_len].view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels), - "f h w pf ph pw c -> c (f pf) (h ph) (w pw)", - )) - cu_len += (ori_len + pad_len) - x.append(x_item[-1]) - return x - - @staticmethod - def create_coordinate_grid(size, start=None, device=None): - if start is None: - start = (0 for _ in size) - - axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)] - grids = torch.meshgrid(axes, indexing="ij") - return torch.stack(grids, dim=-1) - - def patchify_and_embed( - self, - all_x, - all_cap_feats, - all_siglip_feats, - patch_size: int, - f_patch_size: int, - images_noise_mask: List[List[int]] - ): - - bsz = len(all_x) - pH = pW = patch_size - pF = f_patch_size - device = all_x[0][-1].device - dtype = all_x[0][-1].dtype - - all_x_padded = [] - all_x_size = [] - all_x_pos_ids = [] - all_x_pad_mask = [] - all_x_len = [] - all_x_noise_mask = [] - all_cap_padded_feats = [] - all_cap_pos_ids = [] - all_cap_pad_mask = [] - all_cap_len = [] - all_cap_noise_mask = [] - all_siglip_padded_feats = [] - all_siglip_pos_ids = [] - all_siglip_pad_mask = [] - all_siglip_len = [] - all_siglip_noise_mask = [] - - for i in range(bsz): - # process caption - num_images = len(all_x[i]) - cap_padded_feats = [] - cap_item_cu_len = 1 - cap_start_pos = [] - cap_end_pos = [] - cap_padded_pos_ids = [] - cap_pad_mask = [] - cap_len = [] - cap_noise_mask = [] - for j, cap_item in enumerate(all_cap_feats[i]): - cap_item_ori_len = len(cap_item) - cap_item_padding_len = (-cap_item_ori_len) % SEQ_MULTI_OF - cap_len.append(cap_item_ori_len + cap_item_padding_len) - cap_item_padding_pos_ids = ( - self.create_coordinate_grid( - size=(1, 1, 1), - # start=(cap_item_cu_len, 0, 0), - start=(0, 0, 0), - device=device, - ) - .flatten(0, 2) - .repeat(cap_item_padding_len, 1) - ) - cap_start_pos.append(cap_item_cu_len) - # cap_item_cu_len += 1 # for the padding tokens - cap_item_ori_pos_ids = self.create_coordinate_grid( - size=(cap_item_ori_len, 1, 1), - start=(cap_item_cu_len, 0, 0), - device=device, - ).flatten(0, 2) - cap_padded_pos_ids.append(cap_item_ori_pos_ids) - cap_padded_pos_ids.append(cap_item_padding_pos_ids) - cap_item_cu_len += cap_item_ori_len # for the caption tokens - cap_end_pos.append(cap_item_cu_len) - cap_item_cu_len += 2 # for the image vae tokens and siglip tokens - cap_pad_mask.append(torch.zeros((cap_item_ori_len,), dtype=torch.bool, device=device)) - cap_pad_mask.append(torch.ones((cap_item_padding_len,), dtype=torch.bool, device=device)) - cap_item_padded_feat = torch.cat([cap_item, cap_item[-1:].repeat(cap_item_padding_len, 1)], dim=0) - cap_padded_feats.append(cap_item_padded_feat) - if j < len(images_noise_mask[i]): - cap_noise_mask.extend([images_noise_mask[i][j]] * (cap_item_ori_len + cap_item_padding_len)) - else: - cap_noise_mask.extend([1] * (cap_item_ori_len + cap_item_padding_len)) - - all_cap_noise_mask.append(cap_noise_mask) - cap_padded_pos_ids = torch.cat(cap_padded_pos_ids, dim=0) - all_cap_pos_ids.append(cap_padded_pos_ids) - cap_pad_mask = torch.cat(cap_pad_mask, dim=0) - all_cap_pad_mask.append(cap_pad_mask) - all_cap_padded_feats.append(torch.cat(cap_padded_feats, dim=0)) - all_cap_len.append(cap_len) - # print(f"> all_cap_feats[i]: {all_cap_feats[i].shape}", flush=True) - - # process data - x_padded = [] - x_padded_pos_ids = [] - x_pad_mask = [] - x_len = [] - x_size = [] - x_noise_mask = [] - for j, x_item in enumerate(all_x[i]): - if x_item is not None: - # print(i, j, flush=True) - # print(f"x_item: {x_item.shape}", flush=True) - C, F, H, W = x_item.size() - x_size.append((F, H, W)) - F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW - - x_item = x_item.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) - x_item = einops.rearrange(x_item, "c f pf h ph w pw -> (f h w) (pf ph pw c)") - - x_item_ori_len = len(x_item) - x_item_padding_len = (-x_item_ori_len) % SEQ_MULTI_OF - x_len.append(x_item_ori_len + x_item_padding_len) - # padded_pos_ids - x_item_padding_pos_ids = ( - self.create_coordinate_grid( - size=(1, 1, 1), - # start=(cap_start_pos[j], 0, 0), - start=(0, 0, 0), - device=device, - ) - .flatten(0, 2) - .repeat(x_item_padding_len, 1) - ) - x_item_ori_pos_ids = self.create_coordinate_grid( - size=(F_tokens, H_tokens, W_tokens), start=(cap_end_pos[j], 0, 0), device=device - ).flatten(0, 2) - x_padded_pos_ids.append(x_item_ori_pos_ids) - x_padded_pos_ids.append(x_item_padding_pos_ids) - - x_pad_mask.append(torch.zeros((x_item_ori_len,), dtype=torch.bool, device=device)) - x_pad_mask.append(torch.ones((x_item_padding_len,), dtype=torch.bool, device=device)) - x_item_padded_feat = torch.cat([x_item, x_item[-1:].repeat(x_item_padding_len, 1)], dim=0) - x_padded.append(x_item_padded_feat) - x_noise_mask.extend([images_noise_mask[i][j]] * (x_item_ori_len + x_item_padding_len)) - else: - x_pad_dim = 64 - x_item_ori_len = 0 - x_item_padding_len = SEQ_MULTI_OF - x_size.append(None) - x_item_padding_pos_ids = ( - self.create_coordinate_grid( - size=(1, 1, 1), - # start=(cap_start_pos[j], 0, 0), - start=(0, 0, 0), - device=device, - ) - .flatten(0, 2) - .repeat(x_item_padding_len, 1) - ) - x_len.append(x_item_ori_len + x_item_padding_len) - x_padded_pos_ids.append(x_item_padding_pos_ids) - x_pad_mask.append(torch.ones((x_item_padding_len,), dtype=torch.bool, device=device)) - x_padded.append(torch.zeros((x_item_padding_len, x_pad_dim), dtype=dtype, device=device)) - x_noise_mask.extend([images_noise_mask[i][j]] * x_item_padding_len) - - all_x_noise_mask.append(x_noise_mask) - all_x_size.append(x_size) - x_padded_pos_ids = torch.cat(x_padded_pos_ids, dim=0) - all_x_pos_ids.append(x_padded_pos_ids) - x_pad_mask = torch.cat(x_pad_mask, dim=0) - all_x_pad_mask.append(x_pad_mask) - all_x_padded.append(torch.cat(x_padded, dim=0)) - all_x_len.append(x_len) - # print(f"> all_x[i]: {all_x[i].shape}", flush=True) - - # process siglip_feats - if all_siglip_feats[i] is None: - all_siglip_len.append([0 for j in range(num_images)]) - all_siglip_padded_feats.append(None) - else: - sig_padded_feats = [] - sig_padded_pos_ids = [] - sig_pad_mask = [] - sig_len = [] - sig_noise_mask = [] - for j, sig_item in enumerate(all_siglip_feats[i]): - if sig_item is not None: - sig_H, sig_W, sig_C = sig_item.size() - sig_H_tokens, sig_W_tokens, sig_F_tokens = sig_H, sig_W, 1 - - sig_item = sig_item.view(sig_C, sig_F_tokens, 1, sig_H_tokens, 1, sig_W_tokens, 1) - sig_item = einops.rearrange(sig_item, "c f pf h ph w pw -> (f h w) (pf ph pw c)") - - sig_item_ori_len = len(sig_item) - sig_item_padding_len = (-sig_item_ori_len) % SEQ_MULTI_OF - sig_len.append(sig_item_ori_len + sig_item_padding_len) - # padded_pos_ids - sig_item_padding_pos_ids = ( - self.create_coordinate_grid( - size=(1, 1, 1), - # start=(cap_start_pos[j], 0, 0), - start=(0, 0, 0), - device=device, - ) - .flatten(0, 2) - .repeat(sig_item_padding_len, 1) - ) - sig_item_ori_pos_ids = self.create_coordinate_grid( - size=(sig_F_tokens, sig_H_tokens, sig_W_tokens), start=(cap_end_pos[j] + 1, 0, 0), device=device - ) - sig_item_ori_pos_ids[..., 1] = sig_item_ori_pos_ids[..., 1] / (sig_H_tokens - 1) * (x_size[j][1] - 1) - sig_item_ori_pos_ids[..., 2] = sig_item_ori_pos_ids[..., 2] / (sig_W_tokens - 1) * (x_size[j][2] - 1) - sig_item_ori_pos_ids = sig_item_ori_pos_ids.flatten(0, 2) - sig_padded_pos_ids.append(sig_item_ori_pos_ids) - sig_padded_pos_ids.append(sig_item_padding_pos_ids) - - sig_pad_mask.append(torch.zeros((sig_item_ori_len,), dtype=torch.bool, device=device)) - sig_pad_mask.append(torch.ones((sig_item_padding_len,), dtype=torch.bool, device=device)) - sig_item_padded_feat = torch.cat([sig_item, sig_item[-1:].repeat(sig_item_padding_len, 1)], dim=0) - sig_padded_feats.append(sig_item_padded_feat) - sig_noise_mask.extend([images_noise_mask[i][j]] * (sig_item_ori_len + sig_item_padding_len)) - else: - sig_pad_dim = 1152 - sig_item_ori_len = 0 - sig_item_padding_len = SEQ_MULTI_OF - sig_item_padding_pos_ids = ( - self.create_coordinate_grid( - size=(1, 1, 1), - # start=(cap_start_pos[j], 0, 0), - start=(0, 0, 0), - device=device, - ) - .flatten(0, 2) - .repeat(sig_item_padding_len, 1) - ) - sig_padded_pos_ids.append(sig_item_padding_pos_ids) - sig_pad_mask.append(torch.ones((sig_item_padding_len,), dtype=torch.bool, device=device)) - sig_padded_feats.append(torch.zeros((sig_item_padding_len, sig_pad_dim), dtype=dtype, device=device)) - sig_noise_mask.extend([images_noise_mask[i][j]] * sig_item_padding_len) - - all_siglip_noise_mask.append(sig_noise_mask) - sig_padded_pos_ids = torch.cat(sig_padded_pos_ids, dim=0) - all_siglip_pos_ids.append(sig_padded_pos_ids) - sig_pad_mask = torch.cat(sig_pad_mask, dim=0) - all_siglip_pad_mask.append(sig_pad_mask) - all_siglip_padded_feats.append(torch.cat(sig_padded_feats, dim=0)) - all_siglip_len.append(sig_len) - # print(f"> all_siglip_feats[i]: {all_siglip_feats[i].shape}", flush=True) - - all_x_pos_offsets = [] - for i in range(bsz): - start = sum(all_cap_len[i]) - end = start + sum(all_x_len[i]) - all_x_pos_offsets.append((start, end)) - assert all_x_padded[i].shape[0] + all_cap_padded_feats[i].shape[0] == sum(all_cap_len[i]) + sum(all_x_len[i]), f"Batch item {i}: x length {all_x_padded[i].shape[0]} + cap length {all_cap_padded_feats[i].shape[0]} != sum(all_cap_len[i]) + sum(all_x_len[i]) {sum(all_cap_len[i]) + sum(all_x_len[i])}" - - return ( - all_x_padded, - all_cap_padded_feats, - all_siglip_padded_feats, - all_x_size, - all_x_pos_ids, - all_cap_pos_ids, - all_siglip_pos_ids, - all_x_pad_mask, - all_cap_pad_mask, - all_siglip_pad_mask, - all_x_pos_offsets, - all_x_noise_mask, - all_cap_noise_mask, - all_siglip_noise_mask, - ) - - def forward( - self, - x: List[torch.Tensor], - t, - cap_feats: List[List[torch.Tensor]], - cond_latents: List[List[torch.Tensor]], - siglip_feats: List[List[torch.Tensor]], - patch_size=2, - f_patch_size=1, - return_dict: bool = True, - ): - assert patch_size in self.all_patch_size - assert f_patch_size in self.all_f_patch_size - - bsz = len(x) - device = x[0].device - t = torch.cat([t, torch.ones_like(t, dtype=t.dtype, device=device)], dim=0) # (N, D) -> (2N, D) - # t = torch.cat([t, t], dim=0) # (N, D) -> (2N, D) - t = t * self.t_scale - t = self.t_embedder(t) # (2N, embed_dim): first N are original t, last N are t=1 - - # Split t into noisy and clean embeddings - t_noisy = t[:bsz] # (bsz, embed_dim) - original t - t_clean = t[bsz:] # (bsz, embed_dim) - t=1 - - x = [cond_latents[i] + [x[i]] for i in range(bsz)] - image_noise_mask = [[0] * (len(x[i]) - 1) + [1] for i in range(bsz)] - # print(len(x[0]), len(cap_feats[0]), len(siglip_feats[0]), len(image_noise_mask[0])) - # print([[_.shape, _.mean(), _.std()] for _ in x[0]]) - - ( - x, - cap_feats, - siglip_feats, - x_size, - x_pos_ids, - cap_pos_ids, - siglip_pos_ids, - x_inner_pad_mask, - cap_inner_pad_mask, - siglip_inner_pad_mask, - x_pos_offsets, - x_noise_mask, - cap_noise_mask, - siglip_noise_mask - ) = self.patchify_and_embed( - x, cap_feats, siglip_feats, patch_size, f_patch_size, image_noise_mask - ) - - # x embed & refine - x_item_seqlens = [len(_) for _ in x] - assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens) - x_max_item_seqlen = max(x_item_seqlens) - - x = torch.cat(x, dim=0) - x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x) - - x[torch.cat(x_inner_pad_mask)] = self.x_pad_token - x = list(x.split(x_item_seqlens, dim=0)) - x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split([len(_) for _ in x_pos_ids], dim=0)) - - x = pad_sequence(x, batch_first=True, padding_value=0.0) - x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0) - x_freqs_cis = x_freqs_cis[:, : x.shape[1]] - - # Create x_noise_mask tensor matching x shape - x_noise_mask_tensor = [] - for i in range(bsz): - x_mask = torch.tensor(x_noise_mask[i], dtype=torch.long, device=device) - x_noise_mask_tensor.append(x_mask) - x_noise_mask_tensor = pad_sequence(x_noise_mask_tensor, batch_first=True, padding_value=0) - x_noise_mask_tensor = x_noise_mask_tensor[:, : x.shape[1]] - - # Match t_embedder output dtype to x for layerwise casting compatibility - t_noisy_x = t_noisy.type_as(x) - t_clean_x = t_clean.type_as(x) - - x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device) - for i, seq_len in enumerate(x_item_seqlens): - x_attn_mask[i, :seq_len] = 1 - - if torch.is_grad_enabled() and self.gradient_checkpointing: - for layer in self.noise_refiner: - x = self._gradient_checkpointing_func(layer, x, x_attn_mask, x_freqs_cis, - noise_mask=x_noise_mask_tensor, - adaln_noisy=t_noisy_x, adaln_clean=t_clean_x) - else: - for layer in self.noise_refiner: - x = layer(x, x_attn_mask, x_freqs_cis, - noise_mask=x_noise_mask_tensor, - adaln_noisy=t_noisy_x, adaln_clean=t_clean_x) - - # cap embed & refine (no modulation, so no changes needed) - cap_item_seqlens = [len(_) for _ in cap_feats] - cap_max_item_seqlen = max(cap_item_seqlens) - - cap_feats = torch.cat(cap_feats, dim=0) - cap_feats = self.cap_embedder(cap_feats) - cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token - cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0)) - cap_freqs_cis = list( - self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split([len(_) for _ in cap_pos_ids], dim=0) - ) - - cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0) - cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0) - # Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors - cap_freqs_cis = cap_freqs_cis[:, : cap_feats.shape[1]] - - cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device) - for i, seq_len in enumerate(cap_item_seqlens): - cap_attn_mask[i, :seq_len] = 1 - - if torch.is_grad_enabled() and self.gradient_checkpointing: - for layer in self.context_refiner: - cap_feats = self._gradient_checkpointing_func(layer, cap_feats, cap_attn_mask, cap_freqs_cis) - else: - for layer in self.context_refiner: - cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis) - - # siglip embed & refine - if siglip_feats[0] is not None: - siglip_item_seqlens = [len(_) for _ in siglip_feats] - siglip_max_item_seqlen = max(siglip_item_seqlens) - - siglip_feats = torch.cat(siglip_feats, dim=0) - siglip_feats = self.siglip_embedder(siglip_feats) - siglip_feats[torch.cat(siglip_inner_pad_mask)] = self.siglip_pad_token - siglip_feats = list(siglip_feats.split(siglip_item_seqlens, dim=0)) - siglip_freqs_cis = list( - self.rope_embedder(torch.cat(siglip_pos_ids, dim=0)).split([len(_) for _ in siglip_pos_ids], dim=0) - ) - - siglip_feats = pad_sequence(siglip_feats, batch_first=True, padding_value=0.0) - siglip_freqs_cis = pad_sequence(siglip_freqs_cis, batch_first=True, padding_value=0.0) - # Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors - siglip_freqs_cis = siglip_freqs_cis[:, : siglip_feats.shape[1]] - - siglip_attn_mask = torch.zeros((bsz, siglip_max_item_seqlen), dtype=torch.bool, device=device) - for i, seq_len in enumerate(siglip_item_seqlens): - siglip_attn_mask[i, :seq_len] = 1 - - if torch.is_grad_enabled() and self.gradient_checkpointing: - for layer in self.siglip_refiner: - siglip_feats = self._gradient_checkpointing_func(layer, siglip_feats, siglip_attn_mask, siglip_freqs_cis) - else: - for layer in self.siglip_refiner: - siglip_feats = layer(siglip_feats, siglip_attn_mask, siglip_freqs_cis) - - # unified - unified = [] - unified_freqs_cis = [] - unified_noise_mask = [] - if siglip_feats[0] is not None: - for i in range(bsz): - x_len = x_item_seqlens[i] - cap_len = cap_item_seqlens[i] - siglip_len = siglip_item_seqlens[i] - unified.append(torch.cat([cap_feats[i][:cap_len], x[i][:x_len], siglip_feats[i][:siglip_len]])) - unified_freqs_cis.append(torch.cat([cap_freqs_cis[i][:cap_len], x_freqs_cis[i][:x_len], siglip_freqs_cis[i][:siglip_len]])) - # Merge masks: cap_noise_mask + x_noise_mask + siglip_noise_mask - unified_noise_mask.append(torch.tensor( - cap_noise_mask[i] + x_noise_mask[i] + siglip_noise_mask[i], - dtype=torch.long, device=device - )) - unified_item_seqlens = [a + b + c for a, b, c in zip(cap_item_seqlens, x_item_seqlens, siglip_item_seqlens)] - else: - for i in range(bsz): - x_len = x_item_seqlens[i] - cap_len = cap_item_seqlens[i] - unified.append(torch.cat([cap_feats[i][:cap_len], x[i][:x_len]])) - unified_freqs_cis.append(torch.cat([cap_freqs_cis[i][:cap_len], x_freqs_cis[i][:x_len]])) - # Merge masks: cap_noise_mask + x_noise_mask - unified_noise_mask.append(torch.tensor( - cap_noise_mask[i] + x_noise_mask[i], - dtype=torch.long, device=device - )) - unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)] - assert unified_item_seqlens == [len(_) for _ in unified] - unified_max_item_seqlen = max(unified_item_seqlens) - - unified = pad_sequence(unified, batch_first=True, padding_value=0.0) - unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0) - unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device) - for i, seq_len in enumerate(unified_item_seqlens): - unified_attn_mask[i, :seq_len] = 1 - - # Create unified_noise_mask tensor matching unified shape - unified_noise_mask_tensor = pad_sequence(unified_noise_mask, batch_first=True, padding_value=0) - unified_noise_mask_tensor = unified_noise_mask_tensor[:, : unified.shape[1]] - - if torch.is_grad_enabled() and self.gradient_checkpointing: - for layer in self.layers: - unified = self._gradient_checkpointing_func( - layer, unified, unified_attn_mask, unified_freqs_cis, - noise_mask=unified_noise_mask_tensor, - adaln_noisy=t_noisy_x, adaln_clean=t_clean_x - ) - else: - for layer in self.layers: - # print(unified.shape, unified_noise_mask_tensor.shape) - unified = layer(unified, unified_attn_mask, unified_freqs_cis, - noise_mask=unified_noise_mask_tensor, - adaln_noisy=t_noisy_x, adaln_clean=t_clean_x) - - unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"]( - unified, - noise_mask=unified_noise_mask_tensor, - c_noisy=t_noisy_x, c_clean=t_clean_x - ) - - # unified = list(unified.unbind(dim=0)) - # x = self.unpatchify(unified, x_size, patch_size, f_patch_size) - x = self.unpatchify(unified, x_size, patch_size, f_patch_size, x_pos_offsets) - - if not return_dict: - return (x,) - - return Transformer2DModelOutput(sample=x) diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_omni.py b/src/diffusers/pipelines/z_image/pipeline_z_image_omni.py index 62e8b7de5031..9db6226f6f31 100644 --- a/src/diffusers/pipelines/z_image/pipeline_z_image_omni.py +++ b/src/diffusers/pipelines/z_image/pipeline_z_image_omni.py @@ -31,7 +31,7 @@ from ..flux2.image_processor import Flux2ImageProcessor -from ...models.transformers.transformer_z_image_omni import ZImageOmniTransformer2DModel +from ...models.transformers import ZImageTransformer2DModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -149,7 +149,7 @@ def __init__( vae: AutoencoderKL, text_encoder: PreTrainedModel, tokenizer: AutoTokenizer, - transformer: ZImageOmniTransformer2DModel, + transformer: ZImageTransformer2DModel, siglip: Siglip2VisionModel, siglip_processor: Siglip2ImageProcessorFast, ): From 9180579b484c2f4320770b071d52cdcb22f36cdc Mon Sep 17 00:00:00 2001 From: Jerry Qilong Wu Date: Fri, 19 Dec 2025 07:58:12 +0000 Subject: [PATCH 3/6] Fix bugs for controlnet after merging the main branch new feature. --- .../transformers/transformer_z_image.py | 19 +++++++++++++++---- .../z_image/pipeline_z_image_omni.py | 12 ++++++------ 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_z_image.py b/src/diffusers/models/transformers/transformer_z_image.py index b06a537f31ee..fea1024c353a 100644 --- a/src/diffusers/models/transformers/transformer_z_image.py +++ b/src/diffusers/models/transformers/transformer_z_image.py @@ -916,16 +916,20 @@ def forward( if omni_mode: return self._forward_omni( - x, t, cap_feats, cond_latents, siglip_feats, patch_size, f_patch_size, return_dict + x, t, cap_feats, cond_latents, siglip_feats, + controlnet_block_samples, patch_size, f_patch_size, return_dict ) else: - return self._forward_basic(x, t, cap_feats, patch_size, f_patch_size, return_dict) + return self._forward_basic( + x, t, cap_feats, controlnet_block_samples, patch_size, f_patch_size, return_dict + ) def _forward_basic( self, x: List[torch.Tensor], t, cap_feats: List[torch.Tensor], + controlnet_block_samples: Optional[Dict[int, torch.Tensor]], patch_size: int, f_patch_size: int, return_dict: bool, @@ -1053,6 +1057,7 @@ def _forward_omni( cap_feats: List[List[torch.Tensor]], cond_latents: List[List[torch.Tensor]], siglip_feats: List[List[torch.Tensor]], + controlnet_block_samples: Optional[Dict[int, torch.Tensor]], patch_size: int, f_patch_size: int, return_dict: bool, @@ -1241,17 +1246,23 @@ def _forward_omni( unified_noise_mask_tensor = unified_noise_mask_tensor[:, : unified.shape[1]] if torch.is_grad_enabled() and self.gradient_checkpointing: - for layer in self.layers: + for layer_idx, layer in enumerate(self.layers): unified = self._gradient_checkpointing_func( layer, unified, unified_attn_mask, unified_freqs_cis, noise_mask=unified_noise_mask_tensor, adaln_noisy=t_noisy_x, adaln_clean=t_clean_x ) + if controlnet_block_samples is not None: + if layer_idx in controlnet_block_samples: + unified = unified + controlnet_block_samples[layer_idx] else: - for layer in self.layers: + for layer_idx, layer in enumerate(self.layers): unified = layer( unified, unified_attn_mask, unified_freqs_cis, noise_mask=unified_noise_mask_tensor, adaln_noisy=t_noisy_x, adaln_clean=t_clean_x ) + if controlnet_block_samples is not None: + if layer_idx in controlnet_block_samples: + unified = unified + controlnet_block_samples[layer_idx] unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"]( unified, noise_mask=unified_noise_mask_tensor, c_noisy=t_noisy_x, c_clean=t_clean_x diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_omni.py b/src/diffusers/pipelines/z_image/pipeline_z_image_omni.py index 9db6226f6f31..05b27216d6ee 100644 --- a/src/diffusers/pipelines/z_image/pipeline_z_image_omni.py +++ b/src/diffusers/pipelines/z_image/pipeline_z_image_omni.py @@ -659,12 +659,12 @@ def __call__( latent_model_input_list = list(latent_model_input.unbind(dim=0)) model_out_list = self.transformer( - latent_model_input_list, - timestep_model_input, - prompt_embeds_model_input, - condition_latents_model_input, - condition_siglip_embeds_model_input, - return_dict=False + x=latent_model_input_list, + t=timestep_model_input, + cap_feats=prompt_embeds_model_input, + cond_latents=condition_latents_model_input, + siglip_feats=condition_siglip_embeds_model_input, + return_dict=False, )[0] if apply_cfg: From 4c14cf3db442977b2bfe2e8f5f1e242424e024f6 Mon Sep 17 00:00:00 2001 From: Jerry Qilong Wu Date: Fri, 19 Dec 2025 08:12:12 +0000 Subject: [PATCH 4/6] Fix for auto_pipeline, Add Styling. --- src/diffusers/__init__.py | 4 +- .../transformers/transformer_z_image.py | 61 +++++++++++++------ src/diffusers/pipelines/auto_pipeline.py | 11 +++- .../z_image/pipeline_z_image_omni.py | 23 ++++--- 4 files changed, 67 insertions(+), 32 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 95a1e1492798..aa11a741af38 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -675,8 +675,8 @@ "ZImageControlNetInpaintPipeline", "ZImageControlNetPipeline", "ZImageImg2ImgPipeline", - "ZImagePipeline", "ZImageOmniPipeline", + "ZImagePipeline", ] ) @@ -1387,8 +1387,8 @@ ZImageControlNetInpaintPipeline, ZImageControlNetPipeline, ZImageImg2ImgPipeline, - ZImagePipeline, ZImageOmniPipeline, + ZImagePipeline, ) try: diff --git a/src/diffusers/models/transformers/transformer_z_image.py b/src/diffusers/models/transformers/transformer_z_image.py index fea1024c353a..a3e356f1e681 100644 --- a/src/diffusers/models/transformers/transformer_z_image.py +++ b/src/diffusers/models/transformers/transformer_z_image.py @@ -222,7 +222,7 @@ def forward( if self.modulation: if noise_mask is not None and adaln_noisy is not None and adaln_clean is not None: # Per-token modulation based on noise_mask, (batch, seq_len), 1 for noisy tokens, 0 for clean tokens - batch_size, seq_len = x.shape[0], x.shape[1] + _, seq_len = x.shape[0], x.shape[1] mod_noisy = self.adaLN_modulation(adaln_noisy) mod_clean = self.adaLN_modulation(adaln_clean) @@ -260,7 +260,9 @@ def forward( else: # Original global modulation assert adaln_input is not None - scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2) + scale_msa, gate_msa, scale_mlp, gate_mlp = ( + self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2) + ) gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh() scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp @@ -297,7 +299,7 @@ def __init__(self, hidden_size, out_channels): def forward(self, x, c=None, noise_mask=None, c_noisy=None, c_clean=None): if noise_mask is not None and c_noisy is not None and c_clean is not None: # Per-token modulation based on noise_mask - batch_size, seq_len = x.shape[0], x.shape[1] + _, seq_len = x.shape[0], x.shape[1] scale_noisy = 1.0 + self.adaLN_modulation(c_noisy) scale_clean = 1.0 + self.adaLN_modulation(c_clean) @@ -916,8 +918,15 @@ def forward( if omni_mode: return self._forward_omni( - x, t, cap_feats, cond_latents, siglip_feats, - controlnet_block_samples, patch_size, f_patch_size, return_dict + x, + t, + cap_feats, + cond_latents, + siglip_feats, + controlnet_block_samples, + patch_size, + f_patch_size, + return_dict, ) else: return self._forward_basic( @@ -1130,14 +1139,23 @@ def _forward_omni( if torch.is_grad_enabled() and self.gradient_checkpointing: for layer in self.noise_refiner: x = self._gradient_checkpointing_func( - layer, x, x_attn_mask, x_freqs_cis, - noise_mask=x_noise_mask_tensor, adaln_noisy=t_noisy_x, adaln_clean=t_clean_x + layer, + x, + x_attn_mask, + x_freqs_cis, + noise_mask=x_noise_mask_tensor, + adaln_noisy=t_noisy_x, + adaln_clean=t_clean_x, ) else: for layer in self.noise_refiner: x = layer( - x, x_attn_mask, x_freqs_cis, - noise_mask=x_noise_mask_tensor, adaln_noisy=t_noisy_x, adaln_clean=t_clean_x + x, + x_attn_mask, + x_freqs_cis, + noise_mask=x_noise_mask_tensor, + adaln_noisy=t_noisy_x, + adaln_clean=t_clean_x, ) # cap embed & refine (no modulation) @@ -1208,9 +1226,7 @@ def _forward_omni( x_len = x_item_seqlens[i] cap_len = cap_item_seqlens[i] siglip_len = siglip_item_seqlens[i] - unified.append( - torch.cat([cap_feats[i][:cap_len], x[i][:x_len], siglip_feats[i][:siglip_len]]) - ) + unified.append(torch.cat([cap_feats[i][:cap_len], x[i][:x_len], siglip_feats[i][:siglip_len]])) unified_freqs_cis.append( torch.cat([cap_freqs_cis[i][:cap_len], x_freqs_cis[i][:x_len], siglip_freqs_cis[i][:siglip_len]]) ) @@ -1221,7 +1237,9 @@ def _forward_omni( device=device, ) ) - unified_item_seqlens = [a + b + c for a, b, c in zip(cap_item_seqlens, x_item_seqlens, siglip_item_seqlens)] + unified_item_seqlens = [ + a + b + c for a, b, c in zip(cap_item_seqlens, x_item_seqlens, siglip_item_seqlens) + ] else: for i in range(bsz): x_len = x_item_seqlens[i] @@ -1248,8 +1266,13 @@ def _forward_omni( if torch.is_grad_enabled() and self.gradient_checkpointing: for layer_idx, layer in enumerate(self.layers): unified = self._gradient_checkpointing_func( - layer, unified, unified_attn_mask, unified_freqs_cis, - noise_mask=unified_noise_mask_tensor, adaln_noisy=t_noisy_x, adaln_clean=t_clean_x + layer, + unified, + unified_attn_mask, + unified_freqs_cis, + noise_mask=unified_noise_mask_tensor, + adaln_noisy=t_noisy_x, + adaln_clean=t_clean_x, ) if controlnet_block_samples is not None: if layer_idx in controlnet_block_samples: @@ -1257,8 +1280,12 @@ def _forward_omni( else: for layer_idx, layer in enumerate(self.layers): unified = layer( - unified, unified_attn_mask, unified_freqs_cis, - noise_mask=unified_noise_mask_tensor, adaln_noisy=t_noisy_x, adaln_clean=t_clean_x + unified, + unified_attn_mask, + unified_freqs_cis, + noise_mask=unified_noise_mask_tensor, + adaln_noisy=t_noisy_x, + adaln_clean=t_clean_x, ) if controlnet_block_samples is not None: if layer_idx in controlnet_block_samples: diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index db0268a2a73d..7a2e58c3196e 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -119,7 +119,13 @@ ) from .wan import WanImageToVideoPipeline, WanPipeline, WanVideoToVideoPipeline from .wuerstchen import WuerstchenCombinedPipeline, WuerstchenDecoderPipeline -from .z_image import ZImageImg2ImgPipeline, ZImagePipeline +from .z_image import ( + ZImageControlNetInpaintPipeline, + ZImageControlNetPipeline, + ZImageImg2ImgPipeline, + ZImageOmniPipeline, + ZImagePipeline, +) AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict( @@ -164,6 +170,9 @@ ("qwenimage", QwenImagePipeline), ("qwenimage-controlnet", QwenImageControlNetPipeline), ("z-image", ZImagePipeline), + ("z-image-controlnet", ZImageControlNetPipeline), + ("z-image-controlnet-inpaint", ZImageControlNetInpaintPipeline), + ("z-image-omni", ZImageOmniPipeline), ] ) diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_omni.py b/src/diffusers/pipelines/z_image/pipeline_z_image_omni.py index 05b27216d6ee..9d64821fe549 100644 --- a/src/diffusers/pipelines/z_image/pipeline_z_image_omni.py +++ b/src/diffusers/pipelines/z_image/pipeline_z_image_omni.py @@ -16,22 +16,19 @@ from typing import Any, Callable, Dict, List, Optional, Union import PIL - import torch from transformers import AutoTokenizer, PreTrainedModel, Siglip2ImageProcessorFast, Siglip2VisionModel -from ...image_processor import VaeImageProcessor from ...loaders import FromSingleFileMixin, ZImageLoraLoaderMixin from ...models.autoencoders import AutoencoderKL +from ...models.transformers import ZImageTransformer2DModel from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import logging, replace_example_docstring from ...utils.torch_utils import randn_tensor -from .pipeline_output import ZImagePipelineOutput - from ..flux2.image_processor import Flux2ImageProcessor +from .pipeline_output import ZImagePipelineOutput -from ...models.transformers import ZImageTransformer2DModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -224,7 +221,6 @@ def _encode_prompt( prompt = [prompt] for i, prompt_item in enumerate(prompt): - if num_condition_images == 0: prompt[i] = ["<|im_start|>user\n" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n"] elif num_condition_images > 0: @@ -236,7 +232,7 @@ def _encode_prompt( flattened_prompt = [] prompt_list_lengths = [] - + for i in range(len(prompt)): prompt_list_lengths.append(len(prompt[i])) flattened_prompt.extend(prompt[i]) @@ -304,14 +300,15 @@ def prepare_image_latents( image_latents = [] for image in images: image = image.to(device=device, dtype=dtype) - image_latent = (self.vae.encode(image.bfloat16()).latent_dist.mode()[0] - self.vae.config.shift_factor) * self.vae.config.scaling_factor + image_latent = ( + self.vae.encode(image.bfloat16()).latent_dist.mode()[0] - self.vae.config.shift_factor + ) * self.vae.config.scaling_factor image_latent = image_latent.unsqueeze(1).to(dtype) image_latents.append(image_latent) # (16, 128, 128) # image_latents = [image_latents] * batch_size image_latents = [image_latents.copy() for _ in range(batch_size)] - return image_latents def prepare_siglip_embeds( @@ -327,7 +324,7 @@ def prepare_siglip_embeds( shape = siglip_inputs.spatial_shapes[0] hidden_state = self.siglip(**siglip_inputs).last_hidden_state B, N, C = hidden_state.shape - hidden_state = hidden_state[:, :shape[0] * shape[1]] + hidden_state = hidden_state[:, : shape[0] * shape[1]] hidden_state = hidden_state.view(shape[0], shape[1], C) siglip_embeds.append(hidden_state.to(dtype)) @@ -529,7 +526,7 @@ def __call__( image_height = (image_height // multiple_of) * multiple_of img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop") condition_images.append(img) - + if len(condition_images) > 0: height = height or image_height width = width or image_width @@ -591,7 +588,9 @@ def __call__( negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)] condition_siglip_embeds = [None if sels == [] else sels + [None] for sels in condition_siglip_embeds] - negative_condition_siglip_embeds = [None if sels == [] else sels + [None] for sels in negative_condition_siglip_embeds] + negative_condition_siglip_embeds = [ + None if sels == [] else sels + [None] for sels in negative_condition_siglip_embeds + ] actual_batch_size = batch_size * num_images_per_prompt image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2) From 5bc676c0ab576d8b566d2b23e67bf61b28ad836a Mon Sep 17 00:00:00 2001 From: Jerry Qilong Wu Date: Fri, 19 Dec 2025 14:56:23 +0000 Subject: [PATCH 5/6] Refactor noise handling and modulation - Add select_per_token function for per-token value selection - Separate adaptive modulation logic - Cleanify t_noisy/clean variable naming - Move image_noise_mask handler from forward to pipeline --- .../transformers/transformer_z_image.py | 103 ++++++++---------- .../z_image/pipeline_z_image_omni.py | 16 ++- 2 files changed, 57 insertions(+), 62 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_z_image.py b/src/diffusers/models/transformers/transformer_z_image.py index a3e356f1e681..883b16e27525 100644 --- a/src/diffusers/models/transformers/transformer_z_image.py +++ b/src/diffusers/models/transformers/transformer_z_image.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -152,6 +152,20 @@ def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tenso return output +def select_per_token( + value_noisy: torch.Tensor, + value_clean: torch.Tensor, + noise_mask: torch.Tensor, + seq_len: int, +) -> torch.Tensor: + noise_mask_expanded = noise_mask.unsqueeze(-1) # (batch, seq_len, 1) + return torch.where( + noise_mask_expanded == 1, + value_noisy.unsqueeze(1).expand(-1, seq_len, -1), + value_clean.unsqueeze(1).expand(-1, seq_len, -1), + ) + + class FeedForward(nn.Module): def __init__(self, dim: int, hidden_dim: int): super().__init__() @@ -220,10 +234,10 @@ def forward( adaln_clean: Optional[torch.Tensor] = None, ): if self.modulation: - if noise_mask is not None and adaln_noisy is not None and adaln_clean is not None: - # Per-token modulation based on noise_mask, (batch, seq_len), 1 for noisy tokens, 0 for clean tokens - _, seq_len = x.shape[0], x.shape[1] + seq_len = x.shape[1] + if noise_mask is not None: + # Per-token modulation: different modulation for noisy/clean tokens mod_noisy = self.adaLN_modulation(adaln_noisy) mod_clean = self.adaLN_modulation(adaln_clean) @@ -236,33 +250,14 @@ def forward( scale_msa_noisy, scale_mlp_noisy = 1.0 + scale_msa_noisy, 1.0 + scale_mlp_noisy scale_msa_clean, scale_mlp_clean = 1.0 + scale_msa_clean, 1.0 + scale_mlp_clean - noise_mask_expanded = noise_mask.unsqueeze(-1) # (batch, seq_len, 1) - scale_msa = torch.where( - noise_mask_expanded == 1, - scale_msa_noisy.unsqueeze(1).expand(-1, seq_len, -1), - scale_msa_clean.unsqueeze(1).expand(-1, seq_len, -1), - ) - scale_mlp = torch.where( - noise_mask_expanded == 1, - scale_mlp_noisy.unsqueeze(1).expand(-1, seq_len, -1), - scale_mlp_clean.unsqueeze(1).expand(-1, seq_len, -1), - ) - gate_msa = torch.where( - noise_mask_expanded == 1, - gate_msa_noisy.unsqueeze(1).expand(-1, seq_len, -1), - gate_msa_clean.unsqueeze(1).expand(-1, seq_len, -1), - ) - gate_mlp = torch.where( - noise_mask_expanded == 1, - gate_mlp_noisy.unsqueeze(1).expand(-1, seq_len, -1), - gate_mlp_clean.unsqueeze(1).expand(-1, seq_len, -1), - ) + scale_msa = select_per_token(scale_msa_noisy, scale_msa_clean, noise_mask, seq_len) + scale_mlp = select_per_token(scale_mlp_noisy, scale_mlp_clean, noise_mask, seq_len) + gate_msa = select_per_token(gate_msa_noisy, gate_msa_clean, noise_mask, seq_len) + gate_mlp = select_per_token(gate_mlp_noisy, gate_mlp_clean, noise_mask, seq_len) else: - # Original global modulation - assert adaln_input is not None - scale_msa, gate_msa, scale_mlp, gate_mlp = ( - self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2) - ) + # Global modulation: same modulation for all tokens (avoid double select) + mod = self.adaLN_modulation(adaln_input) + scale_msa, gate_msa, scale_mlp, gate_mlp = mod.unsqueeze(1).chunk(4, dim=2) gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh() scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp @@ -297,18 +292,13 @@ def __init__(self, hidden_size, out_channels): ) def forward(self, x, c=None, noise_mask=None, c_noisy=None, c_clean=None): - if noise_mask is not None and c_noisy is not None and c_clean is not None: - # Per-token modulation based on noise_mask - _, seq_len = x.shape[0], x.shape[1] + seq_len = x.shape[1] + + if noise_mask is not None: + # Per-token modulation scale_noisy = 1.0 + self.adaLN_modulation(c_noisy) scale_clean = 1.0 + self.adaLN_modulation(c_clean) - - noise_mask_expanded = noise_mask.unsqueeze(-1) - scale = torch.where( - noise_mask_expanded == 1, - scale_noisy.unsqueeze(1).expand(-1, seq_len, -1), - scale_clean.unsqueeze(1).expand(-1, seq_len, -1), - ) + scale = select_per_token(scale_noisy, scale_clean, noise_mask, seq_len) else: # Original global modulation assert c is not None, "Either c or (c_noisy, c_clean) must be provided" @@ -900,29 +890,29 @@ def patchify_and_embed_omni( def forward( self, - x: List[torch.Tensor], + x: Union[List[torch.Tensor], List[List[torch.Tensor]]], t, - cap_feats: List[torch.Tensor], + cap_feats: Union[List[torch.Tensor], List[List[torch.Tensor]]], controlnet_block_samples: Optional[Dict[int, torch.Tensor]] = None, - cond_latents: Optional[List[List[torch.Tensor]]] = None, siglip_feats: Optional[List[List[torch.Tensor]]] = None, - patch_size=2, - f_patch_size=1, + image_noise_mask: Optional[List[List[int]]] = None, + patch_size: int = 2, + f_patch_size: int = 1, return_dict: bool = True, ): assert patch_size in self.all_patch_size assert f_patch_size in self.all_f_patch_size - # Determine mode based on cond_latents - omni_mode = cond_latents is not None + # Omni mode: x contains lists (multi-image input) + omni_mode = isinstance(x[0], list) if omni_mode: return self._forward_omni( x, t, cap_feats, - cond_latents, siglip_feats, + image_noise_mask, controlnet_block_samples, patch_size, f_patch_size, @@ -1061,11 +1051,11 @@ def _forward_basic( def _forward_omni( self, - x: List[torch.Tensor], + x: List[List[torch.Tensor]], t, cap_feats: List[List[torch.Tensor]], - cond_latents: List[List[torch.Tensor]], siglip_feats: List[List[torch.Tensor]], + image_noise_mask: List[List[int]], controlnet_block_samples: Optional[Dict[int, torch.Tensor]], patch_size: int, f_patch_size: int, @@ -1073,18 +1063,11 @@ def _forward_omni( ): """Omni mode forward pass with image conditioning.""" bsz = len(x) - device = x[0].device + device = x[0][-1].device # From target latent # Create dual timestep embeddings: one for noisy tokens (t), one for clean tokens (t=1) - t_combined = torch.cat([t, torch.ones_like(t, dtype=t.dtype, device=device)], dim=0) - t_combined = t_combined * self.t_scale - t_combined = self.t_embedder(t_combined) - t_noisy = t_combined[:bsz] # Original timestep for noisy tokens - t_clean = t_combined[bsz:] # t=1 for clean (condition) tokens - - # Combine condition latents with target latent - x = [cond_latents[i] + [x[i]] for i in range(bsz)] - image_noise_mask = [[0] * (len(x[i]) - 1) + [1] for i in range(bsz)] + t_noisy = self.t_embedder(t * self.t_scale) + t_clean = self.t_embedder(torch.ones_like(t) * self.t_scale) # Patchify and embed for Omni mode ( diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_omni.py b/src/diffusers/pipelines/z_image/pipeline_z_image_omni.py index 9d64821fe549..51aa1c5b8260 100644 --- a/src/diffusers/pipelines/z_image/pipeline_z_image_omni.py +++ b/src/diffusers/pipelines/z_image/pipeline_z_image_omni.py @@ -657,12 +657,24 @@ def __call__( latent_model_input = latent_model_input.unsqueeze(2) latent_model_input_list = list(latent_model_input.unbind(dim=0)) + # Combine condition latents with target latent + current_batch_size = len(latent_model_input_list) + x_combined = [ + condition_latents_model_input[i] + [latent_model_input_list[i]] + for i in range(current_batch_size) + ] + # Create noise mask: 0 for condition images (clean), 1 for target image (noisy) + image_noise_mask = [ + [0] * len(condition_latents_model_input[i]) + [1] + for i in range(current_batch_size) + ] + model_out_list = self.transformer( - x=latent_model_input_list, + x=x_combined, t=timestep_model_input, cap_feats=prompt_embeds_model_input, - cond_latents=condition_latents_model_input, siglip_feats=condition_siglip_embeds_model_input, + image_noise_mask=image_noise_mask, return_dict=False, )[0] From 732c5275520cb37de11941ed929d4a9c54cbb651 Mon Sep 17 00:00:00 2001 From: Jerry Qilong Wu Date: Fri, 19 Dec 2025 15:03:31 +0000 Subject: [PATCH 6/6] Styling & Formatting. --- src/diffusers/pipelines/z_image/pipeline_z_image_omni.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_omni.py b/src/diffusers/pipelines/z_image/pipeline_z_image_omni.py index 51aa1c5b8260..26848bea0a9e 100644 --- a/src/diffusers/pipelines/z_image/pipeline_z_image_omni.py +++ b/src/diffusers/pipelines/z_image/pipeline_z_image_omni.py @@ -660,13 +660,11 @@ def __call__( # Combine condition latents with target latent current_batch_size = len(latent_model_input_list) x_combined = [ - condition_latents_model_input[i] + [latent_model_input_list[i]] - for i in range(current_batch_size) + condition_latents_model_input[i] + [latent_model_input_list[i]] for i in range(current_batch_size) ] # Create noise mask: 0 for condition images (clean), 1 for target image (noisy) image_noise_mask = [ - [0] * len(condition_latents_model_input[i]) + [1] - for i in range(current_batch_size) + [0] * len(condition_latents_model_input[i]) + [1] for i in range(current_batch_size) ] model_out_list = self.transformer(