Skip to content

Commit aff0d86

Browse files
committed
Move controlnet sample initialization from transformer to pipeline
1 parent a99d99f commit aff0d86

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

src/diffusers/models/transformers/transformer_qwenimage.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -639,9 +639,7 @@ def forward(
639639
if controlnet_block_samples is not None:
640640
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
641641
interval_control = int(np.ceil(interval_control))
642-
sample = controlnet_block_samples[index_block // interval_control]
643-
sample_size = min(sample.size(1), hidden_states.size(1))
644-
hidden_states[:, :sample_size] = hidden_states[:, :sample_size] + sample[:, :sample_size]
642+
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
645643

646644
# Use only the image part (hidden_states) from the dual-stream blocks
647645
hidden_states = self.norm_out(hidden_states, temb)

src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_controlnet.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,15 @@
1818

1919
import numpy as np
2020
import torch
21+
import torch.nn.functional as F
2122
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
2223

2324
from ...image_processor import PipelineImageInput, VaeImageProcessor
2425
from ...loaders import QwenImageLoraLoaderMixin
2526
from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel
2627
from ...models.controlnets.controlnet_qwenimage import QwenImageControlNetModel, QwenImageMultiControlNetModel
2728
from ...schedulers import FlowMatchEulerDiscreteScheduler
28-
from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
29+
from ...utils import is_torch_xla_available, logging, replace_example_docstring
2930
from ...utils.torch_utils import randn_tensor
3031
from ..pipeline_utils import DiffusionPipeline
3132
from .pipeline_output import QwenImagePipelineOutput
@@ -970,6 +971,14 @@ def __call__(
970971
return_dict=False,
971972
)
972973

974+
if image_latents is not None:
975+
padding_size = image_latents.shape[1]
976+
for i, sample in enumerate(controlnet_block_samples):
977+
# Pad right with padding_size zeros at dimension 1 of each sample
978+
pad_tuple = [0] * (2 * sample.dim())
979+
pad_tuple[-3] = padding_size
980+
controlnet_block_samples[i] = F.pad(sample, pad_tuple, mode="constant", value=0)
981+
973982
with self.transformer.cache_context("cond"):
974983
noise_pred = self.transformer(
975984
hidden_states=latent_model_input,

0 commit comments

Comments
 (0)