Skip to content

Commit 4c14cf3

Browse files
committed
Fix for auto_pipeline, Add Styling.
1 parent 9180579 commit 4c14cf3

File tree

4 files changed

+67
-32
lines changed

4 files changed

+67
-32
lines changed

src/diffusers/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -675,8 +675,8 @@
675675
"ZImageControlNetInpaintPipeline",
676676
"ZImageControlNetPipeline",
677677
"ZImageImg2ImgPipeline",
678-
"ZImagePipeline",
679678
"ZImageOmniPipeline",
679+
"ZImagePipeline",
680680
]
681681
)
682682

@@ -1387,8 +1387,8 @@
13871387
ZImageControlNetInpaintPipeline,
13881388
ZImageControlNetPipeline,
13891389
ZImageImg2ImgPipeline,
1390-
ZImagePipeline,
13911390
ZImageOmniPipeline,
1391+
ZImagePipeline,
13921392
)
13931393

13941394
try:

src/diffusers/models/transformers/transformer_z_image.py

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def forward(
222222
if self.modulation:
223223
if noise_mask is not None and adaln_noisy is not None and adaln_clean is not None:
224224
# Per-token modulation based on noise_mask, (batch, seq_len), 1 for noisy tokens, 0 for clean tokens
225-
batch_size, seq_len = x.shape[0], x.shape[1]
225+
_, seq_len = x.shape[0], x.shape[1]
226226

227227
mod_noisy = self.adaLN_modulation(adaln_noisy)
228228
mod_clean = self.adaLN_modulation(adaln_clean)
@@ -260,7 +260,9 @@ def forward(
260260
else:
261261
# Original global modulation
262262
assert adaln_input is not None
263-
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2)
263+
scale_msa, gate_msa, scale_mlp, gate_mlp = (
264+
self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2)
265+
)
264266
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
265267
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
266268

@@ -297,7 +299,7 @@ def __init__(self, hidden_size, out_channels):
297299
def forward(self, x, c=None, noise_mask=None, c_noisy=None, c_clean=None):
298300
if noise_mask is not None and c_noisy is not None and c_clean is not None:
299301
# Per-token modulation based on noise_mask
300-
batch_size, seq_len = x.shape[0], x.shape[1]
302+
_, seq_len = x.shape[0], x.shape[1]
301303
scale_noisy = 1.0 + self.adaLN_modulation(c_noisy)
302304
scale_clean = 1.0 + self.adaLN_modulation(c_clean)
303305

