Skip to content

Commit 5bc676c

Browse files
committed
Refactor noise handling and modulation
- Add select_per_token function for per-token value selection - Separate adaptive modulation logic - Cleanify t_noisy/clean variable naming - Move image_noise_mask handler from forward to pipeline
1 parent 4c14cf3 commit 5bc676c

File tree

2 files changed

+57
-62
lines changed

2 files changed

+57
-62
lines changed

src/diffusers/models/transformers/transformer_z_image.py

Lines changed: 43 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import math
16-
from typing import Dict, List, Optional, Tuple
16+
from typing import Dict, List, Optional, Tuple, Union
1717

1818
import torch
1919
import torch.nn as nn
@@ -152,6 +152,20 @@ def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tenso
152152
return output
153153

154154

155+
def select_per_token(
156+
value_noisy: torch.Tensor,
157+
value_clean: torch.Tensor,
158+
noise_mask: torch.Tensor,
159+
seq_len: int,
160+
) -> torch.Tensor:
161+
noise_mask_expanded = noise_mask.unsqueeze(-1) # (batch, seq_len, 1)
162+
return torch.where(
163+
noise_mask_expanded == 1,
164+
value_noisy.unsqueeze(1).expand(-1, seq_len, -1),
165+
value_clean.unsqueeze(1).expand(-1, seq_len, -1),
166+
)
167+
168+
155169
class FeedForward(nn.Module):
156170
def __init__(self, dim: int, hidden_dim: int):
157171
super().__init__()
@@ -220,10 +234,10 @@ def forward(
220234
adaln_clean: Optional[torch.Tensor] = None,
221235
):
222236
if self.modulation:
223-
if noise_mask is not None and adaln_noisy is not None and adaln_clean is not None:
224-
# Per-token modulation based on noise_mask, (batch, seq_len), 1 for noisy tokens, 0 for clean tokens
225-
_, seq_len = x.shape[0], x.shape[1]
237+
seq_len = x.shape[1]
226238

239+
if noise_mask is not None:
240+
# Per-token modulation: different modulation for noisy/clean tokens
227241
mod_noisy = self.adaLN_modulation(adaln_noisy)
228242
mod_clean = self.adaLN_modulation(adaln_clean)
229243

@@ -236,33 +250,14 @@ def forward(
236250
scale_msa_noisy, scale_mlp_noisy = 1.0 + scale_msa_noisy, 1.0 + scale_mlp_noisy
237251
scale_msa_clean, scale_mlp_clean = 1.0 + scale_msa_clean, 1.0 + scale_mlp_clean
238252

239-
noise_mask_expanded = noise_mask.unsqueeze(-1) # (batch, seq_len, 1)
240-
scale_msa = torch.where(
241-
noise_mask_expanded == 1,
242-
scale_msa_noisy.unsqueeze(1).expand(-1, seq_len, -1),
243-
scale_msa_clean.unsqueeze(1).expand(-1, seq_len, -1),
244-
)
245-
scale_mlp = torch.where(
246-
noise_mask_expanded == 1,
247-
scale_mlp_noisy.unsqueeze(1).expand(-1, seq_len, -1),
248-
scale_mlp_clean.unsqueeze(1).expand(-1, seq_len, -1),
249-
)
250-
gate_msa = torch.where(
251-
noise_mask_expanded == 1,
252-
gate_msa_noisy.unsqueeze(1).expand(-1, seq_len, -1),
253-
gate_msa_clean.unsqueeze(1).expand(-1, seq_len, -1),
254-
)
255-
gate_mlp = torch.where(
256-
noise_mask_expanded == 1,
257-
gate_mlp_noisy.unsqueeze(1).expand(-1, seq_len, -1),
258-
gate_mlp_clean.unsqueeze(1).expand(-1, seq_len, -1),
259-
)
253+
scale_msa = select_per_token(scale_msa_noisy, scale_msa_clean, noise_mask, seq_len)
254+
scale_mlp = select_per_token(scale_mlp_noisy, scale_mlp_clean, noise_mask, seq_len)
255+
gate_msa = select_per_token(gate_msa_noisy, gate_msa_clean, noise_mask, seq_len)
256+
gate_mlp = select_per_token(gate_mlp_noisy, gate_mlp_clean, noise_mask, seq_len)
260257
else:
261-
# Original global modulation
262-
assert adaln_input is not None
263-
scale_msa, gate_msa, scale_mlp, gate_mlp = (
264-
self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2)
265-
)
258+
# Global modulation: same modulation for all tokens (avoid double select)
259+
mod = self.adaLN_modulation(adaln_input)
260+
scale_msa, gate_msa, scale_mlp, gate_mlp = mod.unsqueeze(1).chunk(4, dim=2)
266261
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
267262
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
268263

@@ -297,18 +292,13 @@ def __init__(self, hidden_size, out_channels):
297292
)
298293

