@@ -916,16 +916,20 @@ def forward(
916916
917917 if omni_mode :
918918 return self ._forward_omni (
919- x , t , cap_feats , cond_latents , siglip_feats , patch_size , f_patch_size , return_dict
919+ x , t , cap_feats , cond_latents , siglip_feats ,
920+ controlnet_block_samples , patch_size , f_patch_size , return_dict
920921 )
921922 else :
922- return self ._forward_basic (x , t , cap_feats , patch_size , f_patch_size , return_dict )
923+ return self ._forward_basic (
924+ x , t , cap_feats , controlnet_block_samples , patch_size , f_patch_size , return_dict
925+ )
923926
924927 def _forward_basic (
925928 self ,
926929 x : List [torch .Tensor ],
927930 t ,
928931 cap_feats : List [torch .Tensor ],
932+ controlnet_block_samples : Optional [Dict [int , torch .Tensor ]],
929933 patch_size : int ,
930934 f_patch_size : int ,
931935 return_dict : bool ,
@@ -1053,6 +1057,7 @@ def _forward_omni(
10531057 cap_feats : List [List [torch .Tensor ]],
10541058 cond_latents : List [List [torch .Tensor ]],
10551059 siglip_feats : List [List [torch .Tensor ]],
1060+ controlnet_block_samples : Optional [Dict [int , torch .Tensor ]],
10561061 patch_size : int ,
10571062 f_patch_size : int ,
10581063 return_dict : bool ,
@@ -1241,17 +1246,23 @@ def _forward_omni(
12411246 unified_noise_mask_tensor = unified_noise_mask_tensor [:, : unified .shape [1 ]]
12421247
12431248 if torch .is_grad_enabled () and self .gradient_checkpointing :
1244- for layer in self .layers :
1249+ for layer_idx , layer in enumerate ( self .layers ) :
12451250 unified = self ._gradient_checkpointing_func (
12461251 layer , unified , unified_attn_mask , unified_freqs_cis ,
12471252 noise_mask = unified_noise_mask_tensor , adaln_noisy = t_noisy_x , adaln_clean = t_clean_x
12481253 )
1254+ if controlnet_block_samples is not None :
1255+ if layer_idx in controlnet_block_samples :
1256+ unified = unified + controlnet_block_samples [layer_idx ]
12491257 else :
1250- for layer in self .layers :
1258+ for layer_idx , layer in enumerate ( self .layers ) :
12511259 unified = layer (
12521260 unified , unified_attn_mask , unified_freqs_cis ,
12531261 noise_mask = unified_noise_mask_tensor , adaln_noisy = t_noisy_x , adaln_clean = t_clean_x
12541262 )
1263+ if controlnet_block_samples is not None :
1264+ if layer_idx in controlnet_block_samples :
1265+ unified = unified + controlnet_block_samples [layer_idx ]
12551266
12561267 unified = self .all_final_layer [f"{ patch_size } -{ f_patch_size } " ](
12571268 unified , noise_mask = unified_noise_mask_tensor , c_noisy = t_noisy_x , c_clean = t_clean_x
0 commit comments