diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 6aac3feffd0e..aa11a741af38 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -675,6 +675,7 @@ "ZImageControlNetInpaintPipeline", "ZImageControlNetPipeline", "ZImageImg2ImgPipeline", + "ZImageOmniPipeline", "ZImagePipeline", ] ) @@ -1386,6 +1387,7 @@ ZImageControlNetInpaintPipeline, ZImageControlNetPipeline, ZImageImg2ImgPipeline, + ZImageOmniPipeline, ZImagePipeline, ) diff --git a/src/diffusers/models/transformers/transformer_z_image.py b/src/diffusers/models/transformers/transformer_z_image.py index 17197db3a441..883b16e27525 100644 --- a/src/diffusers/models/transformers/transformer_z_image.py +++ b/src/diffusers/models/transformers/transformer_z_image.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -152,6 +152,20 @@ def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tenso return output +def select_per_token( + value_noisy: torch.Tensor, + value_clean: torch.Tensor, + noise_mask: torch.Tensor, + seq_len: int, +) -> torch.Tensor: + noise_mask_expanded = noise_mask.unsqueeze(-1) # (batch, seq_len, 1) + return torch.where( + noise_mask_expanded == 1, + value_noisy.unsqueeze(1).expand(-1, seq_len, -1), + value_clean.unsqueeze(1).expand(-1, seq_len, -1), + ) + + class FeedForward(nn.Module): def __init__(self, dim: int, hidden_dim: int): super().__init__() @@ -215,12 +229,37 @@ def forward( attn_mask: torch.Tensor, freqs_cis: torch.Tensor, adaln_input: Optional[torch.Tensor] = None, + noise_mask: Optional[torch.Tensor] = None, + adaln_noisy: Optional[torch.Tensor] = None, + adaln_clean: Optional[torch.Tensor] = None, ): if self.modulation: - assert adaln_input is not None - scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2) - gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh() - scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp + seq_len = x.shape[1] + + if noise_mask is not None: + # Per-token modulation: different modulation for noisy/clean tokens + mod_noisy = self.adaLN_modulation(adaln_noisy) + mod_clean = self.adaLN_modulation(adaln_clean) + + scale_msa_noisy, gate_msa_noisy, scale_mlp_noisy, gate_mlp_noisy = mod_noisy.chunk(4, dim=1) + scale_msa_clean, gate_msa_clean, scale_mlp_clean, gate_mlp_clean = mod_clean.chunk(4, dim=1) + + gate_msa_noisy, gate_mlp_noisy = gate_msa_noisy.tanh(), gate_mlp_noisy.tanh() + gate_msa_clean, gate_mlp_clean = gate_msa_clean.tanh(), gate_mlp_clean.tanh() + + scale_msa_noisy, scale_mlp_noisy = 1.0 + scale_msa_noisy, 1.0 + scale_mlp_noisy + scale_msa_clean, scale_mlp_clean = 1.0 + scale_msa_clean, 1.0 + scale_mlp_clean + + scale_msa = select_per_token(scale_msa_noisy, scale_msa_clean, noise_mask, seq_len) + scale_mlp = select_per_token(scale_mlp_noisy, scale_mlp_clean, noise_mask, seq_len) + gate_msa = select_per_token(gate_msa_noisy, gate_msa_clean, noise_mask, seq_len) + gate_mlp = select_per_token(gate_mlp_noisy, gate_mlp_clean, noise_mask, seq_len) + else: + # Global modulation: same modulation for all tokens (avoid double select) + mod = self.adaLN_modulation(adaln_input) + scale_msa, gate_msa, scale_mlp, gate_mlp = mod.unsqueeze(1).chunk(4, dim=2) + gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh() + scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp # Attention block attn_out = self.attention( @@ -252,9 +291,21 @@ def __init__(self, hidden_size, out_channels): nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True), ) - def forward(self, x, c): - scale = 1.0 + self.adaLN_modulation(c) - x = self.norm_final(x) * scale.unsqueeze(1) + def forward(self, x, c=None, noise_mask=None, c_noisy=None, c_clean=None): + seq_len = x.shape[1] + + if noise_mask is not None: + # Per-token modulation + scale_noisy = 1.0 + self.adaLN_modulation(c_noisy) + scale_clean = 1.0 + self.adaLN_modulation(c_clean) + scale = select_per_token(scale_noisy, scale_clean, noise_mask, seq_len) + else: + # Original global modulation + assert c is not None, "Either c or (c_noisy, c_clean) must be provided" + scale = 1.0 + self.adaLN_modulation(c) + scale = scale.unsqueeze(1) + + x = self.norm_final(x) * scale x = self.linear(x) return x @@ -325,6 +376,7 @@ def __init__( norm_eps=1e-5, qk_norm=True, cap_feat_dim=2560, + siglip_feat_dim=None, # Optional: set to enable SigLIP support for Omni rope_theta=256.0, t_scale=1000.0, axes_dims=[32, 48, 48], @@ -386,6 +438,31 @@ def __init__( self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024) self.cap_embedder = nn.Sequential(RMSNorm(cap_feat_dim, eps=norm_eps), nn.Linear(cap_feat_dim, dim, bias=True)) + # Optional SigLIP components (for Omni variant) + if siglip_feat_dim is not None: + self.siglip_embedder = nn.Sequential( + RMSNorm(siglip_feat_dim, eps=norm_eps), nn.Linear(siglip_feat_dim, dim, bias=True) + ) + self.siglip_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + 2000 + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + modulation=False, + ) + for layer_id in range(n_refiner_layers) + ] + ) + self.siglip_pad_token = nn.Parameter(torch.empty((1, dim))) + else: + self.siglip_embedder = None + self.siglip_refiner = None + self.siglip_pad_token = None + self.x_pad_token = nn.Parameter(torch.empty((1, dim))) self.cap_pad_token = nn.Parameter(torch.empty((1, dim))) @@ -402,22 +479,57 @@ def __init__( self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens) - def unpatchify(self, x: List[torch.Tensor], size: List[Tuple], patch_size, f_patch_size) -> List[torch.Tensor]: + def unpatchify( + self, + x: List[torch.Tensor], + size: List[Tuple], + patch_size, + f_patch_size, + x_pos_offsets: Optional[List[Tuple[int, int]]] = None, + ) -> List[torch.Tensor]: pH = pW = patch_size pF = f_patch_size bsz = len(x) assert len(size) == bsz - for i in range(bsz): - F, H, W = size[i] - ori_len = (F // pF) * (H // pH) * (W // pW) - # "f h w pf ph pw c -> c (f pf) (h ph) (w pw)" - x[i] = ( - x[i][:ori_len] - .view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels) - .permute(6, 0, 3, 1, 4, 2, 5) - .reshape(self.out_channels, F, H, W) - ) - return x + + if x_pos_offsets is not None: + # Omni: extract target image from unified sequence (cond_images + target) + result = [] + for i in range(bsz): + unified_x = x[i][x_pos_offsets[i][0] : x_pos_offsets[i][1]] + cu_len = 0 + x_item = None + for j in range(len(size[i])): + if size[i][j] is None: + ori_len = 0 + pad_len = SEQ_MULTI_OF + cu_len += pad_len + ori_len + else: + F, H, W = size[i][j] + ori_len = (F // pF) * (H // pH) * (W // pW) + pad_len = (-ori_len) % SEQ_MULTI_OF + x_item = ( + unified_x[cu_len : cu_len + ori_len] + .view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels) + .permute(6, 0, 3, 1, 4, 2, 5) + .reshape(self.out_channels, F, H, W) + ) + cu_len += ori_len + pad_len + result.append(x_item) # Return only the last (target) image + return result + else: + # Original mode: simple unpatchify + for i in range(bsz): + F, H, W = size[i] + ori_len = (F // pF) * (H // pH) * (W // pW) + # "f h w pf ph pw c -> c (f pf) (h ph) (w pw)" + x[i] = ( + x[i][:ori_len] + .view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels) + .permute(6, 0, 3, 1, 4, 2, 5) + .reshape(self.out_channels, F, H, W) + ) + return x @staticmethod def create_coordinate_grid(size, start=None, device=None): @@ -531,19 +643,297 @@ def patchify_and_embed( all_cap_pad_mask, ) + def patchify_and_embed_omni( + self, + all_x: List[List[torch.Tensor]], + all_cap_feats: List[List[torch.Tensor]], + all_siglip_feats: List[List[torch.Tensor]], + patch_size: int, + f_patch_size: int, + images_noise_mask: List[List[int]], + ): + bsz = len(all_x) + pH = pW = patch_size + pF = f_patch_size + device = all_x[0][-1].device + dtype = all_x[0][-1].dtype + + all_x_padded = [] + all_x_size = [] + all_x_pos_ids = [] + all_x_pad_mask = [] + all_x_len = [] + all_x_noise_mask = [] + all_cap_padded_feats = [] + all_cap_pos_ids = [] + all_cap_pad_mask = [] + all_cap_len = [] + all_cap_noise_mask = [] + all_siglip_padded_feats = [] + all_siglip_pos_ids = [] + all_siglip_pad_mask = [] + all_siglip_len = [] + all_siglip_noise_mask = [] + + for i in range(bsz): + # Process captions + num_images = len(all_x[i]) + cap_padded_feats = [] + cap_item_cu_len = 1 + cap_start_pos = [] + cap_end_pos = [] + cap_padded_pos_ids = [] + cap_pad_mask = [] + cap_len = [] + cap_noise_mask = [] + + for j, cap_item in enumerate(all_cap_feats[i]): + cap_item_ori_len = len(cap_item) + cap_item_padding_len = (-cap_item_ori_len) % SEQ_MULTI_OF + cap_len.append(cap_item_ori_len + cap_item_padding_len) + + cap_item_padding_pos_ids = ( + self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device) + .flatten(0, 2) + .repeat(cap_item_padding_len, 1) + ) + cap_start_pos.append(cap_item_cu_len) + cap_item_ori_pos_ids = self.create_coordinate_grid( + size=(cap_item_ori_len, 1, 1), start=(cap_item_cu_len, 0, 0), device=device + ).flatten(0, 2) + cap_padded_pos_ids.append(cap_item_ori_pos_ids) + cap_padded_pos_ids.append(cap_item_padding_pos_ids) + cap_item_cu_len += cap_item_ori_len + cap_end_pos.append(cap_item_cu_len) + cap_item_cu_len += 2 # for image vae tokens and siglip tokens + + cap_pad_mask.append(torch.zeros((cap_item_ori_len,), dtype=torch.bool, device=device)) + cap_pad_mask.append(torch.ones((cap_item_padding_len,), dtype=torch.bool, device=device)) + cap_item_padded_feat = torch.cat([cap_item, cap_item[-1:].repeat(cap_item_padding_len, 1)], dim=0) + cap_padded_feats.append(cap_item_padded_feat) + + if j < len(images_noise_mask[i]): + cap_noise_mask.extend([images_noise_mask[i][j]] * (cap_item_ori_len + cap_item_padding_len)) + else: + cap_noise_mask.extend([1] * (cap_item_ori_len + cap_item_padding_len)) + + all_cap_noise_mask.append(cap_noise_mask) + cap_padded_pos_ids = torch.cat(cap_padded_pos_ids, dim=0) + all_cap_pos_ids.append(cap_padded_pos_ids) + cap_pad_mask = torch.cat(cap_pad_mask, dim=0) + all_cap_pad_mask.append(cap_pad_mask) + all_cap_padded_feats.append(torch.cat(cap_padded_feats, dim=0)) + all_cap_len.append(cap_len) + + # Process images (x) + x_padded = [] + x_padded_pos_ids = [] + x_pad_mask = [] + x_len = [] + x_size = [] + x_noise_mask = [] + + for j, x_item in enumerate(all_x[i]): + if x_item is not None: + C, F, H, W = x_item.size() + x_size.append((F, H, W)) + F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW + + x_item = x_item.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) + x_item = x_item.permute(1, 3, 5, 2, 4, 6, 0).reshape( + F_tokens * H_tokens * W_tokens, pF * pH * pW * C + ) + + x_item_ori_len = len(x_item) + x_item_padding_len = (-x_item_ori_len) % SEQ_MULTI_OF + x_len.append(x_item_ori_len + x_item_padding_len) + + x_item_padding_pos_ids = ( + self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device) + .flatten(0, 2) + .repeat(x_item_padding_len, 1) + ) + x_item_ori_pos_ids = self.create_coordinate_grid( + size=(F_tokens, H_tokens, W_tokens), start=(cap_end_pos[j], 0, 0), device=device + ).flatten(0, 2) + x_padded_pos_ids.append(x_item_ori_pos_ids) + x_padded_pos_ids.append(x_item_padding_pos_ids) + + x_pad_mask.append(torch.zeros((x_item_ori_len,), dtype=torch.bool, device=device)) + x_pad_mask.append(torch.ones((x_item_padding_len,), dtype=torch.bool, device=device)) + x_item_padded_feat = torch.cat([x_item, x_item[-1:].repeat(x_item_padding_len, 1)], dim=0) + x_padded.append(x_item_padded_feat) + x_noise_mask.extend([images_noise_mask[i][j]] * (x_item_ori_len + x_item_padding_len)) + else: + x_pad_dim = 64 + x_item_padding_len = SEQ_MULTI_OF + x_size.append(None) + x_item_padding_pos_ids = ( + self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device) + .flatten(0, 2) + .repeat(x_item_padding_len, 1) + ) + x_len.append(x_item_padding_len) + x_padded_pos_ids.append(x_item_padding_pos_ids) + x_pad_mask.append(torch.ones((x_item_padding_len,), dtype=torch.bool, device=device)) + x_padded.append(torch.zeros((x_item_padding_len, x_pad_dim), dtype=dtype, device=device)) + x_noise_mask.extend([images_noise_mask[i][j]] * x_item_padding_len) + + all_x_noise_mask.append(x_noise_mask) + all_x_size.append(x_size) + x_padded_pos_ids = torch.cat(x_padded_pos_ids, dim=0) + all_x_pos_ids.append(x_padded_pos_ids) + x_pad_mask = torch.cat(x_pad_mask, dim=0) + all_x_pad_mask.append(x_pad_mask) + all_x_padded.append(torch.cat(x_padded, dim=0)) + all_x_len.append(x_len) + + # Process siglip_feats + if all_siglip_feats[i] is None: + all_siglip_len.append([0 for _ in range(num_images)]) + all_siglip_padded_feats.append(None) + else: + sig_padded_feats = [] + sig_padded_pos_ids = [] + sig_pad_mask = [] + sig_len = [] + sig_noise_mask = [] + + for j, sig_item in enumerate(all_siglip_feats[i]): + if sig_item is not None: + sig_H, sig_W, sig_C = sig_item.size() + sig_H_tokens, sig_W_tokens, sig_F_tokens = sig_H, sig_W, 1 + + sig_item = sig_item.view(sig_C, sig_F_tokens, 1, sig_H_tokens, 1, sig_W_tokens, 1) + sig_item = sig_item.permute(1, 3, 5, 2, 4, 6, 0).reshape( + sig_F_tokens * sig_H_tokens * sig_W_tokens, sig_C + ) + + sig_item_ori_len = len(sig_item) + sig_item_padding_len = (-sig_item_ori_len) % SEQ_MULTI_OF + sig_len.append(sig_item_ori_len + sig_item_padding_len) + + sig_item_padding_pos_ids = ( + self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device) + .flatten(0, 2) + .repeat(sig_item_padding_len, 1) + ) + sig_item_ori_pos_ids = self.create_coordinate_grid( + size=(sig_F_tokens, sig_H_tokens, sig_W_tokens), + start=(cap_end_pos[j] + 1, 0, 0), + device=device, + ) + # Scale position IDs to match x resolution + sig_item_ori_pos_ids[..., 1] = ( + sig_item_ori_pos_ids[..., 1] / (sig_H_tokens - 1) * (x_size[j][1] - 1) + ) + sig_item_ori_pos_ids[..., 2] = ( + sig_item_ori_pos_ids[..., 2] / (sig_W_tokens - 1) * (x_size[j][2] - 1) + ) + sig_item_ori_pos_ids = sig_item_ori_pos_ids.flatten(0, 2) + sig_padded_pos_ids.append(sig_item_ori_pos_ids) + sig_padded_pos_ids.append(sig_item_padding_pos_ids) + + sig_pad_mask.append(torch.zeros((sig_item_ori_len,), dtype=torch.bool, device=device)) + sig_pad_mask.append(torch.ones((sig_item_padding_len,), dtype=torch.bool, device=device)) + sig_item_padded_feat = torch.cat( + [sig_item, sig_item[-1:].repeat(sig_item_padding_len, 1)], dim=0 + ) + sig_padded_feats.append(sig_item_padded_feat) + sig_noise_mask.extend([images_noise_mask[i][j]] * (sig_item_ori_len + sig_item_padding_len)) + else: + sig_pad_dim = self.config.siglip_feat_dim or 1152 + sig_item_padding_len = SEQ_MULTI_OF + sig_item_padding_pos_ids = ( + self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device) + .flatten(0, 2) + .repeat(sig_item_padding_len, 1) + ) + sig_padded_pos_ids.append(sig_item_padding_pos_ids) + sig_pad_mask.append(torch.ones((sig_item_padding_len,), dtype=torch.bool, device=device)) + sig_padded_feats.append( + torch.zeros((sig_item_padding_len, sig_pad_dim), dtype=dtype, device=device) + ) + sig_noise_mask.extend([images_noise_mask[i][j]] * sig_item_padding_len) + + all_siglip_noise_mask.append(sig_noise_mask) + sig_padded_pos_ids = torch.cat(sig_padded_pos_ids, dim=0) + all_siglip_pos_ids.append(sig_padded_pos_ids) + sig_pad_mask = torch.cat(sig_pad_mask, dim=0) + all_siglip_pad_mask.append(sig_pad_mask) + all_siglip_padded_feats.append(torch.cat(sig_padded_feats, dim=0)) + all_siglip_len.append(sig_len) + + # Compute x position offsets + all_x_pos_offsets = [] + for i in range(bsz): + start = sum(all_cap_len[i]) + end = start + sum(all_x_len[i]) + all_x_pos_offsets.append((start, end)) + + return ( + all_x_padded, + all_cap_padded_feats, + all_siglip_padded_feats, + all_x_size, + all_x_pos_ids, + all_cap_pos_ids, + all_siglip_pos_ids, + all_x_pad_mask, + all_cap_pad_mask, + all_siglip_pad_mask, + all_x_pos_offsets, + all_x_noise_mask, + all_cap_noise_mask, + all_siglip_noise_mask, + ) + def forward( self, - x: List[torch.Tensor], + x: Union[List[torch.Tensor], List[List[torch.Tensor]]], t, - cap_feats: List[torch.Tensor], + cap_feats: Union[List[torch.Tensor], List[List[torch.Tensor]]], controlnet_block_samples: Optional[Dict[int, torch.Tensor]] = None, - patch_size=2, - f_patch_size=1, + siglip_feats: Optional[List[List[torch.Tensor]]] = None, + image_noise_mask: Optional[List[List[int]]] = None, + patch_size: int = 2, + f_patch_size: int = 1, return_dict: bool = True, ): assert patch_size in self.all_patch_size assert f_patch_size in self.all_f_patch_size + # Omni mode: x contains lists (multi-image input) + omni_mode = isinstance(x[0], list) + + if omni_mode: + return self._forward_omni( + x, + t, + cap_feats, + siglip_feats, + image_noise_mask, + controlnet_block_samples, + patch_size, + f_patch_size, + return_dict, + ) + else: + return self._forward_basic( + x, t, cap_feats, controlnet_block_samples, patch_size, f_patch_size, return_dict + ) + + def _forward_basic( + self, + x: List[torch.Tensor], + t, + cap_feats: List[torch.Tensor], + controlnet_block_samples: Optional[Dict[int, torch.Tensor]], + patch_size: int, + f_patch_size: int, + return_dict: bool, + ): + """Original text-to-image forward pass.""" bsz = len(x) device = x[0].device t = t * self.t_scale @@ -658,3 +1048,239 @@ def forward( return (x,) return Transformer2DModelOutput(sample=x) + + def _forward_omni( + self, + x: List[List[torch.Tensor]], + t, + cap_feats: List[List[torch.Tensor]], + siglip_feats: List[List[torch.Tensor]], + image_noise_mask: List[List[int]], + controlnet_block_samples: Optional[Dict[int, torch.Tensor]], + patch_size: int, + f_patch_size: int, + return_dict: bool, + ): + """Omni mode forward pass with image conditioning.""" + bsz = len(x) + device = x[0][-1].device # From target latent + + # Create dual timestep embeddings: one for noisy tokens (t), one for clean tokens (t=1) + t_noisy = self.t_embedder(t * self.t_scale) + t_clean = self.t_embedder(torch.ones_like(t) * self.t_scale) + + # Patchify and embed for Omni mode + ( + x, + cap_feats, + siglip_feats, + x_size, + x_pos_ids, + cap_pos_ids, + siglip_pos_ids, + x_inner_pad_mask, + cap_inner_pad_mask, + siglip_inner_pad_mask, + x_pos_offsets, + x_noise_mask, + cap_noise_mask, + siglip_noise_mask, + ) = self.patchify_and_embed_omni(x, cap_feats, siglip_feats, patch_size, f_patch_size, image_noise_mask) + + # x embed & refine + x_item_seqlens = [len(_) for _ in x] + assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens) + x_max_item_seqlen = max(x_item_seqlens) + + x = torch.cat(x, dim=0) + x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x) + + x[torch.cat(x_inner_pad_mask)] = self.x_pad_token + x = list(x.split(x_item_seqlens, dim=0)) + x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split([len(_) for _ in x_pos_ids], dim=0)) + + x = pad_sequence(x, batch_first=True, padding_value=0.0) + x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0) + x_freqs_cis = x_freqs_cis[:, : x.shape[1]] + + # Create x_noise_mask tensor + x_noise_mask_tensor = [] + for i in range(bsz): + x_mask = torch.tensor(x_noise_mask[i], dtype=torch.long, device=device) + x_noise_mask_tensor.append(x_mask) + x_noise_mask_tensor = pad_sequence(x_noise_mask_tensor, batch_first=True, padding_value=0) + x_noise_mask_tensor = x_noise_mask_tensor[:, : x.shape[1]] + + # Match t_embedder output dtype to x + t_noisy_x = t_noisy.type_as(x) + t_clean_x = t_clean.type_as(x) + + x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(x_item_seqlens): + x_attn_mask[i, :seq_len] = 1 + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for layer in self.noise_refiner: + x = self._gradient_checkpointing_func( + layer, + x, + x_attn_mask, + x_freqs_cis, + noise_mask=x_noise_mask_tensor, + adaln_noisy=t_noisy_x, + adaln_clean=t_clean_x, + ) + else: + for layer in self.noise_refiner: + x = layer( + x, + x_attn_mask, + x_freqs_cis, + noise_mask=x_noise_mask_tensor, + adaln_noisy=t_noisy_x, + adaln_clean=t_clean_x, + ) + + # cap embed & refine (no modulation) + cap_item_seqlens = [len(_) for _ in cap_feats] + cap_max_item_seqlen = max(cap_item_seqlens) + + cap_feats = torch.cat(cap_feats, dim=0) + cap_feats = self.cap_embedder(cap_feats) + cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token + cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0)) + cap_freqs_cis = list( + self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split([len(_) for _ in cap_pos_ids], dim=0) + ) + + cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0) + cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0) + cap_freqs_cis = cap_freqs_cis[:, : cap_feats.shape[1]] + + cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(cap_item_seqlens): + cap_attn_mask[i, :seq_len] = 1 + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for layer in self.context_refiner: + cap_feats = self._gradient_checkpointing_func(layer, cap_feats, cap_attn_mask, cap_freqs_cis) + else: + for layer in self.context_refiner: + cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis) + + # siglip embed & refine (if available) + siglip_item_seqlens = None + if siglip_feats[0] is not None and self.siglip_embedder is not None: + siglip_item_seqlens = [len(_) for _ in siglip_feats] + siglip_max_item_seqlen = max(siglip_item_seqlens) + + siglip_feats = torch.cat(siglip_feats, dim=0) + siglip_feats = self.siglip_embedder(siglip_feats) + siglip_feats[torch.cat(siglip_inner_pad_mask)] = self.siglip_pad_token + siglip_feats = list(siglip_feats.split(siglip_item_seqlens, dim=0)) + siglip_freqs_cis = list( + self.rope_embedder(torch.cat(siglip_pos_ids, dim=0)).split([len(_) for _ in siglip_pos_ids], dim=0) + ) + + siglip_feats = pad_sequence(siglip_feats, batch_first=True, padding_value=0.0) + siglip_freqs_cis = pad_sequence(siglip_freqs_cis, batch_first=True, padding_value=0.0) + siglip_freqs_cis = siglip_freqs_cis[:, : siglip_feats.shape[1]] + + siglip_attn_mask = torch.zeros((bsz, siglip_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(siglip_item_seqlens): + siglip_attn_mask[i, :seq_len] = 1 + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for layer in self.siglip_refiner: + siglip_feats = self._gradient_checkpointing_func( + layer, siglip_feats, siglip_attn_mask, siglip_freqs_cis + ) + else: + for layer in self.siglip_refiner: + siglip_feats = layer(siglip_feats, siglip_attn_mask, siglip_freqs_cis) + + # Build unified sequence + unified = [] + unified_freqs_cis = [] + unified_noise_mask = [] + + if siglip_item_seqlens is not None: + for i in range(bsz): + x_len = x_item_seqlens[i] + cap_len = cap_item_seqlens[i] + siglip_len = siglip_item_seqlens[i] + unified.append(torch.cat([cap_feats[i][:cap_len], x[i][:x_len], siglip_feats[i][:siglip_len]])) + unified_freqs_cis.append( + torch.cat([cap_freqs_cis[i][:cap_len], x_freqs_cis[i][:x_len], siglip_freqs_cis[i][:siglip_len]]) + ) + unified_noise_mask.append( + torch.tensor( + cap_noise_mask[i] + x_noise_mask[i] + siglip_noise_mask[i], + dtype=torch.long, + device=device, + ) + ) + unified_item_seqlens = [ + a + b + c for a, b, c in zip(cap_item_seqlens, x_item_seqlens, siglip_item_seqlens) + ] + else: + for i in range(bsz): + x_len = x_item_seqlens[i] + cap_len = cap_item_seqlens[i] + unified.append(torch.cat([cap_feats[i][:cap_len], x[i][:x_len]])) + unified_freqs_cis.append(torch.cat([cap_freqs_cis[i][:cap_len], x_freqs_cis[i][:x_len]])) + unified_noise_mask.append( + torch.tensor(cap_noise_mask[i] + x_noise_mask[i], dtype=torch.long, device=device) + ) + unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)] + + unified_max_item_seqlen = max(unified_item_seqlens) + + unified = pad_sequence(unified, batch_first=True, padding_value=0.0) + unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0) + unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(unified_item_seqlens): + unified_attn_mask[i, :seq_len] = 1 + + # Create unified_noise_mask tensor + unified_noise_mask_tensor = pad_sequence(unified_noise_mask, batch_first=True, padding_value=0) + unified_noise_mask_tensor = unified_noise_mask_tensor[:, : unified.shape[1]] + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for layer_idx, layer in enumerate(self.layers): + unified = self._gradient_checkpointing_func( + layer, + unified, + unified_attn_mask, + unified_freqs_cis, + noise_mask=unified_noise_mask_tensor, + adaln_noisy=t_noisy_x, + adaln_clean=t_clean_x, + ) + if controlnet_block_samples is not None: + if layer_idx in controlnet_block_samples: + unified = unified + controlnet_block_samples[layer_idx] + else: + for layer_idx, layer in enumerate(self.layers): + unified = layer( + unified, + unified_attn_mask, + unified_freqs_cis, + noise_mask=unified_noise_mask_tensor, + adaln_noisy=t_noisy_x, + adaln_clean=t_clean_x, + ) + if controlnet_block_samples is not None: + if layer_idx in controlnet_block_samples: + unified = unified + controlnet_block_samples[layer_idx] + + unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"]( + unified, noise_mask=unified_noise_mask_tensor, c_noisy=t_noisy_x, c_clean=t_clean_x + ) + + x = self.unpatchify(unified, x_size, patch_size, f_patch_size, x_pos_offsets) + + if not return_dict: + return (x,) + + return Transformer2DModelOutput(sample=x) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index e8faf868e741..28925f2d5352 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -411,6 +411,7 @@ "ZImagePipeline", "ZImageControlNetPipeline", "ZImageControlNetInpaintPipeline", + "ZImageOmniPipeline", ] _import_structure["skyreels_v2"] = [ "SkyReelsV2DiffusionForcingPipeline", diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index db0268a2a73d..7a2e58c3196e 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -119,7 +119,13 @@ ) from .wan import WanImageToVideoPipeline, WanPipeline, WanVideoToVideoPipeline from .wuerstchen import WuerstchenCombinedPipeline, WuerstchenDecoderPipeline -from .z_image import ZImageImg2ImgPipeline, ZImagePipeline +from .z_image import ( + ZImageControlNetInpaintPipeline, + ZImageControlNetPipeline, + ZImageImg2ImgPipeline, + ZImageOmniPipeline, + ZImagePipeline, +) AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict( @@ -164,6 +170,9 @@ ("qwenimage", QwenImagePipeline), ("qwenimage-controlnet", QwenImageControlNetPipeline), ("z-image", ZImagePipeline), + ("z-image-controlnet", ZImageControlNetPipeline), + ("z-image-controlnet-inpaint", ZImageControlNetInpaintPipeline), + ("z-image-omni", ZImageOmniPipeline), ] ) diff --git a/src/diffusers/pipelines/z_image/__init__.py b/src/diffusers/pipelines/z_image/__init__.py index 7b3cfbceea2c..78bd3bfacbec 100644 --- a/src/diffusers/pipelines/z_image/__init__.py +++ b/src/diffusers/pipelines/z_image/__init__.py @@ -26,6 +26,7 @@ _import_structure["pipeline_z_image_controlnet"] = ["ZImageControlNetPipeline"] _import_structure["pipeline_z_image_controlnet_inpaint"] = ["ZImageControlNetInpaintPipeline"] _import_structure["pipeline_z_image_img2img"] = ["ZImageImg2ImgPipeline"] + _import_structure["pipeline_z_image_omni"] = ["ZImageOmniPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: @@ -41,7 +42,7 @@ from .pipeline_z_image_controlnet import ZImageControlNetPipeline from .pipeline_z_image_controlnet_inpaint import ZImageControlNetInpaintPipeline from .pipeline_z_image_img2img import ZImageImg2ImgPipeline - + from .pipeline_z_image_omni import ZImageOmniPipeline else: import sys diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_omni.py b/src/diffusers/pipelines/z_image/pipeline_z_image_omni.py new file mode 100644 index 000000000000..26848bea0a9e --- /dev/null +++ b/src/diffusers/pipelines/z_image/pipeline_z_image_omni.py @@ -0,0 +1,742 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import PIL +import torch +from transformers import AutoTokenizer, PreTrainedModel, Siglip2ImageProcessorFast, Siglip2VisionModel + +from ...loaders import FromSingleFileMixin, ZImageLoraLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import ZImageTransformer2DModel +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..flux2.image_processor import Flux2ImageProcessor +from .pipeline_output import ZImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import ZImageOmniPipeline + + >>> pipe = ZImageOmniPipeline.from_pretrained("Z-a-o/Z-Image-Turbo", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> # Optionally, set the attention backend to flash-attn 2 or 3, default is SDPA in PyTorch. + >>> # (1) Use flash attention 2 + >>> # pipe.transformer.set_attention_backend("flash") + >>> # (2) Use flash attention 3 + >>> # pipe.transformer.set_attention_backend("_flash_3") + + >>> prompt = "一幅为名为“造相「Z-IMAGE-TURBO」”的项目设计的创意海报。画面巧妙地将文字概念视觉化:一辆复古蒸汽小火车化身为巨大的拉链头,正拉开厚厚的冬日积雪,展露出一个生机盎然的春天。" + >>> image = pipe( + ... prompt, + ... height=1024, + ... width=1024, + ... num_inference_steps=9, + ... guidance_scale=0.0, + ... generator=torch.Generator("cuda").manual_seed(42), + ... ).images[0] + >>> image.save("zimage.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class ZImageOmniPipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMixin): + model_cpu_offload_seq = "text_encoder->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: PreTrainedModel, + tokenizer: AutoTokenizer, + transformer: ZImageTransformer2DModel, + siglip: Siglip2VisionModel, + siglip_processor: Siglip2ImageProcessorFast, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + transformer=transformer, + siglip=siglip, + siglip_processor=siglip_processor, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + # self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + num_condition_images: int = 0, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt_embeds = self._encode_prompt( + prompt=prompt, + device=device, + prompt_embeds=prompt_embeds, + max_sequence_length=max_sequence_length, + num_condition_images=num_condition_images, + ) + + if do_classifier_free_guidance: + if negative_prompt is None: + negative_prompt = ["" for _ in prompt] + else: + negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + assert len(prompt) == len(negative_prompt) + negative_prompt_embeds = self._encode_prompt( + prompt=negative_prompt, + device=device, + prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + num_condition_images=num_condition_images, + ) + else: + negative_prompt_embeds = [] + return prompt_embeds, negative_prompt_embeds + + def _encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + max_sequence_length: int = 512, + num_condition_images: int = 0, + ) -> List[torch.FloatTensor]: + device = device or self._execution_device + + if prompt_embeds is not None: + return prompt_embeds + + if isinstance(prompt, str): + prompt = [prompt] + + for i, prompt_item in enumerate(prompt): + if num_condition_images == 0: + prompt[i] = ["<|im_start|>user\n" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n"] + elif num_condition_images > 0: + prompt_list = ["<|im_start|>user\n<|vision_start|>"] + prompt_list += ["<|vision_end|><|vision_start|>"] * (num_condition_images - 1) + prompt_list += ["<|vision_end|>" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n<|vision_start|>"] + prompt_list += ["<|vision_end|><|im_end|>"] + prompt[i] = prompt_list + + flattened_prompt = [] + prompt_list_lengths = [] + + for i in range(len(prompt)): + prompt_list_lengths.append(len(prompt[i])) + flattened_prompt.extend(prompt[i]) + + text_inputs = self.tokenizer( + flattened_prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + prompt_masks = text_inputs.attention_mask.to(device).bool() + + prompt_embeds = self.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_masks, + output_hidden_states=True, + ).hidden_states[-2] + + embeddings_list = [] + start_idx = 0 + for i in range(len(prompt_list_lengths)): + batch_embeddings = [] + end_idx = start_idx + prompt_list_lengths[i] + for j in range(start_idx, end_idx): + batch_embeddings.append(prompt_embeds[j][prompt_masks[j]]) + embeddings_list.append(batch_embeddings) + start_idx = end_idx + + return embeddings_list + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + return latents + + def prepare_image_latents( + self, + images: List[torch.Tensor], + batch_size, + device, + dtype, + ): + image_latents = [] + for image in images: + image = image.to(device=device, dtype=dtype) + image_latent = ( + self.vae.encode(image.bfloat16()).latent_dist.mode()[0] - self.vae.config.shift_factor + ) * self.vae.config.scaling_factor + image_latent = image_latent.unsqueeze(1).to(dtype) + image_latents.append(image_latent) # (16, 128, 128) + + # image_latents = [image_latents] * batch_size + image_latents = [image_latents.copy() for _ in range(batch_size)] + + return image_latents + + def prepare_siglip_embeds( + self, + images: List[torch.Tensor], + batch_size, + device, + dtype, + ): + siglip_embeds = [] + for image in images: + siglip_inputs = self.siglip_processor(images=[image], return_tensors="pt").to(device) + shape = siglip_inputs.spatial_shapes[0] + hidden_state = self.siglip(**siglip_inputs).last_hidden_state + B, N, C = hidden_state.shape + hidden_state = hidden_state[:, : shape[0] * shape[1]] + hidden_state = hidden_state.view(shape[0], shape[1], C) + siglip_embeds.append(hidden_state.to(dtype)) + + # siglip_embeds = [siglip_embeds] * batch_size + siglip_embeds = [siglip_embeds.copy() for _ in range(batch_size)] + + return siglip_embeds + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: Optional[Union[List[PIL.Image.Image], PIL.Image.Image]] = None, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 5.0, + cfg_normalization: bool = False, + cfg_truncation: float = 1.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to 1024): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 1024): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + cfg_normalization (`bool`, *optional*, defaults to False): + Whether to apply configuration normalization. + cfg_truncation (`float`, *optional*, defaults to 1.0): + The truncation value for configuration. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain + tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, *optional*, defaults to 512): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + + if image is not None and not isinstance(image, list): + image = [image] + num_condition_images = len(image) if image is not None else 0 + + device = self._execution_device + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + self._cfg_normalization = cfg_normalization + self._cfg_truncation = cfg_truncation + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = len(prompt_embeds) + + # If prompt_embeds is provided and prompt is None, skip encoding + if prompt_embeds is not None and prompt is None: + if self.do_classifier_free_guidance and negative_prompt_embeds is None: + raise ValueError( + "When `prompt_embeds` is provided without `prompt`, " + "`negative_prompt_embeds` must also be provided for classifier-free guidance." + ) + else: + ( + prompt_embeds, + negative_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + num_condition_images=num_condition_images, + ) + + # 3. Process condition images. Copied from diffusers.pipelines.flux2.pipeline_flux2 + condition_images = [] + resized_images = [] + if image is not None: + for img in image: + self.image_processor.check_image_input(img) + for img in image: + image_width, image_height = img.size + if image_width * image_height > 1024 * 1024: + if height is not None and width is not None: + img = self.image_processor._resize_to_target_area(img, height * width) + else: + img = self.image_processor._resize_to_target_area(img, 1024 * 1024) + image_width, image_height = img.size + resized_images.append(img) + + multiple_of = self.vae_scale_factor * 2 + image_width = (image_width // multiple_of) * multiple_of + image_height = (image_height // multiple_of) * multiple_of + img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop") + condition_images.append(img) + + if len(condition_images) > 0: + height = height or image_height + width = width or image_width + + else: + height = height or 1024 + width = width or 1024 + + vae_scale = self.vae_scale_factor * 2 + if height % vae_scale != 0: + raise ValueError( + f"Height must be divisible by {vae_scale} (got {height}). " + f"Please adjust the height to a multiple of {vae_scale}." + ) + if width % vae_scale != 0: + raise ValueError( + f"Width must be divisible by {vae_scale} (got {width}). " + f"Please adjust the width to a multiple of {vae_scale}." + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.in_channels + + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + torch.float32, + device, + generator, + latents, + ) + + condition_latents = self.prepare_image_latents( + images=condition_images, + batch_size=batch_size * num_images_per_prompt, + device=device, + dtype=torch.float32, + ) + condition_latents = [[lat.to(self.transformer.dtype) for lat in lats] for lats in condition_latents] + if self.do_classifier_free_guidance: + negative_condition_latents = [[lat.clone() for lat in batch] for batch in condition_latents] + + condition_siglip_embeds = self.prepare_siglip_embeds( + images=resized_images, + batch_size=batch_size * num_images_per_prompt, + device=device, + dtype=torch.float32, + ) + condition_siglip_embeds = [[se.to(self.transformer.dtype) for se in sels] for sels in condition_siglip_embeds] + if self.do_classifier_free_guidance: + negative_condition_siglip_embeds = [[se.clone() for se in batch] for batch in condition_siglip_embeds] + + # Repeat prompt_embeds for num_images_per_prompt + if num_images_per_prompt > 1: + prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)] + if self.do_classifier_free_guidance and negative_prompt_embeds: + negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)] + + condition_siglip_embeds = [None if sels == [] else sels + [None] for sels in condition_siglip_embeds] + negative_condition_siglip_embeds = [ + None if sels == [] else sels + [None] for sels in negative_condition_siglip_embeds + ] + + actual_batch_size = batch_size * num_images_per_prompt + image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2) + + # 5. Prepare timesteps + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + self.scheduler.sigma_min = 0.0 + scheduler_kwargs = {"mu": mu} + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + **scheduler_kwargs, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]) + timestep = (1000 - timestep) / 1000 + # Normalized time for time-aware config (0 at start, 1 at end) + t_norm = timestep[0].item() + + # Handle cfg truncation + current_guidance_scale = self.guidance_scale + if ( + self.do_classifier_free_guidance + and self._cfg_truncation is not None + and float(self._cfg_truncation) <= 1 + ): + if t_norm > self._cfg_truncation: + current_guidance_scale = 0.0 + + # Run CFG only if configured AND scale is non-zero + apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0 + + if apply_cfg: + latents_typed = latents.to(self.transformer.dtype) + latent_model_input = latents_typed.repeat(2, 1, 1, 1) + prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds + condition_latents_model_input = condition_latents + negative_condition_latents + condition_siglip_embeds_model_input = condition_siglip_embeds + negative_condition_siglip_embeds + timestep_model_input = timestep.repeat(2) + else: + latent_model_input = latents.to(self.transformer.dtype) + prompt_embeds_model_input = prompt_embeds + condition_latents_model_input = condition_latents + condition_siglip_embeds_model_input = condition_siglip_embeds + timestep_model_input = timestep + + latent_model_input = latent_model_input.unsqueeze(2) + latent_model_input_list = list(latent_model_input.unbind(dim=0)) + + # Combine condition latents with target latent + current_batch_size = len(latent_model_input_list) + x_combined = [ + condition_latents_model_input[i] + [latent_model_input_list[i]] for i in range(current_batch_size) + ] + # Create noise mask: 0 for condition images (clean), 1 for target image (noisy) + image_noise_mask = [ + [0] * len(condition_latents_model_input[i]) + [1] for i in range(current_batch_size) + ] + + model_out_list = self.transformer( + x=x_combined, + t=timestep_model_input, + cap_feats=prompt_embeds_model_input, + siglip_feats=condition_siglip_embeds_model_input, + image_noise_mask=image_noise_mask, + return_dict=False, + )[0] + + if apply_cfg: + # Perform CFG + pos_out = model_out_list[:actual_batch_size] + neg_out = model_out_list[actual_batch_size:] + + noise_pred = [] + for j in range(actual_batch_size): + pos = pos_out[j].float() + neg = neg_out[j].float() + + pred = pos + current_guidance_scale * (pos - neg) + + # Renormalization + if self._cfg_normalization and float(self._cfg_normalization) > 0.0: + ori_pos_norm = torch.linalg.vector_norm(pos) + new_pos_norm = torch.linalg.vector_norm(pred) + max_new_norm = ori_pos_norm * float(self._cfg_normalization) + if new_pos_norm > max_new_norm: + pred = pred * (max_new_norm / new_pos_norm) + + noise_pred.append(pred) + + noise_pred = torch.stack(noise_pred, dim=0) + else: + noise_pred = torch.stack([t.float() for t in model_out_list], dim=0) + + noise_pred = noise_pred.squeeze(2) + noise_pred = -noise_pred + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0] + assert latents.dtype == torch.float32 + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if output_type == "latent": + image = latents + + else: + latents = latents.to(self.vae.dtype) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ZImagePipelineOutput(images=image)