299294
def forward(self, x, c=None, noise_mask=None, c_noisy=None, c_clean=None):
300-
if noise_mask is not None and c_noisy is not None and c_clean is not None:
301-
# Per-token modulation based on noise_mask
302-
_, seq_len = x.shape[0], x.shape[1]
295+
seq_len = x.shape[1]
296+
297+
if noise_mask is not None:
298+
# Per-token modulation
303299
scale_noisy = 1.0 + self.adaLN_modulation(c_noisy)
304300
scale_clean = 1.0 + self.adaLN_modulation(c_clean)
305-
306-
noise_mask_expanded = noise_mask.unsqueeze(-1)
307-
scale = torch.where(
308-
noise_mask_expanded == 1,
309-
scale_noisy.unsqueeze(1).expand(-1, seq_len, -1),
310-
scale_clean.unsqueeze(1).expand(-1, seq_len, -1),
311-
)
301+
scale = select_per_token(scale_noisy, scale_clean, noise_mask, seq_len)
312302
else:
313303
# Original global modulation
314304
assert c is not None, "Either c or (c_noisy, c_clean) must be provided"
@@ -900,29 +890,29 @@ def patchify_and_embed_omni(
900890

901891
def forward(
902892
self,
903-
x: List[torch.Tensor],
893+
x: Union[List[torch.Tensor], List[List[torch.Tensor]]],
904894
t,
905-
cap_feats: List[torch.Tensor],
895+
cap_feats: Union[List[torch.Tensor], List[List[torch.Tensor]]],
906896
controlnet_block_samples: Optional[Dict[int, torch.Tensor]] = None,
907-
cond_latents: Optional[List[List[torch.Tensor]]] = None,
908897
siglip_feats: Optional[List[List[torch.Tensor]]] = None,
909-
patch_size=2,
910-
f_patch_size=1,
898+
image_noise_mask: Optional[List[List[int]]] = None,
899+
patch_size: int = 2,
900+
f_patch_size: int = 1,
911901
return_dict: bool = True,
912902
):
913903
assert patch_size in self.all_patch_size
914904
assert f_patch_size in self.all_f_patch_size
915905

916-
# Determine mode based on cond_latents
917-
omni_mode = cond_latents is not None
906+
# Omni mode: x contains lists (multi-image input)
907+
omni_mode = isinstance(x[0], list)
918908

919909
if omni_mode:
920910
return self._forward_omni(
921911
x,
922912
t,
923913
cap_feats,
924-
cond_latents,
925914
siglip_feats,
915+
image_noise_mask,
926916
controlnet_block_samples,
927917
patch_size,
928918
f_patch_size,
@@ -1061,30 +1051,23 @@ def _forward_basic(
10611051

10621052
def _forward_omni(
10631053
self,
1064-
x: List[torch.Tensor],
1054+
x: List[List[torch.Tensor]],
10651055
t,
10661056
cap_feats: List[List[torch.Tensor]],
1067-
cond_latents: List[List[torch.Tensor]],
10681057
siglip_feats: List[List[torch.Tensor]],
1058+
image_noise_mask: List[List[int]],
10691059
controlnet_block_samples: Optional[Dict[int, torch.Tensor]],
10701060
patch_size: int,
10711061
f_patch_size: int,
10721062
return_dict: bool,
10731063
):
10741064
"""Omni mode forward pass with image conditioning."""
10751065
bsz = len(x)
1076-
device = x[0].device
1066+
device = x[0][-1].device # From target latent
10771067

10781068
# Create dual timestep embeddings: one for noisy tokens (t), one for clean tokens (t=1)
1079-
t_combined = torch.cat([t, torch.ones_like(t, dtype=t.dtype, device=device)], dim=0)
1080-
t_combined = t_combined * self.t_scale
1081-
t_combined = self.t_embedder(t_combined)
1082-
t_noisy = t_combined[:bsz] # Original timestep for noisy tokens
1083-
t_clean = t_combined[bsz:] # t=1 for clean (condition) tokens
1084-
1085-
# Combine condition latents with target latent
1086-
x = [cond_latents[i] + [x[i]] for i in range(bsz)]
1087-
image_noise_mask = [[0] * (len(x[i]) - 1) + [1] for i in range(bsz)]
1069+
t_noisy = self.t_embedder(t * self.t_scale)
1070+
t_clean = self.t_embedder(torch.ones_like(t) * self.t_scale)
10881071

10891072
# Patchify and embed for Omni mode
10901073
(

src/diffusers/pipelines/z_image/pipeline_z_image_omni.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -657,12 +657,24 @@ def __call__(
657657
latent_model_input = latent_model_input.unsqueeze(2)
658658
latent_model_input_list = list(latent_model_input.unbind(dim=0))
659659

660+
# Combine condition latents with target latent
661+
current_batch_size = len(latent_model_input_list)
662+
x_combined = [
663+
condition_latents_model_input[i] + [latent_model_input_list[i]]
664+
for i in range(current_batch_size)
665+
]
666+
# Create noise mask: 0 for condition images (clean), 1 for target image (noisy)
667+
image_noise_mask = [
668+
[0] * len(condition_latents_model_input[i]) + [1]
669+
for i in range(current_batch_size)
670+
]
671+
660672
model_out_list = self.transformer(
661-
x=latent_model_input_list,
673+
x=x_combined,
662674
t=timestep_model_input,
663675
cap_feats=prompt_embeds_model_input,
664-
cond_latents=condition_latents_model_input,
665676
siglip_feats=condition_siglip_embeds_model_input,
677+
image_noise_mask=image_noise_mask,
666678
return_dict=False,
667679
)[0]
668680

0 commit comments

Comments
 (0)