@@ -916,8 +918,15 @@ def forward(
916918

917919
if omni_mode:
918920
return self._forward_omni(
919-
x, t, cap_feats, cond_latents, siglip_feats,
920-
controlnet_block_samples, patch_size, f_patch_size, return_dict
921+
x,
922+
t,
923+
cap_feats,
924+
cond_latents,
925+
siglip_feats,
926+
controlnet_block_samples,
927+
patch_size,
928+
f_patch_size,
929+
return_dict,
921930
)
922931
else:
923932
return self._forward_basic(
@@ -1130,14 +1139,23 @@ def _forward_omni(
11301139
if torch.is_grad_enabled() and self.gradient_checkpointing:
11311140
for layer in self.noise_refiner:
11321141
x = self._gradient_checkpointing_func(
1133-
layer, x, x_attn_mask, x_freqs_cis,
1134-
noise_mask=x_noise_mask_tensor, adaln_noisy=t_noisy_x, adaln_clean=t_clean_x
1142+
layer,
1143+
x,
1144+
x_attn_mask,
1145+
x_freqs_cis,
1146+
noise_mask=x_noise_mask_tensor,
1147+
adaln_noisy=t_noisy_x,
1148+
adaln_clean=t_clean_x,
11351149
)
11361150
else:
11371151
for layer in self.noise_refiner:
11381152
x = layer(
1139-
x, x_attn_mask, x_freqs_cis,
1140-
noise_mask=x_noise_mask_tensor, adaln_noisy=t_noisy_x, adaln_clean=t_clean_x
1153+
x,
1154+
x_attn_mask,
1155+
x_freqs_cis,
1156+
noise_mask=x_noise_mask_tensor,
1157+
adaln_noisy=t_noisy_x,
1158+
adaln_clean=t_clean_x,
11411159
)
11421160

11431161
# cap embed & refine (no modulation)
@@ -1208,9 +1226,7 @@ def _forward_omni(
12081226
x_len = x_item_seqlens[i]
12091227
cap_len = cap_item_seqlens[i]
12101228
siglip_len = siglip_item_seqlens[i]
1211-
unified.append(
1212-
torch.cat([cap_feats[i][:cap_len], x[i][:x_len], siglip_feats[i][:siglip_len]])
1213-
)
1229+
unified.append(torch.cat([cap_feats[i][:cap_len], x[i][:x_len], siglip_feats[i][:siglip_len]]))
12141230
unified_freqs_cis.append(
12151231
torch.cat([cap_freqs_cis[i][:cap_len], x_freqs_cis[i][:x_len], siglip_freqs_cis[i][:siglip_len]])
12161232
)
@@ -1221,7 +1237,9 @@ def _forward_omni(
12211237
device=device,
12221238
)
12231239
)
1224-
unified_item_seqlens = [a + b + c for a, b, c in zip(cap_item_seqlens, x_item_seqlens, siglip_item_seqlens)]
1240+
unified_item_seqlens = [
1241+
a + b + c for a, b, c in zip(cap_item_seqlens, x_item_seqlens, siglip_item_seqlens)
1242+
]
12251243
else:
12261244
for i in range(bsz):
12271245
x_len = x_item_seqlens[i]
@@ -1248,17 +1266,26 @@ def _forward_omni(
12481266
if torch.is_grad_enabled() and self.gradient_checkpointing:
12491267
for layer_idx, layer in enumerate(self.layers):
12501268
unified = self._gradient_checkpointing_func(
1251-
layer, unified, unified_attn_mask, unified_freqs_cis,
1252-
noise_mask=unified_noise_mask_tensor, adaln_noisy=t_noisy_x, adaln_clean=t_clean_x
1269+
layer,
1270+
unified,
1271+
unified_attn_mask,
1272+
unified_freqs_cis,
1273+
noise_mask=unified_noise_mask_tensor,
1274+
adaln_noisy=t_noisy_x,
1275+
adaln_clean=t_clean_x,
12531276
)
12541277
if controlnet_block_samples is not None:
12551278
if layer_idx in controlnet_block_samples:
12561279
unified = unified + controlnet_block_samples[layer_idx]
12571280
else:
12581281
for layer_idx, layer in enumerate(self.layers):
12591282
unified = layer(
1260-
unified, unified_attn_mask, unified_freqs_cis,
1261-
noise_mask=unified_noise_mask_tensor, adaln_noisy=t_noisy_x, adaln_clean=t_clean_x
1283+
unified,
1284+
unified_attn_mask,
1285+
unified_freqs_cis,
1286+
noise_mask=unified_noise_mask_tensor,
1287+
adaln_noisy=t_noisy_x,
1288+
adaln_clean=t_clean_x,
12621289
)
12631290
if controlnet_block_samples is not None:
12641291
if layer_idx in controlnet_block_samples:

src/diffusers/pipelines/auto_pipeline.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,13 @@
119119
)
120120
from .wan import WanImageToVideoPipeline, WanPipeline, WanVideoToVideoPipeline
121121
from .wuerstchen import WuerstchenCombinedPipeline, WuerstchenDecoderPipeline
122-
from .z_image import ZImageImg2ImgPipeline, ZImagePipeline
122+
from .z_image import (
123+
ZImageControlNetInpaintPipeline,
124+
ZImageControlNetPipeline,
125+
ZImageImg2ImgPipeline,
126+
ZImageOmniPipeline,
127+
ZImagePipeline,
128+
)
123129

124130

125131
AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
@@ -164,6 +170,9 @@
164170
("qwenimage", QwenImagePipeline),
165171
("qwenimage-controlnet", QwenImageControlNetPipeline),
166172
("z-image", ZImagePipeline),
173+
("z-image-controlnet", ZImageControlNetPipeline),
174+
("z-image-controlnet-inpaint", ZImageControlNetInpaintPipeline),
175+
("z-image-omni", ZImageOmniPipeline),
167176
]
168177
)
169178

src/diffusers/pipelines/z_image/pipeline_z_image_omni.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,19 @@
1616
from typing import Any, Callable, Dict, List, Optional, Union
1717

1818
import PIL
19-
2019
import torch
2120
from transformers import AutoTokenizer, PreTrainedModel, Siglip2ImageProcessorFast, Siglip2VisionModel
2221

23-
from ...image_processor import VaeImageProcessor
2422
from ...loaders import FromSingleFileMixin, ZImageLoraLoaderMixin
2523
from ...models.autoencoders import AutoencoderKL
24+
from ...models.transformers import ZImageTransformer2DModel
2625
from ...pipelines.pipeline_utils import DiffusionPipeline
2726
from ...schedulers import FlowMatchEulerDiscreteScheduler
2827
from ...utils import logging, replace_example_docstring
2928
from ...utils.torch_utils import randn_tensor
30-
from .pipeline_output import ZImagePipelineOutput
31-
3229
from ..flux2.image_processor import Flux2ImageProcessor
30+
from .pipeline_output import ZImagePipelineOutput
3331

34-
from ...models.transformers import ZImageTransformer2DModel
3532

3633
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3734

@@ -224,7 +221,6 @@ def _encode_prompt(
224221
prompt = [prompt]
225222

226223
for i, prompt_item in enumerate(prompt):
227-
228224
if num_condition_images == 0:
229225
prompt[i] = ["<|im_start|>user\n" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n"]
230226
elif num_condition_images > 0:
@@ -236,7 +232,7 @@ def _encode_prompt(
236232

237233
flattened_prompt = []
238234
prompt_list_lengths = []
239-
235+
240236
for i in range(len(prompt)):
241237
prompt_list_lengths.append(len(prompt[i]))
242238
flattened_prompt.extend(prompt[i])
@@ -304,14 +300,15 @@ def prepare_image_latents(
304300
image_latents = []
305301
for image in images:
306302
image = image.to(device=device, dtype=dtype)
307-
image_latent = (self.vae.encode(image.bfloat16()).latent_dist.mode()[0] - self.vae.config.shift_factor) * self.vae.config.scaling_factor
303+
image_latent = (
304+
self.vae.encode(image.bfloat16()).latent_dist.mode()[0] - self.vae.config.shift_factor
305+
) * self.vae.config.scaling_factor
308306
image_latent = image_latent.unsqueeze(1).to(dtype)
309307
image_latents.append(image_latent) # (16, 128, 128)
310308

311309
# image_latents = [image_latents] * batch_size
312310
image_latents = [image_latents.copy() for _ in range(batch_size)]
313311

314-
315312
return image_latents
316313

317314
def prepare_siglip_embeds(
@@ -327,7 +324,7 @@ def prepare_siglip_embeds(
327324
shape = siglip_inputs.spatial_shapes[0]
328325
hidden_state = self.siglip(**siglip_inputs).last_hidden_state
329326
B, N, C = hidden_state.shape
330-
hidden_state = hidden_state[:, :shape[0] * shape[1]]
327+
hidden_state = hidden_state[:, : shape[0] * shape[1]]
331328
hidden_state = hidden_state.view(shape[0], shape[1], C)
332329
siglip_embeds.append(hidden_state.to(dtype))
333330

@@ -529,7 +526,7 @@ def __call__(
529526
image_height = (image_height // multiple_of) * multiple_of
530527
img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop")
531528
condition_images.append(img)
532-
529+
533530
if len(condition_images) > 0:
534531
height = height or image_height
535532
width = width or image_width
@@ -591,7 +588,9 @@ def __call__(
591588
negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)]
592589

593590
condition_siglip_embeds = [None if sels == [] else sels + [None] for sels in condition_siglip_embeds]
594-
negative_condition_siglip_embeds = [None if sels == [] else sels + [None] for sels in negative_condition_siglip_embeds]
591+
negative_condition_siglip_embeds = [
592+
None if sels == [] else sels + [None] for sels in negative_condition_siglip_embeds
593+
]
595594

596595
actual_batch_size = batch_size * num_images_per_prompt
597596
image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2)

0 commit comments

Comments
 (0)