Skip to content

Commit 301e722

Browse files
committed
Add tests and doc
1 parent 65b230e commit 301e722

File tree

2 files changed

+351
-0
lines changed

2 files changed

+351
-0
lines changed

docs/source/en/api/pipelines/qwenimage.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,12 @@ The `guidance_scale` parameter in the pipeline is there to support future guidan
120120
- all
121121
- __call__
122122

123+
## QwenImageEditControlNetPipeline
124+
125+
[[autodoc]] QwenImageEditControlNetPipeline
126+
- all
127+
- __call__
128+
123129
## QwenImageEditInpaintPipeline
124130

125131
[[autodoc]] QwenImageEditInpaintPipeline
Lines changed: 345 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,345 @@
1+
# Copyright 2025 The HuggingFace Team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import numpy as np
18+
import pytest
19+
import torch
20+
from PIL import Image
21+
from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
22+
23+
from diffusers import (
24+
AutoencoderKLQwenImage,
25+
FlowMatchEulerDiscreteScheduler,
26+
QwenImageControlNetModel,
27+
QwenImageEditControlNetPipeline,
28+
QwenImageMultiControlNetModel,
29+
QwenImageTransformer2DModel,
30+
)
31+
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
32+
from diffusers.utils.torch_utils import randn_tensor
33+
34+
from ..pipeline_params import TEXT_TO_IMAGE_PARAMS
35+
from ..test_pipelines_common import PipelineTesterMixin, to_np
36+
37+
38+
enable_full_determinism()
39+
40+
41+
class QwenImageEditControlNetPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
42+
pipeline_class = QwenImageEditControlNetPipeline
43+
params = (TEXT_TO_IMAGE_PARAMS | frozenset(["control_image", "controlnet_conditioning_scale"])) - {
44+
"cross_attention_kwargs"
45+
}
46+
batch_params = frozenset(["prompt", "image", "control_image"])
47+
image_params = frozenset(["image", "control_image"])
48+
image_latents_params = frozenset(["latents"])
49+
required_optional_params = frozenset(
50+
[
51+
"num_inference_steps",
52+
"generator",
53+
"latents",
54+
"control_image",
55+
"controlnet_conditioning_scale",
56+
"return_dict",
57+
"callback_on_step_end",
58+
"callback_on_step_end_tensor_inputs",
59+
]
60+
)
61+
supports_dduf = False
62+
test_xformers_attention = False
63+
test_layerwise_casting = True
64+
test_group_offloading = True
65+
66+
def get_dummy_components(self):
67+
tiny_ckpt_id = "hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration"
68+
69+
torch.manual_seed(0)
70+
transformer = QwenImageTransformer2DModel(
71+
patch_size=2,
72+
in_channels=16,
73+
out_channels=4,
74+
num_layers=2,
75+
attention_head_dim=16,
76+
num_attention_heads=3,
77+
joint_attention_dim=16,
78+
guidance_embeds=False,
79+
axes_dims_rope=(8, 4, 4),
80+
)
81+
82+
torch.manual_seed(0)
83+
controlnet = QwenImageControlNetModel(
84+
patch_size=2,
85+
in_channels=16,
86+
out_channels=4,
87+
num_layers=2,
88+
attention_head_dim=16,
89+
num_attention_heads=3,
90+
joint_attention_dim=16,
91+
axes_dims_rope=(8, 4, 4),
92+
)
93+
94+
torch.manual_seed(0)
95+
z_dim = 4
96+
vae = AutoencoderKLQwenImage(
97+
base_dim=z_dim * 6,
98+
z_dim=z_dim,
99+
dim_mult=[1, 2, 4],
100+
num_res_blocks=1,
101+
temperal_downsample=[False, True],
102+
latents_mean=[0.0] * z_dim,
103+
latents_std=[1.0] * z_dim,
104+
)
105+
106+
torch.manual_seed(0)
107+
scheduler = FlowMatchEulerDiscreteScheduler()
108+
109+
torch.manual_seed(0)
110+
config = Qwen2_5_VLConfig(
111+
text_config={
112+
"hidden_size": 16,
113+
"intermediate_size": 16,
114+
"num_hidden_layers": 2,
115+
"num_attention_heads": 2,
116+
"num_key_value_heads": 2,
117+
"rope_scaling": {
118+
"mrope_section": [1, 1, 2],
119+
"rope_type": "default",
120+
"type": "default",
121+
},
122+
"rope_theta": 1000000.0,
123+
},
124+
vision_config={
125+
"depth": 2,
126+
"hidden_size": 16,
127+
"intermediate_size": 16,
128+
"num_heads": 2,
129+
"out_hidden_size": 16,
130+
},
131+
hidden_size=16,
132+
vocab_size=152064,
133+
vision_end_token_id=151653,
134+
vision_start_token_id=151652,
135+
vision_token_id=151654,
136+
)
137+
text_encoder = Qwen2_5_VLForConditionalGeneration(config)
138+
tokenizer = Qwen2Tokenizer.from_pretrained(tiny_ckpt_id)
139+
140+
components = {
141+
"transformer": transformer,
142+
"vae": vae,
143+
"scheduler": scheduler,
144+
"text_encoder": text_encoder,
145+
"tokenizer": tokenizer,
146+
"processor": Qwen2VLProcessor.from_pretrained(tiny_ckpt_id),
147+
"controlnet": controlnet,
148+
}
149+
return components
150+
151+
def get_dummy_inputs(self, device, seed=0):
152+
if str(device).startswith("mps"):
153+
generator = torch.manual_seed(seed)
154+
else:
155+
generator = torch.Generator(device=device).manual_seed(seed)
156+
157+
control_image = randn_tensor(
158+
(1, 3, 32, 32),
159+
generator=generator,
160+
device=torch.device(device),
161+
dtype=torch.float32,
162+
)
163+
164+
inputs = {
165+
"prompt": "dance monkey",
166+
"image": Image.new("RGB", (32, 32)),
167+
"negative_prompt": "bad quality",
168+
"generator": generator,
169+
"num_inference_steps": 2,
170+
"true_cfg_scale": 1.0,
171+
"height": 32,
172+
"width": 32,
173+
"max_sequence_length": 16,
174+
"control_image": control_image,
175+
"controlnet_conditioning_scale": 0.5,
176+
"output_type": "pt",
177+
}
178+
179+
return inputs
180+
181+
def test_qwen_edit_controlnet(self):
182+
device = "cpu"
183+
components = self.get_dummy_components()
184+
pipe = self.pipeline_class(**components)
185+
pipe.to(device)
186+
pipe.set_progress_bar_config(disable=None)
187+
188+
inputs = self.get_dummy_inputs(device)
189+
image = pipe(**inputs).images
190+
generated_image = image[0]
191+
self.assertEqual(generated_image.shape, (3, 32, 32))
192+
193+
# Expected slice from the generated image
194+
expected_slice = torch.tensor(
195+
[
196+
0.4738,
197+
0.5510,
198+
0.6261,
199+
0.6516,
200+
0.4972,
201+
0.4606,
202+
0.4713,
203+
0.4956,
204+
0.4756,
205+
0.4606,
206+
0.4410,
207+
0.3323,
208+
0.3401,
209+
0.4636,
210+
0.3892,
211+
0.4410,
212+
]
213+
)
214+
215+
generated_slice = generated_image.flatten()
216+
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
217+
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
218+
219+
def test_qwen_edit_controlnet_multicondition(self):
220+
device = "cpu"
221+
components = self.get_dummy_components()
222+
223+
components["controlnet"] = QwenImageMultiControlNetModel([components["controlnet"]])
224+
225+
pipe = self.pipeline_class(**components)
226+
pipe.to(device)
227+
pipe.set_progress_bar_config(disable=None)
228+
229+
inputs = self.get_dummy_inputs(device)
230+
control_image = inputs["control_image"]
231+
inputs["control_image"] = [control_image, control_image]
232+
inputs["controlnet_conditioning_scale"] = [0.5, 0.5]
233+
234+
image = pipe(**inputs).images
235+
generated_image = image[0]
236+
self.assertEqual(generated_image.shape, (3, 32, 32))
237+
# Expected slice from the generated image
238+
expected_slice = torch.tensor(
239+
[
240+
0.6240,
241+
0.6655,
242+
0.5636,
243+
0.6006,
244+
0.5228,
245+
0.4918,
246+
0.5030,
247+
0.5337,
248+
0.4529,
249+
0.3124,
250+
0.3523,
251+
0.5190,
252+
0.5085,
253+
0.5453,
254+
0.4349,
255+
0.5787,
256+
]
257+
)
258+
259+
generated_slice = generated_image.flatten()
260+
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
261+
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
262+
263+
def test_inference_batch_single_identical(self):
264+
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-1)
265+
266+
def test_attention_slicing_forward_pass(
267+
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
268+
):
269+
if not self.test_attention_slicing:
270+
return
271+
272+
components = self.get_dummy_components()
273+
pipe = self.pipeline_class(**components)
274+
for component in pipe.components.values():
275+
if hasattr(component, "set_default_attn_processor"):
276+
component.set_default_attn_processor()
277+
pipe.to(torch_device)
278+
pipe.set_progress_bar_config(disable=None)
279+
280+
generator_device = "cpu"
281+
inputs = self.get_dummy_inputs(generator_device)
282+
output_without_slicing = pipe(**inputs)[0]
283+
284+
pipe.enable_attention_slicing(slice_size=1)
285+
inputs = self.get_dummy_inputs(generator_device)
286+
output_with_slicing1 = pipe(**inputs)[0]
287+
288+
pipe.enable_attention_slicing(slice_size=2)
289+
inputs = self.get_dummy_inputs(generator_device)
290+
output_with_slicing2 = pipe(**inputs)[0]
291+
292+
if test_max_difference:
293+
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
294+
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
295+
self.assertLess(
296+
max(max_diff1, max_diff2),
297+
expected_max_diff,
298+
"Attention slicing should not affect the inference results",
299+
)
300+
301+
def test_vae_tiling(self, expected_diff_max: float = 0.2):
302+
generator_device = "cpu"
303+
components = self.get_dummy_components()
304+
305+
pipe = self.pipeline_class(**components)
306+
pipe.to("cpu")
307+
pipe.set_progress_bar_config(disable=None)
308+
309+
# Without tiling
310+
inputs = self.get_dummy_inputs(generator_device)
311+
inputs["height"] = inputs["width"] = 128
312+
inputs["control_image"] = randn_tensor(
313+
(1, 3, 128, 128),
314+
generator=inputs["generator"],
315+
device=torch.device(generator_device),
316+
dtype=torch.float32,
317+
)
318+
output_without_tiling = pipe(**inputs)[0]
319+
320+
# With tiling
321+
pipe.vae.enable_tiling(
322+
tile_sample_min_height=96,
323+
tile_sample_min_width=96,
324+
tile_sample_stride_height=64,
325+
tile_sample_stride_width=64,
326+
)
327+
inputs = self.get_dummy_inputs(generator_device)
328+
inputs["height"] = inputs["width"] = 128
329+
inputs["control_image"] = randn_tensor(
330+
(1, 3, 128, 128),
331+
generator=inputs["generator"],
332+
device=torch.device(generator_device),
333+
dtype=torch.float32,
334+
)
335+
output_with_tiling = pipe(**inputs)[0]
336+
337+
self.assertLess(
338+
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
339+
expected_diff_max,
340+
"VAE tiling should not affect the inference results",
341+
)
342+
343+
@pytest.mark.xfail(condition=True, reason="Preconfigured embeddings need to be revisited.", strict=True)
344+
def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=None, atol=1e-4, rtol=1e-4):
345+
super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict, atol, rtol)

0 commit comments

Comments
 (0)