Skip to content

Commit 9180579

Browse files
committed
Fix bugs for controlnet after merging the main branch new feature.
1 parent 3e60fa7 commit 9180579

File tree

2 files changed

+21
-10
lines changed

2 files changed

+21
-10
lines changed

src/diffusers/models/transformers/transformer_z_image.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -916,16 +916,20 @@ def forward(
916916

917917
if omni_mode:
918918
return self._forward_omni(
919-
x, t, cap_feats, cond_latents, siglip_feats, patch_size, f_patch_size, return_dict
919+
x, t, cap_feats, cond_latents, siglip_feats,
920+
controlnet_block_samples, patch_size, f_patch_size, return_dict
920921
)
921922
else:
922-
return self._forward_basic(x, t, cap_feats, patch_size, f_patch_size, return_dict)
923+
return self._forward_basic(
924+
x, t, cap_feats, controlnet_block_samples, patch_size, f_patch_size, return_dict
925+
)
923926

924927
def _forward_basic(
925928
self,
926929
x: List[torch.Tensor],
927930
t,
928931
cap_feats: List[torch.Tensor],
932+
controlnet_block_samples: Optional[Dict[int, torch.Tensor]],
929933
patch_size: int,
930934
f_patch_size: int,
931935
return_dict: bool,
@@ -1053,6 +1057,7 @@ def _forward_omni(
10531057
cap_feats: List[List[torch.Tensor]],
10541058
cond_latents: List[List[torch.Tensor]],
10551059
siglip_feats: List[List[torch.Tensor]],
1060+
controlnet_block_samples: Optional[Dict[int, torch.Tensor]],
10561061
patch_size: int,
10571062
f_patch_size: int,
10581063
return_dict: bool,
@@ -1241,17 +1246,23 @@ def _forward_omni(
12411246
unified_noise_mask_tensor = unified_noise_mask_tensor[:, : unified.shape[1]]
12421247

12431248
if torch.is_grad_enabled() and self.gradient_checkpointing:
1244-
for layer in self.layers:
1249+
for layer_idx, layer in enumerate(self.layers):
12451250
unified = self._gradient_checkpointing_func(
12461251
layer, unified, unified_attn_mask, unified_freqs_cis,
12471252
noise_mask=unified_noise_mask_tensor, adaln_noisy=t_noisy_x, adaln_clean=t_clean_x
12481253
)
1254+
if controlnet_block_samples is not None:
1255+
if layer_idx in controlnet_block_samples:
1256+
unified = unified + controlnet_block_samples[layer_idx]
12491257
else:
1250-
for layer in self.layers:
1258+
for layer_idx, layer in enumerate(self.layers):
12511259
unified = layer(
12521260
unified, unified_attn_mask, unified_freqs_cis,
12531261
noise_mask=unified_noise_mask_tensor, adaln_noisy=t_noisy_x, adaln_clean=t_clean_x
12541262
)
1263+
if controlnet_block_samples is not None:
1264+
if layer_idx in controlnet_block_samples:
1265+
unified = unified + controlnet_block_samples[layer_idx]
12551266

12561267
unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](
12571268
unified, noise_mask=unified_noise_mask_tensor, c_noisy=t_noisy_x, c_clean=t_clean_x

src/diffusers/pipelines/z_image/pipeline_z_image_omni.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -659,12 +659,12 @@ def __call__(
659659
latent_model_input_list = list(latent_model_input.unbind(dim=0))
660660

661661
model_out_list = self.transformer(
662-
latent_model_input_list,
663-
timestep_model_input,
664-
prompt_embeds_model_input,
665-
condition_latents_model_input,
666-
condition_siglip_embeds_model_input,
667-
return_dict=False
662+
x=latent_model_input_list,
663+
t=timestep_model_input,
664+
cap_feats=prompt_embeds_model_input,
665+
cond_latents=condition_latents_model_input,
666+
siglip_feats=condition_siglip_embeds_model_input,
667+
return_dict=False,
668668
)[0]
669669

670670
if apply_cfg:

0 commit comments

Comments
 (0)