1313# limitations under the License.
1414
1515import math
16- from typing import Dict , List , Optional , Tuple
16+ from typing import Dict , List , Optional , Tuple , Union
1717
1818import torch
1919import torch .nn as nn
@@ -152,6 +152,20 @@ def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tenso
152152 return output
153153
154154
155+ def select_per_token (
156+ value_noisy : torch .Tensor ,
157+ value_clean : torch .Tensor ,
158+ noise_mask : torch .Tensor ,
159+ seq_len : int ,
160+ ) -> torch .Tensor :
161+ noise_mask_expanded = noise_mask .unsqueeze (- 1 ) # (batch, seq_len, 1)
162+ return torch .where (
163+ noise_mask_expanded == 1 ,
164+ value_noisy .unsqueeze (1 ).expand (- 1 , seq_len , - 1 ),
165+ value_clean .unsqueeze (1 ).expand (- 1 , seq_len , - 1 ),
166+ )
167+
168+
155169class FeedForward (nn .Module ):
156170 def __init__ (self , dim : int , hidden_dim : int ):
157171 super ().__init__ ()
@@ -220,10 +234,10 @@ def forward(
220234 adaln_clean : Optional [torch .Tensor ] = None ,
221235 ):
222236 if self .modulation :
223- if noise_mask is not None and adaln_noisy is not None and adaln_clean is not None :
224- # Per-token modulation based on noise_mask, (batch, seq_len), 1 for noisy tokens, 0 for clean tokens
225- _ , seq_len = x .shape [0 ], x .shape [1 ]
237+ seq_len = x .shape [1 ]
226238
239+ if noise_mask is not None :
240+ # Per-token modulation: different modulation for noisy/clean tokens
227241 mod_noisy = self .adaLN_modulation (adaln_noisy )
228242 mod_clean = self .adaLN_modulation (adaln_clean )
229243
@@ -236,33 +250,14 @@ def forward(
236250 scale_msa_noisy , scale_mlp_noisy = 1.0 + scale_msa_noisy , 1.0 + scale_mlp_noisy
237251 scale_msa_clean , scale_mlp_clean = 1.0 + scale_msa_clean , 1.0 + scale_mlp_clean
238252
239- noise_mask_expanded = noise_mask .unsqueeze (- 1 ) # (batch, seq_len, 1)
240- scale_msa = torch .where (
241- noise_mask_expanded == 1 ,
242- scale_msa_noisy .unsqueeze (1 ).expand (- 1 , seq_len , - 1 ),
243- scale_msa_clean .unsqueeze (1 ).expand (- 1 , seq_len , - 1 ),
244- )
245- scale_mlp = torch .where (
246- noise_mask_expanded == 1 ,
247- scale_mlp_noisy .unsqueeze (1 ).expand (- 1 , seq_len , - 1 ),
248- scale_mlp_clean .unsqueeze (1 ).expand (- 1 , seq_len , - 1 ),
249- )
250- gate_msa = torch .where (
251- noise_mask_expanded == 1 ,
252- gate_msa_noisy .unsqueeze (1 ).expand (- 1 , seq_len , - 1 ),
253- gate_msa_clean .unsqueeze (1 ).expand (- 1 , seq_len , - 1 ),
254- )
255- gate_mlp = torch .where (
256- noise_mask_expanded == 1 ,
257- gate_mlp_noisy .unsqueeze (1 ).expand (- 1 , seq_len , - 1 ),
258- gate_mlp_clean .unsqueeze (1 ).expand (- 1 , seq_len , - 1 ),
259- )
253+ scale_msa = select_per_token (scale_msa_noisy , scale_msa_clean , noise_mask , seq_len )
254+ scale_mlp = select_per_token (scale_mlp_noisy , scale_mlp_clean , noise_mask , seq_len )
255+ gate_msa = select_per_token (gate_msa_noisy , gate_msa_clean , noise_mask , seq_len )
256+ gate_mlp = select_per_token (gate_mlp_noisy , gate_mlp_clean , noise_mask , seq_len )
260257 else :
261- # Original global modulation
262- assert adaln_input is not None
263- scale_msa , gate_msa , scale_mlp , gate_mlp = (
264- self .adaLN_modulation (adaln_input ).unsqueeze (1 ).chunk (4 , dim = 2 )
265- )
258+ # Global modulation: same modulation for all tokens (avoid double select)
259+ mod = self .adaLN_modulation (adaln_input )
260+ scale_msa , gate_msa , scale_mlp , gate_mlp = mod .unsqueeze (1 ).chunk (4 , dim = 2 )
266261 gate_msa , gate_mlp = gate_msa .tanh (), gate_mlp .tanh ()
267262 scale_msa , scale_mlp = 1.0 + scale_msa , 1.0 + scale_mlp
268263
@@ -297,18 +292,13 @@ def __init__(self, hidden_size, out_channels):
297292 )
298293
299294 def forward (self , x , c = None , noise_mask = None , c_noisy = None , c_clean = None ):
300- if noise_mask is not None and c_noisy is not None and c_clean is not None :
301- # Per-token modulation based on noise_mask
302- _ , seq_len = x .shape [0 ], x .shape [1 ]
295+ seq_len = x .shape [1 ]
296+
297+ if noise_mask is not None :
298+ # Per-token modulation
303299 scale_noisy = 1.0 + self .adaLN_modulation (c_noisy )
304300 scale_clean = 1.0 + self .adaLN_modulation (c_clean )
305-
306- noise_mask_expanded = noise_mask .unsqueeze (- 1 )
307- scale = torch .where (
308- noise_mask_expanded == 1 ,
309- scale_noisy .unsqueeze (1 ).expand (- 1 , seq_len , - 1 ),
310- scale_clean .unsqueeze (1 ).expand (- 1 , seq_len , - 1 ),
311- )
301+ scale = select_per_token (scale_noisy , scale_clean , noise_mask , seq_len )
312302 else :
313303 # Original global modulation
314304 assert c is not None , "Either c or (c_noisy, c_clean) must be provided"
@@ -900,29 +890,29 @@ def patchify_and_embed_omni(
900890
901891 def forward (
902892 self ,
903- x : List [torch .Tensor ],
893+ x : Union [ List [torch .Tensor ], List [ List [ torch . Tensor ]] ],
904894 t ,
905- cap_feats : List [torch .Tensor ],
895+ cap_feats : Union [ List [torch .Tensor ], List [ List [ torch . Tensor ]] ],
906896 controlnet_block_samples : Optional [Dict [int , torch .Tensor ]] = None ,
907- cond_latents : Optional [List [List [torch .Tensor ]]] = None ,
908897 siglip_feats : Optional [List [List [torch .Tensor ]]] = None ,
909- patch_size = 2 ,
910- f_patch_size = 1 ,
898+ image_noise_mask : Optional [List [List [int ]]] = None ,
899+ patch_size : int = 2 ,
900+ f_patch_size : int = 1 ,
911901 return_dict : bool = True ,
912902 ):
913903 assert patch_size in self .all_patch_size
914904 assert f_patch_size in self .all_f_patch_size
915905
916- # Determine mode based on cond_latents
917- omni_mode = cond_latents is not None
906+ # Omni mode: x contains lists (multi-image input)
907+ omni_mode = isinstance ( x [ 0 ], list )
918908
919909 if omni_mode :
920910 return self ._forward_omni (
921911 x ,
922912 t ,
923913 cap_feats ,
924- cond_latents ,
925914 siglip_feats ,
915+ image_noise_mask ,
926916 controlnet_block_samples ,
927917 patch_size ,
928918 f_patch_size ,
@@ -1061,30 +1051,23 @@ def _forward_basic(
10611051
10621052 def _forward_omni (
10631053 self ,
1064- x : List [torch .Tensor ],
1054+ x : List [List [ torch .Tensor ] ],
10651055 t ,
10661056 cap_feats : List [List [torch .Tensor ]],
1067- cond_latents : List [List [torch .Tensor ]],
10681057 siglip_feats : List [List [torch .Tensor ]],
1058+ image_noise_mask : List [List [int ]],
10691059 controlnet_block_samples : Optional [Dict [int , torch .Tensor ]],
10701060 patch_size : int ,
10711061 f_patch_size : int ,
10721062 return_dict : bool ,
10731063 ):
10741064 """Omni mode forward pass with image conditioning."""
10751065 bsz = len (x )
1076- device = x [0 ].device
1066+ device = x [0 ][ - 1 ] .device # From target latent
10771067
10781068 # Create dual timestep embeddings: one for noisy tokens (t), one for clean tokens (t=1)
1079- t_combined = torch .cat ([t , torch .ones_like (t , dtype = t .dtype , device = device )], dim = 0 )
1080- t_combined = t_combined * self .t_scale
1081- t_combined = self .t_embedder (t_combined )
1082- t_noisy = t_combined [:bsz ] # Original timestep for noisy tokens
1083- t_clean = t_combined [bsz :] # t=1 for clean (condition) tokens
1084-
1085- # Combine condition latents with target latent
1086- x = [cond_latents [i ] + [x [i ]] for i in range (bsz )]
1087- image_noise_mask = [[0 ] * (len (x [i ]) - 1 ) + [1 ] for i in range (bsz )]
1069+ t_noisy = self .t_embedder (t * self .t_scale )
1070+ t_clean = self .t_embedder (torch .ones_like (t ) * self .t_scale )
10881071
10891072 # Patchify and embed for Omni mode
10901073 (
0 commit comments