Skip to content

Conversation

@RuoyiDu
Copy link
Contributor

@RuoyiDu RuoyiDu commented Dec 18, 2025

What does this PR do?

This PR adds support for the Z-Image-Omni-Base model. Z-Image-Omni-Base is a foundation model designed for easy fine-tuning, unifying core capabilities in both image generation and editing to empower the community to explore custom development and innovative applications.

Before submitting

Who can review?

@yiyixuxu @apolinario @JerryWu-code

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks a lot for the PR! I left some comments, mainly I'm just trying to simplify the code in the transfomer as much as possible by removing unused code path etc
let me know what you think:)

SEQ_MULTI_OF = 32


class TimestepEmbedder(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a #Coped from ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed by merging into one transformer_z_image.

return t_emb


class ZSingleStreamAttnProcessor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a Copied from ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as before.



@maybe_allow_in_graph
class ZImageTransformerBlock(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class ZImageTransformerBlock(nn.Module):
class ZOmniImageTransformerBlock(nn.Module):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ignored due to merging into one.

adaln_clean: Optional[torch.Tensor] = None,
):
if self.modulation:
if noise_mask is not None and adaln_noisy is not None and adaln_clean is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if noise_mask is not None and adaln_noisy is not None and adaln_clean is not None:

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Current codebase in 4c14cf3, it would be needed. But could be optimized by re-design in next pr.

Comment on lines 261 to 266
else:
# Original global modulation
assert adaln_input is not None
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2)
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
else:
# Original global modulation
assert adaln_input is not None
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2)
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp

can we remove this code path if it is not used?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When merging into one, it would be needed.

Comment on lines 794 to 795
patch_size=2,
f_patch_size=1,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
patch_size=2,
f_patch_size=1,

I don't think these two arguments are used in the pipeline, can we remove them? could simplify the code a lot I think -> it can help remove the ModuleDict pattern too

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines 798 to 799
assert patch_size in self.all_patch_size
assert f_patch_size in self.all_f_patch_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert patch_size in self.all_patch_size
assert f_patch_size in self.all_f_patch_size

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cap_noise_mask,
siglip_noise_mask
) = self.patchify_and_embed(
x, cap_feats, siglip_feats, patch_size, f_patch_size, image_noise_mask
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
x, cap_feats, siglip_feats, patch_size, f_patch_size, image_noise_mask
x, cap_feats, siglip_feats, image_noise_mask

grids = torch.meshgrid(axes, indexing="ij")
return torch.stack(grids, dim=-1)

def patchify_and_embed(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this method is really hard to follow here, do you think it's possible to break it into 3?

like

for x, cap_feat, siglip_feat in zip(all_x, all_cap_feats, all_siglip_feats):
    cap_item_cu_len = 1
    
    cap_padded, ..., cap_item_cu_len = self.patchify_and_embed_cap(...)
    all_cap_padded.append(cap_padded)
    
    x_padded, ..., cap_item_cu_len = self.patchify_and_embed_x(..., cap_item_cu_len)
    all_x_padded.append(x_padded)
    ...
    
    siglip_padded, ..., cap_item_cu_len = self.patchify_and_embed_siglip(...,cap_item_cu_len )
    all_siglip_padded.append(siglip_padded)

assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens)
x_max_item_seqlen = max(x_item_seqlens)

x = torch.cat(x, dim=0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hopefully we can simplify to x = self.x_embedder(x) here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Beinsezii
Copy link
Contributor

this gets forgotten all the time lol

diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py
index db0268a2a..2c36ce36b 100644
--- a/src/diffusers/pipelines/auto_pipeline.py
+++ b/src/diffusers/pipelines/auto_pipeline.py
@@ -119,7 +119,7 @@ from .stable_diffusion_xl import (
 )
 from .wan import WanImageToVideoPipeline, WanPipeline, WanVideoToVideoPipeline
 from .wuerstchen import WuerstchenCombinedPipeline, WuerstchenDecoderPipeline
-from .z_image import ZImageImg2ImgPipeline, ZImagePipeline
+from .z_image import ZImageImg2ImgPipeline, ZImageOmniPipeline, ZImagePipeline
 
 
 AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
@@ -164,6 +164,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
         ("qwenimage", QwenImagePipeline),
         ("qwenimage-controlnet", QwenImageControlNetPipeline),
         ("z-image", ZImagePipeline),
+        ("z-image-omni", ZImageOmniPipeline),
     ]
 )
 

@JerryWu-code
Copy link
Contributor

thanks a lot for the PR! I left some comments, mainly I'm just trying to simplify the code in the transfomer as much as possible by removing unused code path etc

Thanks for useful comments yiyi, I would review these and fix these modifications today ~ 😊

@JerryWu-code
Copy link
Contributor

JerryWu-code commented Dec 19, 2025

Hi @yiyixuxu, this branch is ready to merge 😊. This would solve most of your concerns before (including copied xxx, cond_latents xxx, auto_pipeline, styling) by merging into one transformer model and incorporating new feats of main branch upon the start point. More feature updates and code cleanify would be update in another pr, you could review current status and leave some comments, and I would updates more asap ~ Thanks !!!

@JerryWu-code
Copy link
Contributor

this gets forgotten all the time lol

diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py
index db0268a2a..2c36ce36b 100644
--- a/src/diffusers/pipelines/auto_pipeline.py
+++ b/src/diffusers/pipelines/auto_pipeline.py
@@ -119,7 +119,7 @@ from .stable_diffusion_xl import (
 )
 from .wan import WanImageToVideoPipeline, WanPipeline, WanVideoToVideoPipeline
 from .wuerstchen import WuerstchenCombinedPipeline, WuerstchenDecoderPipeline
-from .z_image import ZImageImg2ImgPipeline, ZImagePipeline
+from .z_image import ZImageImg2ImgPipeline, ZImageOmniPipeline, ZImagePipeline
 
 
 AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
@@ -164,6 +164,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
         ("qwenimage", QwenImagePipeline),
         ("qwenimage-controlnet", QwenImageControlNetPipeline),
         ("z-image", ZImagePipeline),
+        ("z-image-omni", ZImageOmniPipeline),
     ]
 )
 

Thanks!! Fixed in 4c14cf3 ~

- Add select_per_token function for per-token value selection
- Separate adaptive modulation logic
- Cleanify t_noisy/clean variable naming
- Move image_noise_mask handler from forward to pipeline
@JerryWu-code
Copy link
Contributor

JerryWu-code commented Dec 19, 2025

Ready, let's merge it for 732c527 ~ 😊

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants