@@ -222,7 +222,7 @@ def forward(
222222 if self .modulation :
223223 if noise_mask is not None and adaln_noisy is not None and adaln_clean is not None :
224224 # Per-token modulation based on noise_mask, (batch, seq_len), 1 for noisy tokens, 0 for clean tokens
225- batch_size , seq_len = x .shape [0 ], x .shape [1 ]
225+ _ , seq_len = x .shape [0 ], x .shape [1 ]
226226
227227 mod_noisy = self .adaLN_modulation (adaln_noisy )
228228 mod_clean = self .adaLN_modulation (adaln_clean )
@@ -260,7 +260,9 @@ def forward(
260260 else :
261261 # Original global modulation
262262 assert adaln_input is not None
263- scale_msa , gate_msa , scale_mlp , gate_mlp = self .adaLN_modulation (adaln_input ).unsqueeze (1 ).chunk (4 , dim = 2 )
263+ scale_msa , gate_msa , scale_mlp , gate_mlp = (
264+ self .adaLN_modulation (adaln_input ).unsqueeze (1 ).chunk (4 , dim = 2 )
265+ )
264266 gate_msa , gate_mlp = gate_msa .tanh (), gate_mlp .tanh ()
265267 scale_msa , scale_mlp = 1.0 + scale_msa , 1.0 + scale_mlp
266268
@@ -297,7 +299,7 @@ def __init__(self, hidden_size, out_channels):
297299 def forward (self , x , c = None , noise_mask = None , c_noisy = None , c_clean = None ):
298300 if noise_mask is not None and c_noisy is not None and c_clean is not None :
299301 # Per-token modulation based on noise_mask
300- batch_size , seq_len = x .shape [0 ], x .shape [1 ]
302+ _ , seq_len = x .shape [0 ], x .shape [1 ]
301303 scale_noisy = 1.0 + self .adaLN_modulation (c_noisy )
302304 scale_clean = 1.0 + self .adaLN_modulation (c_clean )
303305
@@ -916,8 +918,15 @@ def forward(
916918
917919 if omni_mode :
918920 return self ._forward_omni (
919- x , t , cap_feats , cond_latents , siglip_feats ,
920- controlnet_block_samples , patch_size , f_patch_size , return_dict
921+ x ,
922+ t ,
923+ cap_feats ,
924+ cond_latents ,
925+ siglip_feats ,
926+ controlnet_block_samples ,
927+ patch_size ,
928+ f_patch_size ,
929+ return_dict ,
921930 )
922931 else :
923932 return self ._forward_basic (
@@ -1130,14 +1139,23 @@ def _forward_omni(
11301139 if torch .is_grad_enabled () and self .gradient_checkpointing :
11311140 for layer in self .noise_refiner :
11321141 x = self ._gradient_checkpointing_func (
1133- layer , x , x_attn_mask , x_freqs_cis ,
1134- noise_mask = x_noise_mask_tensor , adaln_noisy = t_noisy_x , adaln_clean = t_clean_x
1142+ layer ,
1143+ x ,
1144+ x_attn_mask ,
1145+ x_freqs_cis ,
1146+ noise_mask = x_noise_mask_tensor ,
1147+ adaln_noisy = t_noisy_x ,
1148+ adaln_clean = t_clean_x ,
11351149 )
11361150 else :
11371151 for layer in self .noise_refiner :
11381152 x = layer (
1139- x , x_attn_mask , x_freqs_cis ,
1140- noise_mask = x_noise_mask_tensor , adaln_noisy = t_noisy_x , adaln_clean = t_clean_x
1153+ x ,
1154+ x_attn_mask ,
1155+ x_freqs_cis ,
1156+ noise_mask = x_noise_mask_tensor ,
1157+ adaln_noisy = t_noisy_x ,
1158+ adaln_clean = t_clean_x ,
11411159 )
11421160
11431161 # cap embed & refine (no modulation)
@@ -1208,9 +1226,7 @@ def _forward_omni(
12081226 x_len = x_item_seqlens [i ]
12091227 cap_len = cap_item_seqlens [i ]
12101228 siglip_len = siglip_item_seqlens [i ]
1211- unified .append (
1212- torch .cat ([cap_feats [i ][:cap_len ], x [i ][:x_len ], siglip_feats [i ][:siglip_len ]])
1213- )
1229+ unified .append (torch .cat ([cap_feats [i ][:cap_len ], x [i ][:x_len ], siglip_feats [i ][:siglip_len ]]))
12141230 unified_freqs_cis .append (
12151231 torch .cat ([cap_freqs_cis [i ][:cap_len ], x_freqs_cis [i ][:x_len ], siglip_freqs_cis [i ][:siglip_len ]])
12161232 )
@@ -1221,7 +1237,9 @@ def _forward_omni(
12211237 device = device ,
12221238 )
12231239 )
1224- unified_item_seqlens = [a + b + c for a , b , c in zip (cap_item_seqlens , x_item_seqlens , siglip_item_seqlens )]
1240+ unified_item_seqlens = [
1241+ a + b + c for a , b , c in zip (cap_item_seqlens , x_item_seqlens , siglip_item_seqlens )
1242+ ]
12251243 else :
12261244 for i in range (bsz ):
12271245 x_len = x_item_seqlens [i ]
@@ -1248,17 +1266,26 @@ def _forward_omni(
12481266 if torch .is_grad_enabled () and self .gradient_checkpointing :
12491267 for layer_idx , layer in enumerate (self .layers ):
12501268 unified = self ._gradient_checkpointing_func (
1251- layer , unified , unified_attn_mask , unified_freqs_cis ,
1252- noise_mask = unified_noise_mask_tensor , adaln_noisy = t_noisy_x , adaln_clean = t_clean_x
1269+ layer ,
1270+ unified ,
1271+ unified_attn_mask ,
1272+ unified_freqs_cis ,
1273+ noise_mask = unified_noise_mask_tensor ,
1274+ adaln_noisy = t_noisy_x ,
1275+ adaln_clean = t_clean_x ,
12531276 )
12541277 if controlnet_block_samples is not None :
12551278 if layer_idx in controlnet_block_samples :
12561279 unified = unified + controlnet_block_samples [layer_idx ]
12571280 else :
12581281 for layer_idx , layer in enumerate (self .layers ):
12591282 unified = layer (
1260- unified , unified_attn_mask , unified_freqs_cis ,
1261- noise_mask = unified_noise_mask_tensor , adaln_noisy = t_noisy_x , adaln_clean = t_clean_x
1283+ unified ,
1284+ unified_attn_mask ,
1285+ unified_freqs_cis ,
1286+ noise_mask = unified_noise_mask_tensor ,
1287+ adaln_noisy = t_noisy_x ,
1288+ adaln_clean = t_clean_x ,
12621289 )
12631290 if controlnet_block_samples is not None :
12641291 if layer_idx in controlnet_block_samples :
0 commit comments